[정리] Optimizing Neural Networks with Kronecker-factored Approximate Curvature (JMLR, 2015)

Author: James Martens∗ and Roger Grosse Paper Link: https://arxiv.org/abs/1503.05671

요약

0 Abstract

  • 뉴랄넷의 자연 경사 하강을 위해 Kronecker product 로 Fisher 행렬을 근사하는 K-FAC 을 제안하며, 이는 대각 또는 low rank 행렬 근사와는 다름
  • Fisher 행렬의 대각 블락을 근사하며, 더 작은 두 개의 행렬들의 Kronecker product 로 각 블락을 표현하게 됨
  • 예닐곱 배의 계산량이 추가되지만, 확률적 경사 하강 방법보다 최적화에 성능이 높아 수렴이 더 빠름
  • Hessian-free 같은 full matrix 근사 자연 경사하강 또는 Newton 방법에 비해 확률적 방법에 적합한데, 근사 행렬 및 역행렬 계산이 입력 데이터 수에 무관하고 저장하는 행렬이 작기 때문

1 Introduction

  • Hessian-free (HF) 와 같은 local curvature 를 고려한 방법들은 빠르게 학습이 진행되지만 (1) conjugate gradient (CG) 과정에 의한 계산량이 증가하고 (2) 적은 수의 입력만이 허용될 수 있음
  • HF 에서 사용되는 CG 로 인해 때로 확률적 경사 하강법 (SGD) 대비 이점이 적어지므로, CG 를 사용하지 않은 2차 최적화 방법이 필요함
  • Parameter 들은 각 레이어에 해당하는 그룹으로 나뉘고, Fisher 는 Kronecker 블락들로 근사되는데 이 과정은 gradient 에 대한 특정 가정 과 같음
  • 블락 대각 행렬 또는 띠행렬이 역행렬이 가짐을 가정하며, 분산 행렬의 역, 트리 구조의 그래프 모델, 그리고 선형 회귀와의 관계를 통해 근사를 검증할 것

2 Background and notation

2.1 Neural Networks

  • $i$-th 선형 모듈 레이어 및 활성 함수: $s_i=W\bar{a}_{i-1}, a_i=\phi{}_i(s_i)$
  • 벡터화된 모든 Parameters: $\theta{}=[\mathsf{vec}(W_1)^\text{T}\, \mathsf{vec}(W_2)^\text{T}\, …\, \mathsf{vec}(W_l)^\text{T}]$
  • 뉴랄넷 출럭 $z$ 와 타겟 $y$ 에 대한 손실함수: $L(y, z)=-\log{r(y z)}$ (assumed)
  • Partial gradient operator: $\mathcal{D}v=-\frac{\text{d}\log{p(y x,\theta)}}{\text{d}v}$
  • 레이어 출력 $s_i$ 의 gradient: $g_i=\mathcal{D}s_i$

