본문 바로가기
Paper review

Jamba:A Hybrid Transformer-Mamba Language Model 리뷰

by AI미남홀란드 2024. 4. 2.
728x90

 

 

Jamba: A Hybrid Transformer-Mamba Language Model

We present Jamba, a new base large language model based on a novel hybrid Transformer-Mamba mixture-of-experts (MoE) architecture. Specifically, Jamba interleaves blocks of Transformer and Mamba layers, enjoying the benefits of both model families. MoE is

arxiv.org

 

 

ai21labs/Jamba-v0.1 · Hugging Face

Model Card for Jamba Jamba is a state-of-the-art, hybrid SSM-Transformer LLM. It delivers throughput gains over traditional Transformer-based models, while outperforming or matching the leading models of its size class on most common benchmarks. Jamba is t

huggingface.co

 

abstract

 

Jamba는 새로운 하이브리드 Transformer-Mamba 혼합 전문가(MoE) 구조를 기반으로 한 대규모 언어 모델을 제시합니다. 이 모델은 Transformer와 Mamba 레이어의 블록을 교차 배치하여 두 모델 가족의 장점을 활용합니다. 일부 레이어에 MoE를 추가해 모델 용량을 증가시키면서 활성 파라미터 사용을 효율적으로 관리합니다. 이러한 유연한 구조는 자원 및 목적에 특화된 구성을 가능하게 합니다. 구현된 특정 구성에서는 단일 80GB GPU에 적합한 강력한 모델을 완성합니다. 대규모로 구축될 때, Jamba는 기존 Transformer 모델에 비해 높은 처리량과 작은 메모리 사용량을 제공하며, 동시에 표준 언어 모델 벤치마크와 긴 문맥 평가에서 최신 성능을 달성합니다. 눈에 띄게, 모델은 최대 256K 토큰 문맥 길이에 대해 강력한 결과를 제시합니다. 다양한 아키텍처 결정을 연구하고, Transformer와 Mamba 레이어를 결합하는 방법 및 전문가를 혼합하는 방법을 고려하여, 이러한 결정 중 일부가 대규모 모델링에서 중요하다는 것을 보여줍니다. 또한, Jamba의 훈련 및 평가를 통해 밝혀진 이 아키텍처의 여러 흥미로운 특성을 설명하고, 이 새로운 구조에 대한 추가 탐색을 장려하기 위해 다양한 어블레이션 실행에서 체크포인트를 공개할 계획입니다. Jamba의 구현에 사용된 가중치를 관대한 라이선스 하에 공개할 예정입니다.

 

MoE 구조를 구축하는 것도 어렵다는 말이 있는데, Transformer 와 mamba layer를 같이 결합해서 위 모델을 만들었다는 게 놀라웠다. 오픈소스와 체크포인트를 다 공개한다는 게 기술발전 오픈소스의 선순환이구나 생각을 했다. 나는 잘 가져다가 쓰면.. 그걸로 된 거다..ㅜㅜ

 

1 Introduction

Jamba는 새로운 대규모 언어 모델로, Transformer 레이어최신 상태 공간 모델인 Mamba 레이어, 그리고 혼합 전문가(MoE) 구성요소를 결합한 혁신적인 하이브리드 아키텍처를 기반으로 합니다. 이로 인해 Jamba는 개선된 성능과 높은 처리량을 유지하면서 관리 가능한 메모리 사용량을 제공하는 두 가지 서로 다른 아키텍처 설계를 결합합니다. 공개된 7B 기반 Jamba 모델(활성 파라미터 12B, 총 사용 가능 파라미터 52B)은 단일 80GB GPU에 맞도록 설계되었지만, 하드웨어와 성능 요구에 따라 Jamba 아키텍처를 다른 설계 선택으로 지원합니다.

Jamba의 주된 혁신은 하이브리드 Transformer-Mamba 아키텍처에 있습니다. Transformer 아키텍처는 높은 메모리와 계산 요구 사항으로 인해 긴 문맥 처리에 어려움을 겪습니다. 반면, RNN 모델들은 이러한 제한이 없지만, 훈련이 비용이 많이 들고 장거리 관계를 제한적으로만 포착합니다. Mamba와 같은 최근의 상태 공간 모델(SSM)은 RNN보다 효율적으로 훈련되며 장거리 관계를 더 잘 처리하지만, 비슷한 크기의 Transformer 언어 모델의 성능에는 뒤처집니다. Jamba는 Transformer와 Mamba 레이어를 결합하여 메모리 사용량, 효율적인 훈련, 그리고 긴 문맥 처리 능력 간의 균형을 맞춥니다.

