import torch
from typing import Optional, Dict, List
import torch.nn as nn
import torch.nn.functional as F
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling

from datasets import load_dataset
from torch.utils.data import Dataset

import copy
import math

from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
    AttentionMaskConverter,
    _prepare_4d_attention_mask,
    _prepare_4d_causal_attention_mask,
    _prepare_4d_causal_attention_mask_for_sdpa,
)


def randomly_initialize_module(module):
    for name, param in module.named_parameters(recurse=True):
        if param.requires_grad:
            if param.data.dim() >= 2:
                torch.nn.init.kaiming_uniform_(param)
            else:
                torch.nn.init.uniform_(param)



# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim=None,
        max_position_embeddings=2048,
        base=10000,
        device=None,
        scaling_factor=1.0,
        rope_type="default",
        config = None,
    ):
        super().__init__()
        # TODO (joao): remove the `if` below, only used for BC
        self.rope_kwargs = {}
        if config is None:
            logger.warning_once(
                "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
                "`config` argument. All other arguments will be removed in v4.46"
            )
            self.rope_kwargs = {
                "rope_type": rope_type,
                "factor": scaling_factor,
                "dim": dim,
                "base": base,
                "max_position_embeddings": max_position_embeddings,
            }
            self.rope_type = rope_type
            self.max_seq_len_cached = max_position_embeddings
            self.original_max_seq_len = max_position_embeddings
        else:
            # BC: "rope_type" was originally "type"
            if config.rope_scaling is not None:
                self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
            else:
                self.rope_type = "default"
            self.max_seq_len_cached = config.max_position_embeddings
            self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    def _dynamic_frequency_update(self, position_ids, device):
        """
        dynamic RoPE layers should recompute `inv_freq` in the following situations:
        1 - growing beyond the cached sequence length (allow scaling)
        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
        """
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len, **self.rope_kwargs
            )
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilation
            self.max_seq_len_cached = seq_len

        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)



