ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Week #3 확률적 경사 하강법
    쿠다 4기/<혼공> 머신러닝& 딥러닝 2023. 8. 15. 11:42

    생선마켓에 수산물을 공급하겠다는 곳이 너무 많아 훈련 데이터의 샘플을 어떻게 골라낼지에 대한 고민에  빠졌다.

    1. 기존 훈련 데이터에 새로운 데이터를 추가해 매일 다시 훈련할까?
    2. 새로운 데이터를 추가할 때 이전 데이터를 버려 훈련 데이터 크기를 유지할까?
    3. 훈련한 모델을 버리지 않고 새로운 데이터에 대해서만 조금씩 더 훈련할 수는 없을까?


    1,2 번의 방법은 여러 문제점들이 있고, 3번이 적합한 방법인 것 같다. 3번과 같은 훈련방식을 점진적 학습 또는 온라인 학습이라고 부른다. 대표적으로 확률적 경사 하강법이 있다.

     

    확률적 경사 하강법

    먼저 경사하강법에 대해 알아보자.

    그림1

    경사하강법은 가장 가파른 경사를 따라 원하는 지점에 도달하는 것이 목표이다. 단, 조금씩 내려와야 한다. 그 방법은 훈련세트에서 하나의 샘플을 랜덤하게 골라 가장 가파른 길을 찾는 것이다. 이렇게 훈련 세트에서 랜덤하게 하나의 샘플을 고르는 것이 확률적 경사하강법이다. 하나의 샘플을 골라 가파른 경사를 조금 내려가고, 또 다른 샘플을 랜덤하게 골라 조금 내려가는 것을 전체 샘플을 모두 사용할 때 까지 반복한다. 모든 샘플을 다 사용했지만 다 내려오지 못한 경우 다시 모든 샘플을 이용해 처음부터 시작한다. 이렇게 확률적 경사하강법에서 훈련 세트를 한 번 모두 사용하는 과정을 에포크라고 한다.

     

    미니배치 경사 하강법: 여러 개의 샘플을 이용해 경사 하강법을 수행하는 방식

    배치 경사하강법: 한 번 경사로를 따라 이동하기 위해 전체 샘플 사용하는 방식

     

    이때 배치 경사하강법의 경우 가장 안정적이지만 그만큼 컴퓨터 자원을 많이 사용한다는 단점이 있다. 위 내용을 그림으로 정리하면 아래와 같다.

    그림2

    그런데 그림1에 나타난 그래프, f(x)가 의미하는 것은 뭘까?

    이 f(x)는 손실 함수(비용 함수)라고 부른다. 손실함수는 머신러닝 알고리즘이 얼마나 틀렸는가를 나타내므로 값이 작을 수록 좋다. 따라서 값이 작은 쪽으로 내려가도록 학습하는 것이다. 

     

    손실 함수의 종류

    1. 로지스틱 손실 함수(이진 크로스엔트로피 손실함수):  양성 클래스(타깃 = 1)일 때 -log(예측확률), 음성 클래스(타깃 = 0)일 때 손실은 -log(1-예측 확률)로 계산한다.
    2. 크로스엔트로피 손실함수: 다중 분류에서 이용한다.

     

    확률적 경사 하강법을 사용한 분류모델 만들기

    저번 포스팅에서 한 것과 같은 방법으로 데이터를 준비한다.

    확률적 경사 하강법을 제공하는 분류용 클래스인 SGDClassifier을 이용해 10회 반복해 훈련한 후 테스트 세트와 훈련 세트에서 정확도 점수를 출력한다. 정확도가 매우 낮은 것을 보아, 반복 횟수가 부족한 것 같다.

    1에포크씩 이어서 훈련할 수 있는 partial_fit()메소드를 호출하고 점수를 확인했더니 조금 향상되었다. 

    무작정 반복 횟수를 늘리면 되는 것일까?
    그 기준이 있을까?

     

    에포크 횟수가 적으면 훈련 세트를 덜 학습하기 때문에 과소적합될 것이다. 반대로 너무 많은 에포크 횟수 동안 훈련한 모델은 과대적합이 될 수 있다.

     

    그림3

    그림3은 에포크가 진행됨에 따른 모델의 정확도를 나타낸 것이다. 훈련 세트는 꾸준히 증가하지만 테스트 세트는 증가하다가 어느 시점에서 부터 감소한다. 그 지점이 과대적합되기 시작하는 곳이다. 따라서 과대적합이 시작하기 전에 훈련을 멈추는 것을 조기 종료라고 한다.

    이 그래프를 직접 시각화해보겠다.

    잘 드러나지는 않지만, 에포크 수 100번이 넘어가면 훈련세트와 테스트 세트의 점수가 벌어짐을 확인할 수 있다. 그 전은 전체적으로 점수가 낮은 과소적합을 보인다.

    괜찮은 점수가 나왔다.

Designed by Tistory.