또한, Jamba는 MoE 레이어를 포함하여 모델 용량(총 사용 가능 파라미터 수)을 늘리면서도 계산 요구 사항(활성 파라미터 수)을 증가시키지 않습니다. MoE는 극도로 큰 모델을 효과적으로 훈련시키는 유연한 접근 방식을 제공합니다. Jamba의 구현에서는 MoE를 매 다른 레이어마다 적용하고, 각 토큰마다 상위 2개의 전문가를 사용합니다.

Jamba는 다양한 벤치마크에서 평가되었으며, 비슷한 파라미터 수를 가진 Mixtral-8x7B와 비교하여 비슷하거나 더 큰 Llama-2 70B 모델에 비교할 때 경쟁력 있는 성능을 보여줍니다. 특히, Jamba는 256K 토큰의 문맥 길이를 지원하며, 긴 문맥 평가에서 대부분의 데이터셋에서 Mixtral을 능가합니다. Jamba는 특히 긴 문맥에서 Mixtral-8x7B의 3배에 달하는 처리량을 가지며, 128K 토큰 이상의 문맥에서도 단일 GPU(8비트 가중치 사용)에 맞습니다.

 

장거리관계 즉 긴 콘텍스트를 문맥처리하는 어려움을 mamba로 극복을 한 듯하고, MoE 전문가 구조를 통해서 메모리 효율과 성능을 챙긴 듯하다. jamba는 현재 공개된 opensource 모델 중에 가장 큰 256k를 지원하는 듯하다(정확하진 않음). 이제 오픈소스도 RAG를 효율적으로 쓸 수 있게 될지?

 

그림 1: (a) 단일 잠바 블록. (b) 다양한 유형의 레이어. 여기에 표시된 구현은 l = 8, Attention 대 맘바 레이어의 비율은 a : m = 1 : 7이며, MoE는 모든 e = 2 레이어에 적용됩니다.

 

중요 공지: 공개된 잠바 모델은 사전 학습된 기본 모델로, 정렬이나 인스트럭션 튜닝을 거치지 않았습니다. 정렬 또는 인스트럭션 튜닝을 거치지 않았으며 중재 메커니즘이 없습니다. 프로덕션 환경이나 추가 조정 없이 최종 사용자와 함께 사용해서는 안 됩니다.

 

2 Model Architecture

Jamba는 Transformer 레이어와 Mamba 레이어(최신 상태 공간 모델, SSM), 그리고 혼합 전문가(MoE) 모듈을 결합한 하이브리드 디코더 아키텍처입니다. 이 세 요소의 결합을 'Jamba 블록'이라 부르며, 그림 1에서 이를 시각적으로 볼 수 있습니다. Transformer, Mamba, MoE 요소의 결합은 낮은 메모리 사용, 높은 처리량, 그리고 높은 품질 사이의 때때로 상충되는 목표를 균형 잡기 위한 유연성을 제공합니다.

메모리 사용 측면에서, 모델 파라미터의 총크기를 비교하는 것은 오해의 소지가 있을 수 있습니다. MoE 모델에서는 주어진 순전파 단계에서 참여하는 활성 파라미터의 수가 총 파라미터 수보다 훨씬 적을 수 있습니다. KV 캐시, 즉 문맥에서의 attention 키와 값들을 저장하기 위해 요구되는 메모리는 또 다른 중요한 고려 사항입니다. Transformer 모델을 긴 문맥으로 확장할 때, KV 캐시는 제한 요소가 됩니다. attention 레이어를 Mamba 레이어로 대체하는 것은 KV 캐시의 총 크기를 줄입니다. 우리의 아키텍처는 활성 파라미터의 작은 수뿐만 아니라 일반 Transformer에 비해 8배 작은 KV 캐시를 목표로 합니다. 표 1은 256K 토큰 문맥에서도 작은 KV 캐시를 유지하면서 Jamba의 이점을 보여줍니다.

 

