0%

torch.enisum

简结的矩阵运算方法 - Free Indices:出现在output的indices - Summation Indices:所有其余的indices,只出现在input没有出现在output - example:

1
2
3
a=torch.rand(5)
b=torch.rand(3)
outer=torch.einsum('i,j->ij',a,b)
>这里没有free indices,free indices只能在output出现

RULES

  • Repeating letters in different inputs means those values will be multiplied and those products will be the output. example:
    1
    M=torch.einsum('ik,kj->ij',A,B)
    ,A的第i行的元素会和B的第j列的元素逐个相乘,对应下一条规则,由于k省略了,相乘的结果会相加
  • Omitting a letter means that axis will be summed. example:
    1
    2
    x=torch.ones(3)
    sum_x=torch.einsum('i->',x)
    i在output中省略了,所以在i对应的这一个维度元素相加
  • We can return the unsummed axes in any order example:
    1
    torch.einsum('ijk->kji')
    对于这些没有相加的维度,我们可以在output中交换维度进行reshape ## 实例

1. 外积操作

1
2
3
4
a = torch.rand(5)
b = torch.rand(3)
outer = torch.einsum('i,j->ij', a, b)
print(outer.shape) # 输出: torch.Size([5, 3])
  • 描述:计算向量 ab 的外积,输出大小为 (5, 3)

2. 矩阵乘法

1
2
3
4
A = torch.rand(2, 3)
B = torch.rand(3, 4)
M = torch.einsum('ik,kj->ij', A, B)
print(M.shape) # 输出: torch.Size([2, 4])
  • 描述:计算矩阵 AB 的乘积,输出大小为 (2, 4)

3. 向量求和

1
2
3
x = torch.ones(3)
sum_x = torch.einsum('i->', x)
print(sum_x) # 输出: 3.0
  • 描述:沿着 i 维度求和,输出一个标量。

4. 张量维度交换

1
2
3
tensor = torch.rand(2, 3, 4)
swapped = torch.einsum('ijk->kji', tensor)
print(swapped.shape) # 输出: torch.Size([4, 3, 2])
  • 描述:交换张量的维度,ijk -> kji,输出大小为 (4, 3, 2)

5. 向量内积

1
2
3
4
a = torch.rand(5)
b = torch.rand(5)
inner = torch.einsum('i,i->', a, b)
print(inner) # 输出: 一个标量,表示向量 a 和 b 的内积
  • 描述:计算向量 ab 的内积,输出一个标量。

6. 张量积(高阶外积)

1
2
3
4
A = torch.rand(2, 3)
B = torch.rand(3, 4)
tensor_product = torch.einsum('ij,jk->ik', A, B)
print(tensor_product.shape) # 输出: torch.Size([2, 4])
  • 描述:计算张量 AB 的乘积,输出大小为 (2, 4)

7. 广播操作

1
2
3
4
A = torch.rand(2, 1, 3)
B = torch.rand(3, 4)
result = torch.einsum('ijk,jl->ikl', A, B)
print(result.shape) # 输出: torch.Size([2, 3, 4])
  • 描述:张量 AB 进行广播后进行矩阵乘法,输出大小为 (2, 3, 4)

8. 多个求和轴

1
2
3
4
5
A = torch.rand(2, 3)
B = torch.rand(3, 4)
C = torch.rand(4, 5)
result = torch.einsum('ij,jk,kl->il', A, B, C)
print(result.shape) # 输出: torch.Size([2, 5])
  • 描述:进行三个矩阵的乘法,依次沿 jk 轴求和,输出大小为 (2, 5)

9. 张量的迹 (Trace)

1
2
3
A = torch.rand(3, 3)
trace_A = torch.einsum('ii->', A)
print(trace_A) # 输出: 张量 A 的迹(对角线元素之和)
  • 描述:计算矩阵 A 的迹,等价于 torch.trace(A),输出一个标量。

10. 高维矩阵乘法

1
2
3
4
A = torch.rand(2, 3, 4)
B = torch.rand(4, 5, 6)
result = torch.einsum('abc,cdx->abdx', A, B)
print(result.shape) # 输出: torch.Size([2, 3, 5, 6])
  • 描述:计算两个高维张量的乘法,结果的维度为 (2, 3, 5, 6)

11. 跨维度求和

1
2
3
A = torch.rand(4, 5, 6)
sum_result = torch.einsum('ijk->ik', A)
print(sum_result.shape) # 输出: torch.Size([4, 6])
  • 描述:沿 j 维度求和,输出大小为 (4, 6)

12. 矩阵-向量乘法

1
2
3
4
A = torch.rand(3, 4)
x = torch.rand(4)
result = torch.einsum('ij,j->i', A, x)
print(result.shape) # 输出: torch.Size([3])
  • 描述:计算矩阵 A 和向量 x 的乘积,输出一个大小为 (3,) 的向量。

13. 张量元素乘法

1
2
3
4
A = torch.rand(2, 3)
B = torch.rand(2, 3)
result = torch.einsum('ij,ij->ij', A, B)
print(result.shape) # 输出: torch.Size([2, 3])
  • 描述:按元素进行矩阵 AB 的乘法,输出一个大小为 (2, 3) 的张量。

14. 多维矩阵积

1
2
3
4
A = torch.rand(2, 3, 4)
B = torch.rand(4, 5, 6)
result = torch.einsum('abc,cde->abe', A, B)
print(result.shape) # 输出: torch.Size([2, 3, 6])
  • 描述:进行多个维度的矩阵积,结果大小为 (2, 3, 6)