본문 바로가기
Paper review

Prompt-prompted Mixture of Experts for Efficient LLM Generation 논문 리뷰

by AI미남홀란드 2024. 4. 15.
728x90

오늘 리뷰해 볼 논문은 제가 평소에 LInkdin에서 평소 논문 LLM , RAG 관련 리서치할 때 종종 보는 Pascal Biese

가 소개해준 'Prompt-prompted Mixture of Experts for Efficient LLM Generation '이라는 논문입니다. 궁금해서 공유하기를 눌러두고 오늘 리뷰를 해봅니다.

 

Paper

 

Prompt-prompted Mixture of Experts for Efficient LLM Generation

With the development of transformer-based large language models (LLMs), they have been applied to many fields due to their remarkable utility, but this comes at a considerable computational cost at deployment. Fortunately, some methods such as pruning or c

arxiv.org

Github

 

GitHub - hdong920/GRIFFIN

Contribute to hdong920/GRIFFIN development by creating an account on GitHub.

github.com

 

Abstract

 

이 연구는 변환기 기반 대규모 언어 모델(LLMs)의 발전과 그들이 여러 분야에 적용되면서 높은 유용성을 보이고 있지만, 이는 배포 시 상당한 계산 비용을 수반한다는 점을 설명합니다. 이러한 문제를 해결하기 위해, 연구진은 GRIFFIN이라는 새로운 훈련이 필요 없는 전문가 혼합(Mixture of Experts, MoE) 방식을 도입합니다. 이 방법은 특정한 비선형 활성화 함수를 사용하지 않는 다양한 LLM들에서 효율적인 생성을 가능하게 하는 시퀀스 레벨에서 고유한 전달(Feedforward, FF) 전문가를 선택합니다. 이는 많은 훈련된 LLM들이 시퀀스 내에서 고도로 구조화된 FF 활성화 패턴을 자연스럽게 생성한다는 중요한 관찰에 기반합니다. 이러한 패턴을 연구진은 'flocking'이라고 부릅니다. GRIFFIN은 FF 파라미터의 50%만을 사용하면서도 원래 모델의 성능을 거의 손상 없이 유지하고, 다양한 분류 및 생성 작업에서 속도 개선(예: NVIDIA L40에서 Llama 2 13B 모델의 1.25배 속도 향상)을 달성합니다.

 

위 논문에서는 계산 비용을 줄이기 위해 GRIFFIN 을 소개했고, 특정한 패턴을 flocking이라 하며 효율적인 생성을 가능하게 하고, 필요한 파라미터 수를 줄이면서 원래 모델의 성능을 유지한다. 이 방법은 분류 생성 작업에서 속도를 개선한다고 한다. FF네트워크를 뭔가 유연하게 처리를 했을 것 같은데, 코드가 있으니 나중에 한번 봐야겠다. LoRA도 그렇고 코드로는 생각보다 어렵지 않은 이론이라서 말만 들었을 때는 복잡해 보이지만 코드를 같이 보면 쉽게 이해가 되는 경우가 많다.

Introduction

본 논문에서는 트랜스포머 모델들이 다양한 분야에서 놀라운 성능을 보여준 것에 이어, 그 후속 모델인 대규모 언어 모델(LLM)들이 성능을 한층 끌어올렸지만, 이는 막대한 계산 및 메모리 요구로 이루어졌다고 설명합니다. 특히, LLM들의 피드포워드(FF) 블록에서는 중요하지 않은 중간 특징들에 대해 많은 계산이 낭비되고 있으며, 예를 들어 OPT-175B 모델에서는 FF 블록의 95% 계산이 낭비되고 있습니다. 이러한 문제를 해결하기 위해 다양한 접근법이 시도되고 있지만, 특히 MoE(전문가의 혼합) 방법이 원래의 성능을 유지하면서도 모델을 효율적으로 만드는 방법으로 주목받고 있습니다. MoE 방법은 입력마다 모델의 부분 집합을 선택적으로 사용하지만, Non-ReLU 활성화 함수를 사용하는 기존 LLM에는 효과적으로 적용되지 않는 문제가 있습니다. 이에 대한 해결책으로 'flocking' 현상이 제안되며, 이는 특정 시퀀스에서 지속되는 일관된 희소 활성화를 통해 관찰됩니다. 'flocking'은 높은 상대적 크기를 생성하는 뉴런이 시퀀스 내에서 공유된다는 것을 의미하며, 이는 Llama 2 7B와 Gemma 7B 모델에서도 확인되었습니다.

 