표 1: 긴 컨텍스트에서 사용 가능한 총 매개변수, 활성 매개변수, KV 캐시 메모리 측면에서 잠바와 최신 오픈 모델 비교 매개변수 및 긴 컨텍스트에서의 KV 캐시 메모리 비교. Jamba는 KV 캐시 메모리 요구 사항을 크게 줄여줍니다.

 

처리량 측면에서, 짧은 시퀀스의 경우, attention 연산은 추론 및 훈련 FLOPS의 작은 부분을 차지합니다. 하지만 긴 시퀀스의 경우, attention 연산이 대부분의 계산을 차지합니다. 반면, Mamba 레이어는 더 계산 효율이 좋습니다. 따라서 Mamba 레이어의 비율을 늘리면 특히 긴 시퀀스에 대한 처리량이 개선됩니다.

주된 구성에 대한 설명은 성능과 효율성이 개선된 것을 제공합니다. 섹션 6에는 설계 선택을 지원하는 어블레이션 실험 결과가 포함되어 있습니다. 기본 구성 요소는 Jamba 블록으로, 이는 순차적으로 반복될 수 있습니다. 각 Jamba 블록은 Mamba 또는 attention 레이어의 조합이며, 각 레이어에는 attention 또는 Mamba 모듈이 뒤따르는 MLP가 포함됩니다.
그림 1(b)에서 보여주는 다른 가능한 레이어 타입은 Jamba 블록이 l 레이어를 포함하며, 이들은 a : m의 비율로 섞입니다. 즉, m Mamba 레이어마다 a attention 레이어가 있습니다.

Jamba에서는 일부 MLP가 MoE 레이어로 대체될 수 있으며, 이는 모델 용량을 늘리는 동시에 활성 파라미터 수와 따라서 계산을 작게 유지하는 데 도움이 됩니다. MoE 모듈은 매 e 레이어마다 MLP에 적용될 수 있습니다. MoE를 사용할 때, 각 레이어마다 n개의 가능한 전문가가 있으며, 라우터가 각 토큰에서 상위 K 전문가를 선택합니다. 요약하자면, Jamba 아키텍처의

다양한 자유도는 다음과 같습니다:

  • l: 레이어의 수.
  • a : m: Attention-to-Mamba 레이어의 비율.
  • e: 단일 MLP 대신 MoE를 사용하는 빈도.
  • n: 레이어 당 전문가의 총 수.
  • K: 각 토큰에서 사용되는 상위 전문가의 수.

이 설계 공간을 감안할 때, Jamba는 특정 속성을 다른 것보다 선호하는 유연성을 제공합니다. 예를 들어, m을 증가시키고 a를 감소시키는 것, 즉 Attention 레이어 대신 Mamba 레이어의 비율을 증가시키는 것은 키-값 캐시를 저장하기 위해 필요한 메모리를 줄입니다. 이는 특히 긴 시퀀스를 처리할 때 중요합니다. Mamba 레이어의 비율을 증가시키면 처리량도 특히 긴 시퀀스에서 개선됩니다. 하지만 a를 감소시키면 모델의 능력이 낮아질 수 있습니다.

또한, n, K, 그리고 e의 균형을 맞추는 것은 활성 파라미터와 총 사용 가능 파라미터 사이의 관계에 영향을 미칩니다. n이 크면 메모리 사용량을 증가시키는 대신 모델 용량이 커집니다. K가 클수록 활성 파라미터 사용과 계산 요구를 증가시킵니다. 반대로, e가 크면 모델 용량이 감소하면서 계산(특히 K>1일 때)과 메모리 요구 사항을 줄이고, 전문가 병렬 훈련 및 추론 중 메모리 전송과 GPU 간 통신을 감소시켜 통신 의존성을 줄입니다.

Jamba의 Mamba 레이어 구현은 대규모 모델 스케일에서 훈련을 안정화하는 데 도움이 되는 여러 정규화를 포함합니다. 특히, 우리는 Mamba 레이어에 RMSNorm을 적용합니다.

