张量操作

前言

在使用 pytorch 的过程中,遇到了较多张量操作的相关问题,这篇 blog 主要是为了对相关的操作进行整合、梳理,以求在写代码时少遇到几个 bug

pytorch 张量广播机制

PyTorch 张量的广播(Broadcasting)是一种机制,它允许在不进行显示复制数据的情况下,对形状不同但满足一定条件的张量进行元素级的运算。这使得你可以在不显式地扩展张量的维度的情况下,执行一些操作,从而简化代码并提高效率。

广播的原则是:如果两个张量在某一维度上的形状大小相同,或其中一个张量的大小为 1,那么它们可以在该维度上进行广播。广播会自动扩展大小为 1 的维度,使得两个张量的维度相同,然后执行元素级操作。

注意到,广播一般是某一维为 1,或者某一个量是标量时才可以进行。

从定义而言不难理解,但实际应用中需要注意的地方挺多:
基础用法: 假设 pts:[bs, N, 3], offset[3],即我们需要对所有的点进行 x, y, z 统一的一个平移,那么这段代码最不容易被误解的形式可以写成下面这样:
pts = pts + offset[None, None]
在这个过程中,offset 的 shaep 由[3]首先被扩展为[1, 1, 3], 接着就能够直接由广播机制与 pts 相加,达到偏差的效果

pytorch 张量乘法

pytorch 官方定义的乘法具体可以写为以下几类:

  1. *的乘法
  2. 矩阵 or 张量乘法

逐个元素相乘 *

这种乘法是最直观的一种乘法类型,它能够将两个形状相同(或能够通过广播机制扩展到相同形状)的张量,按照位置逐个元素相乘,得到的结果依然是同样的形状。
由于这种方法与 pytorch 经常进行的矩阵运算形式不太相同,所以在实际应用中,除基础的标量放缩外,我们很少使用这种乘法方式。
这里举一个向量点乘的例子:
假设 A:[3] = (0.1, 0.2, 0.3), B:[3] = (0.4, 0.5, 0.6)
则 A 与 B 的点积 dot(A, B) = A * B = (0.04, 0.1, 0.18)

矩阵乘法 @

torch.mul 用的比较少,用于两个同维度矩阵逐像素点乘(等价于 *)
torch.mm 表示二维矩阵乘法, 只支持 (l, m) 与 (m, n) 相乘,得到维数为 (l, n)的这种类似的二维矩阵运算。
torch.bmm 则是在 torch.mm 的基础上对 batch 做了拓展,支持了 (b, l, m) 与 (b, m, n)相乘,得到维数为(b, l, n)的矩阵运算。
torch.mv 则用于矩阵乘向量的形式 (l, m) 与 (m) 得到 (l) ,实际上类似于将后面的向量直接扩展为(m, 1)然后进行矩阵运算的结果,因此这个函数也比较少用。

比较重要的是 torch.matmul 函数,这个函数囊括了上面的除 torch.mul 之外的所有函数。因此这个函数值得着重介绍。
torch.matmul 中内含了广播机制,因此对于各种张量的运算都能够类似的解决
如类似于 torch.mm 的二维矩阵乘法: (l, m) (m, n) -> (l, n)
类似于 torch.bmm 的扩展 batch 的矩阵乘法: (b, l, m) (b, m, n) -> (b, l, n)
以及加上广播机制的乘法: (l, m) (b, m, n) -> (b, l, n) 或 (l, m) (m) -> (l) 或 (b, c, l, m) (b, c, m, n) -> (b, c, l, n)等形式

@ 表示常规的数学上定义的矩阵相乘, 其作用与 torch.matmal 类似。

einsum 记法

torch.matmal 的功能虽然强大,但有时也会容易引起广播上的歧义或由于使用失误导致运算上的问题出现,因此,本文作者认为 einsum 函数是解决这些复杂张量运算的有效工具。
API: torch.einsum(equation, *operands)
求和:result = torch.einsum(“ij->”, a)
矩阵乘法: result = torch.einsum(“ij,jk->ik”, a, b)
批量矩阵乘法: result = torch.einsum(“bij,bjk->bik”, a, b)
梯度计算: result = torch.einsum(“i,i->”, a, b)

值得注意的是, einsum 在角标的选择上十分灵活,支持如:
torch.einsum(“bij,jk->bik”, a, b)
torch.einsum(“bjk, ij-> bik”, a, b)
等等形式。

pytorch 视图机制

PyTorch 的视图机制是一种在不复制底层数据的情况下创建张量的方式,允许您以不同的方式查看相同的底层数据。这对于节省内存并提高计算效率非常有用。PyTorch 的视图机制包括以下几个函数和属性:

  • view() 方法用于创建一个具有相同数据但形状不同的张量。它适用于原始张量的连续子序列,并且不能改变张量的总元素数。view() 方法仅用于连续内存块。
  • reshape() 方法类似于 view(),但它可以处理非连续内存块。它试图返回一个新的张量,该张量与原始张量共享数据,但形状可能不同。如果原始张量的内存布局不允许重新形状,reshape() 将返回一个副本。
  • squeeze() 用于删除张量中大小为 1 的维度。unsqueeze() 用于在指定位置插入大小为 1 的维度。
  • expand() 方法用于在指定维度上扩展张量的形状,但不会复制数据。它使用广播机制来扩展形状。

