import numpy as np
import torch
from torch import nn
from transformers import AutoModelForCausalLM
import sys
import time
import os
from torch.cuda.amp import autocast

from torch import matmul
from transformers.models.mistral.modeling_mistral import MistralConfig, MistralDecoderLayer, MistralForCausalLM, MistralRMSNorm
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.processing_utils import Unpack
from typing import Callable, List, Optional, Tuple, Union

from transformers import AutoConfig, AutoModelForCausalLM
from accelerate import init_empty_weights

def get_mistral(model):
    import torch
    def skip(*args, **kwargs):
        pass
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    from transformers import AutoModelForCausalLM
    model = AutoModelForCausalLM.from_pretrained(model, torch_dtype='auto')
    model.seqlen = 2048
    return model

def mistral_fuse_rms_single_layer(layer):
    layer.self_attn.q_proj.weight.data = layer.self_attn.q_proj.weight.data @ torch.diag(layer.input_layernorm.weight.data)
    layer.self_attn.k_proj.weight.data = layer.self_attn.k_proj.weight.data @ torch.diag(layer.input_layernorm.weight.data)
    layer.self_attn.v_proj.weight.data = layer.self_attn.v_proj.weight.data @ torch.diag(layer.input_layernorm.weight.data)
    layer.input_layernorm.weight.data = torch.ones_like(layer.input_layernorm.weight.data, dtype=layer.input_layernorm.weight.dtype, device=layer.input_layernorm.weight.device)

    layer.mlp.up_proj.weight.data = layer.mlp.up_proj.weight.data @ torch.diag(layer.post_attention_layernorm.weight.data)
    layer.mlp.gate_proj.weight.data = layer.mlp.gate_proj.weight.data @ torch.diag(layer.post_attention_layernorm.weight.data)
    layer.post_attention_layernorm.weight.data = torch.ones_like(layer.post_attention_layernorm.weight.data, dtype=layer.post_attention_layernorm.weight.dtype, device=layer.post_attention_layernorm.weight.device)

def apply_R1_transform(layer, R1):
    layer.self_attn.q_proj.weight.data = layer.self_attn.q_proj.weight.data @ R1
    layer.self_attn.k_proj.weight.data = layer.self_attn.k_proj.weight.data @ R1
    layer.self_attn.v_proj.weight.data = layer.self_attn.v_proj.weight.data @ R1
    layer.self_attn.o_proj.weight.data = R1.T @ layer.self_attn.o_proj.weight.data
    layer.mlp.up_proj.weight.data = layer.mlp.up_proj.weight.data @ R1
    layer.mlp.gate_proj.weight.data = layer.mlp.gate_proj.weight.data @ R1
    layer.mlp.down_proj.weight.data = R1.T @ layer.mlp.down_proj.weight.data

def apply_R2_transform(layer, R2_list):
    R2_list_o = []
    for r2 in R2_list:
        for _ in range(layer.self_attn.v_proj.weight.data.shape[1] // layer.self_attn.v_proj.weight.data.shape[0]):
            R2_list_o.append(r2)
    R2_transform_o = torch.block_diag(*R2_list_o).to(layer.self_attn.v_proj.weight.data.device)
    R2_transform_v = torch.block_diag(*R2_list).to(layer.self_attn.v_proj.weight.data.device)

    layer.self_attn.v_proj.weight.data = R2_transform_v.T @ layer.self_attn.v_proj.weight.data
    layer.self_attn.o_proj.weight.data = layer.self_attn.o_proj.weight.data @ R2_transform_o

def mistral_fuse_rotation_single_layer(layer, R1=None, R2_list=None):
    if R1 is not None:
        apply_R1_transform(layer, R1)
    if R2_list is not None:
        apply_R2_transform(layer, R2_list)


class RotatedMistralDecoderLayer(MistralDecoderLayer):

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        
        # hidden_states (bsz, length, hidden_dim)
        if self.R1 is not None:
            hidden_states = matmul(hidden_states, self.R1)
        
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states
        
        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        if self.R1 is not None:
            hidden_states = matmul(hidden_states, self.R1.T)

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs

def replace_mistral_layer(model, layer_idx, R1=None):
    rotated_layer = RotatedMistralDecoderLayer(model.config, layer_idx).to(model.config.torch_dtype)
    rotated_layer.load_state_dict(model.model.layers[layer_idx].state_dict(), strict=True)

    hidden_size = model.config.hidden_size
    rotated_layer.R1 = torch.nn.Parameter(R1.to(model.config.torch_dtype).to(model.device))
    model.model.layers[layer_idx] = rotated_layer

def load_rotated_mistral(config_path, state_dict_path):
    config = AutoConfig.from_pretrained(config_path)
    model = AutoModelForCausalLM.from_config(config)
    
    hidden_size = config.hidden_size
    for idx in range(config.num_hidden_layers):
        R1 = torch.eye(hidden_size)
        replace_mistral_layer(model, idx, R1)
    
    state_dict = torch.load(state_dict_path, map_location="cpu")
    model.load_state_dict(state_dict, strict=False)
    
    return model

def load_rotated_mistral_fast(config_path, state_dict_path):
    model = get_mistral(config_path)
    
    hidden_size = model.config.hidden_size
    for idx in range(model.config.num_hidden_layers):
        R1 = torch.eye(hidden_size)
        replace_mistral_layer(model, idx, R1)
    
    state_dict = torch.load(state_dict_path, map_location="cpu")
    model.load_state_dict(state_dict, strict=False)
    
    return model