본문 바로가기
code review

효과적인 Attention 매커니즘 infini-attention 의 Code 리뷰

by AI미남홀란드 2024. 4. 18.
728x90

https://github.com/jlamprou/Infini-Attention/blob/main/infiniAttention.py

infini-attention

 

Infini-Attention/infiniAttention.py at main · jlamprou/Infini-Attention

Efficient Infinite Context Transformers with Infini-attention Pytorch Implementation + QwenMoE Implementation + Training Script + 1M context keypass retrieval - jlamprou/Infini-Attention

github.com

 

+

 

블로그가 잘안보이는 관계로
https://github.com/jh941213/Code_review/blob/main/infini_attention.py

 

Code_review/infini_attention.py at main · jh941213/Code_review

코드리뷰하고 주석달아놓은 파일을 정리해두는 저장소. Contribute to jh941213/Code_review development by creating an account on GitHub.

github.com

코드를 올려두었습니다.(주석과 함께 코드를 읽어보세요)


 

infiniattention.py의 코드는 다음과 같습니다. 함수, 별로 한번 코드를 뜯어보겠습니다.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from transformers import AutoConfig

# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

class InfiniAttention(nn.Module):
    def __init__(self, config: AutoConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.attention_dropout = config.attention_dropout

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        self.rotary_emb = RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

        self.beta = nn.Parameter(torch.randn(1))
        
        self.segment_size = 2048

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        M_Z: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        
        # Initialize memory and normalization term 
        if M_Z is None:
            M = torch.zeros(self.num_heads, self.head_dim, self.head_dim).to(hidden_states.device)
            z = torch.zeros(self.num_heads, self.head_dim).to(hidden_states.device)
        else:
            M, z = M_Z

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)


        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )


        # GQA
        # Memory retrieval and attention calculation per segment
        memory_output = self._retrieve_from_memory(query_states, M, z)
        # Update memory with current segment's key and value states
        M, z  = self._update_memory(key_states, value_states, M, z)
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
        )

        combined_output = self._long_term_injection(attn_output, memory_output)

        # Prepare output for this segment
        combined_output = combined_output.transpose(1, 2).contiguous()
        combined_output = combined_output.view(bsz, q_len, self.hidden_size)
        final_output = self.o_proj(combined_output)
        return final_output, None, None, (M, z)

    def _retrieve_from_memory(self, Q, M, z):
        # Retrieve context from compressive memory using linear attention (Eq. 3)
        M_s_1 = torch.matmul(F.elu(Q) + 1, M)
        Z_s_1 = torch.matmul(F.elu(Q) + 1, z.unsqueeze(-1)) + 1e-8
        A_mem = M_s_1 / Z_s_1
        return A_mem

    def _update_memory(self, K, V, M, z, use_delta=False):
        if use_delta:
            V_retrieved = torch.matmul(F.elu(K).transpose(-2, -1) + 1, M) / (torch.matmul(F.elu(K).transpose(-2, -1) + 1, z.unsqueeze(-1)) + 1e-8)
            updated_M = M + torch.matmul(F.elu(K).transpose(-2, -1) + 1, V - V_retrieved)
        else:
            updated_M = M + torch.matmul(F.elu(K).transpose(-2, -1) + 1, V)
        
        updated_z = z + (F.elu(K) + 1).sum(dim=-2)
        M = updated_M.detach()
        z = updated_z.detach()
        return M, z

    def _long_term_injection(self, A_dot, A_mem):
        beta = torch.sigmoid(self.beta)
        A = beta * A_mem + (1 - beta) * A_dot
        return A
    

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

 

 


repeat_kv 함수

 

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from transformers import AutoConfig

# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

 

위함수의 목적은 텐서를 차원을 조작해서, 특정 차원을 반복확장 하는 것이다. Key, Value 텐선들의 차원수를 조정한다.