그림 1: 라마 2 7B(왼쪽)와 젬마 7B(오른쪽)의 레이어 10에 있는 PG-19 [RPJ+19,GBB+20]의 시퀀스에서 처음 512개의 피처와 토큰의 상대적 FF 활성화 크기. 이 히트맵은 한 시퀀스 내에서 상대적인 활성화 크기가 공유되는 플록킹을 보여줍니다. 더 많은 예는 부록 B에서 확인할 수 있습니다.

 

FF block 에서 계산낭비가 심한 llm에서 대체적으로 발견되는데, 이 문제가 비활성화된 뉴런들로 인해 낭비가 생기는 것이고 RelU활성화 함수를 쓰는 모델에서 흔하게 모인다. LLaMA2는 SwiGLU, gemma 또한 GEGLU를 활성화함수를 쓰는데, RelU 뿐만 아니라 위 활성화 함수들에서도 뉴런들의 비활성화로 인하여, 계산리소스의 낭비를 초래한다. 그래서 이 낭비를 줄이기 위해 MoE라는 방법을 제시한다

 

기존의 가지치기(pruning) 방법이나 MoE(전문가의 혼합) 방법과 달리, 이 논문에서는 GRIFFIN이라는 새로운 접근 방식을 소개합니다. GRIFFIN은 트레이닝이 필요 없는 고성능 방법으로, LLM을 MoE로 변환합니다. 이 방식은 시퀀스의 프롬프트를 사용하여 생성 동안 활성화될 전문가를 결정합니다. GRIFFIN은 준비 작업 없이도 간단히 FF 블록에 적용할 수 있으며, 전문가 선택 과정은 매개변수가 필요 없고 추가 비용이 거의 들지 않습니다. 다양한 모델과 활성화 함수에서 GRIFFIN의 효과가 검증되었으며, 이는 ReLU, SwiGLU, GEGLU, ReGLU 등을 포함합니다. 이 연구는 GRIFFIN이 분류 및 생성 작업에서 원래의 성능을 유지하면서도 FF 뉴런의 50%를 제거해도 지연 시간을 줄일 수 있음을 보여줍니다.

 

sequence prompt 를 활용해서, 필요한 전문가를 선택한다? 가 핵심 같다. 어떻게 하는 거지? 일단 이방식을 쓰면 여러 활성화함수에서도 효율적인 성능을 입증을 했고, 학습 시간이 매우 개선이 되며 효과적이라고 한다. 기존의 성능을 유지하면서도 처리속도를 유지한다고 한다.

 

Background

이 논문의 배경은 피드포워드(FF) 블록의 활성화 희소성(sparsity)과 대규모 언어 모델(LLMs)의 가속화를 탐구하는 이전 연구에 근거합니다. 특히, ReLU 기반 LLMs에서는 활성화가 매우 희소하며, 큰 모델에서 더욱 두드러진다고 합니다. GLU 변형 같은 비희소 활성화 함수를 사용하는 모델에서는 출력에 기여하지 않는 뉴런이 거의 없어, 이러한 방법들의 효과가 제한될 수 있습니다.

 

