이전 글에서 이어진다.
[Mamba 이해하기 1-1] HiPPO의 A 행렬의 중요성 | HiPPO: Recurrent Memory with Optimal Polynomial Projections, Gu et al.
다루고자 하는 문제: 시계열 모델링1차원 연속 신호 x(t)가 주어졌을 때, 우리가 원하는 종류의 신호 y(t)로 바꿔주는 함수(또는 시스템) f(t):R+→R을 찾는 문제이다. x(t)
haawron.tistory.com
HiPPO: Recurrent Memory with Optimal Polynomial Projections, Gu et al., NeurIPS 2020
HiPPO: Recurrent Memory with Optimal Polynomial Projections
A central problem in learning from sequential data is representing cumulative history in an incremental fashion as more data is processed. We introduce a general framework (HiPPO) for the online compression of continuous signals and discrete time series by
arxiv.org
Task: Online Function Approximation

Online function approximation은 매 time step t마다 함수 f의 x≤t 부분을 완전히 기억하는 것을 목표로 한다. 정의역이 x≤t로 제한된 함수 f를 fx≤t 또는 f≤t로 표기한다.
HiPPO는 f≤t를 고차 다항함수의 합으로 표현하려고 한다. 정확히는 N−1차 다항함수 N개 p(t)0,p(t)1,…,p(t)N−1를 미리 정의해놓고 f≤t를 가장 잘 표현하는 다항 함수의 선형결합의 계수들 cn(t)를 구하려고 한다. 이를
f≤t≈N−1∑n=0cn(t)p(t)n
와 같이 표현 할 수 있다. 따라서 HiPPO는 임의의 실함수 f를 고차 다항함수(High-order Polynomial)로 재표현(Projection)하는 연산자(Operators) 또는 그 프레임워크를 일컫는 말이다.
내가 생각하는 Contributions
1. 장기 기억 모델 HiPPO 제시: 기존 RNN, LSTM, GRU 등의 시계열 모델 연구의 핵심은 explicit한 기억과 망각이다. 기억, 망각 회로를 직접적으로 집어넣어 모델이 시계열 데이터의 특정 부분을 얼마나 기억하거나 잊게 할지 결정하도록 한다. 한편, 다항함수는 계수 N개만으로도 함수 전체를 복원할 수 있다. 따라서 유한 개의 스칼라로 거의 완벽한 암기를 할 수 있다. 기억 회로에 HiPPO를 바꿔 끼워 넣음으로서 시간 거리가 먼 곳도 참조(long-term dependency)할 수 있게 된다.
2. 실시간 문제를 빠르게 푸는 법 제시: 실시간이 아닌 다항회귀는 이미 연구가 많이 된 분야이지만 실시간 피팅으로 넘어와서는 재밌는 목표를 세운다. 바로 이전 time step의 결과를 재활용 할 수 있지 않을까 하는 것이다. 맞추고자 하는 함수는 변하지 않으므로 충분히 가능한 목표이다. 실제로 K개의 점이 주어졌을 때, N차 다항회귀 방법의 시간복잡도는 O(K⋅N2+N3)이다. 역함수 연산에 따라 달라지긴 하지만 K는 매우 크다. 따라서 함수 차원 N은 심각한 제한을 받는다. 하지만 HiPPO는 이전 결과를 재활용하여 빠르게 푸는 법을 제시하면서 매우 큰 N 값을 가져갈 수 있다.
3. HiPPO가 LSTM, GRU 등의 generalization임을 수학적으로 보임: 특정한 다항함수 집합을 활용하면 LSTM, GRU를 도출할 수 있다.
4. HiPPO의 4가지 특징 발굴: 2번을 포함하여, 시간 스케일에 영향 안 받음, RNN에 삽입되면 gradient vanishing 문제 해결 가능, 에러값 상계의 4가지 특징이 존재하고 appendix에 증명이 돼있다. 2번과 시간 스케일에 영향을 안 받는다는 내용이 가장 중요하다. 10Hz로 학습한 모델에게 20Hz 인풋을 넣어줘도 똑같은 성능이 나온다는 뜻이다.
논문 읽다가 막히는 부분들, 이 글의 목표
논문의 본문에는 최대한 수학적인 요소를 빼려고 한 것 같다. 그래서 처음 읽었을 때 반도 이해 못한 느낌이 들었다. 질문들이 굉장히 많았다. 예를 들면
- μ는 내가 맘대로 설정할 수 없나?
- 왜 Time-scale에 invariant 한가?
- 왜 빠른가?
- 대체 A는 어디서 나왔나?
- HiPPO를 응용하려고 할 때 건드리면 안 되는 부분은 어디인가?
- Interpretable 한가?
등의 왜 질문들과
- GBT
- Measure, dot product
- SSMs
- c는 뭐고 c(t)는 뭐고 cn(t)는 뭐고, 얘를 t로 미분...?
등의 개념적인 질문들이 있었다. 이 글에선 위 질문들에 대한 나의 답을 공유하려고 한다.
HiPPO: High-order Polynomial Projection Operators