视图机制允许您通过创建不同的张量视图,以不同的方式查看底层数据,而无需复制数据。这对于在不同形状之间共享数据和减少内存消耗非常有用。但请注意,在某些情况下,例如非连续内存块或无法重塑的情况下,这些操作可能会生成副本而不是视图。

torch.expand & torch.repeat

torch.expand() 用于在现有张量的指定维度上扩展张量的形状,使其匹配目标形状。它通过在指定维度上重复元素来实现形状的扩展,但并不复制数据,只是改变了张量的视图。因此,它不会增加内存消耗。

torch.repeat() 用于在指定维度上重复复制张量的内容,从而实现形状的扩展。它会复制原始数据,因此可能会导致内存消耗增加。

总结区别:

  • torch.expand() 主要用于在指定维度上扩展张量形状,不复制数据,不增加内存消耗。
  • torch.repeat() 主要用于在指定维度上重复复制张量内容,会复制数据,可能增加内存消耗。

torch.gather 函数的使用

torch.gather 函数是对取索引非常方便的一种使用方法,它代表了从 A 中按照 B 相对应的索引来取出对应位置的元素,但由于其对函数的张量维数有较高的要求,因此在使用过程中需要额外注意。

torch.gather(input, dim, index, out=None, sparse_grad=False) -> Tensor 是 PyTorch 中的一个函数,用于根据给定的索引在指定维度上从输入张量中收集元素。

函数参数:

input:输入张量,形状为 (N, *),其中 * 表示任意维度。
dim:指定收集操作的维度。
index:包含索引的张量,形状为 (N, *),其中每个索引值指定在 dim 维度上要收集的元素的位置。
out:可选参数,用于指定输出张量的位置。如果未提供,将创建一个新的张量来存储结果。
sparse_grad:一个布尔值,表示是否启用稀疏梯度。默认为 False。
返回值:
返回一个新的张量,其形状与 index 相同,其中每个元素是根据索引从 input 张量中收集的元素。

举一个具体的例子,假设有 features [bs, V, C], 并且想根据 ind [bs, N], 将这些 features 按照 ind 这个索引来取出对应的特征。
其中 ind 的取值范围在$(0, V-1]$,且是一个整数类型的 tensor。那么,我们则应该将 dim 选择为 1,即按照第一维进行选择。
由于此时 ind 和 features 的形状不同,我们首先需要进行扩展,才能取到对应的特征,最终实现的效果是这样的:

1
features_per_ind = torch.gather(input=features, dim=1, index= ind.unsqueeze(-1).expand(-1, -1, C)) # [bs, N, C]

grid_sample 函数的使用

grid_sample, 即 torch.nn.functional.grid_sample,是一个非常常用的函数,其通过对格点特征、坐标进行插值来得到每个点对应位置的特征。
grid_sample 使用非常常见,通常有 2D 和 3D 两种形式。

2D 形式

1
torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)

参数说明:

input:输入张量,可以是形状为 (B, C, H, W) 的 4D 张量,表示批次大小、通道数、高度和宽度。
grid:采样网格,可以是形状为 (B, H, W, 2) 的 4D 张量,其中最后一个维度的 2 个值表示 x 和 y 的采样坐标。
mode:采样模式,可以是 ‘bilinear’(双线性插值)或 ‘nearest’(最近邻插值)。
padding_mode:填充模式,用于处理采样点位于输入张量边界之外的情况。可选值包括 ‘zeros’、’border’、’reflection’。
align_corners:一个布尔值,表示是否根据像素中心对齐网格。通常在使用 ‘bilinear’ 插值时设置为 True。

请注意,输入张量的范围应该在 [0, 1] 之间,以便在采样时映射到对应的坐标。

3D 形式

原地操作(inplace)

张量的原地(in-place)操作是指直接修改输入张量的值而不创建新张量的操作。这些操作通常更内存效率,因为它们不需要为结果分配新的内存。原地操作通常通过在操作名后加一个下划线(_)来表示。例如,add_,sub_,mul_,div_ 等。
还有经常使用的 inplace 的激活函数:F.relu_() or ReLu(inplace=True)等等。
需要注意的是,原地操作在某些情况下可能会导致问题,特别是在计算梯度时。因为原地操作会直接修改张量的值,这可能会破坏计算图,从而导致无法正确地计算梯度。

有一个特别经常的 bug: pytorch 在取 mask 内的 tensor 并进行 inplace 填充的时候,可能会产生与预期不符的结果:
features[mask1] = features_valid 这种情况,并不会导致问题,其对应位置的特征会被正确填充。但这个过程如果直接进行两步,pytorch 的默认操作则是会创建一个新的内存空间,将这片新的内存进行写入,从而不会改变原先的张量:
features[mask1][mask2] = features_valid
这种情况显然与预期的结果不符,因此在使用过程中必须要避免掉这样的 inplace 操作,以防产生赋值错误。