pytorch的张量数据结构以及各种操作函数的底层原理

张开发
2026/4/8 9:19:00 15 分钟阅读

分享文章

pytorch的张量数据结构以及各种操作函数的底层原理
Tensor 的两个组成部分Storage (存储)实际的数据被保存为一个连续的一维数组通常是在 CPU 或 GPU 内存中。View (视图/元数据)描述如何解释这些物理数据。它包含shape (形状)张量维度stride (步长)想要跳到下一个维度在 Storage 中需要跳过多少个元素。storage_offset (偏移量)第一个元素的storage中的位置。多维张量的 stride步长解读stride是一个元组表示在某个维度上移动 1 个位置需要在底层一维 storage 中跳过多少个元素。核心公式对于一个 n 维张量元素[i₁, i₂, ..., iₙ]在 storage 中的位置textoffset storage_offset i₁ × stride[0] i₂ × stride[1] ... iₙ × stride[n-1]import torch # 创建一个 3x4 的连续张量 a torch.arange(12).reshape(3, 4) print(a) # tensor([[ 0, 1, 2, 3], # [ 4, 5, 6, 7], # [ 8, 9, 10, 11]]) print(a.shape) # (3, 4) print(a.stride) # (4, 1) print(a.storage_offset) # 0解读 stride (4, 1)维度stride 值含义维度0行4向下移动 1 行storage 索引跳过 4 个元素维度1列1向右移动 1 列storage 索引跳过 1 个元素行优先storage 索引: 0 1 2 3 4 5 6 7 8 9 10 11 storage 数据: [0] [1] [2] [3] [4] [5] [6] [7] [8] [9][10][11] └───── row0 ─────┘ └───── row1 ─────┘ └───── row2 ─────┘b a.t() # 转置变成 4x3 print(b.shape) # (4, 3) print(b.stride) # (1, 4) ← stride 交换了 print(b) # tensor([[ 0, 4, 8], # [ 1, 5, 9], # [ 2, 6, 10], # [ 3, 7, 11]])解读 stride (1, 4)维度stride含义维度0行1向下移动 1 行storage 索引跳过 1 个元素维度1列4向右移动 1 列storage 索引跳过 4 个元素三维c torch.arange(24).reshape(2, 3, 4) print(c.shape) # (2, 3, 4) print(c.stride) # (12, 4, 1)解读 stride (12, 4, 1)维度stride含义维度0深度/批次12移动 1 个深度跳过 12 个元素3×4维度1行4移动 1 行跳过 4 个元素1行的长度维度2列1移动 1 列跳过 1 个元素storage 布局可视化text深度0第1个 3x4 矩阵: 行0: [0, 1, 2, 3] ← 连续 行1: [4, 5, 6, 7] 行2: [8, 9, 10, 11] 深度1第2个 3x4 矩阵: 行0: [12, 13, 14, 15] 行1: [16, 17, 18, 19] 行2: [20, 21, 22, 23] storage 索引: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 231.切片改变offsetimport torch # 创建一个长度为 6 的张量 a torch.arange(6) # storage: [0, 1, 2, 3, 4, 5] # a.shape (6,) # a.storage_offset 0 # a.stride (1,) # 对 a 进行切片取索引 2 到 4不包括 4 b a[2:4] # b 逻辑上是 [2, 3] # 查看 b 的元数据 print(b.shape) # (2,) print(b.storage_offset) # 2 ← 关键b 从 storage 的第 2 个元素开始 print(b.stride) # (1,) print(b.data_ptr() a.data_ptr()) # True指向同一块内存Storage (a 和 b 共享): 索引: 0 1 2 3 4 5 数据: [0] [1] [2] [3] [4] [5] ↑ └─ b.storage_offset 2 b 从这里开始长度 2 b 逻辑视图: [2, 3]切片本质上是改变storge_offset2.连续性与view(stride[i]stride[i1]×shape[i1])判断张量的连续性就是判断stride1.最后一个数字是否为12.从后往前每个维度是否等于3.size0空特判def is_contiguous(tensor): shape tensor.shape stride tensor.stride() # 1. 处理标量或空张量 if tensor.ndim 0: return True if 0 in shape: return True expected_stride 1 # 从后往前遍历 for i in range(len(shape) - 1, -1, -1): # 关键点如果维度长度为 1这个维度的 stride 是多少都无所谓 # 它不会改变物理内存的连续分布 if shape[i] ! 1: if stride[i] ! expected_stride: return False expected_stride * shape[i] return Truec torch.arange(24).reshape(2, 3, 4)print(c.shape) # (2, 3, 4)print(c.stride) # (12, 4, 1)要判断一个张量是否连续可以简单判断对于shape[i] 1的维度i进行判断stride[i] stride[i1]*shape[i1]c: [c[0]-[ [1,3,4,2],[3,3,2,1],[5,4,2,1] ]c[1]-[ [1,3,4,2],[3,3,2,1],[5,4,2,1] ]]比如在判断stride[0] ? stride[1]*shape[1]可以如此理解stride[0]:跨到0维度的下一个元素需要跳过12个比如从c[0]跳到c[1]stride[1]:跨到1维度的下一个元素需要跳过4个比如从c[0][1]到c[0][2],shape[1]:1维度上的元素数量是3每个c[i]中有三个元素若是连续两者势必相等3.transpose与permute本质修改原数据中的shape与stride。1.tranposeimport torch # 创建一个连续张量 x torch.arange(12).reshape(3, 4) print(f原始张量 x:) print(x) print(fshape: {x.shape}) # (3, 4) print(fstride: {x.stride()}) # (4, 1) print(f连续? {x.is_contiguous()}) # True print(fstorage 地址: {x.data_ptr()}) print() # 执行 transpose交换两个维度 y x.transpose(0, 1) print(f转置后 y:) print(y) print(fshape: {y.shape}) # (4, 3) - shape 也交换了 print(fstride: {y.stride()}) # (1, 4) - stride 交换了 print(f连续? {y.is_contiguous()}) # False print(fstorage 地址: {y.data_ptr()}) # 和 x 相同 print(f是否共享 storage: {y.data_ptr() x.data_ptr()}) # True原始张量 x:tensor([[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]])shape: (3, 4)stride: (4, 1)连续? Truestorage 地址: 140234567890123转置后 y:tensor([[ 0, 4, 8],[ 1, 5, 9],[ 2, 6, 10],[ 3, 7, 11]])shape: (4, 3)stride: (1, 4) ← 交换了连续? Falsestorage 地址: 140234567890123 ← 相同地址是否共享 storage: Truepermute# 3D 张量示例 x torch.arange(24).reshape(2, 3, 4) print(f原始 shape: {x.shape}) # (2, 3, 4) print(f原始 stride: {x.stride()}) # (12, 4, 1) print(f连续? {x.is_contiguous()}) # True # 重排维度把维度 (0,1,2) 变成 (2,0,1) y x.permute(2, 0, 1) print(f\npermute 后 shape: {y.shape}) # (4, 2, 3) print(fpermute 后 stride: {y.stride()}) # (1, 12, 4) ← 按 permute 规则重排 print(f连续? {y.is_contiguous()}) # False print(f共享 storage: {y.data_ptr() x.data_ptr()}) # True4.catstackcat是在现有轨道上“接火车”而stack是给火车“加盖一层新轨道”。以下是结合底层原理和实际应用的深度解析1. 核心区别维度与内存视角特性torch.cat (拼接)torch.stack (堆叠)核心逻辑沿现有维度连接沿新维度连接维度变化维度数不变指定维度的长度增加维度数1形状要求非拼接维度必须完全一致所有维度必须完全一致底层操作数据拷贝与连续化增加维度信息重新索引2. torch.cat底层的“内存搬运工”torch.cat的本质是将多个张量在物理内存上逻辑上首尾相连。底层原理 它不会改变张量的本质结构而是沿着你指定的轴dim将输入张量的数据块像“接龙”一样拼在一起。内存视角假设你有两个形状为(2, 3)的张量。如果你沿dim0拼接PyTorch 会申请一块新的连续内存大小为(4, 3)然后把第一个张量的数据复制进去紧接着复制第二个张量的数据。约束来源为什么要求非拼接维度必须一致因为如果维度不对齐比如一个是 3 列一个是 4 列它们在内存中就无法形成整齐的矩形块破坏了张量的规则结构。代码直观理解import torch a torch.tensor([, ]) # 形状 (2, 2) b torch.tensor([]) # 形状 (1, 2) # 沿第0维行拼接就像把 b 贴在 a 的下面 result torch.cat([a, b], dim0) # 结果形状: (3, 2) - 维度没变行数变多了3. torch.stack底层的“维度升维器”torch.stack的本质是创造一个新的维度用来索引这些张量。底层原理 它不仅仅是数据的组合更是**元数据Metadata**的重构。内存视角stack操作会在张量的形状信息Shape/Stride中插入一个新的维度。它相当于把多个张量“打包”进一个新的容器里。数据布局虽然底层数据在内存中可能依然是连续存储的但在逻辑上PyTorch 增加了一个“层”的概念。比如将两个(3, 4)的张量堆叠结果变成(2, 3, 4)。那个新增的2就是新维度的长度代表“你有2个这样的张量”。约束来源因为要把它们整齐地码放在这个新维度里所以所有输入张量的形状必须一模一样否则无法对齐。代码直观理解import torch a torch.tensor() # 形状 (2,) b torch.tensor() # 形状 (2,) # 沿新维度堆叠相当于把它们叠罗汉 result torch.stack([a, b], dim0) # 结果形状: (2, 2) - 维度从1维变成了2维 # 结果: tensor([, # ])4. 深度对比与避坑指南为了帮你彻底搞懂我用一个表格总结它们的“脾气”场景应该用谁为什么合并数据集cat比如你有 100 条数据和另外 50 条数据你想合并成一个 150 条的大列表。构建批次 (Batch)stack比如你有 3 张大小为(3, 224, 224)的图片你想组成一个 Batch 输入模型变成(3, 3, 224, 224)。特征融合cat在神经网络层之间经常把不同通道的特征图拼在一起如 Inception 结构。常见报错RuntimeErrorCat 报错通常是因为除了拼接维度外其他维度大小不一样比如想拼两个矩阵但一个宽3一个宽4。Stack 报错通常是因为输入的两个张量形状不完全相同。总结如果你想**“变长”**让数据更多用cat。如果你想**“变厚”**让结构更复杂增加层级用stack。理解这一点你在处理 PyTorch 张量形状变换Reshape/View时就会清晰很多不再容易报size mismatch的错误了。

更多文章