HiPPO 프레임워크는 3부분으로 나뉜다. projt, coeft, discretize이다. 시간에 따른 다항함수 계수들의 변화를 SSM으로 표현하여 재활용하는 구조다. SSM과 measure에 대해 잠깐 알아보고 각 연산자를 하나씩 알아볼 것이다.
SSMs; State-Space Models
공학의 많은 문제에서 미분방정식(ODE; Ordinary Differential Equation)이 등장하게 된다. 보통 시간에 따른 인풋 함수 fi(t)와 상태함수 fstate(t), 원하는 정보인 아웃풋 함수 fo(t)로 구성되어있다. 미분방정식은 보통 여러 개가 동시에 등장한다. 이들을 잘 정리하면 인풋, 아웃풋, 상태의 관계를 행렬로 표현할 수 있다. ODE의 행렬 표현이 SSM인 것이다.
미분방정식 여러 개를 행렬로 표현할 수 있다는 내용과 미분방정식의 이산화(discretize)에 대한 내용을 들고 갈 것이다.

ODE까지는 많이 접해봤어도 SSM은 생소할 수 있다. 나는 SSM을 4학년 자동제어를 와서야 배웠다. 친근한(?) ODE로부터 SSM을 만들어보자.
위는 전자공학에서 등장하는 SSM의 한 예시이다. 간단한 RLC 회로이다. 기계공학에서는 질량-스프링-댐퍼 모델에 대응된다. 인풋은 시간에 따른 전압 vi이고 우리는 캐패시터에 걸리는 전압 함수 vo를 구하려고 한다. 이를 위해 전류 i를 상태로 사용한다. R,L,C는 저항, 인덕터, 캐패시터의 저항값, 인덕턴스, 캐패시턴스를 의미한다. 상수값이다.
조금만 더 설명하자면 옴의 법칙에 따라 저항에 걸리는 전류, 전압은 V=IR의 비례 관계를 가질 수 있다. 인덕터, 캐패시터의 것들은 미분방정식으로 표현된다.
v=Ldidt ,i=Cdvdt .
그리고 위 회로에서 소자들의 전류값은 모두 같다.
서론이 길었지만 이들의 관계를 아래 두 식으로 정리해서 표현할 수 있다.
didt=−RiL−vcL+viLdvcdt=iC
멀리서 보면 뭔가 vi를 인풋으로 집어넣으면 i,vc가 이를 인코딩하는 구조 같다. 그리고 선형이다! (제곱이나 나누기 같은 것들이 없다.) 따라서 행렬로 표현할 수 있다.

열심히 정리하면 위와 같이 된다. 쓸데없이 복잡하게 표현한 것 같아도 generalization 하기 편하기 때문에 분석에 다양한 아이디어를 활용할 수 있다. 식들을 유심히 보면 아래와 같이 정리할 수 있다.

