Pytorch 的 7 種張量介紹

作者 | 秦一 @知乎(已授權)編輯 | 極市平臺

來源 | https://zhuanlan.zhihu.com/p/399350505

1 Tensor 的裁剪運算

2 Tensor 的索引與數據篩選

import torch
#torch.where

a = torch.rand(4, 4)
b = torch.rand(4, 4)

print(a)
print(b)

out = torch.where(a > 0.5, a, b)

print(out)

print("torch.index_select")
a = torch.rand(4, 4)
print(a)
out = torch.index_select(a, dim=0,
                   index=torch.tensor([0, 3, 2]))
#dim=0按列,index取的是行
print(out, out.shape)

print("torch.gather")
a = torch.linspace(1, 16, 16).view(4, 4)

print(a)

out = torch.gather(a, dim=0,
             index=torch.tensor([[0, 1, 1, 1],
                                 [0, 1, 2, 2],
                                 [0, 1, 3, 3]]))
print(out)
print(out.shape)
#注:從0開始,第0列的第0個,第一列的第1個,第二列的第1個,第三列的第1個,,,以此類推
#dim=0, out[i, j, k] = input[index[i, j, k], j, k]
#dim=1, out[i, j, k] = input[i, index[i, j, k], k]
#dim=2, out[i, j, k] = input[i, j, index[i, j, k]]

print("torch.masked_index")
a = torch.linspace(1, 16, 16).view(4, 4)
mask = torch.gt(a, 8)
print(a)
print(mask)
out = torch.masked_select(a, mask)
print(out)

print("torch.take")
a = torch.linspace(1, 16, 16).view(4, 4)

b = torch.take(a, index=torch.tensor([0, 15, 13, 10]))

print(b)

#torch.nonzero
print("torch.take")
a = torch.tensor([[0, 1, 2, 0][2, 3, 0, 1]])
out = torch.nonzero(a)
print(out)
#稀疏表示

3 Tensor 的組合 / 拼接

print("torch.stack")
a = torch.linspace(1, 6, 6).view(2, 3)
b = torch.linspace(7, 12, 6).view(2, 3)
print(a, b)
out = torch.stack((a, b)dim=2)
print(out)
print(out.shape)

print(out[:, :, 0])
print(out[:, :, 1])

4 Tensor 的切片

5 Tensor 的變形操作

import torch
a = torch.rand(2, 3)
print(a)
out = torch.reshape(a, (3, 2))
print(out)

print(a)
print(torch.flip(a, dims=[2, 1]))

print(a)
print(a.shape)
out = torch.rot90(a, -1, dims=[0, 2]) #順時針旋轉90°  
print(out)
print(out.shape)

6 Tensor 的填充操作

7 Tensor 的頻譜操作(傅里葉變換)

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/zCJDBUN581X3_pDsG1hCOw