使用torch.chunk 進行tensor切塊

torch。chunk(input, chunks, dim=0) → List of Tensors

功能:將輸入tensor切分為特定數量的塊(chunk),每個chunk是輸入tensor的一個

檢視

(view)。如果dim對應維度大小無法被chunks整除,實際分片數可能小於chunks指定的數量,或者後面分片的大小會小一些。具體可以參考下面提到的不能整除的例子。

返回:切分後tensor構成的列表。所以可以透過訪問列表的方式來獲取結果相應的屬性,如

len(output), output[0]等。

引數說明:

input

Tensor

) –輸入tensor

chunks (int) –

期望

切分的塊數,

注意這裡只能是整數

dim

int

) – tensor從哪個維度進行切分

可以整除

>>> import torch>>> input = torch。tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12], [13,14,15,16]])>>> inputtensor([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [13, 14, 15, 16]])>>> output = torch。chunk(input, 4, 0)>>> output(tensor([[1, 2, 3, 4]]), tensor([[5, 6, 7, 8]]), tensor([[ 9, 10, 11, 12]]), tensor([[13, 14, 15, 16]]))

不能整除

不能整除時會做向上取整的操作,該例中4/3 則設定為2。

>>> import torch>>> input = torch。tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]])>>> inputtensor([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [13, 14, 15, 16]])>>> output = torch。chunk(input, 3, 0)>>> output(tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), tensor([[ 9, 10, 11, 12], [13, 14, 15, 16]]))>>> len(output)2

該例中輸入維度為(4,4) 在維度0切分為3塊。切分後的結果為維度(2,4) (2,4)的2個tensor構成的列表。

>>> input = torch。randn(8,4)>>> torch。chunk(input,7, 0)(tensor([[ 0。8091, 0。7134, -0。1981, -0。2185], [ 1。0662, -0。7637, -1。8087, -0。6494]]), tensor([[-0。0293, -1。0161, 1。1005, 0。3810], [ 0。5001, -0。5396, -0。3986, 0。1753]]), tensor([[ 1。6392, -0。1882, -0。1103, 0。0731], [-0。4641, 0。0861, -0。7095, -0。5457]]), tensor([[-0。1193, 1。2701, -1。2506, -0。1760], [ 0。0256, -0。2543, -1。8225, 0。1837]]))>>> d = torch。chunk(input,7, 0)>>> len(d)4

這個例子針對維度0進行切分,8/7 向上取整使用2做為chunk_size進行分塊,只能分為4塊。