Mamba 레이어를 사용할 때, 위치 임베딩 또는 RoPE와 같은 메커니즘은 필요하지 않으며, 따라서 우리는 명시적인 위치 정보를 사용하지 않습니다.

아키텍처의 다른 세부 사항들은 표준을 따릅니다. 여기에는 GQA, SwiGLU 활성화 함수, MoE의 부하 균형 등이 포함됩니다. 어휘 크기는 64K이며, 토크나이저는 BPE로 훈련되고 각 숫자는 별도의 토큰입니다. 우리는 Llama와 Mistral 토크나이저에서 사용된 더미 공간을 제거하여 더 일관되고 되돌릴 수 있는 토크나이즈를 제공합니다.

 

Attention을 mamba로 대체하면 KV 캐시 때문에 메모리가 많이 절약되고, 긴 Context에서도 효과적이라고 한다. 그러나 Attention 비율에 따라 모델 성능이 좌우되는 듯하다. mamba block 단위로 모델을 구성하는 것 같은데. 조금 설명이 모호한 느낌이었다. 개별의 구조에 개선된 걸 설명하다 보니. 밑에서는 자유도에 따라 파라미터 성능이 어떻게 된다 이런 말이라 조금 헷갈렸다. 단일 블록 구조에서는 그림처럼 l layer=8 , a attention 1 : m Mamba Layer 7로 구성된 모습이고,  e MoE layer는 2번마다 적용이 되어있다. n, k의 균형에 따라 파라미터 영향을 미친다는 뜻인 것 같다.

3 Reaping the Benefits

Jamba는 단일 80GB GPU에서 최적의 성능을 발휘하도록 설계되었으며, 이는 품질과 처리량의 관점에서 최고의 성능을 의미합니다. 구현된 모델은 4개의 Jamba 블록을 연속으로 가지며, 각 Jamba 블록은 다음과 같은 구성을 가집니다:

  • l = 8: 층의 수.
  • a : m = 1 : 7: 주의 레이어와 Mamba 레이어의 비율.
  • e = 2: 단일 MLP 대신 MoE를 사용하는 빈도.
  • n = 16: 전문가의 총 수.
  • K = 2: 각 토큰에서 사용되는 상위 전문가의 수.

a : m = 1 : 7 비율은 품질 측면에서 가장 계산 효율적인 변형 중 하나로, 예비 어블레이션 실험을 통해 결정되었습니다(6절 참조).

전문가 구성은 단일 80GB GPU에서 int8 가중치를 사용하여 모델을 적합하게 하면서 입력에 대한 충분한 메모리를 포함하도록 선택되었습니다. 특히, n과 e는 층 당 평균 약 8개의 전문가를 가질 수 있도록 균형을 맞추었습니다. 또한, n, K, 그리고 e를 균형 있게 조정하여 높은 품질을 유지하면서 계산 요구 사항과 통신 의존성(메모리 전송)을 제한했습니다. 이에 따라, 모든 다른 층에서 MLP 모듈을 MoE로 대체하고 총 16개의 전문가 중 각 토큰마다 2개를 사용하기로 했습니다. 이러한 선택은 이전의 MoE 작업에서 영감을 받아 예비 실험에서 검증되었습니다.

그림2. A100 GPU 에서 컨텍스트 길이

그림 2는 단일 80GB GPU를 사용하여 Jamba 구현이 가능한 최대 문맥 길이를 Mixtral 8x7B와 Llama-2-70B와 비교하여 보여줍니다. Jamba는 Mixtral의 2배, Llama-2-70B의 7배에 달하는 문맥 길이를 제공합니다.

Jamba 구현은 최대 1M 토큰의 문맥 길이에서 성공적으로 훈련되었으며, 공개된 모델은 최대 256K 토큰의 길이를 지원합니다.

 

처리량 분석을 위해, 구체적으로 두 가지 설정에서의 처리량 결과를 제시합니다. 첫 번째 설정에서는 다양한 배치 크기, 단일 A100 80GB GPU, int8 양자화, 8K 문맥 길이, 512 토큰의 출력을 가집니다. 그림 3a에 나타나 있듯이, Jamba는 큰 배치 처리를 가능하게 하여, 비슷한 활성 파라미터 수를 가진 Mixtral보다 3배 높은 처리량을 달성합니다.

