pandas.DataFrameをGroupByでグルーピングし統計量を算出

pandas.DataFrame, pandas.Seriesのgroupby()メソッドでデータをグルーピング(グループ分け)できます。グループごとにデータを集約して、それぞれの平均、最小値、最大値、合計などの統計量を算出したり、任意の関数で処理したりすることが可能。 マルチインデックスを設定することでも同様の処理ができます。以下の記事を参照。

また、pandas.pivot_table(), pandas.crosstab()という関数を用いてカテゴリごとの統計量やサンプル数を算出することもできます。この方法が一番シンプルかも知れない。

ここでは以下の内容について説明します。

irisデータセット groupby()でグルーピング 平均、最小値、最大値、合計などを算出 任意の処理を適用して集約: agg() 主要な統計量を一括算出: describe() グラフをプロット

irisデータセット

例としてirisデータセットを使用します。 irisデータセットについては以下の記事を参照。

ここではseabornにサンプルとして含まれているデータを使います。

import pandas as pd
import seaborn as sns
import numpy as np

df = sns.load_dataset("iris")
print(df.shape)
# (150, 5)

print(df.head(5))
#    sepal_length  sepal_width  petal_length  petal_width species
# 0           5.1          3.5           1.4          0.2  setosa
# 1           4.9          3.0           1.4          0.2  setosa
# 2           4.7          3.2           1.3          0.2  setosa
# 3           4.6          3.1           1.5          0.2  setosa
# 4           5.0          3.6           1.4          0.2  setosa

スペースを削減するために省略した列名に変更しておく。

df.columns = ['sl', 'sw', 'pl', 'pw', 'species']
print(df.head(5))
#     sl   sw   pl   pw species
# 0  5.1  3.5  1.4  0.2  setosa
# 1  4.9  3.0  1.4  0.2  setosa
# 2  4.7  3.2  1.3  0.2  setosa
# 3  4.6  3.1  1.5  0.2  setosa
# 4  5.0  3.6  1.4  0.2  setosa

groupby()でグルーピング

pandas.DataFrameのgroupby()メソッドでグルーピング(グループ分け)します。

pandas.DataFrame.groupby — pandas 0.23.1 documentation

引数に列名を指定するとその列の値ごとにグルーピングされる。 返されるのはGroupByオブジェクトでそれ自体をprint()で出力しても中身は表示されない。

grouped = df.groupby('species')
print(grouped)
# <pandas.core.groupby.groupby.DataFrameGroupBy object at 0x10c69f6a0>

print(type(grouped))
# <class 'pandas.core.groupby.groupby.DataFrameGroupBy'>

pandas.Seriesにも同様にgroupby()メソッドが用意されています。

pandas.Series.groupby — pandas 0.23.1 documentation

GroupByオブジェクトのメソッド一覧は以下の公式ドキュメント参照。

API Reference — pandas 0.23.1 documentation

size()メソッドでそれぞれのグループごとのサンプル数が確認できます。

print(grouped.size())
# species
# setosa        50
# versicolor    50
# virginica     50
# dtype: int64

平均、最小値、最大値、合計などを算出

GroupByオブジェクトに対しmean(), min(), max(), sum()などのメソッドを適用すると、グループごとの平均、最小値、最大値、合計などの統計量を算出できます。

print(grouped.mean())
#                sl     sw     pl     pw
# species                               
# setosa      5.006  3.428  1.462  0.246
# versicolor  5.936  2.770  4.260  1.326
# virginica   6.588  2.974  5.552  2.026

print(grouped.min())
#              sl   sw   pl   pw
# species                       
# setosa      4.3  2.3  1.0  0.1
# versicolor  4.9  2.0  3.0  1.0
# virginica   4.9  2.2  4.5  1.4

print(grouped.max())
#              sl   sw   pl   pw
# species                       
# setosa      5.8  4.4  1.9  0.6
# versicolor  7.0  3.4  5.1  1.8
# virginica   7.9  3.8  6.9  2.5

print(grouped.sum())
#                sl     sw     pl     pw
# species                               
# setosa      250.3  171.4   73.1   12.3
# versicolor  296.8  138.5  213.0   66.3
# virginica   329.4  148.7  277.6  101.3

そのほか標準偏差std()、分散var()などもあります。 いずれのメソッドも新たなpandas.DataFrameを返す。

print(type(grouped.mean()))
# <class 'pandas.core.frame.DataFrame'>

任意の処理を適用して集約: agg()

GroupByオブジェクトのagg()メソッドで任意の処理を適用することができます。

pandas.core.groupby.DataFrameGroupBy.agg — pandas 0.23.1 documentation

引数に適用したい関数を指定します。関数などの呼び出し可能オブジェクト(callable)または関数名の文字列で指定可能。

print(grouped.agg(min))
#              sl   sw   pl   pw
# species                       
# setosa      4.3  2.3  1.0  0.1
# versicolor  4.9  2.0  3.0  1.0
# virginica   4.9  2.2  4.5  1.4

print(grouped.agg('max'))
#              sl   sw   pl   pw
# species                       
# setosa      5.8  4.4  1.9  0.6
# versicolor  7.0  3.4  5.1  1.8
# virginica   7.9  3.8  6.9  2.5

