[정리] Optimizing Neural Networks with Kronecker-factored Approximate Curvature (JMLR, 2015)

Author: James Martens∗ and Roger Grosse Paper Link: https://arxiv.org/abs/1503.05671

요약

0 Abstract

  • 뉴랄넷의 자연 경사 하강을 위해 Kronecker product 로 Fisher 행렬을 근사하는 K-FAC 을 제안하며, 이는 대각 또는 low rank 행렬 근사와는 다름
  • Fisher 행렬의 대각 블락을 근사하며, 더 작은 두 개의 행렬들의 Kronecker product 로 각 블락을 표현하게 됨
  • 예닐곱 배의 계산량이 추가되지만, 확률적 경사 하강 방법보다 최적화에 성능이 높아 수렴이 더 빠름
  • Hessian-free 같은 full matrix 근사 자연 경사하강 또는 Newton 방법에 비해 확률적 방법에 적합한데, 근사 행렬 및 역행렬 계산이 입력 데이터 수에 무관하고 저장하는 행렬이 작기 때문

1 Introduction

  • Hessian-free (HF) 와 같은 local curvature 를 고려한 방법들은 빠르게 학습이 진행되지만 (1) conjugate gradient (CG) 과정에 의한 계산량이 증가하고 (2) 적은 수의 입력만이 허용될 수 있음
  • HF 에서 사용되는 CG 로 인해 때로 확률적 경사 하강법 (SGD) 대비 이점이 적어지므로, CG 를 사용하지 않은 2차 최적화 방법이 필요함
  • Parameter 들은 각 레이어에 해당하는 그룹으로 나뉘고, Fisher 는 Kronecker 블락들로 근사되는데 이 과정은 gradient 에 대한 특정 가정 과 같음
  • 블락 대각 행렬 또는 띠행렬이 역행렬이 가짐을 가정하며, 분산 행렬의 역, 트리 구조의 그래프 모델, 그리고 선형 회귀와의 관계를 통해 근사를 검증할 것

2 Background and notation

2.1 Neural Networks

  • $i$-th 선형 모듈 레이어 및 활성 함수: $s_i=W\bar{a}_{i-1}, a_i=\phi{}_i(s_i)$
  • 벡터화된 모든 Parameters: $\theta{}=[\mathsf{vec}(W_1)^\text{T}\, \mathsf{vec}(W_2)^\text{T}\, …\, \mathsf{vec}(W_l)^\text{T}]$
  • 뉴랄넷 출럭 $z$ 와 타겟 $y$ 에 대한 손실함수: $L(y, z)=-\log{r(y z)}$ (assumed)
  • Partial gradient operator: $\mathcal{D}v=-\frac{\text{d}\log{p(y x,\theta)}}{\text{d}v}$
  • 레이어 출력 $s_i$ 의 gradient: $g_i=\mathcal{D}s_i$