Pruning(가지치기)는 계산 및 메모리 병목 현상을 해결하기 위한 또 다른 방법으로, 가중치를 반복적으로 0으로 만들고 성능 손실을 회복하기 위해 재훈련하는 방식입니다. 구조화된 가지치기는 실제 계산 절약을 보기 어렵고 성능 저하가 심할 수 있습니다.
MoEs(mixture of Experts) 는 더 동적인 희소성을 만들기 위해 설계되었으며, 작은 뉴런 집합만을 사용하여 레이어의 출력을 계산하는 게이팅 기능을 사용합니다. 현재의 방법들은 대부분 훈련이 필요하거나 ReLU 활성화 함수에 의존합니다.

 

일단 Pruning 은 말 그대로 재학습을 시켜한다는 점에서 리소스가 낭비가 될 수 있고, 내가 아는 개념에서는 불필요한 파라미터를 없애는 거라고 알 고 있는데, 모델 정확도 성능에도 영향을 끼칠 수 있다. 그래서 MoE를 통해서 동적인 접근법을 사용해 모델의 일부만 활성화를 통해 계산을 수행하게끔 하는 듯하다.

 

Problem Formulation

이 섹션에서는 FF(FeedForward) 블록의 다양한 구성 요소에 대한 개요와 함께, 우리의 방법이 해결하고자 하는 MoE(Mixture of Experts) 문제에 대한 보다 자세한 소개를 제공합니다. FF 블록은 주의(attention) 메커니즘과 달리 각 토큰에 대해 독립적이고 동일하게 작동합니다. FF 블록은 단일 열 벡터 입력 x에 대해 정의되며, FF(x)는 두 단계의 변환 FF1(비선형 변환)과 FF2(선형 변환)를 거칩니다. 예를 들어, OPT 모델에서는 FF1(x) = σ(W1x + b1) 형태로, GLU 변형을 사용하는 Llama 2나 Gemma 모델에서는 FF1(x) = σ(Wgx + bg) ⊙ (W1x + b1) 형태로 계산됩니다. MoE 설정에서의 목표는 출력 값이 보존되도록 FF 블록을 재매개변수화하는 것입니다. 이는 토큰마다 다를 수 있는 더 작은 매트릭스를 사용하여 계산하는 방식으로, GPU나 TPU에서 효율적으로 작동합니다.

그림 2: 라마 2 7B(왼쪽)와 젬마 7B(오른쪽)에서 WikiText 샘플의 상위 FF 뉴런 활성화 사이의 평균 Jaccard 유사도. 값이 높을수록 유사성이 높다는 뜻입니다.

 

 

4 From Flocking to GRIFFIN

4.1 Observing Flocking

'Flocking'은 시퀀스 내에서 각 토큰별로 각 뉴런의 상대적인 영향력을 살펴볼 때 나타납니다. 이를 관찰하기 위해, Z의 행을 단위 벡터로 정규화하여 상대적 활성화를 구성합니다. 이는 Llama 2 7B와 Gemma 7B에서 시퀀스에 대한 상대적 활성화 크기의 예를 보여주는 그림 1에 나타나 있습니다. 시퀀스 내 모든 토큰에 걸쳐 상대적으로 높은 가중치를 가진 활성화가 공통적이라는 점을, 수직 줄무늬의 형태로 확인할 수 있습니다. 특히, Llama 2 7B와 Gemma 7B는 각각 SwiGLU와 GEGLU 활성화 함수를 사용하며, 그 외에도 다른 주요 아키텍처 차이를 가집니다. 'flocking' 현상은 조직화된 새 떼와 같이 매우 질서 정연하며, 거의 모든 FF 레이어에서 관찰됩니다(부록 B 참조).

