import torch
import torch.nn as nn
from typing import Optional, Dict, List
from transformers import AutoModelForCausalLM, AutoConfig
import copy
import math
import os
try:
    from transformers.activations import ACT2FN
except ImportError:
    from transformers.models.mixtral.modeling_mixtral import ACT2FN

from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

def randomly_initialize_module(module):
    for name, param in module.named_parameters(recurse=True):
        if param.requires_grad:
            if param.data.dim() >= 2:
                torch.nn.init.kaiming_uniform_(param)
            else:
                torch.nn.init.uniform_(param)

import torch_npu
from torch_npu.npu import amp
from torch_npu.contrib import transfer_to_npu

def build_causal_attention_mask(attention_mask):
    
    batch_size, seq_len = attention_mask.size()
   
    causal_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=attention_mask.dtype, device=attention_mask.device))
   
    causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)  # [1,1,seq_len,seq_len]
    
    attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # [batch,1,1,seq_len]
    mask = causal_mask * attention_mask
   
    mask = (1.0 - mask) * -1e9
    return mask

class DummySelfAttention(nn.Module):
    def forward(self, hidden_states, attention_mask=None, **kwargs):
        return (hidden_states, None, None)

# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# TODO @longjie no longer copied from Mistral after static cache
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

class MixtralRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        MixtralRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

class MixtralRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

