from src.models.feature_net import FeatureNet
from src.models.linear_attention import LinearAttnCache, linear_attention

from transformers.models.gpt2.modeling_gpt2 import (GPT2MLP, GPT2Attention, GPT2Block, 
                                                    GPT2PreTrainedModel, GPT2Model, GPT2LMHeadModel)
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.pytorch_utils import Conv1D
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from typing import Optional, Tuple, Union, Dict

GPT2_ATTENTION_CLASSES = {
    "eager": GPT2Attention
}

class ExtendedGPT2Config(GPT2Config):
    use_linear_attn: bool

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.use_linear_attn = kwargs.get("use_linear_attn", False)
        self.feature_type = kwargs.get("feature_type", "fourier")
        self.coef_type = kwargs.get("coef_type", "standard")
        self.adaptive_shift = kwargs.get("adaptive_shift")
        self.num_features = kwargs.get("num_features", 64)
        self._attn_implementation = kwargs.get(
            "_attn_implementation", 
            self._attn_implementation if hasattr(self, "_attn_implementation") else "eager"
        )
        self.post_init()
    
    def update(self, config_dict):
        super().update(config_dict)
        self.post_init()

    def post_init(self):
        head_dim = self.hidden_size // self.num_attention_heads
        if self.feature_type == "dijiang" and self.num_features != head_dim:
            print("Warning: num_features is set to the same value as head_dim when feature_type == 'dijiang'")
            self.num_features = head_dim
        if isinstance(self.num_features, int):
            self.num_features = [self.num_features] * self.num_hidden_layers
        
        if self.feature_type != "positive" and self.adaptive_shift is not None:
            print("Warning: adaptive_shift is set to None when feature_type != 'positive'")
            self.adaptive_shift = None

class ExtendedGPT2Attention(GPT2Attention):
    _retain_kernel_grad: bool

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._retain_kernel_grad = False

    def activate_or_deactivate_kernel_grad(self, activate_or_deactivate: bool):
        self._retain_kernel_grad = activate_or_deactivate

    def _attn(self, query, key, value, attention_mask=None, head_mask=None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        attn_info = {}
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        if self.scale_attn_weights:
            attn_weights = attn_weights / torch.full(
                [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
            )

        if self.scale_attn_by_inverse_layer_idx:
            attn_weights = attn_weights / float(self.layer_idx + 1)

        if not self.is_cross_attention:
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
            mask_value = torch.finfo(attn_weights.dtype).min
            mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
            attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        qk = attn_weights

        if self._retain_kernel_grad:
            shift = qk.amax(dim=-1, keepdim=True)
            kernel = torch.exp(qk - shift)
            kernel.requires_grad_(True)
            kernel.retain_grad()
            attn_info["kernel"] = kernel
            attn_weights = kernel / kernel.sum(dim=-1, keepdim=True)
        else:
            attn_info["query"] = query
            attn_info["key"] = key
            attn_info["qk"] = qk
            attn_weights = nn.functional.softmax(qk, dim=-1)

        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_info["weight"] = attn_weights
        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_info

    def forward(self, hidden_states, *args, output_attentions: Union[bool, list[str]] = False, **kwargs):
        outputs = super().forward(hidden_states, *args, output_attentions=output_attentions, **kwargs)
        if isinstance(output_attentions, bool) and output_attentions:
            outputs = outputs[:-1] + (outputs[-1]["weight"],)
        elif isinstance(output_attentions, list):
            outputs = outputs[:-1] + ({k: v for k, v in outputs[-1].items() if k in output_attentions},)
        else:
            pass
            
        return outputs 

class GPT2LinearAttention(GPT2Attention):
    config: ExtendedGPT2Config
    bias: torch.Tensor
    kernel_clip: torch.Tensor
    recurrence: bool

    def __init__(self, config: ExtendedGPT2Config, is_cross_attention=False, layer_idx=None):
        super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx)
        
        self.register_buffer("kernel_clip", torch.tensor(1e-30))
        self.feature_type = config.feature_type
        self.num_features = config.num_features[layer_idx]
        self.recurrence = False

        if layer_idx is None:
            raise ValueError("layer_idx must be specified")
        
        self.feature_net = FeatureNet(config, layer_idx)
    
    def switch_recurrence(self, recurrence: bool):
        self.recurrence = recurrence

    def prune_heads(self, heads):
        raise NotImplementedError

    def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
        raise NotImplementedError
    
    def _attn(self, *args, **kwargs):
        raise NotImplementedError

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[LinearAttnCache] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        
        if encoder_hidden_states is not None:
            raise NotImplementedError
        else:
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        if self.reorder_and_upcast_attn:
            raise NotImplementedError
        
        if use_cache and self.recurrence and (layer_past is None):
            layer_past = LinearAttnCache()

        attn_output, attn_info = linear_attention(
            self.feature_net, self.attn_dropout, query, key, value, attention_mask, head_mask, 
            cache=layer_past, adaptive_shift=self.config.adaptive_shift, bias=self.bias, kernel_clip=self.kernel_clip
        )
        
        if use_cache:
            present = layer_past
        else:
            present = None

        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)
        if isinstance(output_attentions, bool) and output_attentions:
            outputs += (attn_info["weight"],)
        elif isinstance(output_attentions, list):
            outputs += ({k: v for k, v in attn_info.items() if k in output_attentions},)

        return outputs 
    
    def transplant(self, original: GPT2Attention):
        self.c_attn = original.c_attn
        if hasattr(original, "q_attn"):
            self.q_attn = original.q_attn
        self.c_proj = original.c_proj
        self.attn_dropout = original.attn_dropout
        self.resid_dropout = original.resid_dropout
        assert (self.bias == original.bias).all().item()