$l$ linear layers 의 forward/backward path

  • Forward path: $s_i=W_i\bar{a}_{i-1},a_i=\phi{}(s_i)$
  • Loss derivative: $\mathcal{D}a_l=\frac{\partial{L(y,z)}}{\partial{z}} _{z=a_l}$
  • Backward path: \(\begin{aligned} \mathcal{D}a_i&=\mathcal{D}s_i\odot{}\phi{}'(s_i) \\ \mathcal{D}W_i&=g_i\bar{a}_{i-1}^\text{T} \\ \mathcal{D}a_{i-1}&=W_i^\text{T}g_i \end{aligned}\)

2.2 Natural gradient

  • Fisher 행렬은 아래와 같이 정의되며, 데이터 분포 $Q_x$ 와 학습된 모델의 분포 $p(y|x,\theta)$ 에 대한 기댓값이나, 학습 데이터에 대한 분포 $\hat{Q}_x$ 를 이용하여 계산함 \(F=\mathbf{E}[\frac{\text{d}\log{p(y|x,\theta)}}{\text{d}\theta}\frac{\text{d}\log{p(y|x,\theta)}}{\text{d}\theta}^\text{T}]\)
  • 자연 경사 하강은 정해진 KL-divergence 변화 기준 목표 함수를 최대화하는 gradient 를 말하며, 기본 경사 하강의 경우 Euclidean norm 변하를 기준으로 함
  • Fisher 는 $p(y x,\theta)$ 이 exponential family 일 때 Hessian 행렬의 positive semi-definite (PSD) 근사인 Gauss-Newton 행렬 (GGN) 과 같음

3 A Block-wise Kronecker-factored Fisher Approximation

Figure 2 Figure 2. MNIST 숫자 인식 문제에서 중간 4 개 레이어의 완전한 Fisher $F$, 블락 근사 $\tilde{F}$, 그리고 $F$ 와 $\tilde{F}$ 의 차이

  • $l$ 레이어들을 가진 뉴랄넷의 Fisher 는 $l$-by-$l$ 블락 행렬로 구성되며, $(i,j)$-th 블락 $F_{i,j}$ 은 아래와 같으며, 선형 모듈의 경우 입력 $\bar{a}$ 과 레이어의 gradient $g$ 로 표현 가능함 \(\begin{aligned} F_{i,j}&=\mathbf{E}[\mathsf{vec}(\mathcal{D}W_i)\mathsf{vec}(\mathcal{D}W_j)^\text{T}] \\ &=\mathbf{E}[\mathsf{vec}(g_i\bar{a}_{i-1}^\text{T})\mathsf{vec}(g_j\bar{a}_{j-1}^\text{T})^\text{T}] \\ &=\mathbf{E}[(\bar{a}_{i-1}\otimes{}g_i)(\bar{a}_{j-1}^\text{T}\otimes{}g_j^{\text{T}})] \\ &=\mathbf{E}[\bar{a}_{i-1}\bar{a}_{j-1}^\text{T}\otimes{}g_ig_j^\text{T}] \end{aligned}\)
  • 첫 번째로 각 $\bar{a}{i-1}\bar{a}{j-1}^\text{T}, g_ig_j^\text{T}$ 에 대한 Kronecker-product $\tilde{F}$ 로 $F$ 를 근사함 (Khatri-Rao 곱이 됨) \(\begin{aligned} F_{i,j}&=\mathbf{E}[\bar{a}_{i-1}\bar{a}_{j-1}^\text{T}\otimes{}g_ig_j^\text{T}] \\ &\approx{}\mathbf{E}[\bar{a}_{i-1}\bar{a}_{j-1}^\text{T}]\otimes \mathbf{E}[g_ig_j^\text{T}] \\ &=\bar{A}_{i,j}\otimes{}G_{i,j} \\ &=\tilde{F}_{i,j} \end{aligned}\)
  • $\tilde{F}$ 는 주요 근사로 현실적인 가정이나 극한 상황 아래 Fisher 로 수렴하기 어려우나, 사용 시에는 거친 구조 (coarse structure) 를 반영함 (Fig. 2)

3.1 Interpretations of this Approximation

  • 위 근사는 $\bar{a}^{(1)}\bar{a}^{(2)}$ 과 $g^{(1)}g^{(2)}$ 사이 통계적 독립을 가정하는 것임 \(\mathcal{D}[W_i]_{k_1,k_2}=\bar{a}^{(1)}g^{(1)}, \\ \mathcal{D}[W_i]_{k_3,k_4}=\bar{a}^{(2)}g^{(2)}, \\ \bar{a}^{(1)}=[\bar{a}_{i-1}]_{k_1},\,g^{(1)}=[g_i]_{k_2}, \\ \bar{a}^{(2)}=[\bar{a}_{j-1}]_{k_3},\,g^{(2)}=[g_j]_{k_4}\) \(\begin{aligned} \mathbf{E}[\mathcal{D}[W_i]_{k_1,k_2}\mathcal{D}[W_j]_{k_3,k_4}]&=\mathbf{E}[(\bar{a}^{(1)}g^{(1)})(\bar{a}^{(2)}g^{(2)})] \\ &=\mathbf{E}[\bar{a}^{(1)}\bar{a}^{(2)}g^{(1)}g^{(2)}] \\ &\approx{}\mathbf{E}[\bar{a}^{(1)}\bar{a}^{(2)}]\mathbf{E}[g^{(1)}g^{(2)}] \end{aligned}\)
  • 근사 오차는 culmulant $\kappa{(\bullet{})}$ 을 이용해 아래같이 표현되며, culmulant 는 평균과 분산의 고차원 일반화임 \(\kappa{(\bar{a}^{(1)},\bar{a}^{(2)},g^{(1)},g^{(2)})}+ \mathbf{E}[\bar{a}^{(1)}]\kappa{(\bar{a}^{(2)},g^{(1)},g^{(2)})}+ \mathbf{E}[\bar{a}^{(2)}]\kappa{(\bar{a}^{(1)},g^{(1)},g^{(2)})}\)
  • 다변수 정규분포일 때 culmulant 는 0 이므로, $(\bar{a}^{(1)},\bar{a}^{(2)},g^{(1)},g^{(2)})$ 의 분포가 이에 가까울 수록 근사 오차가 적음

4 Additional approximations to $\tilde{F}$ and inverse computations

$\tilde{F}$ 의 역행렬을 계산하기 위해 두 특별한 구조로 근사할 것인데, 두 번째 방법은 덜 제한적이지만 복잡도가 높음

4.1 Structured inverses and connection to linear regression

Figure 3 Figure 3. Fig.2 의 $\hat{F}$, $\hat{F}^{-1}$ 그리고 블락 평균
$\tilde{F}$ 과 달리 $\tilde{F}^{-1}$ 은 대각 또는 띠 대각 블락 행렬에 가까움을 확인할 수 있음

  • 분산 행렬이 $\Sigma{}$ 인 분포에 대해, $i$-th 변수의 선형 회귀를 위한 $j$-th 변수의 상수를 $[B]{i,j}$ , $i$-th 변수의 선형 회귀 오차의 분산을 $[D]{i,i}$ 라 했을 때, 이들은 아래같이 $\Sigma{}^{-1}$ 로 표현됨 ($[B]_{i,i}=0$) \([B]_{i,j}=-\frac{[\Sigma{}^{-1}]_{i,j}}{[\Sigma{}^{-1}]_{i,i}} \quad{}\text{and}\quad [D]_{i,i}=\frac{1}{[\Sigma{}^{-1}]_{i,i}}\)
  • Precision 행렬 $\Sigma{}^{-1}$ 역시 $B, D$ 로 표현되는데, 직관적으로 $i$-th 변수 예측에 $j$-th 변수가 유용할수록 큰 $[\Sigma{}^{-1}]_{i,j}$ 값을 가짐 \(\Sigma{}^{-1}=D^{-1}(I-B)\)
  • $F$ 는 $\mathcal{D}\theta{}$ 의 분산 행렬이며, $F^{-1}$ 의 성분을 선형 회귀 예측의 상수로 봤을 때 대각 성분이 상대적으로 큼을 의미하고, 따라서 대각 블락 근사가 유용함
  • $\mathcal{D}W_i$ 와 함께 앞뒤 레이어의 $\mathcal{D}W_{i-1}, \mathcal{D}W_{i+1}$ 도 함께 고려하여 덜 제한적인 띠 대각 행렬로 근사할 수도 있음

4.2 Approximating $\hat{F}^{-1}$ as block-diagonal

  • $\tilde{F}$ 를 대각 블락 행렬 $\check{F}$ 로 근사함으로써 $\tilde{F}^{-1}$ 를 대각 블락으로 근사함 \(\begin{aligned} \check{F}&=\text{diag}(\tilde{F}_{1,1},\tilde{F}_{2,2},...,\tilde{F}_{l,l}) \\ &=\text{diag}(\bar{A}_{0,0}\otimes{}G_{1,1},\bar{A}_{1,1}\otimes{}G_{2,2},...,\bar{A}_{l-1,l-1}\otimes{}G_{l,l}) \end{aligned}\)
  • Kronecker product identity 를 이용하면 2$l$ 개 역행렬을 계산해여 $\hat{F}^{-1}$ 를 계산할 수 있음 \(\check{F}^{-1}=\text{diag}(\bar{A}_{0,0}^{-1}\otimes{}G_{1,1}^{-1}, \bar{A}_{1,1}^{-1}\otimes{}G_{2,2}^{-1},..., \bar{A}_{l-1,l-1}^{-1}\otimes{}G_{l,l}^{-1})\)

4.3 Approximating $\hat{F}^{-1}$ as block-triagonal

Figure 4 Figure 4. $\hat{F}^{-1}$ 의 띠 블락 근사와 동등한 UGGM, DGGM

  • $\hat{F}^{-1}$ 를 띠 블락으로 근사하는 것은 $\mathcal{D}\theta{}$ 에 대해 Fig 4 와 같은 undirected Guassian graphical model, UGGM 을 가정하는 것과 같음
  • 위 UGGM 은 Fig 4 아래의 동등한 directed Gaussian graphical model, DGGM 으로 바꿔 표현할 수 있음
  • DGGM 가정 아래 $\mathsf{vec}(\mathcal{D}W_i)$ 는 다음 분포를 따름 \(\mathsf{vec}(\mathcal{D}W_i)\sim{}\mathcal{N}(\Psi{}_{i,i+1}\mathsf{vec}(\mathcal{D}W_{i+1}),\Sigma{}_{i|i+1}) \\ \text{and}\quad{}\mathsf{vec}(\mathcal{D}W_l)\sim{}\mathcal{N}(0,\Sigma{}_l)\)
  • 조건 정규 분포 법칙에 따라 $\Psi{}_{i,i+1}$ 은 다음과 같음 \(\begin{aligned} \Psi{}_{i,i+1}&=\hat{F}_{i,i+1}\hat{F}^{-1}_{i+1,i+1} \\ &=\tilde{F}_{i,i+1}\tilde{F}^{-1}_{i+1,i+1} \\ &=(\bar{A}_{i-1,i}\otimes{}G_{i,i+1})(\bar{A}_{i,i}\otimes{}G_{i+1,i+1})^{-1} \\ &=\Psi{}^{\bar{A}}_{i-1, i}\Psi{}^G_{i,i+1} \\ \text{where}\quad{}&\Psi{}^{\bar{A}}_{i-1, i}=\bar{A}_{i-1,i}\bar{A}_{i,i}^{-1}, \\ &\Psi{}^G_{i,i+1}=G_{i,i+1}G_{i+1,i+1}^{-1} \end{aligned}\)
  • 마찬가지로 분산 행렬 $\Sigma{}_{i|i+1}$ 은 다음과 같으며, 이 것의 효율적인 역행렬 계산 방법 존재함 (Appendix B) \(\begin{aligned} \Sigma{}_{i|i+1}&=\hat{F}_{i,i}-\Psi{}_{i,i+1}\hat{F}_{i+1,i+1}\Psi{}^\mathsf{T}_{i,i+1} \\ &=\tilde{F}_{i,i}-\Psi{}_{i,i+1}\tilde{F}_{i+1,i+1}\Psi{}^\mathsf{T}_{i,i+1} \\ &=\bar{A}_{i-1,i-1}\otimes{}G_{i,i}-\Psi{}^\bar{A}_{i-1,i}\bar{A}_{i,i}\Psi{}^{\bar{A}\mathsf{T}}_{i-1,i} \otimes{}\Psi{}^G_{i,i+1}G_{i+1,i+1}\Psi{}^{G\mathsf{T}}_{i,i+1} \end{aligned}\)

14 Conclusions and future directions

근사 자연 경사 하강법인 K-FAC 을 제안하며,

  • 뉴랄넷 레이어 Fisher 행렬의 효율적인 근사 방법과 이론 및 임상적 평가로 이것의 정당정 보임
  • Fisher 를 Hessian 행렬의 근사로 보고, damping 트릭을 이용해 최적화 알고리즘 기술함
  • 뉴랄넷의 몇몇 reparametrization 에 불변함과 같은 자연 경사 하강법의 특징을 이어받음 보임
  • Autoencoder 문제에서 모멘텀 방법과 mini-batch 크기 스케쥴링을 이용할 때 확률적 경사 하강법보다 더 좋은 성능과 빠른 수렴 확인함

[정리] EKFAC

[정리] Scalable Second Oder Optimization for Deep Learning (2021)

Author: Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
Paper Link: https://arxiv.org/abs/2002.09018
Code: https://github.com/google-research/google-research/tree/master/scalable_shampoo

Figure 0 (image from here)

요약

AdaGrad 를 확장한 preconditioner 이용 최적화 방법인 Shampoo의 계산 속도를 개선하는 방법을 제안함. 각 레이어 별 preconditioner 연산을 병렬화하는 방법과 행렬을 여러 블락들로 근사하는 방법을 이용함. Machine translation, Image classification 같은 문제들에서 큰 계산 시간 증가 없이 더 좋은 성능과 빠른 수렴을 보임.

0 Abstract

  • 1차 경사 하강법에 비해 2차 최적화 방법은 계산량, 통신량이 많고 메모리 소모가 커 잘 사용되지 않았음
  • 큰 스케일 문제에 사용될 수 있는 2차 preconditioned method 를 제안함
  • 높은 수렴성과 1차 경사 하강법 대비 wall-clock 기준 빠른 업데이트를 확인함
  • 여러 CPU 와 가속기(GPU, TPU)를 활용한 결과로, BERT, ImageNet 같은 큰 학습 문제에서 성능 확인함

1 Introduction

  • 2차 방법에서 계산량과 메모리 소모를 줄이는 것이 중요하며, 학습에 필요한 스텝 수를 줄일 수 있음
  • 확률 최적화를 위한 적응형 방법을 제안하며, AdaGrad, Adam 의 full matrix 버전이라 할 수 있고 parameter 사이 correlation 을 고려하게 됨 (이 방법들은 gradients 의 외적합을 이용하여 2차 모멘트 행렬을 계산했음)
  • 2차 방법의 문제점들을 완화하기 위해 K-FAC, K-BFGS, Shampoo 가 제안되었지만, 큰 스케일의 딥러닝을 위해 병렬화하기 어려움

1.1 Contribution

  • 많은 레이어들을 갖는 딥러닝 모델을 학습하기 위해 Shampoo 방법을 확장하였음
  • Preconditioner 행렬을 위해 계산량이 많은 SVD 을 PSD 행렬을 위한 안정적인 방법으로 대체함
  • Machine translation, Language modeling 등에서 학습 속도 (Wall-clock) 향상을 확인하였음
  • Bollapragada et al., 2018 은 노이즈가 적은 Full batch L-BFGS 와 확률 경사 사이의 방법을 제안함
  • 대부분의 이전 Preconditioner 방법은 대각 행렬 근사를 이용하였으나, 최근에 들어 full matrix 를 이용하는 방법들이 제안되었음
    • K-FAC 은 Fisher 행렬을 근사하여 preconditioner 로 사용하는 것이며, K-BFGS 는 레이어의 Hessian 을 유사한 방법으로 근사함
    • GGT 는 AdaGrad preconditioner 를 row-rank 근사하는 방식이나 많은 수의 gradients 를 저장해야므로 중간 정도 크기의 모델에 적합함
  • Ba et al., 2017 은 K-FAC 의 distributed 버전을 제안했음

2 Preliminaries

Notation

  • Frobenius norm of A: $   A   ^2F=\sum{i,j}A^2_{i,j}$
  • Element-wise product of A and B: $C=A\bullet B\iff C_{i,j}=A_{i,j}B_{i,j}$
  • Element-wise power: $(D^{\odot\alpha}){i,j}=D^{\alpha}{i,j}$
  • $A\preceq B$ iff $B-A$ is positive semi-definite (PSD)
  • For a symmentric PSD matrix $A=UDU^\text{T}$, $A^\alpha=UD^\alpha U^\text{T}$
  • Identities of the Kronecker product of A and B:
    • $(A\otimes B)^\alpha=A^\alpha\otimes B^\alpha$
    • $(A\otimes B)\mathbf{vec}(C)=\mathbf{vec}(ACB^\text{T})$

Adaptive preconditioning methods

  • Parameter $w_t\in\mathbb{R}^d$ 는 gradient $\bar{g}_t\in\mathbb{R}^d$ 와 precondition 행렬 $P_t\in\mathbb{R}^{d\times d}$ 에 대해 다음과 같이 업데이트됨 \(w_{t+1}=w_t-P_t\bar{g}_t\)
  • $P_t$ 는 2차 최적화 방법에서 Hessian 행렬이 되지만, adaptive preconditioning 방법에서는 gradients 의 correlation 과 관련됨
  • Parameter 와 gradient 를 행렬 $W, G \in\mathbb{R}^{m\times n}$ 로 각각 표현 했을 때, full matrix 일 경우의 precondition 의 저장공간은 $n^2m^2$ 만큼 필요하며, 업데이트 식의 계산량은 $m^3n^3$ 이므로 근사 없이는 계산이 불가능함

The Shampoo algorithm

  • Shampoo 는 Kronecker product 를 이용하여 대각행렬 근사 방법과 full matrix 방법을 잇는 방법임
  • Iteration $t$ 의 손실 함수에 대한 gradient $G_t=\nabla_Wl(f(W,x_t),y_t)$ 에 대해 $L_t\in\mathbb{R}^{m\times m}$ 와 $R_t\in\mathbb{R}^{n\times n}$ 는 아래같이 정의됨 \(L_t=\epsilon I_m+\textstyle\sum^t_{s=1}G_sG_s^{\text{T}} \quad R_t=\epsilon I_n+\textstyle\sum^t_{s=1}G_s^{\text{T}}G_s\)
  • Full matrix Adagrad preconditioner $H_t$ 는 $(L_t\otimes R_t)^{1/2}$ 로 근사되며, Adagrad 의 업데이트 방식인 $w_{t+1}=w_t-\eta H_t^{-1/2}g_t$ 을 따라 Shampoo 의 업데이트 식은 아래와 같음 \(W_{t+1}=W_t-\eta L_t^{-1/4}G_tR_t^{-1/4}\)

3 Scaling-up Second Order Optimization

적은 연산량과 메모리 소모를 보이는 1차 방법보다 Shampoo 는 (1) Preconditioner 계산 (2) 역행렬 계산 (3) 행렬 곱 계산이 추가됨 이 중에서도 역행렬 계산을 으로 인한 학습 속도 저하를 최소화하는 것이 가장 중요함

3.1 Preconditioning of large layers

Large Layers

다음 Lemma 1 에 따라 full Adagrad precondition $\hat{H}_t$ 의 근사 방법 중 $L_t^{1/p}\otimes R_t$ 을 사용함 (Theorem 3 에서 regret bound 확인)

Lemma 1. 최대 rank () $r$ 인 행렬들 $G_1, …, G_t\in\mathbb{R}^{m\times n}$, $g_s=\mathtt{vec}(G_s)$ 에 대해 $\hat{H}t=\epsilon I{mn}+\textstyle\sum^t_{s=1}g_sg_s^\text{T}$ 를 정의하자. 위에서 정의한 $L_t, R_t$ 와 $1/p+1/q=1$ 을 만족하는 $p, q$ 에 대해 $\hat{H}_t\preceq rL_t^{1/p}\otimes R_t^{1/q}$ 이다.

Preconditioning blocks from large tensors

여러 레이어들의 경우에 계산량 및 메모리 소모를 줄이기 위해 각 레이어에 해당하는 텐서 블락을 분리함

Lemma 2. 벡터 $g_1, …, g_t\in\mathbb{R}^{mk}$ 에 대해 $g_i=[g_{i,1},…,g_{i,k}]$ 이고 $g_{i,j}\in\mathbb{R}^m$ 일 대, $B^{(j)}t=\epsilon I_m +\textstyle\sum^t{s=1}g_{s,j}g_{s,j}^\text{T}$ 이도록 $B_t$ 를 정의하면, $\hat{H}_t\preceq{kB_t}$ 이다.

Delayed preconditioners

Precondition 행렬은 수 백의 스텝마다 업데이트되더라도 성능에 큰 영향을 주지 않았으며 (Fig. 4c), 이는 손실 함수의 landscape 이 꽤 평탄함을 의미하고 성능과 계산 속도가 상충 관계임을 뜻함

3.2 Roots of ill-conditioned matrices

행렬의 interse pth root, 즉 $A^{-1/p}$ 를 구할 때 계산량이 많은 SVD 보다 효율적인 coupled Newton iteration 방법을 사용할 수 있음 (Fig. 7) 또한 $L_t, R_t$ 의 조건수가 매우 크기 때문에 두 방법 모두 double precision 으로 계산될 필요가 있으나, 계산량이 매우 많아질 것임

3.3 Deploying on current ML infrastructure

Figure 1 Fig 1. 병렬화된 Shampoo 의 계산 다이어그램

Heterogeneous training hardware. 가속기 설계의 방향은 주로 낮은 precision 이지만, double precision 계산이 다수의 레이어에서 수행되어야 했고, 따라서 가속기보다 CPU 들을 활용하였음

API inflexibility. 제안하는 방법이 비표준적인 학습 과정을 따르므로 framework 수준의 변경이 필요했고, Lingvo TensorFlow framework 을 사용하엿음

4 Distributed Systen Design

  • 표준 병렬화에서는 parameter 가 각 가속기 코어에 복제되어 forward & back propagation 을 수행하고, 다시 한 곳으로 모여 배치에 대해 평균됨
  • Preconditioner 의 inverse pth root 는 double precision 이어야 하지만, 수 백 스텝마다 비동기적으로 계산되므로 CPU 를 사용할 것임
  • 각 레이어마다 Preconditioner 를 계산하므로 여러 CPU 에 계산을 분산시킴 (Fig. 1)

5 Experiment

5.1 Comparison of second order methods

Figure 2 Fig 2. Autoencoder 문제에서 여러 2차 최적화 방법과 비교

  • Autoencoder 문제를 두고 K-FAC, K-BFGS 와 비교하였으며, 모든 알고리즘의 결과가 유사했음 (Fig. 2)

5.2 Machine translation with a transformer

Figure 3 Fig 3. 기계 번역 문제에서 비교 결과로, 모든 계산은 CPU 로 수행하였

  • 영어의 불어 번역 데이터 WMT’14 의 36.3 문장 쌍을 Transformer 구조(93.3M) 학습하고, Adagrad 와 Adam 과 비교함
  • ~ 60 % 느린 계산 속도를 보였지만 1.95x 빨리 수렴하였으며, preconditioner 계산을 위한 오버헤드는 분산 계산을 통해 꽤 낮아졌음을 확인할 수 있음

Figure 4 Fig 4. (a) 전체 또는 embedding 레이어에만 적용했을 때, (b) 여러 블락들로 근사했을 때, (c) precondtioner 업데이트 주기를 변경했을 때 성능 비교

Preconditioning of embedding and softmax Layers

$R_t, L_t$ 중 하나만 이용하여 Precondition 했을 때 ($G_tR_t^{-1/2}$ or $L_t^{-1/2}G_t$), 6 %의 계산 시간 증가(Fig. 3b)로 20 % 수렴 시간을 줄일 수 있었음 (Fig. 4a)

Reducing overhead in fully-connected Layers

FC 레이어의 preconditioner 를 2 개 그리고 4 개 블락들로 근사하였을 때, 성능 하락은 3 % 이내였음

5.3 Transformer-Big model

Figure 5 Fig 5. 번영 문제에서 최적화 방법 및 배치 크기에 따른 성능 비교

더 큰 Transformer 모델(375.4M)에 대해 비교하였을 때, 30 % 적은 계산 시간을 보였으며, 배치 크기가 클 때 이 효과가 더 두드러

5.4 Ads Click-Through Rate (CTP) Prediction

Figure 6 Fig 6. CTP 예측 문제, 언어 모델링 문제에서 Shampoo 의 성능

광고 클릭 데이터셋에 대한 딥러닝 추천 모델을 학습시킬 때 제안한 방법을 이용하였고, 0.3 % AUC 개선된 SOTA 성능을 보였고, 총 스텝 수도 39.96K 에서 30.97K 로 감소시켰음 (Fig. 6a)

5.5 Language modeling

  • Bert-Large 모델(340M)을 (a) 주변으로부터 가려진 토큰 찾기(MLM) (b) 다음 문장 예측하기 문제(NSP)에 대해 학습시켰음
  • MLM 문제에서 16 % 적은 스텝 수로 1 % 성능 향상을 보였음

5.6 Image classification

ResNet-50 모델을 이용한 ImageNet-2012 분류 문제를 해결할 때, Nesterov momentum 혹은 LARS 최적화 방법을 사용했을 때보다 적은 수의 스텝으로 75.9 % 의 정확도에 도달하였음

6 Concluding Remarks

  • 딥러닝을 위한 2차 최적화 방법을 구현 방법을 제안하였고 스텝 시간과 wall clock 에 있어 향상된 성능을 확인함
    • 기존 구현 대부분이 대칭 행렬을 이용하지만, 대칭 연산자를 이용하는 경우는 발견하지 못했는데, 이는 플롭과 메모리를 약 50 % 절약할 수 있음
    • 섞인 precision 을 사용한다면 preconditioner 계산을 더 자주 수행할 수 있을 것임

[정리] MIER: Meta-Reinforcement Learning Robust to Distributional Shift via Model Identification and Experience Relabeling (ICML Workshop, 2020)

Author: Russell Mendonca, Xinyang Geng, Chelsea Finn, Sergey Levine Paper Link: https://arxiv.org/abs/2006.07178 Talk in NeurIPS2020 Workshop: https://slideslive.com/38931350/ Code: https://github.com/russellmendonca/mier_public.git/

요약

  • Off-policy meta RL 에서 태스크가 Out-of-dist 일 때 외삽이 가능하도록 dynamics/ reward 모델을 meta 학습함
  • 시험 태스크 경험이 주어지면 context variable 을 meta 학습 방식으로 적응시키고, universal policy 에 이용함
  • 외삽 성능을 높이기 위해 학습한 모델로 가상의 경험을 생성하여 policy 학습에 이용함

0. Abstract

  • 여러 태스크를 수행하려 할 때, 다양한 스킬들을 학습하는 것은 많은 수의 샘플이 요구됨
  • Meta RL 은 선험지식을 이용해 빠르게 적응(adapt)하지만, 시험 태스크가 학습 태스크들에 얼마나 가까운지에 따라 결과가 상이함
  • On policy 처럼 많은 샘플 없이 시험 태스크를 해결하는 것, 즉 효율적인 외삽(extrapolate)을 목표로 함
  • Dynamics 모델 식별(identification)과 경험 재분류(experience relabeling) 과정으로 목표를 달성하는 방법 MIER 를 제안함
  • 모델 학습은 적은 off policy 샘플로 가능한 것에서 착안한 것으로, 시험 태스크에서 학습한 모델을 이용하여 policy 와 value 를 학습함으로써 Meta RL 없이 외삽 수행함

1. Introduction

Figure 1 Fig 1. Meta RL 에서의 모델 식별과 경험 재분류 방법
모델 인식에서 얻은 context variable 과 시험 태스크의 샘플를 이용하여 가상의 경험을 만들고 이로부터 policy 를 학습함

  • 선험 지식을 이용하여 새 태스크를 빠르게 학습하는 Meta RL 은 대부분 많은 수의 on policy 샘플들이 요구되었음
  • Off policy 기반의 value 를 meta 학습하여 policy 를 학습시키는 것은 계산량이 많아 어려움 (Appendix D)
  • 대신 dynamics/ reward 모델과 context variable 을 meta 학습하고 이를 이용하여 policy 와 value 를 도출하는 것으로 meta RL 을 수행할 것임
  • 태스크 정보를 담고 있은 context variable 은 모델의 입력이며, gradient descent 를 이용하여 각 태스크에 적응되어 policy 입력이 됨 (Fig. 1)
  • Policy 학습은 context variable 이 state 에 추가됨을 제외하면, standard RL 과 다르지 않음
  • 시험 태스크의 context variable 이 상이하면 policy 성능에 문제될 수 있는데, gradient descent 방식으로 모델 적응시키고 이로부터 얻은 가상 경험으로 policy 학습하는 경험 재분류 방법 사용함 (Fig. 1) (시험 태스크 샘플로 context variable 을 meta learning 방식으로 적응시켰다는 말)

2. Preliminaries

  • Meta RL 은 표준 RL 에 더하여 태스크 분포 $\rho(\mathcal{T})$ 를 가지며 아래와 같은 목적함수를 최대화함
    단, 시험 태스크에 대해서 policy 의 적응을 위해 $D_{adapt}^{(\mathcal{T})}$ 를 수집함
\[\mathbb{E}_{\mathcal{T}\sim\rho(\mathcal{T}),\mathbf{s}_t,\mathbf{a}_t\sim\phi_\mathcal{T}} [\textstyle\sum_{t}\gamma^tr(\mathbf{s}_t,\mathbf{a}_t)]\]
  • Dynamics 모델의 meta 학습은 MAML을 따르며, $D_{adapt}^{(\mathcal{T})}$ 을 이용해 적응한 모델의 $D_{eval}^{(\mathcal{T})}$ 에 대한 손실함수를 이용함 (즉, 적은 데이터로 빠르게 적응해 평가 데이터에 대한 성능을 올리도록 유도함)
\[\min_{f,\mathcal{A}} [\mathcal{L}(f(X_\mathcal{T}; \mathcal{A}(\theta,\mathcal{D}^{(\mathcal{T})}_{adapt})), Y_{\mathcal{T}})]\]
  • 적응을 나타내는 $\mathcal{A}(\theta,\mathcal{D}^{(\mathcal{T})}_{adapt})$ 을 one step 버전으로 나타내면 아래와 같음
    적응의 업데이트 방식은 경사 하강법과 같기 때문에 $\rho(\mathcal{T})$ 와 무관히 모델 정확도의 향상이 이뤄짐
\[\mathcal{A}_{\text{MAML}}(\theta,\mathcal{D}^{(\mathcal{T})}_{adapt}))= \theta-\alpha\nabla_\theta\mathbb{E}_{X,Y\sim\mathcal{D}^{\mathcal{T}}_{adapt}} [\mathcal{L}(f(X;\theta),Y)]\]

