pytorch で配列を反転させる

最近NLP関連でpytorchを触る機会が増え、個人的に覚えておきたいことをメモしておきます。

github

  • jupyter notebook形式のファイルはこちら

google colaboratory

  • google colaboratory で実行する場合はこちら

筆者の環境

筆者のOSはmacOSです。LinuxやUnixのコマンドとはオプションが異なります。

!sw_vers
ProductName:	Mac OS X
ProductVersion:	10.14.6
BuildVersion:	18G103
!python -V
Python 3.8.5

基本的なライブラリをインポートしそのバージョンを確認しておきます。

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import matplotlib
import matplotlib.pyplot as plt
import scipy
import numpy as np
import torch

print('matplotlib version :', matplotlib.__version__)
print('scipy version :', scipy.__version__)
print('numpy version :', np.__version__)
print('torch version :', torch.__version__)
matplotlib version : 3.3.2
scipy version : 1.3.1
numpy version : 1.19.2
torch version : 1.10.0

numpyによる反転

a = np.array([i for i in range(10)])
a
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
a[::-1]
array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
a = np.array([[i * j for i in range(10)] for j in range(10)])
a
array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
       [ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18],
       [ 0,  3,  6,  9, 12, 15, 18, 21, 24, 27],
       [ 0,  4,  8, 12, 16, 20, 24, 28, 32, 36],
       [ 0,  5, 10, 15, 20, 25, 30, 35, 40, 45],
       [ 0,  6, 12, 18, 24, 30, 36, 42, 48, 54],
       [ 0,  7, 14, 21, 28, 35, 42, 49, 56, 63],
       [ 0,  8, 16, 24, 32, 40, 48, 56, 64, 72],
       [ 0,  9, 18, 27, 36, 45, 54, 63, 72, 81]])
a[:,::-1]
array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 9,  8,  7,  6,  5,  4,  3,  2,  1,  0],
       [18, 16, 14, 12, 10,  8,  6,  4,  2,  0],
       [27, 24, 21, 18, 15, 12,  9,  6,  3,  0],
       [36, 32, 28, 24, 20, 16, 12,  8,  4,  0],
       [45, 40, 35, 30, 25, 20, 15, 10,  5,  0],
       [54, 48, 42, 36, 30, 24, 18, 12,  6,  0],
       [63, 56, 49, 42, 35, 28, 21, 14,  7,  0],
       [72, 64, 56, 48, 40, 32, 24, 16,  8,  0],
       [81, 72, 63, 54, 45, 36, 27, 18,  9,  0]])

pytrochによる反転

pytorchはa[::-1]のような反転は出来ないので、別の方法で反転させる必要がある。

1次元tensorの反転

a = torch.tensor(range(12))
a
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
torch.flip(a, dims=[0])
tensor([11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0])

2次元tensorの反転

a = a.reshape(3,4)
a
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

X軸で反転させる。

torch.flip(a, dims=[0])
tensor([[ 8,  9, 10, 11],
        [ 4,  5,  6,  7],
        [ 0,  1,  2,  3]])

Y軸で反転させる。

torch.flip(a, dims=[1])
tensor([[ 3,  2,  1,  0],
        [ 7,  6,  5,  4],
        [11, 10,  9,  8]])

左右で反転させる。

torch.fliplr(a)
tensor([[ 3,  2,  1,  0],
        [ 7,  6,  5,  4],
        [11, 10,  9,  8]])

たまに忘れていちいち調べることになるので、覚えておく。