class MixtralAttention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """

    def __init__(self, config, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.attention_dropout = config.attention_dropout

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        self.rotary_emb = MixtralRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class SharedExpertLayer(nn.Module):
    def __init__(self, W_shared_out, W_shared_in, core_tensors, input_dim, output_dim, use_residual=False, grad=False):
        super().__init__()
        if grad:
            self.W_shared_out = nn.Parameter(W_shared_out.contiguous())
            self.W_shared_in = nn.Parameter(W_shared_in.contiguous())
        else:
            self.W_shared_out = nn.Parameter(W_shared_out.contiguous()).requires_grad_(False)
            self.W_shared_in = nn.Parameter(W_shared_in.contiguous()).requires_grad_(False)

        self.cores = nn.ParameterList([
            nn.Parameter(core.contiguous()) for core in core_tensors
        ])

        self.use_residual = use_residual
        if self.use_residual:
            self.residual = nn.Linear(input_dim, output_dim)
            nn.init.kaiming_uniform_(self.residual.weight, a=math.sqrt(5))

    def forward(self, x):
        outputs = []
        for core in self.cores:
            weight = torch.matmul(self.W_shared_out, torch.matmul(core, self.W_shared_in.T))
            x = x.to(weight.dtype)
            out = x @ weight.T
            outputs.append(out)

        out = torch.stack(outputs).mean(0)

        if self.use_residual:
            out = out + self.residual(x)

        return out

class MoEAdapterLayer(nn.Module):
    def __init__(self, shared_params: Dict[str, torch.Tensor], 
                 task_params: Dict[str, torch.Tensor], 
                 config,
                 hidden_dim: int, intermediate_dim: int,
                 use_residual=False, grad=False,
                 activation_fn=ACT2FN):
        super().__init__()
        self.w1 = SharedExpertLayer(shared_params['w1_out'], shared_params['w1_in'], task_params['w1_core'], hidden_dim, intermediate_dim, use_residual, grad=grad)
        self.w2 = SharedExpertLayer(shared_params['w2_out'], shared_params['w2_in'], task_params['w2_core'], intermediate_dim, hidden_dim, use_residual, grad=grad)
        self.w3 = SharedExpertLayer(shared_params['w3_out'], shared_params['w3_in'], task_params['w3_core'], hidden_dim, intermediate_dim, use_residual, grad=grad)
        self.act_fn = activation_fn[config.hidden_act]

    def forward(self, x):
        hidden = self.act_fn(self.w1(x)) * self.w3(x)
        return self.w2(hidden)

class AdaptedMixtralLayer(nn.Module):
    def __init__(self, original_layer, shared_params, task_params, config, layer_idx, use_residual, grad):
        super().__init__()
        hidden_dim = config.hidden_size

        self.self_attn = MixtralAttention(config, layer_idx)
        # self.input_layernorm = nn.LayerNorm(hidden_dim)
        # self.post_attention_layernorm = nn.LayerNorm(hidden_dim)

        self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        
        # intermediate_dim = original_layer.block_sparse_moe.experts[0].w1.out_features
        if original_layer is not None:
            intermediate_dim = original_layer.block_sparse_moe.experts[0].w1.out_features
        else:
            
            intermediate_dim = getattr(config, 'intermediate_size', 4096)

        self.moe_adapter = MoEAdapterLayer(
            shared_params, task_params, config, hidden_dim, intermediate_dim, use_residual=use_residual, grad=grad
        )

    def forward(self, hidden_states, attention_mask=None, **kwargs):
        if attention_mask.dim() == 1:
            attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
        elif attention_mask.dim() == 2:
            attention_mask = attention_mask[:, None, None, :]

        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        if isinstance(self.self_attn, DummySelfAttention):
            hidden_states = self.self_attn(hidden_states)[0]
        else:
            hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)[0]

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.moe_adapter(hidden_states)
        hidden_states = hidden_states + residual

        return (hidden_states,)

class AdaptedMixtral(nn.Module):
    def __init__(self, base_model: nn.Module, 
                 shared_param_paths: Dict[str, str], 
                 task_param_paths: Dict[str, Dict[str, str]],
                  num_labels,
                 n_layers: Optional[int] = None, device: str = "cpu", 
                 use_residual=False, grad=False,
                 classification=False):
        super().__init__()

        # for param in base_model.parameters():
        #     param.data = param.data.cpu()

        self.classification = classification

        self.config = base_model.config
        self.embed_tokens = base_model.model.embed_tokens
        self.device = device

        self.layers = nn.ModuleList()
        total_layers = n_layers if n_layers is not None else len(base_model.model.layers)

        shared_params = {
            name: 0.01 * torch.randn_like(torch.load(path, map_location="cpu").to(torch.float32))
            for name, path in shared_param_paths.items()
        }
        task_params = {
            name: {
                i: [0.01 * torch.randn_like(torch.load(p, map_location="cpu").to(torch.float32))]
                for i, plist in layer_paths.items()
                for p in (plist if isinstance(plist, list) else [plist])
            }
            for name, layer_paths in task_param_paths.items()
        }

        for i in range(total_layers):
            # original_layer = base_model.model.layers[i]
            original_layer = None
            task_params_i = { name: task_params[name][i] for name in task_params }
            layer = AdaptedMixtralLayer(original_layer, shared_params, task_params_i, self.config, layer_idx=i, use_residual=use_residual, grad=grad)
            self.layers.append(layer)

        # self.norm = nn.LayerNorm(self.config.hidden_size)
        self.norm = MixtralRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
        if not classification:
            self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
            
        self.classifier_single = nn.Linear(self.config.hidden_size, num_labels)  
        self.classifier_mc = nn.Linear(self.config.hidden_size, 1)   
        
        del base_model
        torch.cuda.empty_cache()
        randomly_initialize_module(self)

    def set_device(self):
        self.to(self.device)

    def save_pretrained(self, save_directory):
        os.makedirs(save_directory, exist_ok=True)
        torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
        if hasattr(self, "config"):
            
            if hasattr(self.config, "save_pretrained"):
                self.config.save_pretrained(save_directory)
            else:
                
                import json
                with open(os.path.join(save_directory, "config.json"), "w") as f:
                    json.dump(self.config, f)
        

    @classmethod
    def from_pretrained(cls, load_directory, config=None, **kwargs):
        if config is None:
            from transformers import AutoConfig
            config = AutoConfig.from_pretrained(load_directory)
        model = cls(config, **kwargs)  
        state_dict = torch.load(os.path.join(load_directory, 'pytorch_model.bin'), map_location='cpu')
        for k in list(state_dict.keys()):
            if k.startswith('lm_head'):
                del state_dict[k]
        model.load_state_dict(state_dict, strict=False)
        
        return model


    def forward(self, input_ids, attention_mask=None, labels=None):
        if self.classification and input_ids.dim() == 3:
            
            B, C, S = input_ids.shape
            H = self.embed_tokens.embedding_dim
            hidden_states = self.embed_tokens(input_ids).view(B * C, S, H)
            if attention_mask is not None:
                attention_mask = attention_mask.view(B * C, S)
                if attention_mask.dtype != torch.bool:
                    attention_mask = attention_mask.bool()
                attention_mask = _prepare_4d_causal_attention_mask(
                    attention_mask, input_shape=(B * C, S),
                    inputs_embeds=hidden_states, past_key_values_length=0
                )
            for i, layer in enumerate(self.layers):
                hidden_states = layer(hidden_states, attention_mask=attention_mask)[0]
                if torch.isnan(hidden_states).any():
                    raise ValueError(f"❗ NaN detected after layer {i}")
            hidden_states = self.norm(hidden_states)
            cls_rep = hidden_states[:, 0, :]
            logits = self.classifier_mc(cls_rep).view(B, C)   # (B, C)
            loss = None
            if labels is not None:
                loss = nn.CrossEntropyLoss()(logits, labels)  
            return {"loss": loss, "logits": logits}

        elif self.classification and input_ids.dim() == 2:
            
            B, S = input_ids.shape
            H = self.embed_tokens.embedding_dim
            hidden_states = self.embed_tokens(input_ids)      # (B, S, H)
            if attention_mask is not None and attention_mask.dtype != torch.bool:
                attention_mask = attention_mask.bool()
                attention_mask = _prepare_4d_causal_attention_mask(
                    attention_mask, input_shape=(B, S),
                    inputs_embeds=hidden_states, past_key_values_length=0
                )
            for i, layer in enumerate(self.layers):
                hidden_states = layer(hidden_states, attention_mask=attention_mask)[0]
                if torch.isnan(hidden_states).any():
                    raise ValueError(f"❗ NaN detected after layer {i}")
            hidden_states = self.norm(hidden_states)
            cls_rep = hidden_states[:, 0, :]
            logits = self.classifier_single(cls_rep)          
            loss = None
            if labels is not None:
                loss = nn.CrossEntropyLoss()(logits, labels)  
            return {"loss": loss, "logits": logits}

        else:
            
            B, S = input_ids.shape
            H = self.embed_tokens.embedding_dim

            hidden_states = self.embed_tokens(input_ids)   # (B, S, H)

            if attention_mask is not None:
                if attention_mask.dtype != torch.bool:
                    attention_mask = attention_mask.bool()
                attention_mask = _prepare_4d_causal_attention_mask(
                    attention_mask,
                    input_shape=(B, S),
                    inputs_embeds=hidden_states,
                    past_key_values_length=0
                )

            for i, layer in enumerate(self.layers):
                hidden_states = layer(hidden_states, attention_mask=attention_mask)[0]
                if torch.isnan(hidden_states).any():
                    raise ValueError(f"❗ NaN detected after layer {i}")

            hidden_states = self.norm(hidden_states)

            logits = self.lm_head(hidden_states)  # (B, S, vocab_size)

            loss = None
            if labels is not None:
                # labels: (B, S)
                loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
                loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))

            return {"loss": loss, "logits": logits}


def build_mixtral_random(model_path, load_param_path,  num_labels=4, n_layers=4, device='cpu', use_residual=False, grad=False, classification=True):
   

    config = AutoConfig.from_pretrained(model_path)

   
    class DummyBaseModel:
        def __init__(self, config):
            self.config = config
            
            self.model = type('', (), {})()
            self.model.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
            
            self.model.layers = [None for _ in range(n_layers)]
    
    base_model = DummyBaseModel(config)

    shared_param_paths = {
        'w1_out': f'{load_param_path}/w1/U_out.pt',
        'w1_in': f'{load_param_path}/w1/U_in.pt',
        'w2_out': f'{load_param_path}/w2/U_out.pt',
        'w2_in': f'{load_param_path}/w2/U_in.pt',
        'w3_out': f'{load_param_path}/w3/U_out.pt',
        'w3_in': f'{load_param_path}/w3/U_in.pt',
    }

    task_param_paths = {
        'w1_core': {i: [f'{load_param_path}/w1/layer{i}_coeff.pt'] for i in range(32)},
        'w2_core': {i: [f'{load_param_path}/w2/layer{i}_coeff.pt'] for i in range(32)},
        'w3_core': {i: [f'{load_param_path}/w3/layer{i}_coeff.pt'] for i in range(32)},
    }

    model = AdaptedMixtral(base_model, shared_param_paths, task_param_paths,  num_labels= num_labels, n_layers=n_layers, 
                           device=device, use_residual=use_residual, grad=grad, classification=classification)
    return model


