import torch
from typing import Optional, Dict, List
import torch.nn as nn
import torch.nn.functional as F
import os
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling, AutoConfig

import torch
from modelscope import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

from datasets import load_dataset
from torch.utils.data import Dataset

import copy
import math

from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
    AttentionMaskConverter,
    _prepare_4d_attention_mask,
    _prepare_4d_causal_attention_mask,
    _prepare_4d_causal_attention_mask_for_sdpa,
)

import torch_npu
from torch_npu.npu import amp
from torch_npu.contrib import transfer_to_npu

class DeepseekRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        DeepseekRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class DeepseekRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )
        self.max_seq_len_cached = None


    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.outer(t, self.inv_freq.to(t.device))
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->Deepseek
class DeepseekAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config, layer_idx = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
        self._init_rope()

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = DeepseekRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        else:
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = DeepseekLinearScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = DeepseekDynamicNTKScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )

        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class SharedExpertLayer(nn.Module):
    def __init__(self, W_shared_out, W_shared_in, core_tensors, use_residual=False, grad=False):
        super(SharedExpertLayer, self).__init__()

        if grad:
            self.W_shared_out = nn.Parameter(W_shared_out.contiguous())  # [out_dim, r1]
            self.W_shared_in = nn.Parameter(W_shared_in.contiguous())   # [r2, in_dim]
        else:
            self.W_shared_out = nn.Parameter(W_shared_out.contiguous()).requires_grad_(False)  # [out_dim, r1]
            self.W_shared_in = nn.Parameter(W_shared_in.contiguous()).requires_grad_(False)     # [r2, in_dim]
        # Ensure cores are contiguous
        self.cores = nn.ParameterList([
            nn.Parameter(core.contiguous()) for core in core_tensors
        ])

        self.use_residual = use_residual
        if self.use_residual:
            self.residual = nn.Linear(W_shared_in.shape[1], W_shared_out.shape[0])

    def forward(self, x):
        outputs = []
        for core in self.cores:
            weight = torch.matmul(self.W_shared_out, torch.matmul(core, self.W_shared_in.T))
            x = x.to(weight.dtype)  
            out = x @ weight.T
            outputs.append(out)

        out = torch.stack(outputs).mean(0)

        if self.use_residual:
            out = out + self.residual(x)

        return out


