import torch
from typing import Optional, Dict, List
import torch.nn as nn
import torch.nn.functional as F
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling

from datasets import load_dataset
from torch.utils.data import Dataset

import copy
import math

from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
    AttentionMaskConverter,
    _prepare_4d_attention_mask,
    _prepare_4d_causal_attention_mask,
    _prepare_4d_causal_attention_mask_for_sdpa,
)

import torch_npu
from torch_npu.npu import amp
from torch_npu.contrib import transfer_to_npu

# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim=None,
        max_position_embeddings=2048,
        base=10000,
        device=None,
        scaling_factor=1.0,
        rope_type="default",
        config = None,
    ):
        super().__init__()
        # TODO (joao): remove the `if` below, only used for BC
        self.rope_kwargs = {}
        if config is None:
            logger.warning_once(
                "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
                "`config` argument. All other arguments will be removed in v4.46"
            )
            self.rope_kwargs = {
                "rope_type": rope_type,
                "factor": scaling_factor,
                "dim": dim,
                "base": base,
                "max_position_embeddings": max_position_embeddings,
            }
            self.rope_type = rope_type
            self.max_seq_len_cached = max_position_embeddings
            self.original_max_seq_len = max_position_embeddings
        else:
            # BC: "rope_type" was originally "type"
            if config.rope_scaling is not None:
                self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
            else:
                self.rope_type = "default"
            self.max_seq_len_cached = config.max_position_embeddings
            self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    def _dynamic_frequency_update(self, position_ids, device):
        """
        dynamic RoPE layers should recompute `inv_freq` in the following situations:
        1 - growing beyond the cached sequence length (allow scaling)
        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
        """
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len, **self.rope_kwargs
            )
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilation
            self.max_seq_len_cached = seq_len

        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class Qwen2Attention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """

    def __init__(self, config, layer_idx = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.attention_dropout = config.attention_dropout

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class SharedExpertLayer(nn.Module):
    def __init__(self, W_shared_out, W_shared_in, core_tensors, use_residual=False, grad=False):
        super(SharedExpertLayer, self).__init__()

        if grad:
            self.W_shared_out = nn.Parameter(W_shared_out.contiguous())  # [out_dim, r1]
            self.W_shared_in = nn.Parameter(W_shared_in.contiguous())   # [r2, in_dim]
        else:
            self.W_shared_out = nn.Parameter(W_shared_out.contiguous()).requires_grad_(False)  # [out_dim, r1]
            self.W_shared_in = nn.Parameter(W_shared_in.contiguous()).requires_grad_(False)     # [r2, in_dim]
        # Ensure cores are contiguous
        self.cores = nn.ParameterList([
            nn.Parameter(core.contiguous()) for core in core_tensors
        ])

        self.use_residual = use_residual
        if self.use_residual:
            self.residual = nn.Linear(W_shared_in.shape[1], W_shared_out.shape[0])

    def forward(self, x):
        outputs = []
        for core in self.cores:
            
            weight = torch.matmul(self.W_shared_out, torch.matmul(core, self.W_shared_in.T))
            x = x.to(weight.dtype)  
            out = x @ weight.T
            outputs.append(out)

        out = torch.stack(outputs).mean(0)

        if self.use_residual:
            out = out + self.residual(x)

        return out


class MoEAdapterLayer(nn.Module):
    def __init__(self, shared_params, task_params, config, use_residual=False, activation_fn=ACT2FN, grad=False):
        super(MoEAdapterLayer, self).__init__()

        self.gate_proj = SharedExpertLayer(shared_params['gate_proj_out'], shared_params['gate_proj_in'], task_params['gate_proj_core'], use_residual, grad=grad)
        self.up_proj = SharedExpertLayer(shared_params['up_proj_out'], shared_params['up_proj_in'], task_params['up_proj_core'], use_residual, grad=grad)
        self.down_proj = SharedExpertLayer(shared_params['down_proj_out'], shared_params['down_proj_in'], task_params['down_proj_core'], use_residual, grad=grad)

        self.act_fn = activation_fn[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


class AdaptedQwenLayer(nn.Module):
    def __init__(self, original_layer, shared_params, task_params, config, layer_idx, grad=False, use_random_init=False):
        super().__init__()

        if use_random_init:
            self.self_attn = Qwen2Attention(config, layer_idx)
            self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.self_attn = copy.deepcopy(original_layer.self_attn)
            self.input_layernorm = copy.deepcopy(original_layer.input_layernorm)
            self.post_attention_layernorm = copy.deepcopy(original_layer.post_attention_layernorm)

        self.moe_layer = MoEAdapterLayer(shared_params, task_params, config, grad=grad)

    def forward(self, hidden_states, attention_mask=None, **kwargs):

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states= self.self_attn(hidden_states, attention_mask=attention_mask, **kwargs)[0]

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        
        hidden_states = self.moe_layer(hidden_states)

        hidden_states = residual + hidden_states
        output = (hidden_states,)

        # output = self.post_attention_layernorm(moe_output)
        return output
    
class AdaptedQwen(nn.Module):
    def __init__(self, base_model: nn.Module, 
                 shared_param_paths: Dict[str, str], 
                 task_param_paths: Dict[str, Dict[str, str]],
                 n_layers: Optional[int] = None,
                 grad=False, use_random_init=False):
        super().__init__()

        for param in base_model.parameters():
            param.data = param.data.cpu()

        self.config = base_model.config
        self.embed_tokens = base_model.model.embed_tokens
        self.rotary_emb = base_model.model.rotary_emb

        self.layers = nn.ModuleList()
        total_layers = n_layers if n_layers is not None else len(base_model.model.layers)


        shared_params = {
            name: torch.load(path, map_location='cpu').float()
            for name, path in shared_param_paths.items()
        }

        for i in range(total_layers):
            original_layer = copy.deepcopy(base_model.model.layers[i])
            
            task_params = {
                name: [torch.load(p, map_location='cpu').float() for p in (layer_paths[i] if isinstance(layer_paths[i], list) else [layer_paths[i]])]
                for name, layer_paths in task_param_paths.items()
            }
            layer = AdaptedQwenLayer(original_layer, shared_params, task_params, self.config, i, grad=grad,
                                            use_random_init=use_random_init)
            self.layers.append(layer)

        if use_random_init:
            self.norm = Qwen2RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
        else:
            self.norm = copy.deepcopy(base_model.model.norm)
        self.classifier = nn.Linear(self.config.hidden_size, 1)

        del base_model  
        torch.cuda.empty_cache()

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))

    def forward(self, input_ids, attention_mask=None, labels=None, position_ids=None):
        B, C, S = input_ids.shape
        H = self.embed_tokens.embedding_dim

        if position_ids is None:
            position_ids = torch.arange(S, dtype=torch.long, device=input_ids.device).unsqueeze(0).expand(B * C, S)
        else:
            position_ids = position_ids.view(B * C, S)

        # Step 1: embed
        hidden_states = self.embed_tokens(input_ids)  # [B, C, S, H]
        position_embeddings = self.rotary_emb(hidden_states, position_ids)
        hidden_states_ = hidden_states

        # Step 2: reshape to [B*C, S, H]
        hidden_states = hidden_states.view(B * C, S, H)
        # position_embeddings = self.embed_positions(position_ids)
    
        # Step 3: reshape attention mask
        if attention_mask is not None:

            attention_mask = attention_mask.view(B * C, 1, 1, S).to(hidden_states.dtype)
            
            attention_mask = attention_mask.expand(-1, 1, S, S)
            attention_mask = (1.0 - attention_mask) * -1e9
            if torch.isnan(attention_mask).any() or torch.isinf(attention_mask).any():
                raise ValueError("NaN or Inf detected in attention mask!")

        # Step 4: transformer layers
        # for layer in self.layers:
        #     hidden_states = layer(hidden_states, attention_mask=attention_mask)[0]

        for idx, layer in enumerate(self.layers):
            hidden_states = layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                position_embeddings=position_embeddings
            )[0]
            if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
                raise ValueError(f"NaN detected after layer {idx}")


        # Step 5: normalization
        hidden_states = self.norm(hidden_states)

        # Step 6: get CLS token representation
        cls_rep = hidden_states[:, 0, :]  # [B*C, H]

        # Step 7: classifier
        choice_logits = self.classifier(cls_rep)  # [B*C, 1]
        logits = choice_logits.view(B, C)         # [B, C]

        if torch.isnan(logits).any() or torch.isinf(logits).any():
            raise ValueError("NaN or Inf in logits before loss computation!")

        if torch.isnan(labels).any():
            raise ValueError("NaN in labels!")
        if labels is not None:
            labels = labels.to(logits.device)
            loss_fn = torch.nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
            return {"loss": loss, "logits": logits}




def build_qwen(model_path, load_param_path, n_layers=4, num_experts=2, classification=True, grad=False, use_random_init=False):
    
    base_model = AutoModelForCausalLM.from_pretrained(
        model_path, device_map="cpu", trust_remote_code=True, torch_dtype=torch.float32,
        cache_dir="ckpt"
    )

    # Shared param paths (using FP16)
    shared_param_paths = {
        'down_proj_out': f'{load_param_path}/down_proj/U_out.pt',
        'down_proj_in': f'{load_param_path}/down_proj/U_in.pt',
        'gate_proj_out': f'{load_param_path}/gate_proj/U_out.pt',
        'gate_proj_in': f'{load_param_path}/gate_proj/U_in.pt',
        'up_proj_out': f'{load_param_path}/up_proj/U_out.pt',
        'up_proj_in': f'{load_param_path}/up_proj/U_in.pt',
    }

    task_param_paths = {
        'down_proj_core': {i: [f'{load_param_path}/down_proj/layer{i}_coeff.pt'] for i in range(0,23)},
        'gate_proj_core': {i: [f'{load_param_path}/gate_proj/layer{i}_coeff.pt'] for i in range(0,23)},
        'up_proj_core': {i: [f'{load_param_path}/up_proj/layer{i}_coeff.pt'] for i in range(0,23)},
    }

    model = AdaptedQwen(base_model, shared_param_paths, task_param_paths, n_layers=n_layers, 
                            grad=grad, use_random_init=use_random_init)

    return model