3. Meta Training with Model Identification

Algorithm 1

  • Meta 태스크 식별 문제를 dynamics 와 reward 모델의 meta 학습 방법으로 해결함
  • Dynamics 모델 $\hat{p}(\mathbf{s’, r s, a};\theta,\phi)$ 은 적응 중 각 태스크의 정보를 담고 있는 latent context variable $\phi$ 로 표현됨
  • 적응 시에는 meta 학습으로 오직 context variable 만을 적응시키고 이를 universal policy 에 적용하여 meta RL 을 수행함 (Alg. 1)
  • 모델 학습 시 손실 함수는 음의 로그 확률 $-\log{\hat{p}(\mathbb{s’,r|s,a};\theta{},\phi{})}$ 이고, context variable 의 적응 과정 중 gradient 스텝은 아래와 같음 \(\begin{aligned} \phi{}_\mathcal{T}&=\mathcal{A}_\text{MAML}(\theta{},\phi{},\mathcal{D}^{(\mathcal{T})}_{adapt}) \\ &=\phi{}-\alpha{}\nabla{}_\phi{}\mathbb{E}_{\mathbb{(s,a,s',r)}\sim\mathcal{D}^{(\mathcal{T})}_{eval}} [-\log{\hat{p}(\mathbb{s',r|s,a};\theta{},\phi{})}] \end{aligned}\)
  • 모델 학습의 meta 손실 함수는 다음과 같은데, 적응된 context variable $\phi{}_\mathcal{T}$ 가 사용됨 \(\arg{}\min{}_{\theta{},\phi{}}J_{\hat{p}}(\theta{},\phi{},\mathcal{D}^{(\mathcal{T})}_{adapt}\mathcal{D}^{(\mathcal{T})}_{eval}) \\ =\arg{}\min{}_{\theta{},\phi{}}\mathbb{E}_{(\mathbb{s,a,s',r})\sim{}\mathcal{D}^{(\mathcal{T})}_{eval}} [-\log{\hat{p}(\mathbb{s',r|s,a};\theta{},\phi{}_\mathcal{T})}]\)
  • 평가 데이터에 대해 시작 context variable 은 빠르게 적응하도록, 모델 파라미터 $\theta{}$ 는 적응한 $\phi{}$ 를 받아 정확도를 높이도록 최적화가 이뤄짐
  • Off-policy Meta RL 방법인 PEARL 에서 외삽도 가능하도록 context variable $\phi{}$ 을 적응하도록 확장한 것이며, ${\phi{}}$ 가 out-of-dist. 일 때를 위해 경험 재분류를 이용함 (Section 4)
  • Policy $\pi_\psi$ 는 context variable 을 추가 state 로 받는 universal policy 이며, 표준 off-policy RL 방식인 SAC 으로 학습됨 \(J_\pi{}(\psi{},\mathcal{D},\phi{}_\mathcal{T})= -\mathbb{E}_{\mathbf{s}\sim{}\mathcal{D},\mathbf{a}\sim{}\pi{}} [Q^{\pi{}_\psi{}}(\mathbf{s,a},\phi{}_\mathcal{T})] \\ \text{where}\quad{}Q^{\pi{}_\psi{}}(\mathbf{s,a},\phi{}_\mathcal{T})= \mathbb{E}_{\mathbf{s}_t,\mathbf{a}_a\sim{}\pi{}} [\Sigma{}_t\gamma{}^tr(\mathbf{s}_t,\mathbf{a}_t)| \mathbf{s}_0=\mathbf{s},\mathbf{a}_0=\mathbf{a}]\)

