import torch
import copy
import os
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from transformers.activations import ACT2FN
import math
from typing import List, Optional, Tuple, Union
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoConfig
import torch_npu
from torch_npu.npu import amp
from torch_npu.contrib import transfer_to_npu

def randomly_initialize_module(module):
    for name, param in module.named_parameters(recurse=True):
        if param.requires_grad:
            if param.data.dim() >= 2:
                torch.nn.init.kaiming_uniform_(param)
            else:
                torch.nn.init.uniform_(param)

def build_causal_attention_mask(attention_mask):
   
    batch_size, seq_len = attention_mask.size()
  
    causal_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=attention_mask.dtype, device=attention_mask.device))
   
    causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)  # [1,1,seq_len,seq_len]
  
    attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # [batch,1,1,seq_len]
    mask = causal_mask * attention_mask
   
    mask = (1.0 - mask) * -1e9
    return mask

class DeepseekRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        DeepseekRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class DeepseekRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )
        self.max_seq_len_cached = None


    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.outer(t, self.inv_freq.to(t.device))
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->Deepseek
class DeepseekAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config, layer_idx = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
        self._init_rope()

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = DeepseekRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        else:
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = DeepseekLinearScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = DeepseekDynamicNTKScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )

        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class SharedExpertLayer(nn.Module):
    def __init__(self, W_shared_out, W_shared_in, core_tensors, use_residual=False, grad=False):
        super(SharedExpertLayer, self).__init__()

        if grad:
            self.W_shared_out = nn.Parameter(W_shared_out.contiguous())  # [out_dim, r1]
            self.W_shared_in = nn.Parameter(W_shared_in.contiguous())   # [r2, in_dim]
        else:
            self.W_shared_out = nn.Parameter(W_shared_out.contiguous()).requires_grad_(False)  # [out_dim, r1]
            self.W_shared_in = nn.Parameter(W_shared_in.contiguous()).requires_grad_(False)     # [r2, in_dim]
        # Ensure cores are contiguous
        self.cores = nn.ParameterList([
            nn.Parameter(core.contiguous()) for core in core_tensors
        ])


        self.use_residual = use_residual
        if self.use_residual:
            self.residual = nn.Linear(W_shared_in.shape[1], W_shared_out.shape[0])

    def forward(self, x):
        outputs = []
        for core in self.cores:
           
            weight = torch.matmul(self.W_shared_out, torch.matmul(core, self.W_shared_in.T))
            x = x.to(weight.dtype) 
            out = x @ weight.T
            outputs.append(out)

        out = torch.stack(outputs).mean(0)

        if self.use_residual:
            out = out + self.residual(x)

        return out