class MoEAdapterLayer(nn.Module):
    def __init__(self, shared_params, task_params, config, use_residual=False, activation_fn=ACT2FN, grad=False):
        super(MoEAdapterLayer, self).__init__()

        self.gate_proj = SharedExpertLayer(shared_params['gate_proj_out'], shared_params['gate_proj_in'], task_params['gate_proj_core'], use_residual, grad=grad)
        self.up_proj = SharedExpertLayer(shared_params['up_proj_out'], shared_params['up_proj_in'], task_params['up_proj_core'], use_residual, grad=grad)
        self.down_proj = SharedExpertLayer(shared_params['down_proj_out'], shared_params['down_proj_in'], task_params['down_proj_core'], use_residual, grad=grad)

        self.act_fn = activation_fn[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj
    

class AdaptedDeepseekLayer(nn.Module):
    def __init__(self, original_layer, shared_params, task_params, config, layer_idx, grad=False, use_random_init=False):
        super().__init__()

        if use_random_init:
            self.self_attn = DeepseekAttention(config, layer_idx)
            self.input_layernorm = DeepseekRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            self.post_attention_layernorm = DeepseekRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.self_attn = copy.deepcopy(original_layer.self_attn)
            self.input_layernorm = copy.deepcopy(original_layer.input_layernorm)
            self.post_attention_layernorm = copy.deepcopy(original_layer.post_attention_layernorm)

        self.moe_layer = MoEAdapterLayer(shared_params, task_params, config, grad=grad)

    def forward(self, hidden_states, attention_mask=None, **kwargs):

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states= self.self_attn(hidden_states, attention_mask=attention_mask, **kwargs)[0]

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        

        hidden_states = self.moe_layer(hidden_states)

        hidden_states = residual + hidden_states
        output = (hidden_states,)

        # output = self.post_attention_layernorm(moe_output)
        return output
    
class AdaptedDeepseek(nn.Module):
    def __init__(self, base_model: nn.Module, 
                 shared_param_paths: Dict[str, str], 
                 task_param_paths: Dict[str, Dict[str, str]],
                 n_layers: Optional[int] = None,
                 grad=False, use_random_init=True,
                 classification=False):
        super().__init__()

        for param in base_model.parameters():
            param.data = param.data.cpu()

        self.config = base_model.config
        self.embed_tokens = base_model.model.embed_tokens

        self.classification = classification

        self.layers = nn.ModuleList()
        total_layers = n_layers if n_layers is not None else len(base_model.model.layers)

        shared_params = {
            name: torch.load(path, map_location='cpu').float()
            for name, path in shared_param_paths.items()
        }

        for i in range(total_layers):
            original_layer = copy.deepcopy(base_model.model.layers[i])
            if i == 0:
                self.layers.append(original_layer)
            else:
                task_params = {
                    name: [torch.load(p, map_location='cpu').float() for p in (layer_paths[i] if isinstance(layer_paths[i], list) else [layer_paths[i]])]
                    for name, layer_paths in task_param_paths.items()
                }
                layer = AdaptedDeepseekLayer(original_layer, shared_params, task_params, self.config, i, grad=grad,
                                             use_random_init=use_random_init)
                self.layers.append(layer)

        if use_random_init:
            self.norm = DeepseekRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
        else:
            self.norm = copy.deepcopy(base_model.model.norm)

        if not classification:
            self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
        self.classifier = nn.Linear(self.config.hidden_size, 1)

        del base_model  
        torch.cuda.empty_cache()

    def save_pretrained(self, save_directory):
        os.makedirs(save_directory, exist_ok=True)
        torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
        if hasattr(self, "config"):
 
            if hasattr(self.config, "save_pretrained"):
                self.config.save_pretrained(save_directory)
            else:
                import json
                with open(os.path.join(save_directory, "config.json"), "w") as f:
                    json.dump(self.config, f)
        print(f"save {save_directory}")


    @classmethod
    def from_pretrained(cls, load_directory, shared_param_paths=None, task_param_paths=None, **kwargs):
        from transformers import AutoConfig, AutoModelForCausalLM
        config = AutoConfig.from_pretrained(load_directory)
       
        base_model = AutoModelForCausalLM.from_pretrained(load_directory, config=config)
       
        model = cls(
            base_model,
            shared_param_paths=shared_param_paths,
            task_param_paths=task_param_paths,
            **kwargs
        )
       
        return model




    def forward(self, input_ids, attention_mask=None, labels=None):
        
        if self.classification:
            
            B, C, S = input_ids.shape
            H = self.embed_tokens.embedding_dim

            # Step 1: embed
            hidden_states = self.embed_tokens(input_ids)  # [B, C, S, H]
            hidden_states_ = hidden_states

            # Step 2: reshape to [B*C, S, H]
            hidden_states = hidden_states.view(B * C, S, H)

            # Step 3: attention mask
            if attention_mask is not None:
                attention_mask = attention_mask.view(B * C, 1, 1, S).to(hidden_states.dtype)
                attention_mask = attention_mask.expand(-1, 1, S, S)
                attention_mask = (1.0 - attention_mask) * -1e9
                if torch.isnan(attention_mask).any() or torch.isinf(attention_mask).any():
                    raise ValueError("NaN or Inf detected in attention mask!")

            # Step 4: transformer layers
            for idx, layer in enumerate(self.layers):
                hidden_states = layer(hidden_states, attention_mask=attention_mask)[0]
                if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
                    raise ValueError(f"NaN detected after layer {idx}")

            # Step 5: normalization
            hidden_states = self.norm(hidden_states)

            # Step 6: get CLS token representation
            cls_rep = hidden_states[:, 0, :]  # [B*C, H]

            # Step 7: classifier
            choice_logits = self.classifier(cls_rep)  # [B*C, 1]
            logits = choice_logits.view(B, C)         # [B, C]

            # Step 8: loss
            loss = None
            if labels is not None:
                loss_fn = nn.CrossEntropyLoss()
                loss = loss_fn(logits, labels)

            return {"loss": loss, "logits": logits}

        else:
            
            input_shape = input_ids.size()
            if len(input_shape) == 3:  
                batch_size = input_shape[0]
                seq_len = input_shape[-1]
                input_ids = input_ids.view(batch_size, seq_len)
                if attention_mask is not None:
                    attention_mask = attention_mask.view(batch_size, seq_len)
            else:
                batch_size, seq_len = input_shape

            H = self.embed_tokens.embedding_dim
            hidden_states = self.embed_tokens(input_ids)  # [B, S, H]

            if attention_mask is not None:
                if attention_mask.dim() == 2:
                    # attention_mask = build_causal_attention_mask(attention_mask)
                    attention_mask = _prepare_4d_causal_attention_mask(
                            attention_mask,
                            input_shape=hidden_states.shape[:2],   # (batch, seq_len)
                            inputs_embeds=hidden_states,
                            past_key_values_length=0
                        )

            # transformer layers
            for idx, layer in enumerate(self.layers):
                hidden_states = layer(hidden_states, attention_mask=attention_mask)[0]
                if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
                    raise ValueError(f"NaN detected after layer {idx}")

            hidden_states = self.norm(hidden_states)  # [B, S, H]

            # language modeling head
            logits = self.lm_head(hidden_states)  # [B, S, vocab_size]

            loss = None
            if labels is not None:
                
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
                loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
                loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            return {"loss": loss, "logits": logits}

       
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from transformers.modeling_outputs import SequenceClassifierOutputWithPast

class AdaptedDeepseekForSequenceClassification(nn.Module):
    def __init__(self, base_model: nn.Module, 
                 shared_param_paths, 
                 task_param_paths, 
                 n_layers=None, 
                 grad=False, 
                 use_random_init=False, 
                 config=None,
                 num_labels=4):
        super().__init__()

        for param in base_model.parameters():
            param.data = param.data.cpu()

        self.config = base_model.config

        # self.num_labels = getattr(self.config, 'num_labels', 4)
        self.num_labels = num_labels
        self.problem_type = getattr(self.config, 'problem_type', None)
        self.embed_tokens = base_model.model.embed_tokens

        self._use_sdpa = config._attn_implementation == "sdpa"
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"

        self.layers = nn.ModuleList()
        total_layers = n_layers if n_layers is not None else len(base_model.model.layers)

        
        shared_params = {
            name: torch.load(path, map_location='cpu').float()
            for name, path in shared_param_paths.items()
        }

        
        for i in range(total_layers):
            original_layer = copy.deepcopy(base_model.model.layers[i])
            if i == 0:
                self.layers.append(original_layer)
            else:
                task_params = {
                    name: [torch.load(p, map_location='cpu').float() for p in (layer_paths[i] if isinstance(layer_paths[i], list) else [layer_paths[i]])]
                    for name, layer_paths in task_param_paths.items()
                }
                layer = AdaptedDeepseekLayer(original_layer, shared_params, task_params, self.config, i, grad=grad,
                                             use_random_init=use_random_init)
                self.layers.append(layer)

        
        if use_random_init:
            self.norm = DeepseekRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
        else:
            self.norm = copy.deepcopy(base_model.model.norm)

        
        print("=================== num label 111111  {}".format(self.num_labels))
        self.classifier_mc = nn.Linear(self.config.hidden_size, self.num_labels, bias=False)
        self.classifier = nn.Linear(self.config.hidden_size, 1, bias=False)
        del base_model
        torch.cuda.empty_cache()

    def forward(self, 
                input_ids,
                attention_mask=None,
                labels=None,
                position_ids=None,
                past_key_values=None,
                inputs_embeds=None,
                use_cache=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=True):
        
        # 1. embed
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape[:2]
        elif inputs_embeds is not None:
            batch_size, seq_length = inputs_embeds.shape[:2]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")
            
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.embed_tokens(input_ids)

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        past_key_values_length = 0
        if self._use_flash_attention_2:
            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        elif self._use_sdpa:
            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
                attention_mask,
                (batch_size, seq_length),
                inputs_embeds,
                past_key_values_length,
            )
        else:
            attention_mask = _prepare_4d_causal_attention_mask(
                attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
            )

        # 2. transformer 层
        for idx, layer in enumerate(self.layers):
            hidden_states = layer(hidden_states, attention_mask=attention_mask)[0]
        hidden_states = self.norm(hidden_states) # [B, S, H]

       
        if self.config.pad_token_id is None:
            sequence_lengths = -1
            pooled_logits = hidden_states[:, 0, :]  # fallback to CLS
        else:
            seq_mask = (input_ids == self.config.pad_token_id)
           
            sequence_lengths = seq_mask.int().argmax(dim=-1) - 1
            sequence_lengths[sequence_lengths < 0] = hidden_states.size(1) - 1
            pooled_logits = hidden_states[torch.arange(hidden_states.size(0)), sequence_lengths]

        logits = self.classifier_mc(pooled_logits)  # [B, num_labels]

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            
            if self.problem_type is None:
                if self.num_labels == 1:
                    self.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype in [torch.long, torch.int]):
                    self.problem_type = "single_label_classification"
                else:
                    self.problem_type = "multi_label_classification"

            if self.problem_type == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.squeeze(), labels.squeeze())
            elif self.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            return (loss, logits) if loss is not None else (logits,)
        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        )