상대적 활성화의 크기는 시퀀스 내에서 공유되지만, 일반적으로 시퀀스 간에는 공유되지 않습니다. 이는 토큰 축을 따라 Z의 ℓ2-노름을 취하여 각 샘플 또는 시퀀스에 대한 길이 Dff 벡터를 얻음으로써, 시퀀스 전체에서 FF 뉴런의 기여도를 대략적으로 캡처하는 방식으로 증명합니다. 샘플마다 레이어에서 상위 k개를 취해 시퀀스 간의 자카드 유사도(Jaccard similarity) 계산합니다. 이는 서로 다른 k 대해 선택된 인덱스들을 기반으로 합니다. , 각각의 고유한 상위 k 집합 쌍의 교집합 합집합을 계산합니다. 높은 값은 유사한 상위 k 집합을 나타냅니다.  그림 2에서 WikiText 샘플들에 대한 자카드 유사도를 집계한 결과, Llama 2 7B Gemma 7B 대부분의 레이어에서 시퀀스 활성화 유사도가 부족함을 관찰할 있습니다. 이는 선택된 뉴런들의 집합이 크지 않으면, 일관성이 부족하다는 것을 의미하며, 이는 재훈련 없이 FF 뉴런 전체를 가지치기하는 것이 적응적인 방법보다 효과적일 있음을 나타냅니다.

그림 3: GRIFFIN 개요. 프롬프트의 상대적 활성화에 따라 생성에 사용할 전문가 뉴런이 결정됩니다.

 

프롬프트 FF(FeedForward) 단계:

  • 입력 데이터 X가 주어지고, 비선형 활성화 함수(σ)를 거쳐 첫 번째 가중치 행렬 W1과 곱해집니다.
  • 그 결과로 나온 활성화들이 두 번째 가중치 행렬 W2와 곱해져 최종 출력을 생성합니다.

전문가 선택(Expert Selection) 단계:

  • 활성화 결과 행렬 Z에서 각 행(row)이 단위 벡터가 되도록 정규화합니다. 이는 각 토큰별 활성화의 상대적 중요도를 비교하기 위함입니다.
  • 그다음,각 (column)의L2-norm을계산하여, 특징(feature) 전체 시퀀스에 걸쳐 얼마나 기여하는지 평가합니다. 이는 벡터 S 표현됩니다.

생성 FF 단계:

  • 동일한 입력 X에 대해, 이제 상대적 중요도에 기반하여 선택된 '상위-k' 전문가 뉴런만이 활성화됩니다.
  • 재조정된 가중치 행렬 W1과 W2는 원래 가중치 행렬의 상위 k 행/열로 구성
  • 전문가 가중치 행렬은 입력 X와 곱해져서 최종 출력을 생성한다.

 

4.2 GRIFFIN Algorithm

 

프롬프트 단계에서의 전문가 선택 (Prompt Phase Expert Selection): GRIFFIN에서는 전문가(뉴런)를 시퀀스 레벨에서 선택합니다. 이는 단일 토큰이 아닌 전체 입력 시퀀스의 동적인 상황을 고려할 때 이루어집니다. 전문가 뉴런을 선택하기 위해서는 각 뉴런의 중요성을 알려줄 통계값  s ∈ R DFF 가 필요합니다. 프롬프트 단계에서, 토큰 축을 따라 Z의 ℓ2-노름을 집계하여 이를 계산합니다:

s에서 상위-k 인덱스를 선택하면, 이는 생성 단계에서 이 샘플에 사용될 뉴런들을 결정합니다. 이렇게 선택된 전문가 뉴런들은 E 집합을 구성합니다. E에 있는 전문가를 사용하여, 각각의 FF 블록에 대해 W1, b1, Wg, bg, W2의 해당 행과 열을 선택함으로써 , B, , B, 그리고 를 찾을 수 있습니다. 이 작업은 프롬프트 단계에서 모든 FF 블록에 대해 수행됩니다. 부록 A에서 자세히 설명되어 있듯이, 는 상대적으로 높은 강도로 일관되게 활성화되는 뉴런들을 강조합니다

 