hidden_state.shape을 각각의 변수에 차원(batch, num_key_value_heads, seqlen, head_dim으로 구성이 되어있다)을 맵핑해서 넣어주고. n_rep(반복할 횟수)에 따라서 1이면 변환할 필요가 없어서 그냥 리턴하고, 1이 아니면 텐서에 새로운 축을 추가하고 텐서를 확장시킨다. 이 과정에서 key-value pair 가 반복되어, 확장 연산은 추가메모리를 사용하지 않고, 원본 텐서의 데이터를 재사용한다. 마지막 리턴할 때 reshape을 해서  최종 텐서의 형태를 만들어서 반환한다.

-> attention head의 key 와 value벡터들의 수를 인위적으로 늘려서 많은 문맥정보를 포함시키도록 한다.

 

 

infiniAttention Class

 

class InfiniAttention(nn.Module):
    def __init__(self, config: AutoConfig, layer_idx: Optional[int] = None):
        super().__init__()
        #파라미터
        self.config = config #autoconfig
        self.layer_idx = layer_idx #attention layer index
        self.hidden_size = config.hidden_size #입력 차원특성
        self.num_heads = config.num_attention_heads # 멀티헤드 갯수
        self.head_dim = self.hidden_size // self.num_heads # 헤드에서의 hidden size hiddensize / num_head
        self.num_key_value_heads = config.num_key_value_heads 
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads #key value group
        self.max_position_embeddings = config.max_position_embeddings # 포지셔널 임베딩
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.attention_dropout = config.attention_dropout
		
        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        #프로젝션 레이터
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
		# 공간 정보를 기억하기 위한 로타리 임베딩
        self.rotary_emb = RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )
		#추가파라미터
        self.beta = nn.Parameter(torch.randn(1)) # memory 와 attention 결합시 사용되는 가중치 파라미터
        
        self.segment_size = 2048 #입력시퀀스 처리시 한번에 처리할 세그먼트의 크기

파라미터를 선언해 주고,  프로젝션레이어에서는 Linear 선형변환을 정의한다. 여기서 o_projsms 최종 attention 결과를 hidden_size로 맵핑하는 역할이다. 로터리 임베딩은 하이퍼크로버에서도 쓰였는데 회전 불변의 방식을 제공해서 긴 시퀀스 간 처리에 유용해서 위치정보를 고유하게 유지한다.

 

infiniAttention forward

 

 def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        M_Z: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        #초기 메모리 설정
        # Initialize memory and normalization term 
        if M_Z is None:
            M = torch.zeros(self.num_heads, self.head_dim, self.head_dim).to(hidden_states.device)
            z = torch.zeros(self.num_heads, self.head_dim).to(hidden_states.device)
        else:
            M, z = M_Z

        bsz, q_len, _ = hidden_states.size() #bsz, qlen, _ (batch_size, sequence_length, feature_dim)추출
		#입력 텐서 처리
        query_states = self.q_proj(hidden_states) #쿼리, 키, 값 벡터를 위한 텐서
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
		#view를 써서 mha 계산을 하기 위한 형태로 추출
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        #rotary positional embedding
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) #cos, sin 값계산

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # 로터리 임베딩 적용해서 상대적인 위치를 보여줌
		
        #과거의 key-value 쌍 값 처리
        #이전값의 키와 밸류의 정보를 현재계산에서 재사용하고, 과거의 저장된 값과 함께 업데이트해서 cache_kwargs를 통해 cos, sin 도 같이 전달한다.
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
		
        #attention 마스크 검증(차원확인)		
        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )


        # GQA
        # Memory retrieval and attention calculation per segment
        # 메모리 업데이트 및 검색
        memory_output = self._retrieve_from_memory(query_states, M, z) #메모리검색
        # Update memory with current segment's key and value states
        M, z  = self._update_memory(key_states, value_states, M, z) #새로운 키밸류 메모리 업데이트
        #위의 repeat_kv 함수 키와 벨류 상태를 attention head 에 맞게 반복확장
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
		#어텐션 계산
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
        )
		#출력물 결합 및 최종 처리
        # attn_output 결과와 메모리로부터 검색된 출력(memory_output)을 결합합니다. 여기서 beta 파라미터를 사용하여 두 출력 사이의 결합 비율을 조절할 수 있습니다.
        combined_output = self._long_term_injection(attn_output, memory_output)
        
		
        # Prepare output for this segment
        # 차원조정 및 선형 프로젝션을 통해 차원 수정을 통해 최종 출력함
        combined_output = combined_output.transpose(1, 2).contiguous()
        combined_output = combined_output.view(bsz, q_len, self.hidden_size)
        final_output = self.o_proj(combined_output)
        return final_output, None, None, (M, z)

 