두 번째 설정에서는 단일 배치, 4개의 A100 GPU, 양자화 없음, 다양한 문맥 길이, 512 토큰의 출력을 가집니다. 그림 3b에 나타나 있듯이, 짧은 문맥 길이에서 모든 모델이 비슷한 처리량을 가지지만, Jamba는 긴 문맥에서 두각을 나타냅니다; 128K 토큰에서 그 처리량은 Mixtral의 3배입니다. 이는 커뮤니티가 지난 6년 동안 순수 Transformer 모델에 대해 개발한 최적화를 Jamba가 아직 활용하지 않았음에도 불구하고 그렇습니다. Jamba에 대한 이러한 최적화가 개발됨에 따라 처리량 격차는 더욱 커질 것으로 기대할 수 있습니다.

 

Context 길이뿐만 아니라, 처리량에서 Transformer 구조를 압도하는 모습이다. 

 

4 Training Infrastructure and Dataset

Jamba 모델은 NVIDIA H100 GPU를 사용하여 훈련되었습니다. 훈련 과정은 FSDP(Fully Sharded Data Parallelism), 텐서 병렬 처리, 시퀀스 병렬 처리, 그리고 전문가 병렬 처리를 포함한 대규모 훈련을 효율적으로 수행할 수 있는 회사 내부의 독자적인 프레임워크를 사용하였습니다. Jamba는 웹, 책, 그리고 코드에서 가져온 텍스트 데이터를 포함하는 회사 내부 데이터셋에서 훈련되었습니다. 이 데이터는 2024년 3월을 마지막으로 업데이트되었습니다. 데이터 처리 파이프라인에는 품질 필터와 중복 제거가 포함되어 있습니다. 이렇게 함으로써, 데이터셋의 질을 보장하며, 모델이 보다 다양하고 정제된 데이터에서 학습할 수 있도록 합니다. 모델 훈련에 사용된 인프라와 데이터셋은 Jamba의 성능과 효율성을 최적화하는 데 중요한 역할을 하며, 최신 하드웨어와 전략적인 데이터 관리를 통해 이루어집니다. 

 

5 Evaluation

Jamba는 표준 학술 벤치마크를 포함한 다양한 평가에서 그 성능을 보여줍니다. 벤치마크는 실제 응용 프로그램에서 중요한 것과는 부분적으로만 상관관계가 있으며, 자칫 이를 높이기 위해 체계를 조작하는 것을 초래할 수 있지만, 그럼에도 불구하고 몇 가지 지표 결과를 제시합니다.

5.1 Academic Benchmarks

Jamba는 상식 추론, 독해 이해력 등을 평가하기 위한 범위 넓은 표준 학술 벤치마크에서 좋은 결과를 보고합니다. 이러한 벤치마크에는 HellaSwag, WinoGrande, ARC-E와 ARC-Challenge, PIQA(상식 추론), BoolQ와 QuAC(독해 이해력), GSM8 K, HumanEval, NQ, 그리고 TruthfulQA(기타) 등이 포함되며, 이 외에도 MMLU와 BBH와 같은 종합 벤치마크 결과도 포함됩니다.

 

표 2는 Jamba를 Llama-2 13B, Llama-2 70B, Gemma, 그리고 Mixtral과 같은 다른 공개 모델들과 비교합니다. Jamba는 유사하거나 더 큰 크기의 선도적인 공개 모델들과 비교하여 비슷하거나 더 나은 성능을 달성합니다. 이는 Llama-2 70B와 Mixtral 모델을 포함합니다. 동시에, Jamba의 총 사용 가능한 파라미터 수는 Llama-2 70B에 비해 작습니다(52B 대 70B). 또한, 희소 모델인 Jamba는 Mixtral의 12.9B 활성 파라미터와 유사한 12B의 활성 파라미터만을 가지고 있습니다. 그러나 Mixtral은 완전 주의 기반 모델로 긴 시퀀스에서 큰 메모리 사용량을 가지며, 256K 토큰에서 KV 캐시에 32GB가 필요합니다. 반면, 하이브리드 주의-Mamba 아키텍처 덕분에 Jamba의 KV 캐시는 긴 문맥에서도 오직 4GB만을 사용합니다(섹션 2). 중요한 것은, Jamba가 Llama-2 70B와 Mixtral보다 훨씬 더 나은 처리량을 가지면서, 최대 3배 개선된 성능을 달성하는 동안 강력한 성능을 나타낸다는 점입니다

 

