Python, pandas, seabornでヒートマップを作成

Pythonのビジュアライゼーションライブラリseabornを使うと、二次元データを可視化するヒートマップが簡単に作成できます。 seaborn.heatmap()関数を使います。

seaborn.heatmap — seaborn 0.8.1 documentation

pandasは必須ではないが、二次元データとしてpandas.DataFrameを用いると行名・列名がそのままx軸・y軸のラベルとして表示されて便利。 ここでは、以下の内容について説明します。

seaborn.heatmap()関数の基本的な使い方 オブジェクトとして操作

seaborn.heatmap()関数の主な引数 数値を表示: 引数annot カラーバー表示・非表示: 引数cbar 正方形で表示: 引数square 最大値、最小値、中央値を指定: 引数vmax, vmin, center 色(カラーマップ)を指定: 引数cmap サイズを指定

活用例: 多数の特徴量を持つデータの相関係数を可視化

以下のライブラリをimportしておく。

import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

Jupyter Notebookでグラフをインラインで表示したい場合は%matplotlib inlineを実行しておく。

seaborn.heatmap()関数の基本的な使い方

seaborn.heatmap()の第一引数dataには可視化したい二次元配列を指定します。 Pythonのリストの二次元配列(リストのリスト)、numpy.ndarray、pandas.DataFrameを指定できます。 Pythonのリストの二次元配列(リストのリスト)の場合。

list_2d = [[0, 1, 2], [3, 4, 5]]

Jupyter Notebookの場合は%matplotlib inlineを実行してからseaborn.heatmap()を実行するとグラフがインラインで表示される。

sns.heatmap(list_2d)

画像ファイルとして保存する場合はplt.savefig()、ファイル保存ではなくOSの画像表示プログラムで表示する場合はplt.show()を使います。 繰り返しグラフを作成する場合はplt.figure()で新たなFigureを生成するかplt.clf()で初期化しておかないと前の描画結果が残ることがあるので注意。Jupyter Notebookでインライン表示する場合は特に初期化の必要はない。 さらに、複数(初期値では20以上)のFigureを生成すると警告が出る。plt.savefig()またはplt.show()のあとでplt.close('all')を実行しておけば問題ありません。

plt.figure()
sns.heatmap(list_2d)
plt.savefig('data/dst/seaborn_heatmap_list.png')
plt.close('all')

numpy.ndarrayの場合。

arr_2d = np.arange(-8, 8).reshape((4, 4))
print(arr_2d)
# [[-8 -7 -6 -5]
#  [-4 -3 -2 -1]
#  [ 0  1  2  3]
#  [ 4  5  6  7]]

plt.figure()
sns.heatmap(arr_2d)
plt.savefig('data/dst/seaborn_heatmap_ndarray.png')

pandas.DataFrameの場合。pandas.DataFrameだと行名index、列名columnsがx軸・y軸のラベルとして表示される。

df = pd.DataFrame(data=arr_2d, index=['a', 'b', 'c', 'd'], columns=['A', 'B', 'C', 'D'])
print(df)
#    A  B  C  D
# a -8 -7 -6 -5
# b -4 -3 -2 -1
# c  0  1  2  3
# d  4  5  6  7

plt.figure()
sns.heatmap(df)
plt.savefig('data/dst/seaborn_heatmap_dataframe.png')

オブジェクトとして操作

seaborn.heatmap()が返すのはMatplotlibのAxesSubplotオブジェクト。

Axes class — Matplotlib 2.2.2 documentation

print(type(sns.heatmap(list_2d)))
# <class 'matplotlib.axes._subplots.AxesSubplot'>

デフォルトではアクティブなサブプロットに描画されるが、seaborn.heatmap()の引数axで任意のサブプロットを指定して描画できます。

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
sns.heatmap(list_2d, ax=ax)
fig.savefig('data/dst/seaborn_heatmap_list.png')


fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(8, 6))
sns.heatmap(list_2d, ax=axes[0, 0])
sns.heatmap(arr_2d, ax=axes[1, 2])
fig.savefig('data/dst/seaborn_heatmap_list_sub.png')

seaborn.heatmap()関数の主な引数

seaborn.heatmap()で指定できる主な引数を示す。 ここで挙げるもの以外もあります。公式サイトを参照。

seaborn.heatmap — seaborn 0.8.1 documentation

数値を表示: 引数annot

ヒートマップ上に数値を表示する場合はannot=Trueとします。

sns.heatmap(df, annot=True)

カラーバー表示・非表示: 引数cbar

カラーバーを非表示にするにはcbar=Falseとします。

sns.heatmap(df, cbar=False)

正方形で表示: 引数square

square=Trueとするとヒートマップが正方形で描画される。

sns.heatmap(df, square=True)

最大値、最小値、中央値を指定: 引数vmax, vmin, center

ヒートマップの最大値、最小値、中央値はそれぞれvmax, vmin, centerで指定します。

sns.heatmap(df, vmax=10, vmin=-10, center=0)

色(カラーマップ)を指定: 引数cmap

色はcmapで指定します。Matplotlibで使えるカラーマップがそのまま使える。 以下のMatplotlibの公式サイトにカラーマップが挙げられている。

Choosing Colormaps — Matplotlib 2.0.2 documentation

sns.heatmap(df, cmap='hot')

カラーマップの文字列に_rを追加すると色の順番が逆になる。

sns.heatmap(df, cmap='Blues')


sns.heatmap(df, cmap='Blues_r')

サイズを指定

