from src.models.feature_net import FeatureNet
from src.models.linear_attention import LinearAttnCache, linear_attention

from transformers.models.gpt_neox.modeling_gpt_neox import (GPTNeoXPreTrainedModel, GPTNeoXForCausalLM, GPTNeoXModel, 
                                                            GPTNeoXAttention, GPTNeoXLayer, GPTNeoXRotaryEmbedding, 
                                                            GPTNeoXLinearScalingRotaryEmbedding, GPTNeoXDynamicNTKScalingRotaryEmbedding,
                                                            GPTNeoXMLP, 
                                                            apply_rotary_pos_emb)
from transformers.models.gpt_neox.configuration_gpt_neox import GPTNeoXConfig
from transformers.cache_utils import Cache, DynamicCache

import torch
import torch.nn as nn
from typing import Optional, Tuple, Union

ConfigAttrAlias = {
    "n_embd": "hidden_size", 
    "n_head": "num_attention_heads",
    "n_layer": "num_hidden_layers",
    "attn_pdrop": "attention_dropout",
    "n_positions": "max_position_embeddings",
}
AttnAttrAlias = {
    "attn_dropout": "attention_dropout"
}
LayerAttrAlias = {
    "attn": "attention"
}
ModelAttrAlias = {
    "h": "layers", 
    "ln_f": "final_layer_norm"
}
CausalLMAttrAlias = {
    "transformer": "gpt_neox",
    "lm_head": "embed_out"
}

class ExtendedGPTNeoXConfig(GPTNeoXConfig):
    use_linear_attn: bool
    n_positions: int

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.use_linear_attn = kwargs.get("use_linear_attn", False)
        self.feature_type = kwargs.get("feature_type", "fourier")
        self.coef_type = kwargs.get("coef_type", "standard")
        self.adaptive_shift = kwargs.get("adaptive_shift")
        self.num_features = kwargs.get("num_features", 64)
        self.recurrence = kwargs.get("recurrence", True)
        self._attn_implementation = kwargs.get(
            "_attn_implementation", 
            self._attn_implementation if hasattr(self, "_attn_implementation") else "eager"
        )
        self.post_init()
    
    def __getattr__(self, name):
        if name in ConfigAttrAlias:
            return getattr(self, ConfigAttrAlias[name])
        else:
            raise AttributeError(f"AttributeError: '{self.__class__.__name__}' object has no attribute '{name}'")
            
    def __setattr__(self, name, value):
        if name in LayerAttrAlias:
            return setattr(self, LayerAttrAlias[name], value)
        else:
            return super().__setattr__(name, value)
    
    def update(self,config_dict):
        super().update(config_dict)
        self.post_init()

    def post_init(self):
        head_dim = self.hidden_size // self.num_attention_heads
        if self.feature_type == "dijiang" and self.num_features != head_dim:
            print("Warning: num_features is set to the same value as head_dim when feature_type == 'dijiang'")
            self.num_features = head_dim
        if isinstance(self.num_features, int):
            self.num_features = [self.num_features] * self.num_hidden_layers
        
        if self.feature_type != "positive" and self.adaptive_shift is not None:
            print("Warning: adaptive_shift is set to None when feature_type != 'positive'")
            self.adaptive_shift = None

