Pytorch 的 7 種張量介紹
作者 | 秦一 @知乎(已授權)編輯 | 極市平臺
來源 | https://zhuanlan.zhihu.com/p/399350505
1 Tensor 的裁剪運算
-
對 Tensor 中的元素進行範圍過濾
-
常用於梯度裁剪(gradient clipping),即在發生梯度離散或者梯度爆炸時對梯度的處理
-
torch.clamp(input, min, max, out=None) → Tensor:將輸入
input
張量每個元素的夾緊到區間 [min,max],並返回結果到一個新張量。
2 Tensor 的索引與數據篩選
-
torch.where(codition,x,y): 按照條件從 x 和 y 中選出滿足條件的元素組成新的 tensor,輸入參數 condition:條件限制,如果滿足條件,則選擇 a,否則選擇 b 作爲輸出。
-
torch.gather(input,dim,index,out=None): 在指定維度上按照索引賦值輸出 tensor
-
torch.inex_select(input,dim,index,out=None): 按照指定索引賦值輸出 tensor
-
torch.masked_select(input,mask,out=None): 按照 mask 輸出 tensor,輸出爲向量
-
torch.take(input,indices): 將輸入看成 1D-tensor,按照索引得到輸出 tensor
-
torch.nonzero(input,out=None): 輸出非 0 元素的座標
import torch
#torch.where
a = torch.rand(4, 4)
b = torch.rand(4, 4)
print(a)
print(b)
out = torch.where(a > 0.5, a, b)
print(out)
print("torch.index_select")
a = torch.rand(4, 4)
print(a)
out = torch.index_select(a, dim=0,
index=torch.tensor([0, 3, 2]))
#dim=0按列,index取的是行
print(out, out.shape)
print("torch.gather")
a = torch.linspace(1, 16, 16).view(4, 4)
print(a)
out = torch.gather(a, dim=0,
index=torch.tensor([[0, 1, 1, 1],
[0, 1, 2, 2],
[0, 1, 3, 3]]))
print(out)
print(out.shape)
#注:從0開始,第0列的第0個,第一列的第1個,第二列的第1個,第三列的第1個,,,以此類推
#dim=0, out[i, j, k] = input[index[i, j, k], j, k]
#dim=1, out[i, j, k] = input[i, index[i, j, k], k]
#dim=2, out[i, j, k] = input[i, j, index[i, j, k]]
print("torch.masked_index")
a = torch.linspace(1, 16, 16).view(4, 4)
mask = torch.gt(a, 8)
print(a)
print(mask)
out = torch.masked_select(a, mask)
print(out)
print("torch.take")
a = torch.linspace(1, 16, 16).view(4, 4)
b = torch.take(a, index=torch.tensor([0, 15, 13, 10]))
print(b)
#torch.nonzero
print("torch.take")
a = torch.tensor([[0, 1, 2, 0], [2, 3, 0, 1]])
out = torch.nonzero(a)
print(out)
#稀疏表示
3 Tensor 的組合 / 拼接
-
torch.cat(seq,dim=0,out=None): 按照已經存在的維度進行拼接
-
torch.stack(seq,dim=0,out=None): 沿着一個新維度對輸入張量序列進行連接。序列中所有的張量都應該爲相同形狀。
print("torch.stack")
a = torch.linspace(1, 6, 6).view(2, 3)
b = torch.linspace(7, 12, 6).view(2, 3)
print(a, b)
out = torch.stack((a, b), dim=2)
print(out)
print(out.shape)
print(out[:, :, 0])
print(out[:, :, 1])
4 Tensor 的切片
-
torch.chunk(tensor,chunks,dim=0): 按照某個維度平均分塊(最後一個可能小於平均值)
-
torch.split(tensor,split_size_or_sections,dim=0): 按照某個維度依照第二個參數給出的 list 或者 int 進行分割 tensor
5 Tensor 的變形操作
-
torch().reshape(input,shape)
-
torch().t(input): 只針對 2D tensor 轉置
-
torch().transpose(input,dim0,dim1): 交換兩個維度
-
torch().squeeze(input,dim=None,out=None): 去除那些維度大小爲 1 的維度
-
torch().unbind(tensor,dim=0): 去除某個維度
-
torch().unsqueeze(input,dim,out=None): 在指定位置添加維度, dim=-1 在最後添加
-
torch().flip(input,dims): 按照給定維度翻轉張量
-
torch().rot90(input,k,dims): 按照指定維度和旋轉次數進行張量旋轉
import torch
a = torch.rand(2, 3)
print(a)
out = torch.reshape(a, (3, 2))
print(out)
print(a)
print(torch.flip(a, dims=[2, 1]))
print(a)
print(a.shape)
out = torch.rot90(a, -1, dims=[0, 2]) #順時針旋轉90°
print(out)
print(out.shape)
6 Tensor 的填充操作
- torch.full((2,3),3.14)
7 Tensor 的頻譜操作(傅里葉變換)
本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源:https://mp.weixin.qq.com/s/zCJDBUN581X3_pDsG1hCOw