SSM이 완성되었다. u를 입력, x를 상태, y를 출력이라고 한다. A는 상태천이행렬(transition matrix)이다. SSM은 상태가 어떻게 바뀌는지 분석하기 위해 만들어진 모델이기 때문에 A가 가장 중요하다. A의 고윳값, 고유벡터 등을 분석하여 시스템의 특성을 미리 알 수 있다. 진동할지 발산할지 파형에 굴곡이 몇 개일지가 그 예이다. 이러한 시스템의 행동양상을 dynamics라고 한다.
자동차 시트를 예로 들면 과속방지턱을 넘을 때 우리 몸은 진동을 2번 정도 느낀 후 정지한 느낌을 받는다. SSM으로 진동 2번을 예측할 수 있다. A의 고윳값은 R=2√L/C이면 입력 전압이 갑자기 가해졌을 때(step function) 출력이 무한히 진동할 것이라고 알려준다.
요약하면, ODE가 선형적으로 표현된다면 SSM을 만들 수 있고, 여기서 나온 A는 상당히 많은 정보를 알려준다.
SSM과 ODE의 이산화(Discretization)
문제를 컴퓨터로 푼다면 함수가 아닌 함숫값들이 이산적으로 주어진다. 시간 간격 Δt>0에 대해 u[k]=u(kΔt),k∈{0,1,…}에 대한 상태 수열 x[k]를 이산화를 통해 도출할 수 있다. 미분의 정의부터 시작한다.
y′(t)=limΔt→0y(t+Δt)−y(t)Δt .
만약 시간 간격 Δt가 충분히 작다면 다음과 같이 표현할 수있다.
y′(t)≈y(t+Δt)−y(t)Δt .
그리고 정리하면,
y(t+Δt)≈y(t)+Δty′(t)
y(t+Δt)−y(t)≈Δty′(t)
y[k+1]−y[k]≈Δty′(t)
와 같이 표현된다. 보통 미분방정식은 y′(t)=f(t,y) 형태로 주어진다. 매 time step t=kΔt마다 f(t,y)는 대입만으로 쉽게 얻을 수 있으므로 차분값을 구해 미분방정식을 차분방정식으로 바꿀 수 있다. 이 과정이 이산화이며, 방금 소개한 방법은 forward Euler method라고 한다.
(TODO: Bilinear와 GBT는 나중에 소개하겠다.)
위키피디아에서 많은 부분을 참조했다.
Measure와 함수 간 내적
우리는 함수 f≤t의 가장 비슷한 다항함수 전사(projection) g(t)를 구하고싶다. 여기서 '비슷한'을 정의하기 위해 거리 개념이 필요하다. 함수를 벡터의 연장으로 생각해서 RMSE를 거리로 두면 되지 않을까? 하는 생각이 들 수도 있다. 어느 정도 맞다.
[0,∞)에서 정의된 실함수 f,g가 square integrable이라면 내적 ⟨f,g⟩L2=∫∞0f(x)g(x)dx는 힐베르트 공간을 형성할 수 있고, f의 norm은 ‖f‖L2=⟨f,f⟩1/2L2와 같다.
여기서 square integrable은 적분 ∫∞0|f(x)|2dx가 발산하지 않는다는 뜻이고, 힐베르트 공간은 함수들의 집합인데 이 공간의 함수들은 내적값이 크면 비슷하게 생겼다는 것을 알려준다. 거리는 ‖f−g‖L2로 정의할 수 있다.
한편, 우리는 함수를 다항함수로 표현하고 싶다. 만약 다항함수들이 orthogonal basis를 형성한다면 예쁘게 표현할 수 있겠다.
(잠깐 선형대수 복습: 기저(Basis)?)
우리는 선형대수에서 N차원 벡터 u를 orthogonal basis V={v0,v1,…,vN−1}이 span하는 선형공간에 전사를 시켜봤다. 다항함수간 전사를 이해하기 위해 이미 배운 내용에서 analogy를 갖고 갈 것이다.
u≈ˆu=c′0v0+c′1v1+⋯+c′nvn+⋯+c′N−1vN−1=N−1∑n=0c′nvn
이를 ˆu의 basis expansion이라고 한다. 기저가 orthogonal 하다는 뜻은 임의의 i,j∈[N]에 대해 ⟨vi,vj⟩=δij라는 뜻이다. 따라서 계수 c′n은 내적으로 쉽게 구할 수 있다.
⟨ˆu,vn⟩=⟨N−1∑i=0c′ivi , vn⟩=N−1∑i=0c′i⟨vi,vn⟩=0+0+⋯+c′n⟨vn,vn⟩+⋯+0=c′n‖vn‖2=:cn .
다시 돌아와서, 우리는 내적이 잘 정의됐다면 N개의 orthogonal한 다항함수 기저도 정의할 수 있다는 사실을 알게 됐다. (많은 게 생략됐지만 맥락은 비슷하다.) 임의의 두 다항함수를 곱해서 적분했을 때 0이 나오는 다항함수 집합을 찾으면 된다. 조금 어렵지만 이미 정의된 유명한 orthogonal한 다항함수 기저들이 많다. 라게르, 르장드르 다항함수가 그 예이다. 정확히는 내적을 위와같이 정의하면 안 되고 적절한 weight function w(x)를 같이 줘야 수직이 된다.
"적절한 w(x)를 줬다"는 것을 dμ(x)로 표현한다. 그러면 정의가 아래와 같아진다.