def build_deepseek(model_path, load_param_path, n_layers=4, num_experts=2, classification=True, grad=False, use_random_init=False):
    
    try:
        import flash_attn
        use_flash_attn = True
    except ImportError:
        use_flash_attn = False

    import sys
    sys.path.append("ckpt_llm") 

    from deepseek.modeling_deepseek import DeepseekForCausalLM
    from deepseek.configuration_deepseek import DeepseekConfig
    
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    base_model = DeepseekForCausalLM.from_pretrained(
        model_path,
        config=config,
        torch_dtype=torch.float32,
        trust_remote_code=True 
    )

    # Shared param paths (using FP16)
    shared_param_paths = {
        'down_proj_out': f'{load_param_path}/down_proj/U_out.pt',
        'down_proj_in': f'{load_param_path}/down_proj/U_in.pt',
        'gate_proj_out': f'{load_param_path}/gate_proj/U_out.pt',
        'gate_proj_in': f'{load_param_path}/gate_proj/U_in.pt',
        'up_proj_out': f'{load_param_path}/up_proj/U_out.pt',
        'up_proj_in': f'{load_param_path}/up_proj/U_in.pt',
    }


    task_param_paths = {
        'down_proj_core': {i: [f'{load_param_path}/down_proj/layer{i}_coeff.pt'] for i in range(1,27)},
        'gate_proj_core': {i: [f'{load_param_path}/gate_proj/layer{i}_coeff.pt'] for i in range(1,27)},
        'up_proj_core': {i: [f'{load_param_path}/up_proj/layer{i}_coeff.pt'] for i in range(1,27)},
    }

    model = AdaptedDeepseek(base_model, shared_param_paths, task_param_paths, n_layers=n_layers, 
                            grad=grad, use_random_init=use_random_init, classification=classification)


    return model