class ExtendedGPT2Block(GPT2Block):
    def __init__(self, config, layer_idx=None):
        nn.Module.__init__(self)

        hidden_size = config.hidden_size
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
        if config.use_linear_attn:
            attention_class = GPT2LinearAttention
        else:
            attention_class = ExtendedGPT2Attention
            assert config._attn_implementation == "eager", "Only eager implementation is supported"

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = attention_class(config=config, layer_idx=layer_idx) 
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

        if config.add_cross_attention:
            self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx)
            self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = GPT2MLP(inner_dim, config)
    
    def transplant(self, original: GPT2Block):
        self.ln_1 = original.ln_1
        self.ln_2 = original.ln_2
        self.attn.transplant(original.attn)
        if hasattr(self, "crossattention"):
            self.crossattention = original.crossattention
            self.ln_cross_attn = original.ln_cross_attn
        self.mlp = original.mlp

class ExtendedGPT2Model(GPT2Model):
    def __init__(self, config):
        GPT2PreTrainedModel.__init__(self, config)

        self.embed_dim = config.hidden_size

        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)

        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([ExtendedGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

        self.model_parallel = False
        self.device_map = None
        self.gradient_checkpointing = False
        self._attn_implementation = config._attn_implementation

        self.post_init()
    
    def transplant(self, original: GPT2Model):
        self.wte = original.wte
        self.wpe = original.wpe
        for block, original_block in zip(self.h, original.h):
            block.transplant(original_block)
        self.ln_f = original.ln_f

class ExtendedGPT2LMHeadModel(GPT2LMHeadModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config: ExtendedGPT2Config):
        GPT2PreTrainedModel.__init__(self, config)

        self.transformer = ExtendedGPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.model_parallel = False
        self.device_map = None

        self.post_init()

    def transplant(self, original: GPT2LMHeadModel):
        assert self.config.use_linear_attn, "Only linear attention is supported"
        
        self.transformer.transplant(original.transformer)
        self.lm_head = original.lm_head
    
    def switch_recurrence(self, recurrence: bool):
        for block in self.transformer.h:
            block.attn.switch_recurrence(recurrence)