여기서 적절한 weight function을 주는 μ를 측도(measure)라고 한다. 측도는 내맘대로 정의하면 안 된다. 선택한 다항함수 기저를 orthogonal하게 만들 수 있도록 줘야 한다. 라게르를 예를 들면 w(x)=xαe−x,x≥0을 줘야 한다. 다항함수 기저를 고르면 적절한 측도는 알아서 정해진다. (르장드르의 경우는 w(x)=1/2이다. 특별한 측도가 없어도 orthogonal이다.)

한편, HiPPO에서는 이렇게 정해진 측도를 '시간에 따른 중요도'로 해석했다. 라게르 weight function에서 α=0으로 주면 w(x)=e−x가 된다. HiPPO에서는 이를 뒤집고 translate 해서 e−(t−x),x≤t로 만들었다. x가 t에서 멀어질수록 기하급수적으로 값이 작아진다. 이를 "오래됐을수록 기하급수적으로 덜 중요하다"라고 해석하는 식이다.
논문에서는 측도를 먼저 주면 다항함수가 따라오는 것처럼 표현이 돼서 혼란을 야기한다. 실상은 적절한 다항함수 기저를 먼저 골라놓고 이를 orthogonal 하게 만들기 위해 준 측도들이 time-domain에서 reasonable 하게 해석이 된 것이다.
projt와 coeft와 계수들의 Dynamics
우리는 f≤t를 다항함수 g(t)에 전사시키고 있다. g(t)는 orthogonal한 다항함수 기저 {gn}n<N이 span하는 공간의 한 원소이다. g(t)의 basis expansion의 계수 cn(t)는 내적으로 쉽게 구할 수 있다. 논문에서는 아예 내적으로 정의해버렸다.
f≤t≈g(t)cn(t):=⟨f≤t,gn⟩μ(t) .
위가 전사projt, 아래가 계수 구하기coeft 연산이다.
한편, 구한 계수들을 실시간으로 재활용하려면 계수들의 dynamics를 활용할 수 있다. Dynamics는 SSM을 구하면 쉽게 알 수 있다. SSM을 구하려면 ODE를 먼저 closed-form으로 구해야 한다. cn(t)의 정의로부터 closed-form ODE를 얻으려면 어떻게 해야 할까? 우리가 원하는 form은
c′n(t)=[function of cn(t) and t]
이다. 공교롭게도 함수간 내적은 적분으로 정의된다. 그럼 cn(t)의 정의를 미분하면? 좌변은 우리가 원하는 폼이 되고 우변은 계산을 조금만 거치면 cn(t),cn−1(t),…와 t로 표현할 수 있다. 다만 계산이 조금 복잡해서 appendix로 갔다. 우리는 연산자 정의만 살펴보고 appendix로 넘어갈 것이다. 맛을 잠깐 먼저 보자면,

이렇게 생겼다. 계산은 저자가 다 해놨으니 우리는 차근차근 따라가기만 하면 된다.


hippo(f)의 결과는 전사된 함수의 basis expansion의 계수인 N개의 스칼라이며, 그 계수는 매 시간 t마다 바뀐다. 왜냐하면 연산자들이 매 시간 t마다 바뀌기 떄문이다. 연산자가 바뀌었다는 뜻은 같은 인풋에 대해 다른 아웃풋을 내놓는다는 것이다. 실시간으로 연산자들이 바뀌는 이유는 μ(t)가 바뀌기 때문이다! 여기부터 어지럽다. Appeindix에서 머릿속의 빈공간을 채워보자.
Appendix 톺아보기
우리는 Appendix의 C와 D를 알아볼 것이다. C는 f로부터 c를 얻는 과정을 제너럴하게 설명하고, D는 scaled and shifted 르장드르 다항함수 기저와 측도(LegS)를 대입해서 직접 계산한다. D에서는 LegS말고도 여러 기저-측도 세트에서도 계산을 했지만 우리는 LegS만 살펴볼 것이다
Appendix C는 문제 정의부터 시작한다. (−∞,t]에서 정의된 time-vaying measure μ(t), 기저 함수 수열이 만드는 공간 G=span{g(t)n}n∈[N]와 연속함수 f:R≥0→R 가 주어졌을 때, HiPPO는 f로부터 다음을 만족하는 f의 optimal projection의 계수 c:R≥0→RN를 얻는 연산자로 정의한다.

