본문 바로가기

머신러닝, 딥러닝

단순선형회귀 : 설명이 안되는 범위가 너무 많은데요?

처음 단순 선형 회귀(Simple Linear Regression)를 배우면서, 집 면적에 따른 집값을 예측하는 선을 

다음과 같은 코드로 그려보며 많이 부족하다고 느꼈다.

from sklearn.linear_model import LinearRegression # 1. choose model class
model  = LinearRegression() # 2. instantiate model
feature = ['feature_A']
target = ['feature_A']

#타겟 벡터 만들기
X_train = df[feature_B]
y_train = df[feature_B]
model.fit(X_train, y_train) # 3. fit the model to your data

# 전체 데이터를 모델을 통해 예측하기
X_test = [[x] for x in select_df['']]
y_pred = model.predict(X_test)

 

Train, Test 데이터를 구할 때 

X_train = df[[feature_B]] 와 같은 열벡터 형식을 왜 사용하는지 궁금하다면? 

https://aimb.tistory.com/135

 

선형회귀에서 2차원 array를 사용하는 이유

from sklearn.linear_model import LinearRegression # 1. choose model class model = LinearRegression() # 2. instantiate model feature = ['feature_A'] target = ['feature_A'] #타겟 벡터 만들기 X_train =..

aimb.tistory.com


집값을 예측하는 선의 기울기가 생각보다 낮게 나왔다. 

저 예측 선은 (예측값-관측값)을 제곱해주어 그 합을 최소화하는 선이지만, 

이 직선 하나만으론 설명이 안되는 범위가 너무 많게 보였다.(예측 모델 그래프 기준 위쪽 데이터들)

변수를 더 주어 예측 확률을 업데이트하거나 하는 식으로 커버 범위나 정확성을 더 높일 필요가 있어 보였고,

왜 굳이 직선 그래프로 해야 하는지 납득하기 어려웠다.

 

왜 기울기가 낮나 했더니, 기울기의 경사가 1보다 낮다는 건 수학적으로 "평균으로 회귀한다" 라는 의미라고 한다.

그래서 예측선이 회귀선이라 불린다. (최소제곱법과 회귀법도 어느 정도 동의어라고 한다.)

출처 

 

나중에 모델을 택한다고 해도 단순선형회귀는 돌릴 일 없지 않을까? 했는데

https://blog.insightdatascience.com/always-start-with-a-stupid-model-no-exceptions-3a22314b9aaa

 

Always start with a stupid model, no exceptions.

How to efficiently build Machine Learning powered products.

blog.insightdatascience.com

다음의 글을 읽어보고 왜 단순선형회귀를 제일 먼저 가르치는지 납득이 되었다.

예측선을 만드는 단순선형회귀 모델은

  • 복잡한 모델 대비 10%의 시간을 사용해 결과의 최대 90%까지 예측할 수 있다.
  • 모델을 빠르게 교육해서 성능에 대한 피드백을 빠르게 제공할 수 있다.
  • 최소한의 기능을 구현하는데 테스트하기 적합하다.
  • 이렇게 뽑아낸 예측 기준선이 어디에서 문제가 되는지 살펴보면, 다음엔 어떤 모델을 써야 할지, 복잡한 모델을 사용할 때도 어떤 부분에 신경 써야 할지 알 수 있다. (데이터를 더 잘 이해할 수 있다. 어떤 방향으로 모델을 다듬어야 더 데이터를 잘 다룰 수 있을지 인사이트를 얻을 수 있다.)

 

Plotly로 단순선형회귀 예측선 구현하기

import plotly.express as px
px.scatter(select_df, x = '', y = '', opacity=0.65, 
           trendline='ols', trendline_color_override='red')

trendline을 추가하면 된다. 색깔도 바꿀 수 있다.