전문가를 사용한 생성 (Generation with Experts): 토큰을 생성할 때, 우리는 전문가 뉴런이 포함된 가지치기된 레이어인 를 직접 사용하여 모든 미래 토큰에 대해 를 추정합니다. Llama 2 13B 및 Gemma 7B에서 이는 생성 중 활성 파라미터의 수를 각각 13B에서 8.8B로, 8.5B에서 5.4B로 줄입니다.

 

Classification tasks

표 1은 FF 뉴런의 50%를 가지치기했을 때의 여러 분류 작업에 대한 zero-shot 정확도를 나타내며, GRIFFIN이 사용된 모델이 기존 모델과 비교하여 어떠한 성능을 보이는지를 나타냅니다. GRIFFIN이 전체 FF 뉴런을 사용하는 대신, 생성을 위해 가장 중요한 뉴런만을 선택하여 계산 효율성을 높이는 것을 목표로 하는 것을 알 수 있습니다.

 

이해가 되지 않아 GPT에게 도움을 구했다. 쉽게 말하면, 결국 시퀀스 단위의 문장이 들어오면 어떤 뉴럴 전문가로 사용할지 결정하고, 모델이 다루는 단어 토큰에 대해 뉴런이 얼마나 중요한지 수치화해서 통계적 계산을 한다. 그리고 중요한 뉴런을 골라내어 이들만 사용한다는 뜻이다. 뉴럴을 선별하고 , 그 후에 선별된 뉴런들을 통해서 문장을 만드는 것이다.

 

Experiments

 

GRIFFIN 알고리즘에 관한 연구에서, 연구진은 여러 작업과 모델에서 GRIFFIN의 우수한 성능을 보여줬고, 처리 속도를 개선했습니다. 또한 전문가 뉴런 샘플링, 시퀀스 길이 스케일링, 무작위 입력 등 GRIFFIN의 여러 특성에 대한 연구도 진행했습니다.

 

Performance

GRIFFIN은 생성 및 분류 작업에 대해 다양한 모델에서 평가되었습니다. 생성 작업에는 XSum, CNN/DailyMail, COQA, SCROLLS QASPER를 사용했고, 분류 작업에는 HellaSwag, PIQA, COPA, ARCEasy/Challenge, BoolQ를 사용했습니다. 대부분의 실험에는 LM Evaluation Harness를 사용했으며, 기존의 대규모 언어 모델(LLM)과 GRIFFIN, 그리고 뉴런 크기를 기반으로 한 정적인 시퀀스-레벨 MoE와 비교했습니다. GLU 변형이 있는 경우, 뉴런별 노름을 원소별 곱셈하여 가지치기 메트릭을 생성했습니다. 이 단순한 기준선은 분류 결과에서 좋은 성과를 보였지만 생성 작업에서는 성과가 떨어졌습니다.

표 2: 생성 작업 XSum(1샷), CNN/DailyMail(1샷), CoQA(0샷), SCROLLS QASPER (0샷), 50% FF 희소성에서. 매그니튜드 뉴런 가지치기는 거의 모든 경우에서 실패하는 반면, GRIFFIN은 효과적으로 성능을 보존합니다.

 

Efficient

GRIFFIN의 효율성 측정 결과를 제시합니다. 동일한 길이의 샘플을 가진 합성 데이터셋을 수집하고 샘플들에 대한 결과의 평균을 산출했습니다. GRIFFIN은 단일 샘플 입력에 이상적이므로, 이 실험에는 배치 크기를 1로 설정했습니다. Llama 2 13B와 Gemma 7B의 Hugging Face 구현을 FP16 정밀도로 사용하여 NVIDIA L40 GPU에서 다양한 시나리오의 지연 시간을 측정했습니다.

 