$l$ linear layers 의 forward/backward path

  • Forward path: $s_i=W_i\bar{a}_{i-1},a_i=\phi{}(s_i)$
  • Loss derivative: $\mathcal{D}a_l=\frac{\partial{L(y,z)}}{\partial{z}} _{z=a_l}$
  • Backward path: \(\begin{aligned} \mathcal{D}a_i&=\mathcal{D}s_i\odot{}\phi{}'(s_i) \\ \mathcal{D}W_i&=g_i\bar{a}_{i-1}^\text{T} \\ \mathcal{D}a_{i-1}&=W_i^\text{T}g_i \end{aligned}\)

2.2 Natural gradient

  • Fisher 행렬은 아래와 같이 정의되며, 데이터 분포 $Q_x$ 와 학습된 모델의 분포 $p(y|x,\theta)$ 에 대한 기댓값이나, 학습 데이터에 대한 분포 $\hat{Q}_x$ 를 이용하여 계산함 \(F=\mathbf{E}[\frac{\text{d}\log{p(y|x,\theta)}}{\text{d}\theta}\frac{\text{d}\log{p(y|x,\theta)}}{\text{d}\theta}^\text{T}]\)
  • 자연 경사 하강은 정해진 KL-divergence 변화 기준 목표 함수를 최대화하는 gradient 를 말하며, 기본 경사 하강의 경우 Euclidean norm 변하를 기준으로 함
  • Fisher 는 $p(y x,\theta)$ 이 exponential family 일 때 Hessian 행렬의 positive semi-definite (PSD) 근사인 Gauss-Newton 행렬 (GGN) 과 같음

3 A Block-wise Kronecker-factored Fisher Approximation

Figure 2 Figure 2. MNIST 숫자 인식 문제에서 중간 4 개 레이어의 완전한 Fisher $F$, 블락 근사 $\tilde{F}$, 그리고 $F$ 와 $\tilde{F}$ 의 차이

  • $l$ 레이어들을 가진 뉴랄넷의 Fisher 는 $l$-by-$l$ 블락 행렬로 구성되며, $(i,j)$-th 블락 $F_{i,j}$ 은 아래와 같으며, 선형 모듈의 경우 입력 $\bar{a}$ 과 레이어의 gradient $g$ 로 표현 가능함 \(\begin{aligned} F_{i,j}&=\mathbf{E}[\mathsf{vec}(\mathcal{D}W_i)\mathsf{vec}(\mathcal{D}W_j)^\text{T}] \\ &=\mathbf{E}[\mathsf{vec}(g_i\bar{a}_{i-1}^\text{T})\mathsf{vec}(g_j\bar{a}_{j-1}^\text{T})^\text{T}] \\ &=\mathbf{E}[(\bar{a}_{i-1}\otimes{}g_i)(\bar{a}_{j-1}^\text{T}\otimes{}g_j^{\text{T}})] \\ &=\mathbf{E}[\bar{a}_{i-1}\bar{a}_{j-1}^\text{T}\otimes{}g_ig_j^\text{T}] \end{aligned}\)
  • 첫 번째로 각 $\bar{a}{i-1}\bar{a}{j-1}^\text{T}, g_ig_j^\text{T}$ 에 대한 Kronecker-product $\tilde{F}$ 로 $F$ 를 근사함 (Khatri-Rao 곱이 됨) \(\begin{aligned} F_{i,j}&=\mathbf{E}[\bar{a}_{i-1}\bar{a}_{j-1}^\text{T}\otimes{}g_ig_j^\text{T}] \\ &\approx{}\mathbf{E}[\bar{a}_{i-1}\bar{a}_{j-1}^\text{T}]\otimes \mathbf{E}[g_ig_j^\text{T}] \\ &=\bar{A}_{i,j}\otimes{}G_{i,j} \\ &=\tilde{F}_{i,j} \end{aligned}\)
  • $\tilde{F}$ 는 주요 근사로 현실적인 가정이나 극한 상황 아래 Fisher 로 수렴하기 어려우나, 사용 시에는 거친 구조 (coarse structure) 를 반영함 (Fig. 2)

3.1 Interpretations of this Approximation

  • 위 근사는 $\bar{a}^{(1)}\bar{a}^{(2)}$ 과 $g^{(1)}g^{(2)}$ 사이 통계적 독립을 가정하는 것임 \(\mathcal{D}[W_i]_{k_1,k_2}=\bar{a}^{(1)}g^{(1)}, \\ \mathcal{D}[W_i]_{k_3,k_4}=\bar{a}^{(2)}g^{(2)}, \\ \bar{a}^{(1)}=[\bar{a}_{i-1}]_{k_1},\,g^{(1)}=[g_i]_{k_2}, \\ \bar{a}^{(2)}=[\bar{a}_{j-1}]_{k_3},\,g^{(2)}=[g_j]_{k_4}\) \(\begin{aligned} \mathbf{E}[\mathcal{D}[W_i]_{k_1,k_2}\mathcal{D}[W_j]_{k_3,k_4}]&=\mathbf{E}[(\bar{a}^{(1)}g^{(1)})(\bar{a}^{(2)}g^{(2)})] \\ &=\mathbf{E}[\bar{a}^{(1)}\bar{a}^{(2)}g^{(1)}g^{(2)}] \\ &\approx{}\mathbf{E}[\bar{a}^{(1)}\bar{a}^{(2)}]\mathbf{E}[g^{(1)}g^{(2)}] \end{aligned}\)
  • 근사 오차는 culmulant $\kappa{(\bullet{})}$ 을 이용해 아래같이 표현되며, culmulant 는 평균과 분산의 고차원 일반화임 \(\kappa{(\bar{a}^{(1)},\bar{a}^{(2)},g^{(1)},g^{(2)})}+ \mathbf{E}[\bar{a}^{(1)}]\kappa{(\bar{a}^{(2)},g^{(1)},g^{(2)})}+ \mathbf{E}[\bar{a}^{(2)}]\kappa{(\bar{a}^{(1)},g^{(1)},g^{(2)})}\)
  • 다변수 정규분포일 때 culmulant 는 0 이므로, $(\bar{a}^{(1)},\bar{a}^{(2)},g^{(1)},g^{(2)})$ 의 분포가 이에 가까울 수록 근사 오차가 적음

4 Additional approximations to $\tilde{F}$ and inverse computations

$\tilde{F}$ 의 역행렬을 계산하기 위해 두 특별한 구조로 근사할 것인데, 두 번째 방법은 덜 제한적이지만 복잡도가 높음

4.1 Structured inverses and connection to linear regression

Figure 3 Figure 3. Fig.2 의 $\hat{F}$, $\hat{F}^{-1}$ 그리고 블락 평균
$\tilde{F}$ 과 달리 $\tilde{F}^{-1}$ 은 대각 또는 띠 대각 블락 행렬에 가까움을 확인할 수 있음

  • 분산 행렬이 $\Sigma{}$ 인 분포에 대해, $i$-th 변수의 선형 회귀를 위한 $j$-th 변수의 상수를 $[B]{i,j}$ , $i$-th 변수의 선형 회귀 오차의 분산을 $[D]{i,i}$ 라 했을 때, 이들은 아래같이 $\Sigma{}^{-1}$ 로 표현됨 ($[B]_{i,i}=0$) \([B]_{i,j}=-\frac{[\Sigma{}^{-1}]_{i,j}}{[\Sigma{}^{-1}]_{i,i}} \quad{}\text{and}\quad [D]_{i,i}=\frac{1}{[\Sigma{}^{-1}]_{i,i}}\)
  • Precision 행렬 $\Sigma{}^{-1}$ 역시 $B, D$ 로 표현되는데, 직관적으로 $i$-th 변수 예측에 $j$-th 변수가 유용할수록 큰 $[\Sigma{}^{-1}]_{i,j}$ 값을 가짐 \(\Sigma{}^{-1}=D^{-1}(I-B)\)
  • $F$ 는 $\mathcal{D}\theta{}$ 의 분산 행렬이며, $F^{-1}$ 의 성분을 선형 회귀 예측의 상수로 봤을 때 대각 성분이 상대적으로 큼을 의미하고, 따라서 대각 블락 근사가 유용함
  • $\mathcal{D}W_i$ 와 함께 앞뒤 레이어의 $\mathcal{D}W_{i-1}, \mathcal{D}W_{i+1}$ 도 함께 고려하여 덜 제한적인 띠 대각 행렬로 근사할 수도 있음

4.2 Approximating $\hat{F}^{-1}$ as block-diagonal

  • $\tilde{F}$ 를 대각 블락 행렬 $\check{F}$ 로 근사함으로써 $\tilde{F}^{-1}$ 를 대각 블락으로 근사함 \(\begin{aligned} \check{F}&=\text{diag}(\tilde{F}_{1,1},\tilde{F}_{2,2},...,\tilde{F}_{l,l}) \\ &=\text{diag}(\bar{A}_{0,0}\otimes{}G_{1,1},\bar{A}_{1,1}\otimes{}G_{2,2},...,\bar{A}_{l-1,l-1}\otimes{}G_{l,l}) \end{aligned}\)
  • Kronecker product identity 를 이용하면 2$l$ 개 역행렬을 계산해여 $\hat{F}^{-1}$ 를 계산할 수 있음 \(\check{F}^{-1}=\text{diag}(\bar{A}_{0,0}^{-1}\otimes{}G_{1,1}^{-1}, \bar{A}_{1,1}^{-1}\otimes{}G_{2,2}^{-1},..., \bar{A}_{l-1,l-1}^{-1}\otimes{}G_{l,l}^{-1})\)

4.3 Approximating $\hat{F}^{-1}$ as block-triagonal

Figure 4 Figure 4. $\hat{F}^{-1}$ 의 띠 블락 근사와 동등한 UGGM, DGGM

  • $\hat{F}^{-1}$ 를 띠 블락으로 근사하는 것은 $\mathcal{D}\theta{}$ 에 대해 Fig 4 와 같은 undirected Guassian graphical model, UGGM 을 가정하는 것과 같음
  • 위 UGGM 은 Fig 4 아래의 동등한 directed Gaussian graphical model, DGGM 으로 바꿔 표현할 수 있음
  • DGGM 가정 아래 $\mathsf{vec}(\mathcal{D}W_i)$ 는 다음 분포를 따름 \(\mathsf{vec}(\mathcal{D}W_i)\sim{}\mathcal{N}(\Psi{}_{i,i+1}\mathsf{vec}(\mathcal{D}W_{i+1}),\Sigma{}_{i|i+1}) \\ \text{and}\quad{}\mathsf{vec}(\mathcal{D}W_l)\sim{}\mathcal{N}(0,\Sigma{}_l)\)
  • 조건 정규 분포 법칙에 따라 $\Psi{}_{i,i+1}$ 은 다음과 같음 \(\begin{aligned} \Psi{}_{i,i+1}&=\hat{F}_{i,i+1}\hat{F}^{-1}_{i+1,i+1} \\ &=\tilde{F}_{i,i+1}\tilde{F}^{-1}_{i+1,i+1} \\ &=(\bar{A}_{i-1,i}\otimes{}G_{i,i+1})(\bar{A}_{i,i}\otimes{}G_{i+1,i+1})^{-1} \\ &=\Psi{}^{\bar{A}}_{i-1, i}\Psi{}^G_{i,i+1} \\ \text{where}\quad{}&\Psi{}^{\bar{A}}_{i-1, i}=\bar{A}_{i-1,i}\bar{A}_{i,i}^{-1}, \\ &\Psi{}^G_{i,i+1}=G_{i,i+1}G_{i+1,i+1}^{-1} \end{aligned}\)
  • 마찬가지로 분산 행렬 $\Sigma{}_{i|i+1}$ 은 다음과 같으며, 이 것의 효율적인 역행렬 계산 방법 존재함 (Appendix B) \(\begin{aligned} \Sigma{}_{i|i+1}&=\hat{F}_{i,i}-\Psi{}_{i,i+1}\hat{F}_{i+1,i+1}\Psi{}^\mathsf{T}_{i,i+1} \\ &=\tilde{F}_{i,i}-\Psi{}_{i,i+1}\tilde{F}_{i+1,i+1}\Psi{}^\mathsf{T}_{i,i+1} \\ &=\bar{A}_{i-1,i-1}\otimes{}G_{i,i}-\Psi{}^\bar{A}_{i-1,i}\bar{A}_{i,i}\Psi{}^{\bar{A}\mathsf{T}}_{i-1,i} \otimes{}\Psi{}^G_{i,i+1}G_{i+1,i+1}\Psi{}^{G\mathsf{T}}_{i,i+1} \end{aligned}\)

14 Conclusions and future directions

근사 자연 경사 하강법인 K-FAC 을 제안하며,

  • 뉴랄넷 레이어 Fisher 행렬의 효율적인 근사 방법과 이론 및 임상적 평가로 이것의 정당정 보임
  • Fisher 를 Hessian 행렬의 근사로 보고, damping 트릭을 이용해 최적화 알고리즘 기술함
  • 뉴랄넷의 몇몇 reparametrization 에 불변함과 같은 자연 경사 하강법의 특징을 이어받음 보임
  • Autoencoder 문제에서 모멘텀 방법과 mini-batch 크기 스케쥴링을 이용할 때 확률적 경사 하강법보다 더 좋은 성능과 빠른 수렴 확인함