class MoEAdapterLayer(nn.Module):
    def __init__(self, shared_params, task_params, config, use_residual=False, activation_fn=ACT2FN, grad=False):
        super(MoEAdapterLayer, self).__init__()

        self.gate_proj = SharedExpertLayer(shared_params['gate_proj_out'], shared_params['gate_proj_in'], task_params['gate_proj_core'], use_residual, grad=grad)
        self.up_proj = SharedExpertLayer(shared_params['up_proj_out'], shared_params['up_proj_in'], task_params['up_proj_core'], use_residual, grad=grad)
        self.down_proj = SharedExpertLayer(shared_params['down_proj_out'], shared_params['down_proj_in'], task_params['down_proj_core'], use_residual, grad=grad)

        self.act_fn = activation_fn[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj
    

class AdaptedDeepseekLayer(nn.Module):
    def __init__(self, original_layer, shared_params, task_params, config, layer_idx, grad=False):
        super().__init__()
        hidden_dim = config.hidden_size

        self.self_attn = DeepseekAttention(config, layer_idx)
        self.input_layernorm = DeepseekRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = DeepseekRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.moe_layer = MoEAdapterLayer(shared_params, task_params, config, grad=grad)

    def forward(self, hidden_states, attention_mask=None, **kwargs):

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states= self.self_attn(hidden_states, attention_mask=attention_mask, **kwargs)[0]

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        
        
        hidden_states = self.moe_layer(hidden_states)

        hidden_states = residual + hidden_states
        output = (hidden_states,)

        # output = self.post_attention_layernorm(moe_output)
        return output
    
class AdaptedDeepseek(nn.Module):
    def __init__(self, base_model: nn.Module, 
                 shared_param_paths: dict[str, str], 
                 task_param_paths: dict[str, dict[str, str]],
                 n_layers: Optional[int] = None,
                 grad=False,
                 classification=False):
        super().__init__()

        # for param in base_model.parameters():
        #     param.data = param.data.cpu()

        self.classification = classification

        self.config = base_model.config
        self.embed_tokens = base_model.model.embed_tokens

        self.layers = nn.ModuleList()
        total_layers = n_layers if n_layers is not None else len(base_model.model.layers)

        shared_params = {
            name: 0.01 * torch.randn_like(torch.load(path, map_location="cpu").to(torch.float32))
            for name, path in shared_param_paths.items()
        }
        task_params = {
            name: {
                i: [0.01 * torch.randn_like(torch.load(p, map_location="cpu").to(torch.float32))]
                for i, plist in layer_paths.items()
                for p in (plist if isinstance(plist, list) else [plist])
            }
            for name, layer_paths in task_param_paths.items()
        }

        for i in range(total_layers):
            original_layer = copy.deepcopy(base_model.model.layers[i])
            if i == 0:
                self.layers.append(original_layer)
            else:
                task_params_i = { name: task_params[name][i] for name in task_params }
                layer = AdaptedDeepseekLayer(original_layer, shared_params, task_params_i, self.config, layer_idx=i, grad=grad)
                self.layers.append(layer)

        
        self.norm = DeepseekRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
        if not classification:
            self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
        self.classifier = nn.Linear(self.config.hidden_size, 1)


        del base_model  
        torch.cuda.empty_cache()

        randomly_initialize_module(self)

    def save_pretrained(self, save_directory):
        os.makedirs(save_directory, exist_ok=True)
        torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
        if hasattr(self, "config"):
            
            if hasattr(self.config, "save_pretrained"):
                self.config.save_pretrained(save_directory)
            else:
                
                import json
                with open(os.path.join(save_directory, "config.json"), "w") as f:
                    json.dump(self.config, f)
        print(f"save {save_directory}")

    @classmethod
    def from_pretrained(cls, load_directory, shared_param_paths=None, task_param_paths=None, **kwargs):
        
        from transformers import AutoConfig, AutoModelForCausalLM
        config = AutoConfig.from_pretrained(load_directory)
       
        base_model = AutoModelForCausalLM.from_pretrained(load_directory, config=config)
        
        model = cls(
            base_model,
            shared_param_paths=shared_param_paths,
            task_param_paths=task_param_paths,
            **kwargs
        )
       
        return model


    def forward(self, input_ids, attention_mask=None, labels=None):
        
        if self.classification:
            
            B, C, S = input_ids.shape
            H = self.embed_tokens.embedding_dim

            # Step 1: embed
            hidden_states = self.embed_tokens(input_ids)  # [B, C, S, H]
            hidden_states_ = hidden_states

            # Step 2: reshape to [B*C, S, H]
            hidden_states = hidden_states.view(B * C, S, H)

            # Step 3: attention mask
            if attention_mask is not None:
                attention_mask = attention_mask.view(B * C, 1, 1, S).to(hidden_states.dtype)
                attention_mask = attention_mask.expand(-1, 1, S, S)
                attention_mask = (1.0 - attention_mask) * -1e9
                if torch.isnan(attention_mask).any() or torch.isinf(attention_mask).any():
                    raise ValueError("NaN or Inf detected in attention mask!")

            # Step 4: transformer layers
            for idx, layer in enumerate(self.layers):
                hidden_states = layer(hidden_states, attention_mask=attention_mask)[0]
                if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
                    raise ValueError(f"NaN detected after layer {idx}")

            # Step 5: normalization
            hidden_states = self.norm(hidden_states)

            # Step 6: get CLS token representation
            cls_rep = hidden_states[:, 0, :]  # [B*C, H]

            # Step 7: classifier
            choice_logits = self.classifier(cls_rep)  # [B*C, 1]
            logits = choice_logits.view(B, C)         # [B, C]

            # Step 8: loss
            loss = None
            if labels is not None:
                loss_fn = nn.CrossEntropyLoss()
                loss = loss_fn(logits, labels)

            return {"loss": loss, "logits": logits}

        else:
            
            input_shape = input_ids.size()
            if len(input_shape) == 3: 
                batch_size = input_shape[0]
                seq_len = input_shape[-1]
                input_ids = input_ids.view(batch_size, seq_len)
                if attention_mask is not None:
                    attention_mask = attention_mask.view(batch_size, seq_len)
            else:
                batch_size, seq_len = input_shape

            H = self.embed_tokens.embedding_dim
            hidden_states = self.embed_tokens(input_ids)  # [B, S, H]

            if attention_mask is not None:
                if attention_mask.dim() == 2:
                    attention_mask = build_causal_attention_mask(attention_mask)


            # transformer layers
            for idx, layer in enumerate(self.layers):
                hidden_states = layer(hidden_states, attention_mask=attention_mask)[0]
                if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
                    raise ValueError(f"NaN detected after layer {idx}")

            hidden_states = self.norm(hidden_states)  # [B, S, H]

            # language modeling head
            logits = self.lm_head(hidden_states)  # [B, S, vocab_size]

            loss = None
            if labels is not None:
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
                loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
                loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            return {"loss": loss, "logits": logits}

        

def build_deepseek_random(model_path, load_param_path, n_layers=4, num_experts=2, classification=True, grad=False, use_random_init=False,
                          load_ckpt_path=None):
    
    base_model = AutoModelForCausalLM.from_pretrained(
        model_path, device_map="cpu", trust_remote_code=True, torch_dtype=torch.float32, attn_implementation="eager", 
    )
    
    # Shared param paths (using FP16)
    shared_param_paths = {
        'down_proj_out': f'{load_param_path}/down_proj/U_out.pt',
        'down_proj_in': f'{load_param_path}/down_proj/U_in.pt',
        'gate_proj_out': f'{load_param_path}/gate_proj/U_out.pt',
        'gate_proj_in': f'{load_param_path}/gate_proj/U_in.pt',
        'up_proj_out': f'{load_param_path}/up_proj/U_out.pt',
        'up_proj_in': f'{load_param_path}/up_proj/U_in.pt',
    }

    task_param_paths = {
        'down_proj_core': {i: [f'{load_param_path}/down_proj/layer{i}_coeff.pt'] for i in range(1,27)},
        'gate_proj_core': {i: [f'{load_param_path}/gate_proj/layer{i}_coeff.pt'] for i in range(1,27)},
        'up_proj_core': {i: [f'{load_param_path}/up_proj/layer{i}_coeff.pt'] for i in range(1,27)},
    }

    model = AdaptedDeepseek(base_model, shared_param_paths, task_param_paths, n_layers=n_layers, 
                            grad=grad, classification=classification)


    return model

if __name__ == "__main__":

    model = build_deepseek_random(model_path="/ckpt_llm/deepseek",
                                  load_param_path="deepseek/tucker_rank512",
                                  n_layers=2)
    
    state_dict = load_file("/deepseek_random_layer2/model_batch000/model.safetensors")
    model.load_state_dict(state_dict, strict=False)