def build_deepseek_classification(model_path, load_param_path, num_labels, n_layers=4, num_experts=2, classification=True, grad=False, use_random_init=False):
    
    try:
        import flash_attn
        use_flash_attn = True
    except ImportError:
        use_flash_attn = False

    import sys
    sys.path.append("/ckpt_llm")  

    from deepseek.modeling_deepseek import DeepseekForCausalLM
    from deepseek.configuration_deepseek import DeepseekConfig
    
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    base_model = DeepseekForCausalLM.from_pretrained(
        model_path,
        config=config,
        torch_dtype=torch.float32,
        trust_remote_code=True 
    )

    # Shared param paths (using FP16)
    shared_param_paths = {
        'down_proj_out': f'{load_param_path}/down_proj/U_out.pt',
        'down_proj_in': f'{load_param_path}/down_proj/U_in.pt',
        'gate_proj_out': f'{load_param_path}/gate_proj/U_out.pt',
        'gate_proj_in': f'{load_param_path}/gate_proj/U_in.pt',
        'up_proj_out': f'{load_param_path}/up_proj/U_out.pt',
        'up_proj_in': f'{load_param_path}/up_proj/U_in.pt',
    }

    
    task_param_paths = {
        'down_proj_core': {i: [f'{load_param_path}/down_proj/layer{i}_coeff.pt'] for i in range(1,27)},
        'gate_proj_core': {i: [f'{load_param_path}/gate_proj/layer{i}_coeff.pt'] for i in range(1,27)},
        'up_proj_core': {i: [f'{load_param_path}/up_proj/layer{i}_coeff.pt'] for i in range(1,27)},
    }

    
    model = AdaptedDeepseekForSequenceClassification(base_model, shared_param_paths, task_param_paths, n_layers=n_layers, 
                            grad=grad, use_random_init=use_random_init, config=config, num_labels=num_labels,
    )

    return model
