728x90

이번 포스팅에서는 주로 weight 변화에 따라서 bias에 어떠한 영향을 미치는지를 중점적으로 살펴보도록 한다.

weight, bias가 3,3인 것으로 부터 mu, sigma의 변화에 따른 영향을 살펴본다.

 

*표준편차가 1이하 일 때

(0, std)를 따를 때, std가 0에 가까울수록 뽑히는 x의 값이 0에 가까운 값이 뽑히게 될 것이다. 예를 들어 x의 값이 0.1, 0.5처럼 1보다 작은 값들이 뽑히게 된다는 것이다. 저번 포스팅에서 배웠던 것처럼 x의 값이 1보다 작게 되면 weight가 학습되는 양 자체가 적어진다. 그리고 std =1 일때는 th1,th0이 균등하게 대각선 방향으로 학습이 된다. 위 왼쪽 그림의 보락색이 이에 해당 된다.

 

th1의 loss grdient 양은 -2*x*lr(y-pred)이 된다. 이때 입력되는 x 값이 소수점이라면 gradient의 양 자체가 감소하게 된다. 그래서 th1의 학습속도가 느려지게 된다. 즉, step by step으로 학습이 매번 이뤄질때마다 학습 되는 양이 적다는 얘기다.

 

왼쪽 그림의 보락색은 초반에 띄엄 띄엄 학습되는 걸 볼수 있다. 반면 파란색은 매 학습 step마다 촘촘히 학습되고 있다. 이는 곧 "학습되는 양" 자체가 적어서 업데이트도 적게 이뤄진다는 걸 알수 있다.

오른쪽 th0의 학습 그래프를 보면, 학습 되는 속도가 std 변화에 따라서 크게 영향을 끼치지 않는다. std가 작아지든 커지든 th0의 학습에는 영향을 주지 않는 것인가? th0의 gradient를 보면 -2*lr*(y-pred) , x가 빠져 있다. 즉 x에 영향을 받지 않는 다는 것이다. 왼쪽 그림은 th1이 혼자 느려질때의 th1,th0의 그래프이고, th0만 따로 떼어 보면 오른쪽 처럼 균일한 학습속도를 보이게 되는 것이다.

 

 

*표준편차가 1이상 일때

반대로 std가 1보다 큰 값이 들어오게 된다면, th1의 gradient 값이 th0보다 많아 지기 때문에 왼쪽 그래프 처럼 th1에 지배적인 학습이 먼저 이뤄진다. gradient가 많아지게 되면 발생하는 대표적인 문제점이 하나 있다. 바로 발산이다. 그리고 발산하기 직전에는 지그재그로 왔다갔다 하게 된다. 왼쪽 그림에서 std가 가장 높은 점들을 보면 지그재그로 왔다갔다 학습이 이뤄지고 있다. 지그잭로 학습 되기 때문에 당연히 전반적인 학습 속도가 느리게 된다.

 

오른쪽 그림의 th0은 마찬가지로 gradient에서 x가 영향을 끼치지 않으므로 gradient의 양 변동이 크지 않다. 오직 th1의 gradient가 커졌다 작아졌다 할뿐이다.

가장 이성적인 학습속도는 어떤 것인가?? 바로 th1,th0이 "동시에" 학습이 완료되는 것이다. 마치 std=1일때 처럼 말이다.

 

정리를 하자면(std의 변화에 따른 결론이다)

1. std의 값이 1이하 일 때

-> th1의 gradient 양이 "작아서" 학습 속도가 느려진다.

2. std의 값이 1 이상 일 때

-> th1의 gradient 양이 커져서 "지그재그로 학습되어" 속도가 느려진다.

 

 

*mean 변화량에 따른 학습 영향

mean값이 커질수록 loss function은 수직으로 세워지게 된다.(앞서 학습 내용과 겹치므로 자세한 내용 생략) 따라서 mean이 높을수록 더욱 지그재그로 학습이 많이 이뤄진다. mean 0일 때와 5일 때를 비교해보자. th0의 학습 속도에 차이가 난다.

 

왜일까? mean =5이기 때문에 지그재그로 더 많이 학습이 되기 때문이다. 쉽게 말하자면 빨리 갈 수 있는 길을 내버려두고 돌고 돌아서 가기 때문에 목표지점에 도달하는 시간이 더 길어지는 것이다.

그런데, 오른쪽 그림을 보면 아까와는 다른 그래프 모양을 가진다. 왜일까? "지그재그"때문이다. 지그재그 할 때는 증감을 반복하기 때문이다. th0 입장에서 왼쪽 하늘색을 보면 증감의 반복을 볼수있다. 이런 현상이 왼쪽 그래프 반영된 것이다.

이러한 현상때문에 bias은 mean에 대해서 아주 큰 영향을 받게 된다. 따라서 input값들을 평균 0으로 맞춰 주는게 중요하다. 0으로 맞추면 std가 크든 작든 좌우 대칭으로 dataset이 뽑히기 때문에 위와 같은 악영향을 줄일수 있다.

 