Jamba는 동일한 크기 클래스의 최신 Transformer 기반 모델의 성능에 도달할 수 있는 하이브리드 아키텍처의 능력을 시연하면서, SSM(상태 공간 모델, 맘바의 아키텍처)의 이점을 갖추고 있음을 보여준다.

 

5.2 Long-Context Evaluations

Jamba 모델은 최대 1백만(1M) 토큰의 문맥 길이로 훈련되었으며, 공개된 모델은 최대 25만 6천(256K) 토큰의 문맥 길이를 처리할 수 있습니다. 이 섹션에서는 Jamba의 긴 문맥 처리 능력을 평가하기 위해 합성적이고 자연스러운 벤치마크를 사용하여 테스트합니다.

긴 문맥 평가는 언어 모델이 긴 문맥 정보를 효과적으로 이해하고, 이를 기반으로 적절한 예측을 수행하는 능력을 측정합니다. 이는 특히 소설이나 기술 문서, 코드 등과 같이 긴 문맥 정보가 필수적인 문서를 처리할 때 중요합니다.

합성적 벤치마크는 특별히 설계된 테스트를 통해 모델이 어떻게 긴 문맥 정보를 처리하고, 이를 기반으로 결론을 도출하는지를 평가합니다. 자연스러운 벤치마크는 실제 세계의 데이터를 기반으로 하여 모델이 실제 상황에서 얼마나 잘 작동하는지를 평가합니다.

긴 문맥 평가를 통해, Jamba가 실제 사용 사례에 얼마나 적합한지, 그리고 기존 모델들과 비교했을 때 얼마나 큰 개선이 이루어졌는지를 확인할 수 있습니다. Jamba가 실제 애플리케이션에서 사용될 준비가 되어 있는지, 특히 복잡하고 정보가 많은 시나리오에서 그 성능을 발휘할 수 있는지를 이해하는 것은 모델의 실용성을 평가하는 데 매우 중요합니다.

 

5.2.1 Needle-in-a-haystack

"Needle-in-a-haystack" 평가에서 Jamba는 긴 문맥 창에서 간단한 진술을 검색하는 데 뛰어난 성능을 보였습니다. 특히 Jamba의 구현이 단 4개의 주의 레이어만을 사용함에도 이러한 결과를 달성한 점이 주목할 만합니다.

 

5.2.2 Naturalistic long-context evaluation

"Naturalistic long-context evaluation"에서는 긴 입력을 포함하는 질문-응답 벤치마크를 사용하여 Jamba의 긴 문맥 처리 능력을 평가합니다. 이를 위해 L-Eval에서 가장 긴 문맥 데이터셋 다섯 가지를 몇 샷 형식으로 구조화하여 재활용했습니다(여기서 진행된 모든 실험에서 3샷을 사용). 구체적으로, 다음 데이터셋에서 모델을 평가했습니다: NarrativeQA(이야기에 대한 QA), LongFQA(금융), Natural Questions(NQ; 위키백과), CUAD(법률), 그리고 SFiction(과학 소설). 이 데이터셋에서 평균 입력 길이는 6K에서 62K 토큰까지 다양합니다. 이러한 문맥 길이는 몇 샷 형식으로 더 확장됩니다.

표 3은 평가 결과를 F1 점수 측면에서 요약합니다. Jamba는 대부분의 데이터셋에서 Mixtral을 능가하며 평균적으로도 더 높은 성능을 보입니다. 또한, 이러한 긴 문맥 작업은 상당한 계산을 요구하기 때문에, 긴 문맥에서 Jamba의 효율성이 빛을 발하며, 훨씬 더 나은 처리량을 보여줍니다(3.2 섹션 참조).

 

Needle-in-a-haystack" 평가를 통해 Jamba는 최대 256K 토큰 길이의 문맥 중간에 배치된 진술을 회상하는 능력을 보여줍니다. 이 그림은 Jamba가 긴 문맥 속에서 중요한 정보를 추출해 낼 수 있는 능력을 시각적으로 나타내며, 문맥 길이가 증가함에 따라 검색 성공률이 어떻게 변화하는지를 보여줍니다.

