본문 바로가기

Study doc./Deep Learning

[cs231n] 오차역전파(propagation)를 이용한 gradient 계산

지난 포스팅에서 최적화의 방법으로 두 가지 개념이 나왔습니다.

바로 numerical gradient와 analytic gradient 인데요, numerical gradient 방법은 하나하나 다 계산하기 때문에 정확하지만 시간이 너무 오래걸린다는 단점이 있다고 했습니다.

 

그래서 보통 w를 구할때 analytic gradient 방법을 사용한다고 공부했었는데,

오늘은 analytic gradient 방법의 원리에 대해서 자세히 공부해보겠습니다.

 

지난 포스팅에서 그냥 넘어갔던 중간 과정

 

 


 

순서

  1. computational graph
  2. 스칼라의 backpropagation
    1. backpropagation의 이해
    2. local gradient/ global gradient(upstream gradient)
  3. 벡터의 backpropagation

 


 

 

1. computational graph

gradient의 계산 순서를 명확하게 파악하기 위해서는 computational graph에 대한 이해가 선행되어야 합니다.

computational graph는 하나의 식으로부터 숫자와 연산자를 모두 분리한 도표인데요, 

지난 포스팅에서 정의했던 오차함수를 computational graph로 표현하면 다음과 같습니다.

 

 

먼저 input인 x와 가중치 행렬인 w를 곱해서 점수(score)을 냅니다.

그 점수를 기반으로 SVM 오차함수(hinge loss)를 이용하여 loss값(Li)을 도출했고, 여기에 규제식(regularization)을 더하여 최종 오차함수식을 정의했습니다.

이렇게 그려진 computational graph를 이용하여 차근차근 gradient 값을 구해나갈 것입니다.

어떻게요? 오차역전파(backpropagation) 방법을 통해서요!!

 

 

 

2. 스칼라의 backpropagation

2-1. backpropatation의 이해

일단 이해를 돕기 위해 쉬운 예시로 시작하겠습니다.

다음과 같은 함수를 정의하고 compuatational graph로 그렸습니다.

그리고 x, y, z 값들을 다음과 같은 값들로 가정했습니다.

 

우리가 여기서 구하려고 하는게 뭐죠?

dW값, 즉 gradient값, 즉 미분값, 즉 각 변수들이 함수 f 에 미치는 영향 정도로 생각할 수 있습니다.

따라서 df/dx(변수 x가 함수 f에 미치는 영향), df/dy, df/dz 를 구해야 합니다. 

이게 곧 dW 값들이 되고, 여기에 step size를 곱해 기존의 W에서 빼주면 새로운 W 후보가 생성되는 것이죠.

 

아무튼 다시 본론으로 돌아와서, 그럼 df/dx를 어떻게 구할수 있을까요?

여기서 드디어 backpropagation을 사용하게 됩니다.

 

 

가장 마지막 gradient 값은 본인스스로를 미분하기 때문에 항상 1 입니다.

그럼 변수 z가 함수 f에 미치는 영향(df/dz)는 얼마인가요?

f=qz 를 z에 대해서 미분하게 되면 df/dz 는 q 값(3)이라는 것을 알게 될 것입니다.

 

그렇다면 변수 y가 함수 f에 미치는 영향(df/dy)는 얼마일까요?

이건 위의 식에서 찾아볼 수 없습니다. 하지만 q와 f가 연결되어 있다는 점을 이용하면 됩니다.

이렇게 말이죠!!

 

이런 방법을 chain rule 이라고 합니다. 서로 연결된 특성을 이용해서 값을 도출하는 것입니다.

같은 방법으로 모든 gradient를 구해보면 아래 빨간색 숫자와 같은 값이 나옵니다.

 

 

이렇게 gradient를 계산할때 뒤에서부터 앞으로 전파(propagate)되었기 때문에 backpropagation 이라고 합니다.

 

 

 

2-2. local gradient/ global gradient(upstream gradient)

backpropagation을 하기 전에, 기존의 수식을 통해 구할 수 있는 gradient를 local gradient라고 합니다.

그리고 뒤에서부터 앞으로 넘어오는 gradient를 global gradient 혹은 upstream gradient라고 합니다.

이 local gradient와 global gradient를 곱하는 방식으로 새로운 gradient 값을 찾게 되는데, 이것은 chain rule에서 명칭만 지정해준 것입니다.

 

예를 들어 df/dq를 구하기 위해서는 df/df 와 df/dq를 곱해야 했고,

df/dq가 local gradient, df/df 가 global gradient 였습니다.

 

두 번째로 df/dx를 구하기 위해서는 df/dq와 dq/dx를 곱해야 하고,

dq/dx가 local gradient, df/dq 가 global gradient 입니다.

 

여기서 알 수 있듯, 뒤에서 앞으로 나아갈때 local gradient가 global gradient로 변합니다.

그리고 다시 local gradient를 만나 곱하게 되고,, 계속 반복합니다.

이런 방식으로 dW 행렬을 numerical gradient 방식보다 빠르게 찾을 수 있었습니다.

 

 

+참고)

다음과 같이 앞에서 온 gradient가 여러개라면 어떻게 할까요?? 그냥 더해주면 됩니다!!

 

 

3. 벡터의 backpropagation

지금까지는 이해를 위해 스칼라 형태의 input으로 가정했습니다.

하지만 실전에서는 대부분 벡터 형태(다변수)를 input으로 받게 되기 때문에 공부한 내용을 좀 더 심화시켜 보겠습니다.

 

위 식은 W와 x를 곱한 후, 모든 원소를 제곱해서 더하는 함수입니다.

W 와 x 행렬값은 가정했고, computational graph가 그려진 모습입니다.

여기서 우리가 알아야 할 정보는 W 행렬과 x 행렬의 각 원소별 gradient, 즉 각 원소가 함수 f 에 미친 영향입니다.

 

첫 gradient는 자기 자신을 미분한 값이기 때문에 항상 1 이라고 했습니다.

다음으로 행렬 q의 각 원소가 함수 f 에 미친 영향을 계산해봅시다. 

 

qi^2을 미분하면 2qi가 되기 때문에 df/dq 는 2q라는 사실을 알게되었고(local gradient), global gradient는 1이기 때문에 이 둘을 곱한 값인 2q, 즉 0.44와 0.52가 함수 q의 gradient가 됩니다.

이제 W 행렬의 각 원소들이 함수 f에 미치는 영향을 파악해야 합니다.

 

스칼라라고 가정했을때 W의 local gradient는 x 값입니다.

즉 W의 gradient는 x 값을 이용해서 계산됩니다.

 

위 그림의 오른쪽 아래를 보시면 W의 i행 j열 원소가 f에 미치는 영향을 계산한 식이 있습니다.

이 중 k=i 일때 1이고, k != i 이면 0인 조건이 있는데, 이는 각 행의 계산을 독립적으로 해주기 위함입니다.

q의 첫 행은 W의 첫 행과만 연관이 있기 때문입니다.

 

식을 벡터화 한다면 x를 전치시킨 후 global gradient 값인 2q 행렬을 곱해주는게 됩니다.

여기서 중요한 점은 변수 행렬과 gradient행렬의 shape는 항상 동일해야 한다는 점입니다(x를 전치시킨 이유, 행렬 연산의 특징).

같은 방법으로 x 행렬의 원소가 함수 f에 미치는 영향을 계산해보면 다음과 같습니다.