from typing import Callable, Optional, Tuple

import torch
import torch.utils.checkpoint
from torch import nn
from transformers.cache_utils import Cache
from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils import TransformersKwargs

class LlamaLowDimAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, base_attention, lowdim_model):
        super().__init__()
        self.config = base_attention.config
        self.layer_idx = base_attention.layer_idx
        self.head_dim = base_attention.head_dim
        self.num_key_value_groups = base_attention.num_key_value_groups
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = base_attention.attention_dropout
        self.is_causal = base_attention.is_causal
        self.lowdim_model = lowdim_model
        self.lowdim_train_mode = False
        self.lowdim_project = False
        self.lowdim_collect_scores = False

        self.q_proj = base_attention.q_proj
        self.k_proj = base_attention.k_proj
        self.v_proj = base_attention.v_proj
        self.o_proj = base_attention.o_proj

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        past_key_values = past_key_values if past_key_values is not None else past_key_value
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # LowDim eval and collect scores
        if self.lowdim_collect_scores:
            self.lowdim_model.eval_batch(query_states, key_states, values=value_states)

        # LowDim train
        if self.lowdim_train_mode:
            self.lowdim_model.partial_train(query_states, key_states)

        # LowDim projections
        if self.lowdim_project:
            query_states = self.lowdim_model.project(query_states)
            key_states = self.lowdim_model.project(key_states)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights
    



    def lowdim_train(self):
        self.lowdim_train_mode = True
        self.lowdim_project = False

    def finalize_lowdim_epoch_training(self):
        self.lowdim_model.finalize_epoch_training()

    def lowdim_eval(self):
        self.lowdim_train_mode = False
        self.lowdim_project = True

    def start_score_collection(self):
        self.lowdim_project = False
        self.lowdim_collect_scores = True
    
    def finalize_score_collection(self):
        self.lowdim_project = True
        self.lowdim_collect_scores = False
        return self.lowdim_model.finalize_score_collection()

    def save(self, target_dir="./checkpoints"):
        self.lowdim_model.save(target_dir, self.layer_idx)