이러한 평가들은 Jamba가 긴 문맥에서 정보를 처리하고 접근할 수 있는 능력이 뛰어나다는 것을 입증하며, 특히 주의를 기반으로 한 다른 모델들과 비교했을 때 더 나은 성능과 효율성을 가진다는 점을 강조합니다.

샷 형식의 긴 컨텍스트 QA 벤치마크 결과(F1)

 

6 Ablations and Insights

 

Jamba 아키텍처의 다양한 설계 선택에 대한 어블레이션 실험과 그에 대한 통찰력을 공유합니다. 여기서는 Attention과  Mamba 레이어를 결합하는 이점, 어떤 비율로 결합해야 하는지, 그리고 어떻게 교차 배치해야 하는지를 살펴봅니다. 순수 Mamba 모델이 실패하는 경우를 조사함으로써, 이 모델이 문맥 내 학습 능력을 개발하는 데 어려움을 겪는다는 것을 제안합니다. 반면에 Attention-Mamba 하이브리드는 바닐라 Transformer와 유사한 문맥 내 학습을 보여줍니다. 그다음으로는 하이브리드 Attention-Mamba 모델 위에 MoE를 추가하는 이점을 보여줍니다. 마지막으로, Jamba에 명시적인 위치 정보가 필요하지 않고, 대규모 훈련에서 안정성을 위해 Mamba 레이어가 특별한 정규화가 필요하다는 두 가지 추가적인 학습 사항을 공유합니다.

 

 

6.1 Benefits of combining Attention and Mamba

Attention과 Mamba의 결합의 장점

1.3B 파라미터 모델에서 250B 토큰을 훈련한 결과, 하이브리드 Jamba 모델은 순수 주의 또는 Mamba 모델을 능가합니다. 주의와 Mamba 레이어의 비율은 1:3 또는 1:7로 거의 성능 차이가 없습니다. 이러한 결과는 훈련 손실 그림에서 볼 수 있으며, 여기서 Jamba는 훈련 중 개선된 손실을 보여줍니다. 계산 효율성이 더 높고 비슷한 성능을 보이는 1:7 비율을 대규모 실험에서 선택합니다.

 

 

Mamba Layer를 구성할 때 아까 본 대로 1:3, 1:7의 차이가 없는 것을 볼 수 있었고, 기존의 attention 메커니즘과 mamba 보다 Jamba가 성능이 좋은 것을 볼 수가  있었다. 이건 나중에 유저가 시도해 보면 될 듯함.

 

6.2 Why does the Combination Work?

조합은 왜 작동하나?

순수 Mamba 모델은 일반적인 벤치마크 작업에서 주의 모델보다 상당히 나쁜 성능을 보여줍니다. 하지만, Attention-Mamba 하이브리드는 이러한 데이터셋에서 주의 모델과 유사하게 수행됩니다. 이 결과는 주의 메커니즘이 없는 순수 Mamba 모델이 문맥 내에서 학습하기 어려울 수 있다는 가능성을 시사합니다. 반면에 하이브리드 Attention-Mamba 모델은 문맥 내 학습에 성공적으로 수행됩니다.

 

6.3 The Effect of Mixture-of-Experts (MoE)

Moe 효과

MoE 계산을 관리 가능하게 유지하면서 Transformer 언어 모델의 성능을 개선하는 것으로 나타났습니다. MoE 대규모에서 상태 공간 모델, 특히 하이브리드 Attention-Mamba 아키텍처와 잘 통합되는지 여부는 명확하지 않습니다. 하지만, MoE 대규모(7B 파라미터, 50B 토큰 훈련)에서 하이브리드 Attention-Mamba 아키텍처의 성능을 향상시키는 것으로 나타났습니다.

 

 

6.4 Stabilizing Mamba at large scale 

대규모에서의 안정화