그림 4: 라마 2 7B(왼쪽), 젬마 7B(가운데), 미스트랄 7B(오른쪽)에 대한 GRIFFIN의 상대적 성능. FF 블록당 다양한 수준의 희소성을 적용하기 때문입니다. 모든 작업의 경우, 각 작업에 대한 원래 모델의 성능은 각 작업에 대한 원래 모델의 성능은 1로 정규화됩니다.
표 3: 생성 단계 지연 시간(초). 길이에서 G 토큰을 생성하는 작업을 "P + G"로 표시합니다. P 프롬프트. 해당되는 경우, 시간은 50% / 75% FF 희소성 형식으로 표시됩니다.

 

5.3 추가 분석 및 연구

  • 샘플링 기반 선택: 상위 k 전문가 선택이 샘플링 기반 방법보다 우수한 결과를 낳는다는 것을 확인했습니다.
  • 프롬프트 대비 생성 길이: GRIFFIN은 프롬프트를 길게 하여 긴 생성에 대한 강건성을 향상할 수 있음을 발견했습니다.
  • 랜덤 시퀀스에서의 희소성: 랜덤 입력에 대한 희소성을 탐구하여, 언어 내에서 실제로 활성화를 다양화하는 요소가 있음을 발견했습니다.

표 2는 XSum, CNN/DailyMail, CoQA, SCROLLS QASPER 작업에서 FF 뉴런의 50%를 가지치기했을 때 GRIFFIN의 성능을 보여줍니다. 결과는 뉴런의 크기를 기준으로 가지치기하는 방법이 거의 모든 경우에서 실패하는 반면, GRIFFIN은 성능을 효과적으로 보존하는 것을 보여줍니다. 예를 들어, Llama 2 7B 모델에서는 XSum 작업에서 Rouge-1/2/L 점수가 각각 27.15/9.06/22.62인 반면, GRIFFIN을 사용하면 24.75/7.41/20.55로, 거의 유사한 수준의 성능을 유지합니다.

이 연구는 GRIFFIN이 기존의 성능을 대체로 보존하면서도 효율성을 크게 향상시킬 수 있음을 보여주는 강력한 증거를 제공합니다.

 

50% FF 희소성에서 다양한 전문가 선정 방법 간 비교.
라마 2 7B(왼쪽) 및 젬마 7B(오른쪽)의 프롬프트 길이 대 세대 길이 측정 결과 FF 희소도 50%에서 연결된 위키텍스트의 전체 모델에 비해 난해성(PPL)의 증가.

 

Conclusion

이 작업에서는 FF 레이어에서 특별한 형태의 희소성과 이를 활용하는 간단한 방법을 보여주었습니다.  플로킹은 많은 LLM에서 나타나는 흥미로운 현상으로, 시퀀스 내의 토큰이 비슷한 강도로 활성화되는 것입니다. 이 구조는 학습이 필요 없는 MoE 선택 메커니즘인 GRIFFIN의 설계에 동기를 부여했습니다. 시퀀스 수준에서 추론하는 동안 FF 뉴런을 제거하여 전체 모델의 성능을 유지합니다. 대규모 분류 및 생성 작업에서 전체 모델의 성능을 유지하면서 지연 시간을 50% FF 희소성으로 낮출 수 있습니다. 또한, 적용 범위가 ReLU 기반 LLM을 넘어 더 많은 모델에 대한 더 많은 모델에 적용할 수 있습니다. 간단한 알고리즘과 무료 배포를 통해 GRIFFIN은 다양한 LLM에 대한 생성 추론을 위한 수많은 LLM의 접근성을 확장합니다.

 

Code

Model

# Adapted from Hugging Face implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
from griffin.utils import select_neurons

# griffin 알고리즘 적용
def get_gemma_griffin(model, k_schedule):
    config = model.config
    for i, l in enumerate(model.model.layers):
        new_mlp = GemmaMLP(config, k_schedule[i])
        
        new_mlp.gate_proj = l.mlp.gate_proj
        new_mlp.up_proj = l.mlp.up_proj
        new_mlp.down_proj = l.mlp.down_proj
        new_mlp.act_fn = l.mlp.act_fn