class Qwen2Attention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """

    def __init__(self, config, layer_idx = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.attention_dropout = config.attention_dropout

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class SharedExpertLayer(nn.Module):
    def __init__(self, W_shared_out, W_shared_in, core_tensors, use_residual=False, grad=False):
        super(SharedExpertLayer, self).__init__()

        
        if grad:
            self.W_shared_out = nn.Parameter(W_shared_out.contiguous())  # [out_dim, r1]
            self.W_shared_in = nn.Parameter(W_shared_in.contiguous())   # [r2, in_dim]
        else:
            self.W_shared_out = nn.Parameter(W_shared_out.contiguous()).requires_grad_(False)  # [out_dim, r1]
            self.W_shared_in = nn.Parameter(W_shared_in.contiguous()).requires_grad_(False)     # [r2, in_dim]
        # Ensure cores are contiguous
        self.cores = nn.ParameterList([
            nn.Parameter(core.contiguous()) for core in core_tensors
        ])

      
        self.use_residual = use_residual
        if self.use_residual:
            self.residual = nn.Linear(W_shared_in.shape[1], W_shared_out.shape[0])

    def forward(self, x):
        outputs = []
        for core in self.cores:
            
            weight = torch.matmul(self.W_shared_out, torch.matmul(core, self.W_shared_in.T))
            x = x.to(weight.dtype)  
            out = x @ weight.T
            outputs.append(out)

       
        out = torch.stack(outputs).mean(0)

        
        if self.use_residual:
            out = out + self.residual(x)

        return out


class MoEAdapterLayer(nn.Module):
    def __init__(self, shared_params, task_params, config, use_residual=False, activation_fn=ACT2FN, grad=False):
        super(MoEAdapterLayer, self).__init__()

        self.gate_proj = SharedExpertLayer(shared_params['gate_proj_out'], shared_params['gate_proj_in'], task_params['gate_proj_core'], use_residual, grad=grad)
        self.up_proj = SharedExpertLayer(shared_params['up_proj_out'], shared_params['up_proj_in'], task_params['up_proj_core'], use_residual, grad=grad)
        self.down_proj = SharedExpertLayer(shared_params['down_proj_out'], shared_params['down_proj_in'], task_params['down_proj_core'], use_residual, grad=grad)

        self.act_fn = activation_fn[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


class AdaptedQwenLayer(nn.Module):
    def __init__(self, original_layer, shared_params, task_params, config, layer_idx, grad=False, use_random_init=False):
        super().__init__()

        if use_random_init:
            self.self_attn = Qwen2Attention(config, layer_idx)
            self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            # if hasattr(original_layer.self_attn, "rotary_emb"):
            #     self.rotary_emb = original_layer.self_attn.rotary_emb
            # else:
            
            #     self.rotary_emb = Qwen2RotaryEmbedding(config=config)

            # self.rotary_emb = self.self_attn.rotary_emb  
            self.self_attn = copy.deepcopy(original_layer.self_attn)
            self.input_layernorm = copy.deepcopy(original_layer.input_layernorm)
            self.post_attention_layernorm = copy.deepcopy(original_layer.post_attention_layernorm)

        self.moe_layer = MoEAdapterLayer(shared_params, task_params, config, grad=grad)

    def forward(self, hidden_states, attention_mask=None, position_ids=None, position_embeddings=None, **kwargs):

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)
     
        # cos, sin = self.rotary_emb(hidden_states, position_ids)
        # position_embeddings = self.rotary_emb(hidden_states, position_ids)
        hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask, position_ids=position_ids, position_embeddings=position_embeddings)[0]

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        
      
        hidden_states = self.moe_layer(hidden_states)

        hidden_states = residual + hidden_states
        output = (hidden_states,)

       
        # output = self.post_attention_layernorm(moe_output)
        return output
    
class AdaptedQwen(nn.Module):
    def __init__(self, base_model: nn.Module, 
                 shared_param_paths: Dict[str, str], 
                 task_param_paths: Dict[str, Dict[str, str]],
                 n_layers: Optional[int] = None,
                 grad=False, use_random_init=False,
                 classification=False):
        super().__init__()

    
        for param in base_model.parameters():
            param.data = param.data.cpu()

        self.classification = classification

        self.config = base_model.config
        self.embed_tokens = base_model.model.embed_tokens
        self.rotary_emb = base_model.model.rotary_emb
        # self.embed_positions = base_model.model.embed_positions

        self.layers = nn.ModuleList()
        total_layers = n_layers if n_layers is not None else len(base_model.model.layers)

        shared_params = {
            name: 0.01 * torch.randn_like(torch.load(path, map_location="cpu").to(torch.float32))
            for name, path in shared_param_paths.items()
        }
        task_params = {
            name: {
                i: [0.01 * torch.randn_like(torch.load(p, map_location="cpu").to(torch.float32))]
                for i, plist in layer_paths.items()
                for p in (plist if isinstance(plist, list) else [plist])
            }
            for name, layer_paths in task_param_paths.items()
        }

        for i in range(total_layers):
            original_layer = copy.deepcopy(base_model.model.layers[i])
            task_params_i = { name: task_params[name][i] for name in task_params }
            layer = AdaptedQwenLayer(original_layer, shared_params, task_params_i, self.config, layer_idx=i, grad=grad)
            self.layers.append(layer)

        
        self.norm = Qwen2RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
        self.classifier = nn.Linear(self.config.hidden_size, 1)
        if not self.classification:
            self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)

        
        del base_model  
        torch.cuda.empty_cache()

        randomly_initialize_module(self)

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))

    def save_pretrained(self, save_directory):
        os.makedirs(save_directory, exist_ok=True)
        torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
        
        if hasattr(self, "config"):
            self.config.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, load_directory, config=None, shared_param_paths=None, task_param_paths=None, **kwargs):
        if config is None:
            from transformers import AutoConfig
            config = AutoConfig.from_pretrained(load_directory)
       
        model = cls(config, shared_param_paths=shared_param_paths, task_param_paths=task_param_paths, **kwargs)
        # ...
        return model
        
    def forward(self, input_ids, attention_mask=None, labels=None, position_ids=None):
        if self.classification:
            B, C, S = input_ids.shape
            H = self.embed_tokens.embedding_dim

            if position_ids is None:
                position_ids = torch.arange(S, dtype=torch.long, device=input_ids.device).unsqueeze(0).expand(B * C, S)
            else:
                position_ids = position_ids.view(B * C, S)

            # Step 1: embed
            hidden_states = self.embed_tokens(input_ids)  # [B, C, S, H]
            position_embeddings = self.rotary_emb(hidden_states, position_ids)
            hidden_states_ = hidden_states

            # Step 2: reshape to [B*C, S, H]
            hidden_states = hidden_states.view(B * C, S, H)
            # position_embeddings = self.embed_positions(position_ids)
        
            # Step 3: reshape attention mask
            if attention_mask is not None:
                
                # attention_mask = attention_mask.to(dtype=hidden_states.dtype)
                attention_mask = attention_mask.view(B * C, 1, 1, S).to(hidden_states.dtype)
               
                attention_mask = attention_mask.expand(-1, 1, S, S)
                attention_mask = (1.0 - attention_mask) * -1e9
                if torch.isnan(attention_mask).any() or torch.isinf(attention_mask).any():
                    raise ValueError("NaN or Inf detected in attention mask!")

            # Step 4: transformer layers
            # for layer in self.layers:
            #     hidden_states = layer(hidden_states, attention_mask=attention_mask)[0]

            for idx, layer in enumerate(self.layers):
                hidden_states = layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    position_embeddings=position_embeddings
                )[0]
                if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
                    raise ValueError(f"NaN detected after layer {idx}")


            # Step 5: normalization
            hidden_states = self.norm(hidden_states)

            # Step 6: get CLS token representation
            cls_rep = hidden_states[:, 0, :]  # [B*C, H]

            # Step 7: classifier
            choice_logits = self.classifier(cls_rep)  # [B*C, 1]
            logits = choice_logits.view(B, C)         # [B, C]

            if torch.isnan(logits).any() or torch.isinf(logits).any():
                raise ValueError("NaN or Inf in logits before loss computation!")

            if torch.isnan(labels).any():
                raise ValueError("NaN in labels!")
            if labels is not None:
                labels = labels.to(logits.device)
                loss_fn = torch.nn.CrossEntropyLoss()
                loss = loss_fn(logits, labels)
                return {"loss": loss, "logits": logits}
        else:
            B, S = input_ids.shape
            H = self.embed_tokens.embedding_dim

            # Step 1: Embedding
            hidden_states = self.embed_tokens(input_ids)  # [B, S, H]

            # Step 2: position_ids
            if position_ids is None:
                position_ids = torch.arange(S, dtype=torch.long, device=input_ids.device).unsqueeze(0).expand(B, S)

            # Step 3: rotary position embeddings
            position_embeddings = self.rotary_emb(hidden_states, position_ids)  # tuple: (cos, sin)

            # Step 4: attention mask
            if attention_mask is not None:
               
                extended_attention_mask = attention_mask[:, None, None, :].to(hidden_states.dtype)  # [B, 1, 1, S]
               
                seq_length = S
                causal_mask = torch.triu(
                    torch.ones((seq_length, seq_length), device=input_ids.device), diagonal=1
                ).bool()
                extended_attention_mask = extended_attention_mask.expand(B, 1, seq_length, seq_length)
                combined_mask = torch.zeros_like(extended_attention_mask)
                combined_mask.masked_fill_(causal_mask, float('-inf'))
                combined_mask = combined_mask + (1.0 - extended_attention_mask) * -1e9
                attention_mask = combined_mask
            else:
                attention_mask = None

            # Step 5: transformer layers
            for idx, layer in enumerate(self.layers):
                hidden_states = layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    position_embeddings=position_embeddings
                )[0]
                if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
                    raise ValueError(f"NaN detected after layer {idx}")

            # Step 6: layer norm
            hidden_states = self.norm(hidden_states)

            # Step 7: lm_head to logits
            logits = self.lm_head(hidden_states)  # [B, S, vocab_size]

            outputs = (logits,)

            # Step 8: loss (optional)
            if labels is not None:
                
                loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
                loss = loss_fct(
                    logits.view(-1, logits.size(-1)),
                    labels.view(-1)
                )
                outputs = (loss,) + outputs

            return outputs  # (loss, logits) if labels else (logits,)


def build_qwen_random(model_path, load_param_path, n_layers=4, num_experts=2, classification=True, grad=False, use_random_init=False):

    base_model = AutoModelForCausalLM.from_pretrained(
        model_path, device_map="cpu", trust_remote_code=True, torch_dtype=torch.float32
    )
    print("Available attributes:", dir(base_model.model))


    # Shared param paths (using FP16)
    shared_param_paths = {
        'down_proj_out': f'{load_param_path}/down_proj/U_out.pt',
        'down_proj_in': f'{load_param_path}/down_proj/U_in.pt',
        'gate_proj_out': f'{load_param_path}/gate_proj/U_out.pt',
        'gate_proj_in': f'{load_param_path}/gate_proj/U_in.pt',
        'up_proj_out': f'{load_param_path}/up_proj/U_out.pt',
        'up_proj_in': f'{load_param_path}/up_proj/U_in.pt',
    }

    task_param_paths = {
        'down_proj_core': {i: [f'{load_param_path}/down_proj/layer{i}_coeff.pt'] for i in range(0,23)},
        'gate_proj_core': {i: [f'{load_param_path}/gate_proj/layer{i}_coeff.pt'] for i in range(0,23)},
        'up_proj_core': {i: [f'{load_param_path}/up_proj/layer{i}_coeff.pt'] for i in range(0,23)},
    }

    model = AdaptedQwen(base_model, shared_param_paths, task_param_paths, n_layers=n_layers, 
                            grad=grad, use_random_init=use_random_init,
                            classification=classification)

    return model
