Transpose
주어진 Tensor의 배열을 바꿔준다.
(n, c, h, w)
shape을 가지는 4차원 Tensor가 있다고 할 때, h와 w를 서로 바꿔서 (n, c, w, h)
로 만들 수 있다.
[[[ 1 2 3]
[ 4 5 6]]
[[ 7 8 9]
[10 11 12]]]
위와 같이 (2, 2, 3) Tensor가 있다고 할 때 C, H, W 차원 중 H와 W를 서로 치환한다면,
[[[ 1 4]
[ 2 5]
[ 3 6]]
[[ 7 10]
[ 8 11]
[ 9 12]]]
위와 같이 (2, 3, 2) Tensor로 변형이 된다.
마찬가지로 (2, 2, 3) Tensor의 C <-> W 차원을 치환한다면 아래와 같이 (3, 2, 2) 형태로 변형된다.
[[[ 1 7]
[ 4 10]]
[[ 2 8]
[ 5 11]]
[[ 3 9]
[ 6 12]]
변형하는 방법은 데이터를 읽는 순서를 생각해보면 쉽게 이해할 수 있다.
NCHW
format은 데이터가 W, H, C, N
순서대로 배열되어 있다.
즉, 위의 (2, 2, 3) Tensor의 예를 살펴보면 데이터는 1, 2, 3, ... , 11, 12
순서로 되어 있다.
순서대로 W 방향으로 1, 2, 3 다음으로 H 방향으로 내려와서 4, 5, 6 더 이상 H 방향으로 데이터가 없으므로 C 방향으로 이동하여 다시 같은 방법으로 7, 8, 9 마지막으로 10, 11, 12 식으로 읽어가는 방법이다. (이 예에서는 N 차원은 없으므로 생략하지만, 방법은 동일하다.)
NCWH
와 같이 H<->W 차원이 서로 치환되는 경우에도 H->W->C->N
순서대로 읽으면 된다.
H 방향으로 1, 4 다음 차원인 W 방향으로 다시 2, 5 W 방향이 아직 남아 있으므로 다시 3, 9
이제 C 방향으로 위와 동일하게 7, 10 그리고 8, 11 마지막으로 9, 12 순서이다.
이렇게 변환되면 원래 1, 2, 3, ...
순서에서 1, 4, 2, 5, ..., 9, 12
와 같이 데이터의 배열이 뒤바뀌게 된다.
흔히 Reshape과 혼동하는 경우가 있는데 Transpose와 Reshape 모두 결과적으로 shape의 모양이 바뀌게 되지만 Reshape은 데이터의 순서가 계속해서 유지되는 반면에 Transpose의 경우에는 데이터의 순서까지 뒤바뀌게 된다.
PyTorch
torch.transpose(input, dim0, dim1) → Tensor
https://pytorch.org/docs/stable/generated/torch.transpose.html
dimension 2개를 입력으로 받아서 서로 swap해준다. 예를들어 (3, 2)의 경우 (2, 3)으로 바꿔줄 수 있다. 서로 Swap하는 것이기 때문에 dim0, dim1의 순서는 의미가 없다.
TensorFlow
tf.transpose(
a, perm=None, conjugate=False, name='transpose'
)
https://www.tensorflow.org/api_docs/python/tf/transpose
TensorFlow의 경우에는 PyTorch와는 다르게 모든 차원의 순서를 바꿀 수 있다. `NCHW` 라면 순서대로 0, 1, 2, 3 index를 갖게 되고 `H<->W`를 서로 바꾸고 싶다면 perm=[0, 1, 3, 2] 와 같이 변경하고자 하는 차원의 순서를 쓰면 된다. 3이 W, 2가 H 이므로 `NCWH` 형태가 된다.
ONNX
ONNX의 경우 TensorFlow와 유사하다.
https://github.com/onnx/onnx/blob/main/docs/Operators.md#Transpose
예제
import tensorflow as tf
import torch
x = tf.constant([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
y = tf.transpose(x, perm=[0, 2, 1])
print(y)
x = torch.tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
y = torch.transpose(x, 2, 1)
print(y)
결과
tf.Tensor(
[[[ 1 4]
[ 2 5]
[ 3 6]]
[[ 7 10]
[ 8 11]
[ 9 12]]], shape=(2, 3, 2), dtype=int32)
tensor([[[ 1, 4],
[ 2, 5],
[ 3, 6]],
[[ 7, 10],
[ 8, 11],
[ 9, 12]]])