これはseaborn.heatmap()の引数ではないが説明しておく。 生成される画像のサイズはfigsize(単位: インチ)とdpi(インチ当たりのドット数)で決定される。 figsizeはplt.figure()またはplt.subplots()の引数で、dpiはsavefig()の引数で指定します。 それぞれ以下のように確認および変更ができます。

current_figsize = mpl.rcParams['figure.figsize']
print(current_figsize)
# [6.0, 4.0]

plt.figure(figsize=(9, 6))
sns.heatmap(df, square=True)
plt.savefig('data/dst/seaborn_heatmap_big.png')


current_dpi = mpl.rcParams['figure.dpi']
print(current_dpi)
# 72.0

plt.figure()
sns.heatmap(df, square=True)
plt.savefig('data/dst/seaborn_heatmap_big_2.png', dpi=current_dpi * 1.5)

活用例: 多数の特徴量を持つデータの相関係数を可視化

具体的な活用例として、多数の特徴量を持つデータの相関係数を可視化します。 Kaggleの住宅価格を推定する問題のトレーニングデータを使用します。

House Prices: Advanced Regression Techniques | Kaggle

こちらにも置いてあります。

house_prices_train.csv

pandas.DataFrameのメソッドcorr()を使うと、pandas.DataFrameの各列の間の相関係数を算出できます。

df_house = pd.read_csv('data/src/house_prices_train.csv', index_col=0)

df_house_corr = df_house.corr()
print(df_house_corr.shape)
# (37, 37)

print(df_house_corr.head())
#              MSSubClass  LotFrontage   LotArea  OverallQual  OverallCond  \
# MSSubClass     1.000000    -0.386347 -0.139781     0.032628    -0.059316   
# LotFrontage   -0.386347     1.000000  0.426095     0.251646    -0.059213   
# LotArea       -0.139781     0.426095  1.000000     0.105806    -0.005636   
# OverallQual    0.032628     0.251646  0.105806     1.000000    -0.091932   
# OverallCond   -0.059316    -0.059213 -0.005636    -0.091932     1.000000   
#              YearBuilt  YearRemodAdd  MasVnrArea  BsmtFinSF1  BsmtFinSF2  \
# MSSubClass    0.027850      0.040581    0.022936   -0.069836   -0.065649   
# LotFrontage   0.123349      0.088866    0.193458    0.233633    0.049900   
# LotArea       0.014228      0.013788    0.104160    0.214103    0.111170   
# OverallQual   0.572323      0.550684    0.411876    0.239666   -0.059119   
# OverallCond  -0.375983      0.073741   -0.128101   -0.046231    0.040229   
#                ...      WoodDeckSF  OpenPorchSF  EnclosedPorch  3SsnPorch  \
# MSSubClass     ...       -0.012579    -0.006100      -0.012037  -0.043825   
# LotFrontage    ...        0.088521     0.151972       0.010700   0.070029   
# LotArea        ...        0.171698     0.084774      -0.018340   0.020423   
# OverallQual    ...        0.238923     0.308819      -0.113937   0.030371   
# OverallCond    ...       -0.003334    -0.032589       0.070356   0.025504   
#              ScreenPorch  PoolArea   MiscVal    MoSold    YrSold  SalePrice  
# MSSubClass     -0.026030  0.008283 -0.007683 -0.013585 -0.021407  -0.084284  
# LotFrontage     0.041383  0.206167  0.003368  0.011200  0.007450   0.351799  
# LotArea         0.043160  0.077672  0.038068  0.001205 -0.014261   0.263843  
# OverallQual     0.064886  0.065166 -0.031406  0.070815 -0.027347   0.790982  
# OverallCond     0.054811 -0.001985  0.068777 -0.003511  0.043950  -0.077856  
# [5 rows x 37 columns]

pandas.corr()は数値の列のみが対象で、欠損値NaNは除外して算出される。 本来はNaNの補完や文字列のカテゴリー変数の数値化などの必要があり、データをそのまま読み込んで使うのは乱暴ではあるが、各変数の関係性をとりあえずざっくり確認するのに非常に便利。 なお、この例のように変数が多い場合はサイズを大きくしておかないと結果が見にくいので注意。

fig, ax = plt.subplots(figsize=(12, 9))
sns.heatmap(df_house_corr, square=True, vmax=1, vmin=-1, center=0)
plt.savefig('data/dst/seaborn_heatmap_house_price.png')

シェア

関連カテゴリー

Python Matplotlib seaborn pandas 機械学習

Python, pandas, seabornでペアプロット図(散布図行列)を作成 pandasのplotメソッドでグラフを作成しデータを可視化 『Python Data Science Handbook』(英語の無料オンライン版あり) pandas.DataFrameの各列間の相関係数を算出、ヒートマップで可視化 pandasでカテゴリ変数をダミー変数に変換(get_dummies) pandas参考書『Pythonによるデータ分析入門』の注意点 pandas参考書『Python for Data Analysis, 2nd Edition』 pandas, Matplotlib(mpl_finance)でローソク足チャートを作成 『Pythonデータサイエンスハンドブック』は良書(NumPy, pandasほか) seabornでMatplotlibの見た目を良くする pandasでExcelファイル(xlsx, xls)の読み込み(read_excel) pandasの行・列をランダムサンプリング(抽出)するsample pandasで要素、行、列に関数を適用するmap, applymap, apply Pythonのグラフ描画ライブラリMatplotlibの基礎

Last Updated: 6/26/2019, 10:34:03 PM