오늘 소개해드릴 논문은 Long-Context에서 효과적인 방법을 위한 새로운 메커니즘 infini-attention에 관한 내용입니다.
https://arxiv.org/abs/2404.07143
Abstract
이 연구에서는 트랜스포머 기반의 대규모 언어 모델(LLMs)을 무한히 긴 입력에 대해 제한된 메모리와 계산으로 확장할 수 있는 효율적인 방법을 소개합니다. 우리가 제안하는 접근 방식에서 핵심 요소는 Infini-attention이라고 불리는 새로운 주의 기법입니다. Infini-attention 은 기존의 주의 메커니즘에 압축 메모리를 통합하고 하나의 transformer 블록에서 masked-local-attention 과 long-term linear attention 메커니즘을 모두 구축합니다. 우리는 Long-Context 언어 모델링 벤치마크, 100만 시퀀스 길이의 패스키 컨텍스트 블록 검색, 그리고 50만 길이의 책 요약 작업에서 우리의 접근 방식의 효과를 입증합니다. 이는 10억과 80억 LLMs를 사용합니다. 우리의 접근법은 최소한의 제한된 메모리 매개변수를 도입하고 LLMs에 대한 빠른 스트리밍 추론을 가능하게 합니다.
Transformer의 특징인 attention 구조의 메커니즘을 새롭게 해서 Long context에서 효율적인 추론이 가능하다는 것
1 Introduction
기억력은 특정 맥락에 맞춰 효율적인 계산을 가능하게 하는 지능의 기초입니다. 그러나 Transformer 및 Transformer 기반의 대규모 언어 모델들은 attention 메커니즘의 특성 때문에 맥락에 따라 제한된 메모리를 가지고 있습니다. 이 논문에서 우리는 Transformer LLMs가 제한된 메모리와 계산만을 사용하여 무한히 긴 입력을 효과적으로 처리할 수 있는 새로운 접근 방식을 소개합니다. 우리가 제안하는 접근 방식에서 핵심 구성 요소는 Infini-attention이라는 새로운 주의 기법입니다. Infini-attention 은 기존 주의 메커니즘에 Compressive memory 도입하고 한 개의 Transformer 블록 내에서 masked-local attention 및 Long-term linear attention 구축합니다.
우리의 Infini-attention은 표준 주의 계산의 모든 키, 값, 질의 상태를 장기 기억 합성과 검색에 재사용합니다. 우리는 주의 메커니즘에서와 같이 오래된 KV 상태를 버리는 대신 Compressive memory 에 저장합니다. 그런 다음 후속 시퀀스를 처리할 때 attention Query 상태를 사용하여 메모리에서 값들을 검색합니다. 최종 맥락 출력을 계산하기 위해, Infini-attention은 장기 기억에서 검색된 값들과 로컬 주의 맥락들을 통합합니다.
우리의 실험에서, 우리의 접근 방식이 긴 맥락 언어 모델링 벤치마크에서 기준 모델을 능가하며, 메모리 크기 측면에서 114배의 이해 비율을 보여준다는 것을 보여줍니다. 모델은 10만 시퀀스 길이로 훈련될 때 더 나은 혼란도를 달성합니다. 10억 LLM은 자연스럽게 100만 시퀀스 길이로 확장되고 Infini-attention을 주입할 때 패스키 검색 작업을 해결합니다. 마지막으로, 우리는 Infini-attention을 장착한 80억 모델이 지속적인 사전 훈련 및 작업 미세 조정 후 50만 길이 책 요약 작업에서 새로운 SOTA(State of the Art) 결과를 달성한다는 것을 보여줍니다.
요약하면, 저희의 작업은 다음과 같은 기여를 합니다:
1. 실용적이면서도 강력한 주의 메커니즘인 Infini-attention을 도입했습니다. 장거리 및 단거리 콘텍스트 종속성을 효율적으로 모델링하기 위한 장기 압축 메모리와 로컬 인과적 주의 메커니즘을 도입했습니다.
2. Infini-attention은 표준 스케일 dot product attention 에 최소한의 변화를 도입하고 플러그 앤 플레이 방식의 지속적인 사전 학습과 장거리 콘텍스트 적응을 지원합니다.
3. 유니티의 접근 방식을 통해 트랜스포머 LLM은 무한히 긴 컨텍스트에 맞게 확장할 수 있습니다. 매우 긴 입력을 스트리밍 방식으로 처리하여 메모리 및 컴퓨팅 리소스가 스트리밍 방식으로 처리합니다.
Compressive memory & Linear attention을 그림에도 가지고 있어서 Query를 전달해 주고 KV 쌍의 이전 값들이 업데이트가 같이 되면서 Concat 하고 기존의 Dot-product-attention과 합쳐지는 형태인듯하다. 압축메모리와 리니어 어텐션으로 정보보존하고, 메모리효율, 스트리밍을 처리하게끔 하는 게 이 메커니즘의 특징 같다.
2 Method
그림 2는 우리의 모델인 인피니-트랜스포머(Infini-Transformer)와 트랜스포머-XL을 비교합니다. 트랜스포머-XL과 유사하게, 인피니-트랜스포머는 여러 개의 세그먼트로 구성된 시퀀스에서 작동합니다. 우리는 각 세그먼트 내에서 표준 인과적 점곱셈 주의 맥락을 계산합니다. 따라서 점곱셈 주의 계산은 현재 세그먼트의 토큰들을 총 N개까지만 고려하는 로컬 한 성격을 갖습니다(여기서 N은 세그먼트의 길이입니다).
그러나, Local attention 는 이전 세그먼트의 주의 상태를 다음 세그먼트를 처리할 때 폐기합니다. 인피니-트랜스포머에서는, 오래된 키-값(KV) 주의 상태를 제외시키는 대신, 압축 메모리를 사용하여 전체 맥락의 역사를 유지하기 위해 이들을 재사용할 것을 제안합니다. 따라서 인피니-트랜스포머의 각 주의 레이어는 글로벌 압축 상태와 로컬 세밀한 상태를 모두 갖고 있습니다. 우리는 이러한 효율적인 주의 메커니즘을 인피니-어텐션(Infini-attention)이라 부릅니다, 이는 그림 1에서 보이고 있으며, 다음 섹션에서 정식으로 설명됩니다.
2.1.1 Scaled Dot-product Attention
Multi-head Scaled Dot-product Attention 은 특히 self-attention 변형을 통해, 대규모 언어 모델(LLMs)에서 핵심 구성 블록으로 자리 잡았습니다. MHA(Multi-Head Attention)는 맥락에 의존하는 동적 계산을 모델링하는 강력한 능력과, 시간적 마스킹(temporal masking)의 편리함이 생성 모델에서 널리 활용되고 있습니다.
바닐라 MHA(Vanilla Multi-Head Attention)의 단일 헤드는 입력 세그먼트 X ∈ R^(Nxd_model) 시퀀스로부터 주의 맥락 A_dot ∈ R^(Nxd_value)을 계산합니다. 먼저, 주의 질의(attention query), 키(key), 밸류(value) 상태를 계산합니다:
여기서 W_k ∈ R^(d_model x d_key), W_v ∈ R^(d_model x d_value) 및 W_q ∈ R^(d_model x d_key)는 훈련 가능한 투영(projection) 행렬입니다. 그런 다음, 주의 맥락은 다른 모든 값들의 가중 평균으로 계산됩니다:
MHA의 경우, 각 시퀀스 요소에 대한 H 개의 주의 맥락 벡터를 병렬로 계산하고, 이들을 두 번째 차원을 따라 연결(concatenate)한 후, 최종적으로 모델 공간으로 투영하여 주의 결과를 얻습니다.
2.1.2 Compressive Memory
Infini-attention 에서는 새로운 메모리 항목을 계산하는 대신, 점곱셈 주의 계산에서 얻은 질의(query, Q), 키(key, K) 및 값(value, V) 상태를 재사용합니다. 이러한 상태 공유와 재사용은 플러그 앤 플레이(plug-and-play) 긴 맥락 적응뿐만 아니라 훈련과 추론 속도를 향상하는 데에도 효과적입니다. 우리의 목표는 압축 메모리에 키와 값의 상태를 저장하고, 질의 벡터를 사용하여 검색하는 것입니다.
문헌에서 제안된 압축 메모리의 다양한 형태들이 있지만, 이 작업에서는 연관 행렬(associative matrix)을 사용하여 메모리를 매개변수화하여 간단함과 계산 효율성을 추구합니다. 이 접근법은 또한 메모리 업데이트와 검색 프로세스를 선형 주의 메커니즘으로 해석하고 관련 메서드의 안정적인 훈련 기술을 활용하게 합니다. 특히, 그 간단함과 경쟁력 있는 성능 때문에 Katharopoulos 등(2020)의 업데이트 규칙과 검색 메커니즘을 채택했습니다.
메모리검색은 Q를 사용해서, 메모리Ms-1에서 새로운 내용인 Amem을 검색합니다. 여기서 σ와 Zs-1은 비선형 활성화 함수와 정규화 항. 훈련 안정성을 위해 중요한 비선형성과 정규화 방법을 선택하며, Katharopoulos 을 따라 모든 키에 대한 합계를 정규화 항으로 기록하고 활성화 함수로 요소별 ELU + 1을 사용합니다.
검색이 완료되면 새로운 KV 항목으로 메모리와 정규화 항을 업데이트하고 다음 상태를 얻습니다. 새로운 메모리 상태 Ms와 zs는 각 주의 레이어에서 재발성을 구축하며 다음 세그먼트 S + 1로 전달됩니다. 수식 (4)의 우변 항 σ(K)TV는 연관 결합 연산자로 알려져 있습니다.
우리는 인피니-어텐션에 델타 규칙(delta rule)을 통합하였습니다. 델타 규칙은 기존의 값 항목을 먼저 검색하고 새로운 값들에서 빼기 전에 연관 결합을 새로운 업데이트로 적용하여 메모리 업데이트를 약간 개선하려고 시도합니다.
로컬 주의 상태 Adot과 메모리에서 검색된 내용 Amem을 배운 게이팅 스칼라 β를 통해 통합합니다. 이는 헤드 당 단일 스칼라 값을 훈련 매개변수로 추가하면서 모델 내의 장기 및 로컬 정보 흐름 사이의 학습 가능한 균형을 허용합니다. 표준 MHA와 마찬가지로 Multi-head infini attention의 경우 병렬로 H개의 맥랑 상태를 계산하고 최종 attention 출력을 위해 이들을 연결하고 투영함. 여기서 Wo는 훈련 가능한 가중치입니다.
요약하면 인피니-어텐션은 기존의 점곱셈 주의 메커니즘에서 Q, K, V 상태를 재사용하여 압축 메모리에 저장하고 질의 벡터를 통해 검색합니다. 단순성과 계산 효율성을 위해 연관 행렬을 사용하여 메모리를 매개변수 화하며, 선형 주의 메커니즘으로 메모리 업데이트와 검색 프로세스를 해석합니다. 또한, 새로운 델타 규칙을 통해 메모리 업데이트를 약간 개선하고, 장기 맥락 주입을 통해 로컬 주의 상태와 메모리에서 검색된 내용을 통합합니다. 이러한 방법은 모델에서 장기 및 로컬 정보 흐름 사이의 균형을 학습할 수 있게 하며, 최종 주의 출력을 위해 다중 헤드 주의 맥락 상태를 병렬로 계산하고 통합합니다.
2.1.2 Memory and Effective Context Window
infini-transformer 는 제한된 메모리를 가지면서 무한한 맥락 창을 가능하게 합니다. 이를 설명하기 위해 표 1은 모델의 맥락-메모리 크기와 실질적인 맥락 길이를 모델 매개변수와 입력 세그먼트 길이에 따라 정의하여 이전 세그먼트 레벨 메모리 모델들을 나열하고 있습니다. 인피니-트랜스포머는 각 헤드의 단일 레이어에서 압축된 맥락을 저장하기 위해 Ms와 zs에 대한 상수 메모리 복잡성인 dkey × dvalue + dkey를 가집니다. 반면 다른 모델들은 시퀀스 차원과 함께 복잡성이 증가합니다 - 메모리 복잡성은 Transformer-XL, Compressive Transformer 및 Memorizing Transformers의 캐시 크기나 RMT 및 AutoCompressors의 소프트-프롬프트 크기에 따라 달라집니다.
Transformer-XL은 마지막 세그먼트에서 캐시 된 KV 상태를 현재 상태와 함께 계산하여 주의를 계산합니다. 이 작업이 각 레이어에서 수행되기 때문에 Transformer-XL은 추가 메모리 크기 (dkey + dvalue) × H × N × l로 N에서 N × l 토큰까지 맥락 창을 확장합니다. Compressive Transformer는 Transformer-XL에 두 번째 캐시를 추가하고 과거 세그먼트 활성화의 압축된 표현을 저장합니다. 그래서 Transformer-XL의 맥락 창을 c × r × l만큼 확장하지만 여전히 큰 맥락-메모리 복잡성을 가집니다. Memorizing Transformers는 입력 시퀀스의 맥락으로 전체 KV 상태를 저장하려고 합니다. 저장이 비실용적으로 비싸기 때문에, 그들은 맥락 계산을 단일 레이어로 제한합니다. 빠른 kNN 검색기를 사용하여, Memorizing Transformers는 저장 비용을 높이면서 길이 N × S의 전체 시퀀스 역사를 커버하는 맥락 창을 구축합니다. 우리의 실험은 인피니-트랜스포머 LM이 Memorizing Transformers보다 100배 이상의 압축률을 달성하면서도 혼란도(perplexity) 점수를 더 개선할 수 있음을 보여줍니다.
RMT와 AutoCompressors는 입력을 요약 벡터로 압축한 다음, 이를 후속 세그먼트에 대한 추가적인 소프트-프롬프트 입력으로 전달하기 때문에 잠재적으로 무한한 맥락 길이를 허용합니다. 그러나 실제로 이러한 기술의 성공은 소프트-프롬프트 벡터의 크기에 크게 의존합니다. 즉, AutoCompressors의 성능을 향상시키기 위해 소프트-프롬프트(요약) 벡터의 수를 늘려야 하며, 그에 따라 메모리와 계산 복잡성이 빠르게 증가하여 효율성이 저하됩니다. AutoCompressors에서는 효율적인 압축 목표가 그러한 프롬프트 압축 기술을 훈련하는 데 필요하다는 것도 관찰되었습니다.
3. Experiments
우리는 극도로 긴 입력 시퀀스를 포함하는 벤치마크에서 인피니-트랜스포머 모델을 평가했습니다: 긴 맥락 언어 모델링, 100만 길이 패스키 맥락 블록 검색 및 50만 길이 책 요약 작업. 언어 모델링 벤치마크의 경우, 우리는 모델을 처음부터 훈련시키고, 패스키와 책 요약 작업의 경우, 기존 LLM들을 지속적으로 사전 훈련시켜 우리 접근법의 플러그 앤 플레이 긴 맥락 적응 능력을 강조합니다.
3.1 Long-context Language Modeling
우리는 PG19와 Arxiv-math벤치마크에서 작은 infini-transformer 모델을 훈련하고 평가했습니다. 우리의 설정은 Memorizing Transformers(와 매우 유사합니다. 즉, 모든 모델은 각각 128의 차원을 가진 12개의 레이어와 8개의 주의 헤드 및 4096의 은닉층을 가진 FFN을 가지고 있습니다. 모든 attention layer 에 대해 infini-attention Segement 길이 N을 2048로 설정하고, 훈련을 위해 입력 시퀀스 길이를 32768로 설정합니다. 이를 통해 infini-attention 은 압축 메모리 상태에 대해 16 단계로 펼쳐질 수 있습니다. RMT 베이스라인의 경우, 요약 프롬프트 길이 50, 100, 150과 시퀀스 길이 4096, 8192, 32768에 대해 여러 번 실행했습니다. 100개의 요약 벡터를 가진 RMT는 8192 길이의 시퀀스에서 훈련할 때 최고의 결과를 내었습니다. 언어 모델링 실험에서 얻은 주요 결과는 표 2에 요약되어 있습니다. 우리의 인피니-트랜스포머는 Transformer-XL과 Memorizing Transformers 베이스라인을 능가하면서, 9번째 레이어에서 65K 길이의 KV 메모리를 기반으로 하는 Memorizing Transformer 모델보다 114배 적은 메모리 매개변수를 유지합니다.
100K 길이 훈련: 우리는 훈련 시퀀스 길이를 32K에서 100K로 늘려 Arxiv-math 데이터셋에서 모델을 훈련시켰습니다. 100K 훈련은 선형 및 선형 + 델타 모델에 대해 혼란도 점수를 2.21과 2.20으로 더 감소시켰습니다.
게이팅 점수 시각화: 그림 3은 각 레이어의 모든 주의 헤드에 대한 압축 메모리의 게이팅 점수, sigmoid(β)를 시각화합니다. 훈련 후 인피니-어텐션에 두 종류의 헤드가 나타났습니다: 게이팅 점수가 0이나 1에 가까운 전문화된 헤드와 점수가 0.5에 가까운 믹서 헤드입니다. 전문화된 헤드는 로컬 주의 계산을 통해 맥락 정보를 처리하거나 압축 메모리에서 검색하는 반면, 믹서 헤드는 현재 맥락 정보와 장기 메모리 내용을 함께 단일 출력으로 집약합니다. 흥미롭게도, 각 레이어는 적어도 하나의 단기 범위 헤드를 가지고 있어, 입력 신호를 출력 레이어까지 전파할 수 있습니다. 우리는 또한 전방 계산을 통해 장기 및 단기 내용 검색이 교차하는 것을 관찰했습니다.
3.2 LLM Continual Pre-training
기존 대규모 언어 모델들(LLMs)의 긴 맥락 적응을 위해 경량화된 지속적인 사전 훈련을 수행했습니다. 사전 훈련 데이터에는 4K 토큰 이상 길이의 PG19 및 Arxiv-math 코퍼스와 C4 텍스트(Raffel 등, 2020)가 포함되었습니다. 실험 전반에 걸쳐 세그먼트 길이 N은 2K로 설정되었습니다.
1M 패스키 검색 벤치마크: 우리는 10억 개의 LLM에서 바닐라 MHA를 인피니-어텐션으로 교체하고 4K 길이의 입력으로 사전 훈련을 계속했습니다. 모델은 64의 배치 크기로 30K 스텝 훈련한 뒤 패스키 검색 작업(Mohtashami & Jaggi, 2024)에 대한 미세 조정을 진행했습니다.
패스키 작업은 임의의 숫자를 긴 텍스트에 숨기고 모델 출력에서 이를 다시 묻습니다. 주의를 산만하게 하는 텍스트의 길이는 텍스트 덩어리를 여러 번 반복함으로써 변화됩니다. 이전 연구(Chen 등, 2023a)는 8B LLaMA 모델이 같은 32K 길이의 입력으로 미세 조정할 때 32K 길이까지 작업을 해결할 수 있음을 보여주었습니다. 우리는 이 도전을 더 나아가 5K 길이의 입력으로만 미세 조정하여 1M 길이 체제에서 테스트합니다.
500K 길이 책 요약(BookSum): 우리는 8K 입력 길이로 8B LLM 모델을 지속적으로 사전 훈련하며 30K 스텝으로 접근 방식을 확장했습니다. 그런 다음 책 요약 작업인 BookSum(Kryści ´ nski 등, 2021)에 대해 미세 조정을 진행했습니다. 여기서 목표는 전체 책 텍스트의 요약을 생성하는 것입니다.
미세 조정을 위해 입력 길이를 32K로 설정하고 평가를 위해 500K로 늘렸습니다. 우리는 생성 온도를 0.5, topp를 0.95로 설정하고 각 책의 요약을 생성하기 위해 디코딩 단계 수를 1024로 설정했습니다.
표 4는 특히 요약 작업(Lewis 등, 2019; Xiao 등, 2021)을 위해 구축된 인코더-디코더 모델들과 그들의 검색 기반 긴 맥락 확장(Bertsch 등, 2024)과 우리 모델을 비교합니다. 우리 모델은 이전 최고의 결과를 능가하고 전체 책 텍스트를 처리함으로써 BookSum에서 새로운 SOTA를 달성합니다. 우리는 또한 BookSum 데이터의 검증 분할에서 전체 Rouge 점수를 그림 4에 표시했습니다. 책에서 제공되는 텍스트가 많을수록 우리 인피니-트랜스포머는 요약 성능 메트릭을 향상하는 명확한 경향을 보여줍니다.
인피니-트랜스포머 모델은 지속적인 사전 훈련을 통해 긴 맥락에 적응하도록 설계되었으며, 긴 텍스트의 패스키 검색 작업과 책 요약 작업에서 탁월한 성능을 보여주었습니다. 1M 길이의 패스키 검색에서는 5K 길이의 입력에 대한 미세 조정 후 1M 맥락 길이에서 작업을 해결했습니다. 500K 길이의 책 요약에서는 모델이 BookSum 벤치마크에서 SOTA를 달성하며, 책에서 더 많은 텍스트를 입력으로 제공할수록 요약 성능이 향상되는 경향을 보였습니다.
4. Related Work
압축 메모리, 긴 맥락 지속적인 사전 훈련, 효율적인 주의 기술은 인피니-트랜스포머와 관련된 주요 연구 주제들입니다. 압축 메모리는 고정된 매개변수 수를 유지하며 효율적인 계산을 가능하게 하는 반면, 주의 메커니즘은 긴 입력 시퀀스를 처리할 때 맥락 창을 제한할 수 있습니다. 지속적인 사전 훈련은 긴 맥락을 다루기 위해 주의 층을 확장하고 LLM을 계속해서 훈련시키는 방법을 탐색합니다. 효율적인 주의 기술은 점곱셈 주의를 근사화하거나 시스템 수준에서 최적화하여 계산 비용을 줄이고자 합니다. 이러한 기술은 긴 시퀀스를 처리할 수 있게 하면서도 하드웨어 아키텍처를 최대한 활용하여 성능을 개선하려는 연구들로 이어지고 있습니다.
5. Conclusion
이 연구는 대규모 언어 모델이 제한된 메모리와 계산 자원을 사용하여 무한히 긴 맥락을 처리할 수 있도록 기존의 점곱셈 주의 계층에 압축 메모리 모듈을 통합합니다. 우리의 접근법은 긴 맥락 언어 모델링과 책 요약 작업에서 기존의 기준 모델들을 능가하며, 입력 시퀀스의 백만 길이 체제까지 자연스럽게 확장 가능함을 보여줍니다. 또한, 단 5K 길이의 시퀀스로 미세 조정된 10억 크기의 모델이 100만 길이의 문제를 해결할 수 있음을 보여주며, 이는 우리 모델의 길이 일반화 능력이 유망함을 나타냅니다.
총평 : infini attention 분명 Long Context 에 강하다. 압축메모리와 리니어 어텐션으로 이전기억을 저장하고 선형적인 로드를 통해 합치는 구조인데 모델의 복잡성이랑 오버헤드는 올라갔을 것이다. 트레이드오프관련해서도 문제가 될 것이다 Compressive memory에 과연 정보손실이 없을지도, 써봐야 알겠지만 얼마나 이전 Dot product attn 만큼의 일반화적인 성능이 나올지 궁금하긴 하다. 또 결국 이 구조가 Long context에 유리하기 때문에, 짧은 데이터에 대한 학습의존도는 어떻게 할지 궁금하다.