class ExtendedGPTNeoXAttention(GPTNeoXAttention):
    _retain_kernel_grad: bool

    def __init__(self, config, layer_idx=None):
        super().__init__(config, layer_idx)
        self._retain_kernel_grad = False

    def __getattr__(self, name):
        if name in AttnAttrAlias:
            return getattr(self, AttnAttrAlias[name])
        else:
            return super().__getattr__(name)
        
    def __setattr__(self, name, value):
        if name in LayerAttrAlias:
            return setattr(self, LayerAttrAlias[name], value)
        else:
            return super().__setattr__(name, value)

    def activate_or_deactivate_kernel_grad(self, activate_or_deactivate: bool):
        self._retain_kernel_grad = activate_or_deactivate

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        attn_info = {"query": query, "key": key}
        
        norm_factor = self.norm_factor
        attention_dropout = self.attention_dropout
        training = self.training

        batch_size, num_attention_heads, query_length, attn_head_size = query.size()
        key_length = key.size(-2)

        query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
        key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
        attn_scores = torch.zeros(
            batch_size * num_attention_heads,
            query_length,
            key_length,
            dtype=query.dtype,
            device=key.device,
        )
        attn_scores = torch.baddbmm(
            attn_scores,
            query,
            key.transpose(1, 2),
            beta=1.0,
            alpha=norm_factor,
        )
        attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)

        if attention_mask is not None: 
            causal_mask = attention_mask[:, :, :, : key.shape[-2]]
            attn_scores = attn_scores + causal_mask

        if self._retain_kernel_grad:
            shift = attn_scores.amax(dim=-1, keepdim=True)
            kernel = torch.exp(attn_scores - shift)
            kernel.requires_grad_(True)
            kernel.retain_grad()
            attn_info["kernel"] = kernel
            attn_weights = kernel / kernel.sum(dim=-1, keepdim=True)
        else:
            attn_info["qk"] = attn_scores
            attn_weights = nn.functional.softmax(attn_scores, dim=-1)
        attn_weights = attn_weights.to(value.dtype)

        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_weights = self.attention_dropout(attn_weights)

        attn_info["weight"] = attn_weights
        attn_output = torch.matmul(attn_weights, value)

        attn_output = attn_output.transpose(1, 2).contiguous()

        return attn_output, attn_info
    
    def forward(self, hidden_states, attention_mask, position_ids, 
                head_mask=None, layer_past=None, use_cache=False, 
                output_attentions: Union[bool, list[str]] = False, 
                padding_mask=None, cache_position=None, position_embeddings=None
                ):
        
        bsz, seq_len, _ = hidden_states.shape

        query, key, value, present = self._attn_projections_and_rope(
            hidden_states=hidden_states,
            position_ids=position_ids,
            layer_past=layer_past,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
        )

        attention_type = self.config._attn_implementation
        assert attention_type == "eager", f"Unsupported attention implementation: {attention_type}"

        attn_output, attn_info = self._attn(query, key, value, attention_mask, head_mask)

        attn_output = attn_output.contiguous()
        attn_output = attn_output.view(bsz, seq_len, -1)
        attn_output = self.dense(attn_output)
        
        outputs = (attn_output, present)
        if isinstance(output_attentions, bool) and output_attentions:
            outputs = outputs + (attn_info["weight"],)
        elif isinstance(output_attentions, list):
            outputs = outputs + ({k: v for k, v in attn_info.items() if k in output_attentions},)
        
        return outputs

class GPTNeoXLinearAttention(GPTNeoXAttention):
    config: ExtendedGPTNeoXConfig
    bias: torch.Tensor
    kernel_clip: torch.Tensor
    recurrence: bool

    def __init__(self, config: ExtendedGPTNeoXConfig, layer_idx=None):
        super().__init__(config, layer_idx=layer_idx)

        self.register_buffer("kernel_clip", torch.tensor(1e-30))
        self.feature_type = config.feature_type
        self.num_features = config.num_features[layer_idx]
        self.recurrence = False

        if layer_idx is None:
            raise ValueError("layer_idx must be specified")
        
        self.feature_net = FeatureNet(config, layer_idx)

    def __getattr__(self, name):
        if name in AttnAttrAlias:
            return getattr(self, AttnAttrAlias[name])
        else:
            return super().__getattr__(name)

    def __setattr__(self, name, value):
        if name in LayerAttrAlias:
            return setattr(self, LayerAttrAlias[name], value)
        else:
            return super().__setattr__(name, value)

    def forward(self, hidden_states, attention_mask, position_ids, 
                head_mask=None, layer_past=None, use_cache=False, 
                output_attentions: Union[bool, list[str]] = False, 
                padding_mask=None, cache_position=None, position_embeddings=None):

        bsz, seq_len, _ = hidden_states.shape

        query, key, value, _ = self._attn_projections_and_rope(
            hidden_states=hidden_states,
            position_ids=position_ids,
            layer_past=None,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
        )

        if isinstance(layer_past, DynamicCache):
            assert layer_past.get_seq_length() == 0
        if use_cache and self.recurrence:
            layer_past = LinearAttnCache()
        else:
            layer_past = None

        attn_output, attn_info = linear_attention(
            self.feature_net, self.attention_dropout, query, key, value, attention_mask, head_mask, 
            cache=layer_past, adaptive_shift=self.config.adaptive_shift, bias=self.bias, kernel_clip=self.kernel_clip
        )

        if layer_past is not None:
            present = layer_past
        else:
            present = DynamicCache()

        attn_output = attn_output.contiguous()
        attn_output = attn_output.view(bsz, seq_len, -1)
        attn_output = self.dense(attn_output)
        
        outputs = (attn_output, present)
        if isinstance(output_attentions, bool) and output_attentions:
            outputs = outputs + (attn_info["weight"],)
        elif isinstance(output_attentions, list):
            outputs = outputs + ({k: v for k, v in attn_info[-1].items() if k in output_attentions},)

        return outputs 

    def transplant(self, original: GPTNeoXAttention):
        self.query_key_value = original.query_key_value
        self.dense = original.dense
        self.attention_dropout = original.attention_dropout
        assert (self.bias == original.bias).all().item()