이해를 돕기 위해서 위 학습 내용을 벡터로 표현해본다. 단, 그래프 표현을 위해서 벡터의 크기는 1로 통일한다 (gradient의 학습 양을 벡터로 그대로 표현하면 그래프로 보기가 힘들기 때문이다)

 

 

*벡터로 표현한 학습

오른쪽 벡터의 중심점이 본래 가지고 있던 점이다. 즉, 출발점이다. 벡터가 가리키는 것이 학습의 방향성이다. 왼쪽의 학습 방향을 벡터로 표현한 것이 오른쪽이다.

 

 

 

*std를 높힌 경우

std를 높힌다는 것은 x의 절대값이 커진다는 것이고, losss function contour plot이 y축으로 기울어 진다는 것이다.

오른쪽 그림를 보면 양쪽 좌우에 벡터가 뭉쳐 있다. 반면 상대적으로 가운데 벡터는 적다. 왜그럴까? std가 커진다는건 x의 절대값이 커진다는 것이다. 예를들어 -3,-5,3,4 이러한 x가 반영된다. 1보다 큰 x값들은 y축으로 기울진 즉, 수직에 가까운 loss function을 갖게 된다. 이러한 loss function에 projection되는 방향으로 학습이 되기 때문에 벡터 좌우에 몰리게 되고, 가운데 분포는 적게 된다. 좌우로 쏠린 벡터들이 서로 상쇄가 되면서 학습된다. 즉, 평균이 0이고, std만 달라져도 서로 대칭되는 값들이 뽑히기 때문에 서로 상쇄가 되어서 학습에 큰 영향을 주지 않게 된다. 진짜 문제는 평균이 0이 아닐때 발생하게 되는것이다.

 

참고로 가운데 벡터는 x축과 평행에 가까운 loss function에 projection되는 방향이다. 즉, x의 값이 1보다 작은 값들이 나올 때 이러한 벡터 방향이 나온다.

 

*mean를 높인 경우

mean을 키우게 되면, 전반적인 x의 값들을 키운 것과 같다. mean을 키운 순간 대각선 방향으로의 th1,th0 학습을 하지 못하게 된다. 어느 한쪽으로 쏠리게 되는데, 이 경우는 양수에 치우친 x값들이 생기게 된다. 따라서 왼쪽 하단에 분포가 많이 생기게 된다.

 

또한 (1,1)이므로 0에 가까운 loss function도 많이 만들어질 것이다. 그래서 오른쪽 상단 벡터에도 분포가 많이 생기게 된다. 즉, 왼쪽 오른쪽 분포가 많으므로 지그재그 형태를 띠게 된다.

mean더 키우게 된다는 것은 y축과 더욱 평행한 loss function이 만들어진다는 것이며, 사실상 음수 loss function이 만들어지지 않는다고도 볼 수 있다.

지금까지는 데이터 하나를 가지고 loss function의 변화를 봤다면, 지금부터는 여러 개의 데이터 샘플을 사용한 cost의 변화를 보도록 하자.

가장 이상적인 모습을 보인 건 전체적인 데이터 특성을 제일 잘 나타낸 배치 사이즈 32이다. 배치 사이즈가 커지면 아웃라이어의 영향을 줄일 수 있다. 그래서 배치 2를 보면 다른 것에 비해 불규칙적으로 움직인 걸 볼 수 있다.

 

 

*std 변화에 따른 영향

std =5 일 때의 그래프이다. 배치가 2인 파란색이 확실히 변동이 있는 반면, 배치가 32인 빨간색은 안정적으로 학습됨을 볼 수 있다.

 

 

*mean 변화에 따른 영향

mean 자체가 커지면서 지그재그 파동이 많아짐을 볼 수 있다. cost라도 mean이 커지면 지그재그 현상이 없어지지 않는다.

 

 

* std 변화에 따라서 벡터의 모양을 비교해 보록 하자

배치가 2일 땐 데이터 샘플 하나를 이용했을 때와 별 차이점이 없다. 그래서 좌우로 지그재그로 학습이 되기 때문에 벡터에도 좌우에 많은 분포를 가진다.

배치가 커질수록 벡터의 얖옆이 줄어들면서 가운데에 분포를 하게 된다. 다시 말해서 좌우로 가는 벡터가 적어짐에 따라 지그재그 현상이 감소하게 된다.

 

 

* mean변화에 따라서 벡터의 모양을 비교해 보록 하자

mean의 변화에 따른 벡터의 모양을 보면 비슷함을 알 수 있다. 특히 좌우 벡터의 양이 std 변화에 따른 벡터처럼 줄어들지가 않는다. 왜일까?

 

mean 자체를 증가시켰기 때문에 std=1이든 2이든 어찌 됐든 x가 큰 값이 input으로 들어가게 된다. mean이 커지면 더 이상 표준 정규분포가 아니게 된다. 위 그래프의 학습 cost function 모양이 대략  \일 것이다.

 

아무리 32개를 뽑고 cost를 계산해봐도 \과 비슷한 cost function을 만들게 된다. 또한 x의 값이 커지면 발산의 우려가 있다.  그 전초기 증상이 바로 지그재그이다. 이를 막기 위해서 lr의 조정이 필요한 것이다.

728x90

+ Recent posts