import torch
from transformers.models.starcoder2.modeling_starcoder2 import eager_attention_forward, apply_rotary_pos_emb, Starcoder2Attention, FlashAttentionKwargs
from transformers.cache_utils import Cache
from typing_extensions import Unpack
from typing import Optional
import code
from datasets import load_from_disk

from torch import nn as nn
from transformers import Starcoder2Config

def repeat_kv(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
    b, n_kv, slen, d = hidden.shape
    if n_rep == 1:
        return hidden
    return hidden.unsqueeze(2).expand(b, n_kv, n_rep, slen, d).reshape(b, n_kv * n_rep, slen, d)


class PatchedStarcoder2Attention(Starcoder2Attention):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, original_attn):
        super().__init__(original_attn.config)
        # Copy weights and key attributes from the original
        self.load_state_dict(original_attn.state_dict(), strict=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        #if self.config._attn_implementation != "eager":
        #    raise Exception("Please use eager attention")

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=getattr(self.config, "sliding_window", None),  # diff with Llama
            **kwargs,
        )
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        attn_output = nn.functional.dropout(
            attn_output, p=self.residual_dropout, training=self.training
        )  # diff with Llama

        return attn_output, attn_weights

class IdentityPatchedStarcoder2Attention(Starcoder2Attention):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, original_attn, identity_heads: list):
        super().__init__(original_attn.config)
        # Copy weights and key attributes from the original
        self.identity_heads = identity_heads
        self.load_state_dict(original_attn.state_dict(), strict=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_identity_attention_forward
        #if self.config._attn_implementation != "eager":
        #    raise Exception("Please use eager attention")

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            self.identity_heads,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=getattr(self.config, "sliding_window", None),  # diff with Llama
            **kwargs,
        )
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        attn_output = nn.functional.dropout(
            attn_output, p=self.residual_dropout, training=self.training
        )  # diff with Llama

        return attn_output, attn_weights

def eager_identity_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    identity_heads: list,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_weights[:, identity_heads, :, :] = torch.eye(attn_weights.shape[-1], device=attn_weights.device, dtype=attn_weights.dtype) #Identity pattern for selected heads
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