4 Improving Out-of-Distribution Performance by Experience relabeling

Algorithm 2

  • 적응한 context variable 이 policy 에게 out-of-dist. 이면 성능이 저하되므로, dynamics/ reward 모델을 이용하여 가상의 경험을 만들고 policy 를 학습시키는 경험 재분류 방법을 이용함
  • 기존 다른 태스크들의 경험을 이용하여 가상 경험을 만드는데, time step 이 길면 에러가 누적되어 문제되므로 한 step 만 고려함 (When to Trust Your Model)
  • 경험 재분류는 기존 태스크의 경험을 Importance sampling 을 통해 이용하는 MQL 과 유사하나, 시험 태스크가 다른 분포에서 샘플되는 점이 다름 ˛˛
  • 비모델 encoder 기반 meta RL 은 적응 중 경험을 context variable 로 전환하고 이를 universal policy 에 사용하였음
    • 전환을 위해 Recurrent encoder 또는 variational inference 를 이용하였음
    • Out-of-dist. 태스크에서 적응 성능이 좋지 않음
  • 적응 중 경사 하강을 이용하는 비모델 meta RL 방법도 존재하나, On-policy 샘플을 이용하여 적응하기 때문에 샘플 효율이 낮음
  • 모델 기반 meta RL 은 적응한 모델을 model predictive control 에 이용하는 방식이나, 대체로 낮은 성능을 보였음 (Fig. 4)

