NumPyのeyeまたはidentityでone-hot表現に変換

one-hot表現とは 1つだけが1(high)で、それ以外は0(low)のビット列をone-hotと呼ぶ。1-of-K表現とも呼ばれる。

One-hot - Wikipedia

ちなみに、1つだけが0でそれ以外が1であるビット列をone-coldと呼ぶこともあります。らしい。 TensorFlowなどの機械学習で分類を行う際には、正解ラベルをone-hotで表現する必要があります。例えば、手書き数字(0〜9の10種類)のデータセットであるMNISTで正解となるラベルが2の場合、one-hotで表すと、[0,0,1,0,0,0,0,0,0,0]となります。 NumPyのeye関数またはidentity関数を使うと簡単にone-hot表現に変換できます。

numpy.eye()

numpy.eye()は、1が斜めに並んで、それ以外は0となる2次元のndarrayを返す関数。

numpy.eye — NumPy v1.13 Manual

e = np.eye(4)
print(type(e))
print(e)
print(e.dtype)
# <class 'numpy.ndarray'>
# [[ 1.  0.  0.  0.]
#  [ 0.  1.  0.  0.]
#  [ 0.  0.  1.  0.]
#  [ 0.  0.  0.  1.]]
# float64

デフォルトのデータ型はfloat64。引数dtypeでデータ型を指定できます。

e = np.eye(4, M=3, k=1, dtype=np.int8)
print(e)
print(e.dtype)
# [[0 1 0]
#  [0 0 1]
#  [0 0 0]
#  [0 0 0]]
# int8

引数Mで列のサイズ、kで1の始まり位置を変えられる。

numpy.identity()

numpy.identity()は名前の通り、単位行列(identity matrix)を返す関数。

numpy.identity — NumPy v1.13 Manual

i = np.identity(4)
print(i)
print(i.dtype)
# [[ 1.  0.  0.  0.]
#  [ 0.  1.  0.  0.]
#  [ 0.  0.  1.  0.]
#  [ 0.  0.  0.  1.]]
# float64

デフォルトのデータ型はfloat64。引数dtypeでデータ型を指定できます。

i = np.identity(4, dtype=np.uint8)
print(i)
print(i.dtype)
# [[1 0 0 0]
#  [0 1 0 0]
#  [0 0 1 0]
#  [0 0 0 1]]
# uint8

他の引数はない。 なぜ同じような関数が2つもあるのかと思って、ソースを見てみると、numpy.identity()は内部でnumpy.eye()を呼んでいるだけ。

    from numpy import eye
    return eye(n, dtype=dtype)

numpy.eye()でも単位行列は得られるけど、numpy.identity()というわかりやすい名前の関数も用意してある、ということだろう。 任意の対角行列を生成するための関数numpy.diag()もあります。詳細は以下の記事を参照。

one-hot表現に変換

単位行列があればone-hot表現に変換するのは簡単。 例えば変換元が10種類の場合は、10×10の単位行列を作ってインデックスに変換元の値をいれてやればいい。

a = [3, 0, 8, 1, 9]
a_one_hot = np.identity(10)[a]
print(a)
print(a_one_hot)
# [3, 0, 8, 1, 9]
# [[ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
#  [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
#  [ 0.  0.  0.  0.  0.  0.  0.  0.  1.  0.]
#  [ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
#  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]]

適当な0〜9の値をもつ配列aをone-hot表現に変換しています。 irisデータセットのように正解ラベルが3種類の場合もやり方は同じ。

a = [2, 2, 0, 1, 0]
a_one_hot = np.identity(3)[a]
print(a)
print(a_one_hot)
# [2, 2, 0, 1, 0]
# [[ 0.  0.  1.]
#  [ 0.  0.  1.]
#  [ 1.  0.  0.]
#  [ 0.  1.  0.]
#  [ 1.  0.  0.]]

データ型は引数dtypeで適宜指定すれば問題ありません。 例ではわかりやすい名前のnumpy.identity()を使っているが、numpy.eye()でも同じ。お好みで。

シェア

関連カテゴリー

Python NumPy

NumPyで条件に応じた処理を行うwhereの使い方 NumPy配列ndarrayから条件を満たす要素・行・列を抽出、削除 NumPyで全要素を同じ値で初期化した配列ndarrayを生成 NumPy配列ndarrayの次元をEllipsis(...)で省略して指定 NumPyのarange, linspaceの使い方(連番や等差数列を生成) NumPy配列ndarrayを分割(split, array_split, hsplit, vsplit, dsplit) NumPy配列ndarrayの最大値・最小値のインデックス(位置)を取得 NumPy配列ndarrayとPython標準のリストを相互に変換 Python, OpenCV, NumPyで画像のアルファブレンドとマスク処理 NumPy配列ndarrayの行・列を任意の順番に並べ替え、選択(抽出) NumPy, randomで様々な種類の乱数の配列を生成 NumPy配列ndarrayをシフト(スクロール)させるnp.roll pandas.DataFrame, SeriesとNumPy配列ndarrayを相互に変換 Python, NumPyで行列の演算(逆行列、行列式、固有値など) Pythonでメソッドチェーンを改行して書く

Last Updated: 6/26/2019, 10:34:03 PM