# 뉴런의 통계적 중요성을 가중치를 기반으로 계산하고, 이를 사용하여 전문가 로 유지될 뉴런을 결정
        if config.selection_method == 'magnitude':
            assert k_schedule[i] > 0.0
            gate_stat = l.mlp.gate_proj.weight.data.norm(dim=1)
            up_stat = l.mlp.up_proj.weight.data.norm(dim=1)
            stat = (gate_stat * up_stat).unsqueeze(0)
            _, indices = torch.topk(stat, int(stat.shape[1] * new_mlp.k_factor), dim=-1)
            new_mlp.prepare_reduced_weights(indices)
            new_mlp.mag_mask = torch.ones(stat.shape[-1], dtype=bool)
            new_mlp.mag_mask[indices[0]] = False
        
        l.mlp = new_mlp

    return model


class GemmaMLP(nn.Module):
    def __init__(self, config, k_factor):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = F.gelu
        
        self.k_factor = k_factor
        self.mode = config.mode
        assert self.mode in ['gen', 'class']

#top-k 로 전문가를 선택하는 메소드
    def prepare_reduced_weights(self, topk_indices):
        assert topk_indices.shape[0] == 1
        self.gate_proj_reduced = nn.Linear(self.gate_proj.weight.data.shape[1], len(topk_indices), bias=False)
        self.up_proj_reduced = nn.Linear(self.up_proj.weight.data.shape[1], len(topk_indices), bias=False)
        self.down_proj_reduced = nn.Linear(len(topk_indices), self.down_proj.weight.data.shape[0], bias=False)

        topk_indices = topk_indices[0]
        self.gate_proj_reduced.weight.data = self.gate_proj.weight.data[topk_indices]
        self.up_proj_reduced.weight.data = self.up_proj.weight.data[topk_indices]
        self.down_proj_reduced.weight.data = self.down_proj.weight.data[:, topk_indices]

#포워드 에서 gen, class 모드로 설정
    def forward(self, x):
        k_factor = self.k_factor
        if self.mode == 'gen':
            if x.shape[1] > 1:
                int_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x)

                # GRIFFIN
                if self.config.selection_method != 'magnitude' and k_factor > 0.0: ###
                    k = int(int_states.shape[-1] * k_factor)
                    neuron_stat = ((int_states / int_states.norm(dim=-1).unsqueeze(-1))).norm(dim=1) # B, D
                    topk_weight, topk_indices = select_neurons(neuron_stat, self.config.selection_method, k)
                    self.prepare_reduced_weights(topk_indices)

                down_proj = self.down_proj(int_states)
                return down_proj
                
            else:
                if k_factor == 0.0:
                    return 0 * x
                else:
                    return self.down_proj_reduced(self.act_fn(self.gate_proj_reduced(x)) * self.up_proj_reduced(x))
        
        elif self.mode == 'class':
            assert x.shape[1] > 1
            int_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
            if self.config.selection_method != 'magnitude': ###
                k = int(int_states.shape[-1] * k_factor)
                neuron_stat = ((int_states / int_states.norm(dim=-1).unsqueeze(-1)))[:, :-1].norm(dim=1) # B, D
            
                topk_weight, topk_indices = select_neurons(neuron_stat, self.config.selection_method, k)
                
                # Not tested for batch size > 1
                mask = torch.zeros_like(int_states[:, -1], dtype=torch.bool)
                mask.scatter_(dim=-1, index=topk_indices, src=torch.ones_like(mask))
                int_states[:, -1] = mask * int_states[:, -1]
            else:
                int_states[:, -1, self.mag_mask.to(int_states.device)] = 0
                
            down_proj = self.down_proj(int_states)
            
            return down_proj
        else:
            raise NotImplementedError

 

 

eval