포워드 하고 업데이트하는 과정인데 밑에 함수의 코드들이 있으니 연관 지어서 주석을 쭉 읽어보면 좋습니다. 메모리설정, key-value의 이전값을 업데이트 및 파라미터에 맞게 차원을 수정하고, 어텐션 계산해서 결과를 가져오는 코드.

 

infini Attention memory 처리,  long_term_injection 처리

 #메모리검색함수
 def _retrieve_from_memory(self, Q, M, z):
        # Retrieve context from compressive memory using linear attention (Eq. 3)
        M_s_1 = torch.matmul(F.elu(Q) + 1, M)
        Z_s_1 = torch.matmul(F.elu(Q) + 1, z.unsqueeze(-1)) + 1e-8
        A_mem = M_s_1 / Z_s_1
        return A_mem
        
#메모리업데이트
def _update_memory(self, K, V, M, z, use_delta=False):
    if use_delta: #논문에서 델타규칙, deltarule을 통해서 메모리를 업데이트 할지 결정
        V_retrieved = torch.matmul(F.elu(K).transpose(-2, -1) + 1, M) / (torch.matmul(F.elu(K).transpose(-2, -1) + 1, z.unsqueeze(-1)) + 1e-8)
        updated_M = M + torch.matmul(F.elu(K).transpose(-2, -1) + 1, V - V_retrieved)
    else:
        updated_M = M + torch.matmul(F.elu(K).transpose(-2, -1) + 1, V)

    updated_z = z + (F.elu(K) + 1).sum(dim=-2)
    M = updated_M.detach() #detach() 는 기존 계산 그래프와 분리해 메모리가 그래디언트 업데이트에 영향을 미치지 않도록 함
    z = updated_z.detach()
    return M, z
    
#장기기억 결합
# 어텐션결과와 이전 메모리 장기기억인 A_mem 을 결합하고 sigmoid를 통해서 0,과 1사이의 비율을 조정하는데 사용하고 A는 장기기억과 현재기억을 반영한 최종출력
def _long_term_injection(self, A_dot, A_mem):
    beta = torch.sigmoid(self.beta)
    A = beta * A_mem + (1 - beta) * A_dot
    return A

 

RotaryEmbedding

 

Transformer 모델의 회전 위치 임베딩(Rotary Position Embedding)을 계산하는 데 사용됩니다. 회전 위치 임베딩은 모델이 입력 시퀀스의 각 요소 위치 정보를 고려할 수 있도록 돕는 방법 중 하나로, 특히 Transformer 아키텍처에서 유용하게 사용됩니다.

 

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

 

 

 

# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

 

로터리 임베딩은 따로 논문과 함께 코드리뷰를 해보겠습니다. 코드만 참고해 주시고 이 함수가 위의 forward에서 이렇게 쓰이는구나 하고 넘어가시면 좋을 것 같습니다.


repeat_kv function

  • 키(Key)와 값(Value) 텐서들의 차원 수를 조정하는 함수입니다. 이 함수는 어텐션 헤드 수에 맞게 키와 값 벡터들을 반복 확장하여, 어텐션 메커니즘에서 더 많은 문맥 정보를 포함시킬 수 있도록 합니다.

InfiniAttention class

  • 클래스는 초기화에서 다수의 설정을 포함하며, 포워드 패스에서 입력 텐서를 처리하는 로직을 구현합니다. 여기에는 메모리 업데이트 및 검색, 로터리 임베딩 적용, 그리고 최종 어텐션 스코어 계산이 포함됩니다.

RotaryEmbedding class

  • 회전 위치 임베딩을 계산하여, 모델이 입력 시퀀스의 위치 정보를 보다 정확하게 처리할 수 있도록 합니다.

 

728x90