ろぐれこーど

限界組み込みエンジニアの学習記録とちょっぴりポエム

(numpy)特定の軸に対して全て同じ値を入れる

めも。
numpyで、特定の軸に対して全て同じ値やarrayを入れたい時があった。例えば二次元の場合、

a = np.zeros((5, 5))
#array([[ 0.,  0.,  0.,  0.,  0.],
#       [ 0.,  0.,  0.,  0.,  0.],
#       [ 0.,  0.,  0.,  0.,  0.],
#       [ 0.,  0.,  0.,  0.,  0.],
#       [ 0.,  0.,  0.,  0.,  0.]])

に対して、全ての列に[0, 1, 2, 3, 4]を入れたいとする。その場合、以下のような操作を行うことで実現できる。

b = np.arange(5)
# array([0, 1, 2, 3, 4])
a[:, :] = b[:, None]

# a
#array([[ 0.,  0.,  0.,  0.,  0.],
#       [ 1.,  1.,  1.,  1.,  1.],
#       [ 2.,  2.,  2.,  2.,  2.],
#       [ 3.,  3.,  3.,  3.,  3.],
#       [ 4.,  4.,  4.,  4.,  4.]])

# b[:, None]
#array([[0],
#       [1],
#       [2],
#       [3],
#       [4]])

bのshapeは(5,)となっており、1次元arrayである。ここにb[:, None]と書くことで、Noneの場所にaxisが追加されることになる。この場合、b.shapeは(5, 1)となる。b.reshape(-1, 1) と結果は同じである。

対象とするaxis以外のshapeが等しい場合、そのaxisに渡って全ての値が代入される(上の例ではaxis=1に対して、同じ値[0, 1, 2, 3, 4]が代入された)。3次元以上の場合も同様に書ける。

a = np.zeros((5, 5, 3))

# a[:, :, 0]
#array([[ 0.,  0.,  0.,  0.,  0.],
#       [ 0.,  0.,  0.,  0.,  0.],
#       [ 0.,  0.,  0.,  0.,  0.],
#       [ 0.,  0.,  0.,  0.,  0.],
#       [ 0.,  0.,  0.,  0.,  0.]])

b = np.arange(25).reshape(5, 5)

# b
#array([[ 0,  1,  2,  3,  4],
#       [ 5,  6,  7,  8,  9],
#       [10, 11, 12, 13, 14],
#       [15, 16, 17, 18, 19],
#       [20, 21, 22, 23, 24]])

a[:, :, :] = b[:, :, None]
# a[:, :, 0], それ以降も同様
#array([[ 0,  1,  2,  3,  4],
#       [ 5,  6,  7,  8,  9],
#       [10, 11, 12, 13, 14],
#       [15, 16, 17, 18, 19],
#       [20, 21, 22, 23, 24]])

今までわざわざ軸に関してforループを回していたが、これ一行で書けることを知った。shapeさえあっていれば任意の軸で実行可能である。

a = np.zeros((5, 5, 3))
b = np.arange(15).reshape(5, 3)

# b
#array([[ 0,  1,  2],
#       [ 3,  4,  5],
#       [ 6,  7,  8],
#       [ 9, 10, 11],
#       [12, 13, 14]])

a[:, :, :] = b[:, None, :]
# a[:, 0, :]
#array([[ 0,  1,  2],
#       [ 3,  4,  5],
#       [ 6,  7,  8],
#       [ 9, 10, 11],
#       [12, 13, 14]])

# a[:, :, 0]
#array([[  0.,   0.,   0.,   0.,   0.],
#       [  3.,   3.,   3.,   3.,   3.],
#       [  6.,   6.,   6.,   6.,   6.],
#       [  9.,   9.,   9.,   9.,   9.],
#       [ 12.,  12.,  12.,  12.,  12.]])