HiPPO라는 상자에 f를 넣으면 c가 나온다!
우리는 이로부터 ODE ddtc(t)=Ac(t)+Bf(t),∈RN×N,B∈RN×1를 얻을 것이다.
C.1 각종 정의들
시간에 따른 중요도인 측도 μ가 주어졌을 때 {Pn}n∈N은 μ에 대해 orthogonal 한 다항함수 기저로 정의한다. ⟨Pi,Pj⟩μ=δij,i,j∈N이 된다는 뜻이다. 논문에서는 이렇게 써놨지만 실상은 순서가 반대다. 체비셰브, 라게르, 르장드르 같은 유명한 다항함수 기저를 먼저 채택하고 이들을 orthogonal하게 만드는 μ를 찾는 것이다.
한편 실시간 다항회귀에서는 실시간으로 정의역이 변한다. 시점 t에서 [0,t)이다. μ는 정의역 전체를 커버해야하기 때문에 μ도 같이 변한다. 이를 μ(t)로 표기한다. 다항함수는 측도랑 세트기 때문에 같이 변한다. 함수꼴이 변하는 건 아니고 간단한 scaling and shifting 정도만 한다. 이를 {P(t)n}n∈N으로 표기한다.
p(t)n는 P(t)n의 normalized 버전이다. Norm이 1이 되도록 norm으로 나눠준 것이다.
pn(t,x)=p(t)n(x)
로 정의한다. 그때그때 저자가 좀 더 명확하다고 생각하는 표현으로 바꿔 쓴다. 위키피디아 같은 곳에 정의된 P(t)n는 normalized 될 필요는 없다. 해당 논문에서는 p(t)n를 주로 사용한다.
여기에 tilting이라고 다항함수를 orthogonal로 만들어주기위한 장치가 하나 더 있는데 르장드르는 필요 없기 때문에 정의는 생략한다. 다만, 계속 등장하기 때문에 르장드르에서는 χ=1, ζ=1, gn=pn, ν=μ라고 생각해서 대입하면 된다.
그리고 ω(t,x)=ω(t)(x)는 μ가 주는 중요도 weight function이다.
C.2 전사와 계수
인풋은 실시간(online)으로 주어지는 f의 일부분(제한; restriction) f≤t(x)=f(x)x≤t이다. f:[0,∞)→R는 1번 이상 미분 가능한 함수이다.
아웃풋은 최적의 전사의 계수이다. orthogonal basis {pn}n∈[N]이 주어졌을 때,
cn(t):=⟨f≤t,g(t)n⟩μ(t)=∫fp(t)nω(t)=∫t0f(x)⋅p(t)n(x)⋅ω(t)(x)dx
로 계산한다. c의 정의이다. 아래에서 길을 잃으면 보통 얘 때문이다. c가 뭐지..? 하는 생각이 들면 여기로 다시 오면 된다.
계수를 구했으면 함수를 복원(reconstruct)할 수 있다.
f≤t≈g(t):=N−1∑n=0⟨f≤t,p(t)n⟩μ(t)p(t)n=N−1∑n=0cn(t)p(t)n
로 계산한다. 이 글 첫 움짤의 초록색이 g(t)이다. 바로 위 식이 projt의 실체이다. coeft는 여기서 c만 뽑아가는 연산자이다.
C.3 Coefficient Dynamics
위 내적으로부터 ODE를 만들어볼 것이다. t로 미분한다.
ddtcn(t)=∫f(x)(∂∂tpn(t,x))ω(t,x)dx+∫f(x)pn(t,x)(∂∂tω(t,x))dx
고등학교 때 배운 곱의 미분의 확장이다. 우변을 c와 t로 다시 표현하는 것이 핵심이다. 첫 번째 항부터 보면, 다항함수는 미분해도 다항함수고 기저에서 벗어나지 않는다. 따라서 p0,…,pn−1의 선형결합으로 표현할 수 있다. 그럼 c로 다시 표현할 수 있다.
두 번째 항이 문제이다. ω가 미분했을 때 이상한 form이 나오면 적분이 그대로 살아있을 거기 때문에 여기서 막힌다. 하지만 르장드르의 경우 ω′도 ω로 다시 표현이 가능하다. 따라서 c의 정의에 따라 c로 다시 표현이 가능하다.
따라서 ODE가 완성된다. 완성된 ODE는 D에서 확인한다. 그렇다면 ODE를 다음과 같이 정리할 수 있을 것이다.
ddtc(t)=−Ac(t)+Bf(t) .
이게 hippo 연산자의 실체이다. A앞의 −부호는 계산의 편의를 위해 해놓았다.
C.4 이산화
ddtc(t)=−Ac(t)+Bf(t)로부터 ck=−Ack−1+Bfk를 얻는다.
D. HiPPO 연산자 계산
대망의 D이다. D는 C에서 유도한 식에 각 다항함수 기저와 측도를 대입한다. 우리는 LegS만 살펴볼 것이다. 계산은 4단계이다. 1. 측도와 기저, 2. 다항함수 미분, 3. 계수 동역학, 4. 복원.
1. 측도와 기저