1.3B 파라미터 모델을 훈련할 때는 특별한 문제없이 안정적인 훈련을 관찰했습니다. 그러나 여기서 발표된 가장 큰 모델(기반 7B, 활성/총 파라미터 12B/52B)로 확장할 때, 큰 손실 급증을 경험했습니다. 이를 조사한 결과, Mamba 레이어의 내부 부분이 큰 활성화 값을 겪고 있어, 급증을 일으킨다는 것을 발견했습니다. 그래서 내부 활성화에 RMSNorm을 추가했습니다. 이로 인해 훈련이 안정화되고 추가적인 손실 급증이 방지되었습니다.

그림 8: Mamba 레이어에 RMSNorm을 추가하면 손실 급증을 방지할 수 있습니다.

 

신기하다.. 어떻게 저걸 찾았을까 RMSNorm을 추가하면 Loss 손실을 방지할 수 있다고 한다.

 

6.5 Jamba does not Require Explicit Positional Information

Jamba는 명시적인 위치 정보가 필요 없음

MoE가 포함된 Jamba 아키텍처의 결과를 위치 정보가 없는 경우와 주의 레이어에 RoPE를 적용한 경우(1.3B 파라미터 모델, 250B 토큰)와 비교합니다. 결과는 비슷하며, 이는 하이브리드 아키텍처에 명시적인 위치 정보가 필요하지 않을 수 있음을 시사합니다. 추측컨대, 주의 레이어 앞에 배치된 Mamba 레이어가 암시적인 위치 정보를 제공합니다.

 

기존 Transformer는 포지셔널 인코딩을 해서 위치정보를 알게 되는데, Jamba는 Mamba(SSM) 레이어를 통해 암시적인 위치정보를 줄 수 있다고 하니 모델복잡성을 완전히 줄 일수 있다는 게 혁신 같았다.

 

7. Conclusion

 

Jamba는 Attention과 Mamba 레이어를 결합하고, 혼합 전문가(MoE) 모듈을 포함하는 새로운 아키텍처를 제시하며, 이를 통해 최신 성능을 달성하고 긴 문맥을 지원하는 열린 구현을 제공합니다. Jamba는 성능과 메모리 요구 사항 간의 균형을 유연하게 조절하면서 높은 처리량을 유지하는 방법을 보여줍니다. 우리는 Attention-to-Mamba 레이어의 비율과 같은 여러 설계 선택을 실험하고, Hybrid attention-ssm에 대한 향후 작업을 안내할 개발 과정 중 발견된 몇 가지 발견에 대해 논의했습니다. 이러한 연구를 촉진하기 위해, 우리는 소규모 훈련 실행에서 모델 체크포인트를 공개할 계획입니다.

이번에 제공하는 가장 큰 모델은 12B의 활성 파라미터와 52B의 총 사용 가능 파라미터를 가지고 있으며, 최대 256K 토큰의 문맥 길이를 지원하고, 140K 토큰 텍스트를 처리할 때조차 단일 80GB GPU에 적합합니다.

Jamba의 개발은 특히 복잡하고 정보가 풍부한 언어 처리 작업에서 모델의 효율성과 성능을 개선할 수 있는 새로운 방향을 제시합니다. 이는 향후 언어 모델링뿐만 아니라 다양한 AI 분야에서 하이브리드 모델 아키텍처의 탐색과 발전에 중요한 기여를 할 것으로 기대됩니다.

 

8. Hugging Face

Inference

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
                                             trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]

outputs = model.generate(input_ids, max_new_tokens=216)

print(tokenizer.batch_decode(outputs))
# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]

 

Fine-tunning Examples

 

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", trust_remote_code=True, device_map='auto')

dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config = LoraConfig(
    r=8,
    target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
    task_type="CAUSAL_LM",
    bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)

trainer.train()

 

총평 : 전체적으로 짜임새 있게 구성한 모델 같았다. Attention-mamba 구조를 연결하고 MoE Layer 구성을 통해 성능과 메모리 효율 두 마리 토끼를 잡은 듯한데 사실 써봐야 알겠지만  또 mamba (SSM) 구조를 알고 있어야 리딩이 더 쉬울 듯하다 맘바의 구조를 제대로 모르니 그냥 좋다.라고 생각하게 되기만 한 것 같다. 다음에 리뷰해 봐야겠다.
한번 Fine-tunning 시도해 봐야겠다.!!

728x90