【Python】線形回帰分析の実践方法
- 2021.06.08
- Python
Pythonで様々なライブラリを用いて線形回帰をしてみたいと思います。
(本来はトレーニングデータとテストデータを分けて行うべきですが、簡単にするために下記で示す例では、モデル作成と予測に同じデータを使っています。)
Numpyを使った線形回帰
使用データはUSDJPYの為替レートを被説明変数、米国10年債利回りを説明変数として使うことにする。
まずは、必要なライブラリのインポートとpandas-datareaderを使ってFREDからのデータ取得を行う。
import numpy as np
import pandas_datareader.data as web
df = web.DataReader(['DEXJPUS','DGS10'],'fred',start='2015-01-01').dropna()
df.columns = ['JPY=','US10Y'] # わかり安いように列名を変更
必要データが用意できたところで、numpyを使って線形回帰を行う。
np.polyfit(x, y, deg=n次数)のように指定してモデルを作成。
その後は、polyval()またはpoly1d()を使ってモデルにデータをあてはめる。
model = np.polyfit(df['US10Y'],df['JPY='],deg=1)
result = np.polyval(model,df['US10Y'])
''' poly1dを使う場合
result = np.poly1d(model)(df['US10Y'])
'''
結果の可視化もしてみる。ここではcufflinksを利用してグラフを作成する。
import cufflinks
fig = df.iplot(kind='scatter',x='US10Y' ,y='JPY=',
asFigure=True, theme='solar', mode='markers', size=4.0)
# kind = 'グラフ種類を指定'
# theme = '使いたいテーマがあればそれを指定'
# mode = マーカー(点)を指定
# size = マーカーサイズの指定
fig.add_scatter(x=df['US10Y', y=result, name='result') # 回帰結果のプロット
fig.show()
結果はこのようになる。(↑はスクリーンショットですが、cufflinksで作っているため、実際には対話的にカーソルをあわせてりすると値を表示してくれます)
Scipyを使った線形回帰
次にscipyを使って線形回帰を行う。
numpyを使う場合とは結果の内容が異なり、yの予測値を返すのではなく、切片や係数を返す。
from scipy.stats import linregress
model_scipy = linregress(df['US10Y'],df['JPY=']) # モデルの作成
'''結果
LinregressResult(slope=3.443315677621587, intercept=104.19556887847025, rvalue=0.418666652240735,
pvalue=6.384388783565978e-69, stderr=0.18684130034119245, intercept_stderr=0.3938246096463683)
'''
上記の通り、結果の構成内容がnumpyの場合とは異なるため、可視化するためには、予測値を切片と係数、説明変数から計算してやる必要がある。
fig = df.iplot(kind='scatter',x='US10Y' ,y='JPY=', asFigure=True,
theme='solar', mode='markers', size=4.0)
fig.add_scatter(x=df['US10Y'],y=model.intercept+model2.slope*df['US10Y'],name='result')
# y値の計算と結果のプロット
fig.show()
↓
Scikit-learnを使った線形回帰
次にscikit-learnを使った線形回帰の実行と可視化方法です。
from sklearn.linear_model import LinearRegression
model = LinearRegression().fit(df[['US10Y']],df['JPY='])
result = model.predict(df[['US10Y']])
fig = df.iplot(kind='scatter',x='US10Y',y='JPY=',
asFigure=True,theme='solar',mode='markers',size=4.0)
fig.add_scatter(x=df['US10Y'],y=result,name='result')
fig.show()
↓
-
前の記事
ISM PMI、NMIのサブインデックスについて 2021.06.08
-
次の記事
東京都在住の独身20代後半の家計状況(2021/6) 2021.06.12