르장드르 다항함수 Pn은 [−1,1]에서 정의된다. [0,t]에서 사용하려면 위와 같은 약간의 scaling을 거쳐야 한다. (2n+1)12는 Pn의 norm이다. I는 rect function이다. [0,t]에서만 1이고 나머진 전부 0이라는 뜻이다.
오른쪽 그림의 파란색, 보라색은 서로 다른 t에 대한 ω이다. 이 그림이 이해된다면 넘어가도 된다.
2. 다항함수와 중요도 함수 미분
적분꼴을 없애려면 ω′는 ω로, g′은 g로 표현하는 것이 핵심이다. (c의 정의를 다시 보고 오면 된다.)

여기서 z=2xt−1로 놓으면,

가 된다. gn,gn−1,…로 재표현 됐다.
3. 계수 동역학
C.3에서 만든 식에 2.에서 계산한 미분 식을 대입하면 된다.

−1t로 묶으면 선형결합이기에 아래와 같은 행렬 표현이 가능하다.

여기 A가 이제 계속 등장할 HiPPO-LegS A이다. A앞의 − 부호에 주의한다.
4. 복원

이산화

위 GBT 계산 식에 집어넣으면 된다. 이마저도 미리 계산되어있다.
c(k+1)=(I−1k+1α(−A))−1(I+1k(1−α)(−A))c(k)+1k(I−1k+1α(−A))−1Bfk
논문에서는 A의 −가 고려 안 된 상태로 표기되어 있어 혼동을 유발한다. 구현을 위해 식 하나로 표기했다.
여기서 핵심은 HiPPO-LegS의 이산화는 Δt가 없다는 점이다. 다 자기끼리 나눠져서 없어진다. 따라서 time-scale에 invariant 하다.
코드
르장드르 다항함수는 기본적으로 계수들이 매우 작기 때문에 sympy 기반의 Legendre
클래스를 사용해 연산을 먼저 하고 함숫값은 나중에 eval 한다.
A_bwd
를 계산할 때 1k+1을 넣으면 값이 튄다. 1k를 넣었다.
문제점
아직 HiPPO가 완성되지 않았다. HiPPO의 큰 장점 중 하나는 N에 선형인 연산 속도이다. 반면 위의 이산화 식을 보면 O(N3)을 요구하는 역함수가 껴있다. 여기서도 저자의 천재성이 드러난다. 천재적인 방법으로 불필요한 연산을 줄였다.
'AIML' 카테고리의 다른 글
[Mamba 이해하기 1-1] HiPPO의 A 행렬의 중요성 | HiPPO: Recurrent Memory with Optimal Polynomial Projections, Gu et al., NeurIPS 2020 (0) | 2024.09.03 |
---|---|
[Mamba 이해하기 0] 개요 (0) | 2024.09.03 |