class ExtendedGPTNeoXLayer(GPTNeoXLayer):

    def __init__(self, config: ExtendedGPTNeoXConfig, layer_idx=None):
        nn.Module.__init__(self)

        self.use_parallel_residual = config.use_parallel_residual
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
        self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)

        if config.use_linear_attn:
            self.attention = GPTNeoXLinearAttention(config, layer_idx)
        else:
            self.attention = ExtendedGPTNeoXAttention(config, layer_idx)

        self.mlp = GPTNeoXMLP(config)

    def __getattr__(self, name):
        if name in LayerAttrAlias:
            return getattr(self, LayerAttrAlias[name])
        else:
            return super().__getattr__(name)
        
    def __setattr__(self, name, value):
        if name in LayerAttrAlias:
            return setattr(self, LayerAttrAlias[name], value)
        else:
            return super().__setattr__(name, value)
        
    def transplant(self, original: GPTNeoXLayer):
        self.input_layernorm = original.input_layernorm
        self.post_attention_layernorm = original.post_attention_layernorm
        self.post_attention_dropout = original.post_attention_dropout
        self.post_mlp_dropout = original.post_mlp_dropout
        self.attention.transplant(original.attention)
        self.mlp = original.mlp

class ExtendedGPTNeoXModel(GPTNeoXModel):
    h: nn.ModuleList

    def __init__(self, config):
        GPTNeoXPreTrainedModel.__init__(self, config)
        self.config = config

        self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
        self.emb_dropout = nn.Dropout(config.hidden_dropout)
        self.layers = nn.ModuleList([ExtendedGPTNeoXLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
        self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.rotary_emb = GPTNeoXRotaryEmbedding(config=config)

        self._attn_implementation = config._attn_implementation

        self.gradient_checkpointing = False

        self.post_init()

    def __getattr__(self, name):
        if name in ModelAttrAlias:
            return getattr(self, ModelAttrAlias[name])
        else:
            return super().__getattr__(name)
        
    def __setattr__(self, name, value):
        if name in LayerAttrAlias:
            return setattr(self, LayerAttrAlias[name], value)
        else:
            return super().__setattr__(name, value)
        
    def transplant(self, original: GPTNeoXModel):
        self.embed_in = original.embed_in
        self.emb_dropout = original.emb_dropout
        for layer, original_layer in zip(self.layers, original.layers):
            layer.transplant(original_layer)
        self.final_layer_norm = original.final_layer_norm

class ExtendedGPTNeoXForCausalLM(GPTNeoXForCausalLM):
    _tied_weights_keys = ["embed_out.weight"]
    transformer: ExtendedGPTNeoXModel

    def __init__(self, config: ExtendedGPTNeoXConfig):
        GPTNeoXPreTrainedModel.__init__(self, config)

        self.gpt_neox = ExtendedGPTNeoXModel(config)
        self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.post_init()

    def __getattr__(self, name):
        if name in CausalLMAttrAlias:
            return getattr(self, CausalLMAttrAlias[name])
        else:
            return super().__getattr__(name)
        
    def __setattr__(self, name, value):
        if name in LayerAttrAlias:
            return setattr(self, LayerAttrAlias[name], value)
        else:
            return super().__setattr__(name, value)

    def transplant(self, original: GPTNeoXForCausalLM):
        assert self.config.use_linear_attn, "Only linear attention is supported"

        self.transformer.transplant(original.transformer)
        self.embed_out = original.embed_out
    
    def switch_recurrence(self, recurrence: bool):
        for layer in self.gpt_neox.layers:
            layer.attention.switch_recurrence(recurrence)