if args.density < 1:
        model.config.mode = 'gen'
        model.config.selection_method = args.selection_method
        
        model = modify_dict[args.model_arch](model, schedule_k)


    model.half()
    model.eval().to(args.device)
    if args.max_length == -1:
        args.max_length = config.max_position_embeddings
    logger.info(args)

    requests = []
    for input_path in input_paths:
         with open(input_path, 'r') as f:
             for line in f:
                 if line.strip() != '':
                     requests.append(json.loads(line))

    requests = requests[:args.sample_num]

    results = []
    rouge = Rouge()

    seq_lens = []
    rouge1_score_list = []
    rouge2_score_list = []
    rougel_score_list = []

    skipped=0
    
    with torch.no_grad():
        for i, request in enumerate(tqdm.tqdm(requests)):

            stop = ['###']
            temperature = args.temp
            prompt = request['article']
            label = request['summary_gt']
            max_tokens = args.max_tokens
            result = {}
            if args.model_arch == 'gemma':
                input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
            else:
                input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device)
            if len(input_ids[0]) > args.max_length-max_tokens:
                skipped+=1
                print('skipped', skipped)

            else:
                output_sequences = model.generate(
                    input_ids=input_ids,
                    max_length=max_tokens + len(input_ids[0]),
                    temperature=temperature,
                    top_k=args.k,
                    top_p=1,
                    do_sample=not args.greedy,
                    num_return_sequences=1,
                    return_dict_in_generate=True, output_scores=True,
                    )

                tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):]
                logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']]
                top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}]

                generate_text = tokenizer.decode(output_sequences['sequences'].squeeze(0)[len(input_ids[0]):])
                generate_text = generate_text[: generate_text.find(stop[0])]

                scores = rouge.get_scores(generate_text, label)[0]
                seq_lens.append(len(input_ids[0]))
                rouge1_score_list.append(scores['rouge-1']['f'])
                rouge2_score_list.append(scores['rouge-2']['f'])
                rougel_score_list.append(scores['rouge-l']['f'])

                result['result'] = {
                    "choices": [
                        {
                            "text": generate_text,
                            "logprobs": {
                                "tokens": tokens, 
                                "token_logprobs": logprobs, 
                                "top_logprobs": top_logprobs, 
                                "text_offset": []
                            }, 
                            "finish_reason": "length"
                        }
                    ], 
                    "request_time": {
                        "batch_time": 0, 
                        "batch_size": 1}
                }

                results.append(result)
                print('rouge-1: {:.6f}, rouge-2: {:.6f}, rouge-l: {:.6f}'.format(np.mean(rouge1_score_list), np.mean(rouge2_score_list), np.mean(rougel_score_list)))

    print("FINAL RESULTS")
    print('rouge-1: {:.6f}, rouge-2: {:.6f}, rouge-l: {:.6f}'.format(np.mean(rouge1_score_list), np.mean(rouge2_score_list), np.mean(rougel_score_list)))

 

스크립트는 프롬프트를 기반으로 입력을 받아 전문가 뉴런을 선택하고, 이를 사용하여 텍스트를 생성하는 프로세스를 자동화한다. 이 과정에서 모델은 더 적은 수의 뉴런을 사용하여 문장을 생성하게 된다, 즉 모델의 효율성을 높이는 것. 생성된 텍스트는 rouge 라이브러리를 사용하여 평가되며, 성능 지표로는 Rouge-1, Rouge-2, Rouge-L 점수가 사용된다.

 

총평 : 그리핀 알고리즘을 통해서 MoE로 전문가 FF 만 선택해서 학습을 시킨다는 개념으로 자원을 절약할 수 있겠구나 싶었다. 사실 코드를 봐도 조금 어렵긴 한데, 직관적으로 그나마 이해가 갈 수 있는 부분이라 데이터를 직접 넣는 시뮬레이션을 해보면서 이해하면 조금 더 쉽지 않을까 싶다. 아직 github repo에서는 크게 Hype 이 되지 않는 걸로 보아 이것도 나중에 크게 될 수 있지 않을까 조심스럽게 생각했다.

 

 

728x90