6 Experimental Evaluation

Figure 2 Figure 2. 외삽 평가가 없는 표준 meta RL benchmarks 에서 여러 알고리즘의 성능

다음과 같은 질문에 대한 답을 얻고자 하며, Open AI gym 과 mujoco 시뮬레이터를 이용하였음 1) 표준 meta RL benchmarks 에서 효율적으로 학습하여 SOTA 에 견줄만한 성능을 보이는가? 2) 기존 meta RL 과 비교하여 시험 태스크에 대한 외삽 성능은 어느 정도인가? 3) 경험 재분류가 외삽 시 성능에 얼마나 영향을 미치는가?

6.1 Meta-Training Sample Efficiency on Meta-RL Benchmarks

  • 기존 meta RL 방법인 PEARL, MQL, MAML, ProMP, MAML 그리고 RL2 와 외삽이 요구되지 않는 문제에 대해 비교하였음 (Fig. 2)
  • 경험 재분류가 없을 때 MIER 의 성능도 함께 비교했을 때 (MIER-wR), 모델 식별의 meta RL 만으로도 SOTA 성능과 유사했음

6.2 Adaptation to Out-of-Distribution Tasks

Figure 3 Figure 3. Out-of-dist 태스크 설명

Figure 4 Figure 4. Out-of-dist 태스크에서 meta RL 성능 비교

  • 외삽은 reward 와 dynamics 에 대해 각각 고려되었으며, MIER 가 높은 성능을 보임
    • Dynamics 변형은 제어 신호에 대한 움직임 방향을 반대로 하는 방식을 사용함

7 Conclusion

  • 효율적인 meta RL 을 위해 모델 인식 문제로 변형하고, 외삽을 위해 경험 재분류 방법을 이용함
    • 시험 태스크에 dynamics/reward 모델을 적응하고, 만들어진 가상의 데이터를 이용하여 off-policy RL 방식으로 policy 를 학습함
    • 모델 인식은 meta RL 위해, 경험 재분류는 외삽을 위해 사용됨
    • 외삽이 필요한 meta RL 문제에서 높은 성능을 보임