なお、組み込み関数にないmean()などはmeanと指定するとエラーになる。NumPyの関数np.meanか文字列'mean'として指定します。

# print(grouped.agg(mean))
# NameError: name 'mean' is not defined

print(grouped.agg(np.mean))
#                sl     sw     pl     pw
# species                               
# setosa      5.006  3.428  1.462  0.246
# versicolor  5.936  2.770  4.260  1.326
# virginica   6.588  2.974  5.552  2.026

print(grouped.agg('mean'))
#                sl     sw     pl     pw
# species                               
# setosa      5.006  3.428  1.462  0.246
# versicolor  5.936  2.770  4.260  1.326
# virginica   6.588  2.974  5.552  2.026

NumPyの関数np.meanはpd.np.meanとして指定することも可能。

リストで指定すると複数の処理を同時に適用できます。この場合は結果のpandas.DataFrameのcolumnsがマルチインデックスになる。

print(grouped.agg([min, 'max']))
#              sl        sw        pl        pw     
#             min  max  min  max  min  max  min  max
# species                                           
# setosa      4.3  5.8  2.3  4.4  1.0  1.9  0.1  0.6
# versicolor  4.9  7.0  2.0  3.4  3.0  5.1  1.0  1.8
# virginica   4.9  7.9  2.2  3.8  4.5  6.9  1.4  2.5

列名をキーとした辞書(dict型オブジェクト)で列ごとに異なる処理を適用することも可能。

print(grouped.agg({'sl': min, 'sw': max, 'pl': np.mean, 'pw': 'mean'}))
#              sl   sw     pl     pw
# species                           
# setosa      4.3  4.4  1.462  0.246
# versicolor  4.9  3.4  4.260  1.326
# virginica   4.9  3.8  5.552  2.026

無名関数(ラムダ式)でも問題ありません。

print(grouped.agg(lambda x: max(x) - min(x)))
#              sl   sw   pl   pw
# species                       
# setosa      1.5  2.1  0.9  0.5
# versicolor  2.1  1.4  2.1  0.8
# virginica   3.0  1.6  2.4  1.1

ラムダ式に対しては各グループの値がpandas.Seriesとして渡される。

print(grouped.agg(lambda x: type(x))['sl'])
# species
# setosa        <class 'pandas.core.series.Series'>
# versicolor    <class 'pandas.core.series.Series'>
# virginica     <class 'pandas.core.series.Series'>
# Name: sl, dtype: object

pandas.Seriesを受け取って一つのオブジェクトを返すラムダ式でないとエラーになるので注意。

# print(grouped.agg(lambda x: x + 1))
# Exception: Must produce aggregated value

文字列の要素に対して処理した例は以下の記事の最後を参照。

複数の統計量を一括算出: describe()

describe()メソッドを使うとグループごとの主要な統計量を一括で算出できます。

pandas.core.groupby.DataFrameGroupBy.describe — pandas 0.23.1 documentation

以下の例ではsl列に対する結果のみ出力しています。

print(grouped.describe()['sl'])
#             count   mean       std  min    25%  50%  75%  max
# species                                                      
# setosa       50.0  5.006  0.352490  4.3  4.800  5.0  5.2  5.8
# versicolor   50.0  5.936  0.516171  4.9  5.600  5.9  6.3  7.0
# virginica    50.0  6.588  0.635880  4.9  6.225  6.5  6.9  7.9

各項目の意味などは以下の記事を参照。

グループごとの統計量のグラフをプロット

上述のようにGroupByオブジェクトに対しmean(), min(), max(), sum()などのメソッドを適用すると返ってくるのはpandas.DataFrameなので、そのままplot()メソッドを使ってグラフを描画して可視化できます。

print(type(grouped.max()))
# <class 'pandas.core.frame.DataFrame'>

ax = grouped.max().plot.bar(rot=0)
fig = ax.get_figure()
fig.savefig('data/dst/iris_pandas_groupby_max.jpg')

plot()についての詳細は以下の記事を参照。

シェア

関連カテゴリー

Python pandas

pandasで特定の文字列を含む行を抽出(完全一致、部分一致) pandas.DataFrameから特定の型dtypeの列を抽出(選択) pandas.DataFrame, Seriesの行をランダムソート(シャッフル) pandasで複数条件のand, or, notから行を抽出(選択) 『Pythonデータサイエンスハンドブック』は良書(NumPy, pandasほか) pandas.DataFrame, Seriesを連結するconcat pandas.DataFrame, Seriesをソートするsort_values, sort_index pandas.DataFrameの複数の列の文字列を結合して新たな列を生成 pandasの文字列を区切り文字や正規表現で複数の列に分割 pandas.Seriesのmapメソッドで列の要素を置換 pandas.DataFrameの行名・列名の変更 pandas.DataFrameに列や行を追加(assign, appendなど) pandasの文字列から正規表現で抽出して新たな列を生成 pandasでExcelファイル(xlsx, xls)の読み込み(read_excel) pandasで窓関数を適用するrollingを使って移動平均などを算出

Last Updated: 6/26/2019, 10:34:03 PM