import os
import math
import json
import pandas as pd
from datasets import Dataset, concatenate_datasets
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
import gc


##########################################
# - top3 routing
# - p_i = average gating score
# - f_i = fraction of tokens assigned
# - aux_loss = sum_{i} (p_i * f_i) * alpha
##########################################
class RSL(nn.Module):
    def __init__(self, in_features, num_experts, alpha=0.1):
        super().__init__()
        self.alpha = alpha
        self.num_experts = num_experts
        self.mlp = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.GELU(),
            nn.Linear(256, num_experts),
        )
        # **Modify initialization method**
        nn.init.xavier_normal_(self.mlp[0].weight, gain=math.sqrt(2))  # Suitable for GELU
        nn.init.zeros_(self.mlp[0].bias)
        nn.init.xavier_normal_(self.mlp[2].weight, gain=1.0)  # Also can be set to 1.0

    def forward(self, hidden_states):
        B, seq_len, hidden_dim = hidden_states.shape
        logits = self.mlp(hidden_states).float()
        scores = F.softmax(logits, dim=-1)

        # top3
        top3_weights, top3_indices = torch.topk(scores, k=3, dim=-1)
        top3_weights = F.normalize(top3_weights, p=1, dim=-1)

        aux_loss = None
        if self.training and self.alpha > 0.0:
            Pi = scores.mean(dim=0)  # Mean over tokens
            fi = F.one_hot(top3_indices.view(-1), num_classes=self.num_experts).float().mean(dim=0)

            balance = (Pi * fi).sum()
            entropy = - (scores * scores.clamp(min=1e-9).log()).sum(dim=-1).mean()

            gamma = 0.1  # Control selectivity preference, suggested range 0.01 ~ 0.1
            aux_loss = self.alpha * balance - gamma * entropy
        
        router_weights = scores.view(B, seq_len, self.num_experts).float()
        return router_weights, aux_loss, top3_indices, top3_weights


#########################################
# Data Loading Functions
#########################################

from transformers import AutoConfig, AutoModelForCausalLM

def load_model_and_tokenizer(model_name: str):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        use_cache=False,
    )
    param_dtypes = set()
    for _, param in model.named_parameters():
        param_dtypes.add(param.dtype)
    main_dtype = max(param_dtypes, key=lambda x: sum(p.dtype == x for p in model.parameters()))
    print(f"[INFO] Main model dtype: {main_dtype}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.add_eos_token = True
    return model, tokenizer


def load_medical_dataset(path):
    df = pd.read_csv(path)
    df = df[['question', 'answer']].rename(columns={'question': 'text'})
    dataset = Dataset.from_pandas(df)
    dataset = dataset.train_test_split(test_size=0.99)['test']
    return dataset.map(lambda x: {'domain': 'medical', 'domain_id': 2})


def load_math_dataset(path):
    df = pd.read_parquet(path)
    df = df[['question', 'answer']].rename(columns={'question': 'text'})
    dataset = Dataset.from_pandas(df)
    dataset = dataset.train_test_split(test_size=0.99)['test']
    return dataset.map(lambda x: {'domain': 'math', 'domain_id': 1})


def load_coding_dataset(path):
    df = pd.read_parquet(path)
    df = df[['prompt', 'canonical_solution']].rename(columns={'prompt': 'text'})
    return Dataset.from_pandas(df).map(lambda x: {'domain': 'coding', 'domain_id': 0})


def load_coding_dataset_v2(path):
    df = pd.read_parquet(path)
    df = df[['text', 'code']].rename(columns={'code': 'canonical_solution'})
    return Dataset.from_pandas(df).map(lambda x: {'domain': 'coding', 'domain_id': 0})


def load_glue_dataset(path, task_name):
    df = pd.read_parquet(path)
    if task_name == "sst2":
        df = df[['sentence', 'label']].rename(columns={'sentence': 'text'})
        dataset = Dataset.from_pandas(df)
        dataset = dataset.train_test_split(test_size=0.99)['test']
    elif task_name == "cola":
        df = df[['sentence', 'label']].rename(columns={'sentence': 'text'})
        dataset = Dataset.from_pandas(df)
        dataset = dataset.train_test_split(test_size=0.99)['test']
    return dataset.map(lambda x: {'domain': task_name, 'domain_id': 3 if task_name == 'sst2' else 4})


def load_arc_dataset(path):
    df = pd.read_parquet(path)
    if 'question' in df.columns and 'choices' in df.columns:
        df['choices_text'] = df['choices'].apply(
            lambda x: ', '.join([f"{label}: {text}" for label, text in zip(x['label'], x['text'])])
        )
        df = df[['question', 'choices_text',"answerKey"]].rename(columns={'question': 'text', 'choices_text': 'canonical_solution'})
    else:
        df = df[['question']].rename(columns={'question': 'text'})
    dataset = Dataset.from_pandas(df)
    dataset = dataset.train_test_split(test_size=0.3)['test']
    domain_id = 5  # Assume ARC dataset domain_id is 5
    return dataset.map(lambda x: {'domain': 'arc', 'domain_id': domain_id})


#########################################
# Prompt Formatting Functions
#########################################

def create_prompt_formats(sample: dict):
    domain = sample['domain']
    if domain == 'medical':
        SYSTEM_MESSAGE = (
            "You are a knowledgeable and empathetic medical expert, focused on providing accurate, clear, and patient-friendly answers."
        )
        instruction = f"### Question: {sample['text']}"
        response = f"### Answer: {sample['answer']}"
        formatted_prompt = f"{SYSTEM_MESSAGE}\n{instruction}\n\n{response} </s>"
    elif domain == 'math':
        system_prompt = "You are a math expert. Solve the problem step by step, and put your final answer within \\boxed{}"
        formatted_prompt = (
            f"{system_prompt}\n"
            "### Problem Description:\n"
            f"{sample['text']}\n"
            "### Solution:\n"
            f"{sample['answer']}"
        )
    elif domain == 'coding':
        formatted_prompt = (
            "Complete the following Python function. Respond only with the code implementation, no explanations.\n\n"
            "### Problem Description:\n"
            f"{sample['text']}\n"
            "### Solution:\n"
            f"{sample['canonical_solution']}"
        )
    elif domain == 'sst2':
        system_prompt = "Classify the sentiment of the sentence as negative (0) or positive (1). Respond only with 0 or 1."
        formatted_prompt = (
            f"{system_prompt}\n"
            f"Sentence: {sample['text']}\n"
            f"### Answer: {str(sample['label'])}"
        )
    elif domain == 'cola':
        system_prompt = "Classify the sentence as grammatically acceptable (1) or unacceptable (0). Respond only with 0 or 1."
        formatted_prompt = (
            f"{system_prompt}\n"
            f"Sentence: {sample['text']}\n"
            f"### Answer: {str(sample['label'])}"
        )
    elif domain == 'arc':
        system_prompt = "Answer the following multiple choice questions and give me the correct choice, it is a common sense reasoning question."
        question = sample['text']
        answer = sample['answerKey']
        choices = sample.get('canonical_solution', "")
        formatted_prompt = (
            f"{system_prompt}\n"
            f"Question: {question}\n"
            f"Choices: {choices}\n"
            f"### Answer: {answer} "
        )
    else:
        raise ValueError(f"Unknown domain: {domain}")
    return {"text": formatted_prompt, "domain": domain, "domain_id": sample['domain_id']}


def create_model_inputs(example, tokenizer, max_length=1024):
    input_encoding = tokenizer(
        example["text"],
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

    labels = input_encoding["input_ids"].clone()

    labels[labels == tokenizer.pad_token_id] = -100
    return {
        "input_ids": input_encoding["input_ids"].flatten().to("cuda"),
        "attention_mask": input_encoding["attention_mask"].flatten().to("cuda"),
        "labels": labels.flatten().to("cuda"),
    }


def preprocess_function(examples, tokenizer, max_length=512):
    domain_tags = {
        'medical': "[MEDICAL]",
        'math': "[MATH]",
        'coding': "[CODING]",
        'sst2': "[LANGUAGE UNSTANDING]",
        'cola': "[LANGUAGE UNSTANDING]",
        'arc': "[COMMON SENSE REASONING]"
    }
    if isinstance(examples, dict):
        examples_dict = examples
    else:
        examples_dict = {k: examples[k] for k in examples.keys()}

    texts = [
        f"{domain_tags[domain]} {text}"
        for domain, text in zip(examples_dict['domain'], examples_dict['text'])
    ]

    tokenized = tokenizer(
        texts,
        max_length=max_length,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )
    tokenized['domain_ids'] = torch.tensor(examples_dict['domain_id'], dtype=torch.long)

    for key in tokenized:
        if isinstance(tokenized[key], torch.Tensor):
            tokenized[key] = tokenized[key].to("cuda")

    return tokenized


def prepare_datasets(config, tokenizer):
    medical_ds = load_medical_dataset(config['medical_path'])
    math_ds = load_math_dataset(config['math_path'])
    coding_ds = load_coding_dataset(config['coding_path'])
    coding_dsv2 = load_coding_dataset_v2(config['coding_path_v2'])
    sst2_ds = load_glue_dataset(config['sst2_path'], 'sst2')
    cola_ds = load_glue_dataset(config['cola_path'], 'cola')
    arc_easy_ds = load_arc_dataset(config['arc_easy_path'])
    arc_challenge_ds = load_arc_dataset(config['arc_challenge_path'])

    # Concatenate datasets and shuffle
    mixed_ds = concatenate_datasets([
        coding_ds, coding_dsv2, math_ds, medical_ds,
        sst2_ds, cola_ds, arc_easy_ds, arc_challenge_ds
    ])
    mixed_ds = mixed_ds.shuffle(seed=42)

    # Format prompt
    mixed_ds = mixed_ds.map(create_prompt_formats)

    # Map to (input_ids, labels) format
    mixed_ds = mixed_ds.map(
        lambda x: create_model_inputs(x, tokenizer),
        remove_columns=['answer', 'canonical_solution', 'label']
    )

    # Further preprocess (input_ids, domain_id)
    mixed_ds = mixed_ds.map(
        lambda x: preprocess_function(x, tokenizer),
        batched=True,
        remove_columns=['domain', 'text', "answerKey"]
    )

    mixed_ds = mixed_ds.train_test_split(test_size=0.1)
    return mixed_ds['train'], mixed_ds['test']


#########################################
# LoRA_Mixer Class
#########################################
class LoRA_Mixer(nn.Module):
    """
    This class integrates DeepSeek's gating auxiliary loss and supports two-stage training:
      - freeze_router=True forces routing to experts corresponding to domain_id, no top3 routing
      - freeze_router=False uses DeepSeek top3 routing
    """
    def __init__(
        self,
        base_model,
        expert_paths,
        num_layers=32,
        proj_names=["q_proj", "k_proj", "v_proj", "o_proj"],
        alpha=0.1,
        freeze_router=False
    ):
        super().__init__()

        self.num_experts = len(expert_paths)
        self.expert_activation = {}  # Track expert activations by layer
        for layer_idx in range(num_layers):
            for proj_name in proj_names:
                key = f"layer{layer_idx}_{proj_name}"
                self.expert_activation[key] = [0] * self.num_experts

        self.base_model = base_model
        self.num_layers = num_layers
        self.proj_names = proj_names
        self.register_buffer("expert_count", torch.zeros(self.num_experts))
        self.router_logits_cache = {}
        self.aux_loss = 0.0
        self.alpha = alpha
        self.freeze_router = freeze_router

        # Router network & LoRA
        self.routers = nn.ModuleDict()
        self.lora_params = nn.ParameterDict()

        # Store expert assignments
        self.expert_assignments = []
        self.expert_usage_list = []

        # Freeze base model parameters
        for param in base_model.parameters():
            param.requires_grad = False

        # Initialize routers and LoRA
        self._init_routers_and_lora(expert_paths)

        # Decide whether routers are trainable based on freeze_router
        if freeze_router:
            for name, param in self.routers.named_parameters():
                param.requires_grad = False

        # Monkey patch linear layers
        self._monkey_patch_linear_layers()

    def _init_routers_and_lora(self, expert_paths):
        print("\n[INFO] Initializing router network and LoRA parameters...")
        experts = [load_file(path) for path in expert_paths]
        for layer_idx in range(self.num_layers):
            for proj_name in self.proj_names:
                linear_layer = getattr(self.base_model.model.layers[layer_idx].self_attn, proj_name)
                if not isinstance(linear_layer, nn.Linear):
                    raise TypeError(f"Expected nn.Linear, but {proj_name} is {type(linear_layer)}")

                # RSL
                router = RSL(
                    in_features=linear_layer.in_features,
                    num_experts=self.num_experts,
                    alpha=self.alpha
                )
                self.routers[f"layer{layer_idx}_{proj_name}"] = router

                # Load LoRA parameters
                lora_As, lora_Bs = [], []
                for expert in experts:
                    key_A = f"base_model.model.model.layers.{layer_idx}.self_attn.{proj_name}.lora_A.weight"
                    key_B = f"base_model.model.model.layers.{layer_idx}.self_attn.{proj_name}.lora_B.weight"
                    lora_As.append(nn.Parameter(expert[key_A]))
                    lora_Bs.append(nn.Parameter(expert[key_B]))
                self.lora_params[f"layer{layer_idx}_{proj_name}_A"] = nn.ParameterList(lora_As)
                self.lora_params[f"layer{layer_idx}_{proj_name}_B"] = nn.ParameterList(lora_Bs)

    def _monkey_patch_linear_layers(self):
        print("\n[INFO] Replacing forward methods of linear layers (DeepSeek gating or domain-forced) ...")
        for layer_idx in range(self.num_layers):
            for proj_name in self.proj_names:
                key = f"layer{layer_idx}_{proj_name}"
                linear_layer = getattr(self.base_model.model.layers[layer_idx].self_attn, proj_name)
                if not isinstance(linear_layer, nn.Linear):
                    raise TypeError(f"Expected nn.Linear, but {proj_name} is {type(linear_layer)}")

                orig_forward = linear_layer.forward

                def make_forward(router_key, orig):
                    def new_forward(x):
                        B, seq_len, _ = x.shape
                        router = self.routers[router_key]
                        domain_ids = getattr(self, '_temp_domain_ids', None)

                        # Base model output
                        base_out = orig(x)

                        lora_As = self.lora_params[f"{router_key}_A"]
                        lora_Bs = self.lora_params[f"{router_key}_B"]

                        if self.freeze_router:
                            if domain_ids is None:
                                raise ValueError("freeze_router=True but domain_ids not provided")

                            expert_outputs = []
                            for i in range(B):
                                expert_id = domain_ids[i].item()
                                lora_A = lora_As[expert_id]
                                lora_B = lora_Bs[expert_id]

                                delta = F.linear(x[i], lora_A)
                                delta = F.linear(delta, lora_B)
                                expert_outputs.append(delta.unsqueeze(0))

                            lora_out = torch.cat(expert_outputs, dim=0)

                            domain_ids_2d = domain_ids.unsqueeze(1).expand(B, seq_len)
                            domain_ids_flat = domain_ids_2d.reshape(-1).cpu().numpy()
                            for idx in domain_ids_flat:
                                self.expert_count[idx] += 1
                                self.expert_activation[router_key][idx] += 1
                        else:
                            router_weights, aux_loss, top3_indices, _ = router(x)

                            if self.training and aux_loss is not None:
                                self.aux_loss += aux_loss

                            expert_outputs = []
                            for expert_idx in range(self.num_experts):
                                lora_A = lora_As[expert_idx]
                                lora_B = lora_Bs[expert_idx]
                                delta = F.linear(x, lora_A)
                                delta = F.linear(delta, lora_B)
                                expert_outputs.append(delta)

                            expert_outputs = torch.stack(expert_outputs, dim=-1)
                            lora_out = torch.einsum('bsen,bsn->bse', expert_outputs, router_weights)

                            top_indices_flat = top3_indices.view(-1).cpu().numpy()
                            for idx in top_indices_flat:
                                self.expert_count[idx] += 1
                                self.expert_activation[router_key][idx] += 1

                        assert base_out.shape == lora_out.shape, f"[{router_key}] shape mismatch: base={base_out.shape}, lora={lora_out.shape}"
                        final_out = base_out + lora_out
                        return final_out

                    return new_forward

                linear_layer.forward = make_forward(router_key=key, orig=orig_forward)

    def forward(self, input_ids, attention_mask=None, labels=None, domain_ids=None):
        """
        Add domain_ids to pass when freeze_router=True to force routing
        """
        self.aux_loss = 0.0
        self._temp_domain_ids = domain_ids if domain_ids is not None else None

        self.temp_domain_ids = domain_ids  # Temporarily save

        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )

        if self.training and self.alpha > 0:
            total_router_layers = self.num_layers * len(self.proj_names)
            self.aux_loss = self.aux_loss / total_router_layers

        expert_usage = self.get_expert_usage()
        self.expert_usage_list.append(expert_usage)

        return {
            'loss': outputs.loss,
            'logits': outputs.logits,
        }

    def generate(self, *args, **kwargs):
        """Let LoRA_Mixer directly call the base_model's generate method"""
        return self.base_model.generate(*args, **kwargs)

    def reset_expert_activation(self):
        """Reset expert activation counts for all layers"""
        for key in self.expert_activation:
            self.expert_activation[key] = [0] * self.num_experts

    def get_expert_activation_stats(self):
        """Get detailed expert activation stats"""
        total_stats = defaultdict(int)
        layer_stats = {}

        for key, counts in self.expert_activation.items():
            total = sum(counts)
            if total == 0:
                continue

            percentages = [c / total for c in counts]
            layer_stats[key] = {
                "total": total,
                "counts": counts.copy(),
                "percentages": percentages
            }

            for expert_idx, count in enumerate(counts):
                total_stats[expert_idx] += count

        global_total = sum(total_stats.values())
        global_percent = {}
        if global_total > 0:
            for expert in total_stats:
                global_percent[expert] = total_stats[expert] / global_total

        return {
            "layer_stats": layer_stats,
            "global_stats": {
                "total": global_total,
                "counts": dict(total_stats),
                "percentages": global_percent
            }
        }

    def get_expert_usage(self):
        total = self.expert_count.sum()
        usage_dict = {}
        if total > 0:
            for i, cnt in enumerate(self.expert_count):
                usage_dict[f"expert_{i}"] = float(cnt / total)
        else:
            for i, _ in enumerate(self.expert_count):
                usage_dict[f"expert_{i}"] = 0.0
        return usage_dict

    def save_moe_data(self, output_dir):
        """Save expert usage & token-level expert assignments"""
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, "expert_usage.json"), "w") as f:
            json.dump(self.expert_usage_list, f, indent=4)

        torch.save(self.expert_assignments, os.path.join(output_dir, "expert_assignments.pt"))
        print(f"✅ Expert data saved to {output_dir}")


###########################################
# Custom Trainer: Supports DeepSeek aux_loss + L2 Regularization
###########################################
class BalancedLoRATrainer(Trainer):
    def __init__(
        self,
        *args,
        balance_loss_weight=0.1,
        distill_l2_reg=0.0,
        old_expert_params=None,
        unconstrained_experts=[],  # NEW: List of unconstrained expert indices
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.balance_loss_weight = balance_loss_weight
        self.distill_l2_reg = distill_l2_reg
        self.old_expert_params = old_expert_params
        self.unconstrained_experts = unconstrained_experts

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        domain_ids = inputs.get("domain_ids", None)
        outputs = model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            labels=inputs['labels'],
            domain_ids=domain_ids,
        )

        task_loss = outputs['loss']
        aux_loss = model.aux_loss

        total_loss = task_loss + self.balance_loss_weight * aux_loss

        # NEW: Skip L2 regularization for unconstrained experts
        if self.distill_l2_reg > 0.0 and self.old_expert_params is not None:
            l2_distill_loss = 0.0
            for name, param in model.named_parameters():
                if ("lora_params" in name) and (name in self.old_expert_params):
                    parts = name.split('.')
                    expert_index = int(parts[-1])
                    if expert_index in self.unconstrained_experts:
                        continue  # Skip unconstrained experts
                    old_p = self.old_expert_params[name].to(param.device)
                    l2_distill_loss += (param - old_p).pow(2).sum()
            total_loss += self.distill_l2_reg * l2_distill_loss

        usage_dict = model.get_expert_usage()
        log_dict = {
            "task_loss": task_loss.item(),
            "aux_loss": aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
        }
        log_dict.update(usage_dict)
        self.log(log_dict)

        return (total_loss, outputs) if return_outputs else total_loss


#############################
# Helper Functions
#############################
def print_trainable_parameters(model: nn.Module):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {trainable_params} / {total_params} ({100 * trainable_params / total_params:.2f}%)")


def save_mamba_moe(model, tokenizer, save_path):
    """Save only LoRA and Router parameters"""
    os.makedirs(save_path, exist_ok=True)

    state_dict = model.state_dict()

    moe_state_dict = {
        k: v for k, v in state_dict.items()
        if ("routers" in k or "lora_params" in k)
    }
    torch.save(moe_state_dict, os.path.join(save_path, "mamba_moe_top3-gelu.pth"))

    tokenizer.save_pretrained(save_path)
    print(f"[INFO] Only LoRA + Router parameters saved to {save_path}/mamba_moe_top3-gelu.pth")

    del state_dict
    torch.cuda.empty_cache()


def main():
    # Configurations
    config = {
        "base_model": "/root/llama",
        "expert_paths": [
            #Lora experts path
        ],
        #dataset path
        "medical_path": "",
        "math_path": "",
        "coding_path": "",
        "coding_path_v2": "",
        "sst2_path": "",
        "cola_path": "",
        "arc_easy_path": "",
        "arc_challenge_path": "",
        "output_dir": "/root/llama/lora-moe/"
    }

    # Load model and tokenizer
    base_model2, tokenizer2 = load_model_and_tokenizer(config["base_model"])
    train_dataset, validation_dataset = prepare_datasets(config, tokenizer2)

    # Initialize LoRA_Mixer
    model2 = LoRA_Mixer(
        base_model=base_model2,
        expert_paths=config["expert_paths"],
        alpha=0.01,
        freeze_router=False,  # Unfreeze router in second stage
        num_layers=32
    )

    # Data collator for language modeling
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer2,
        mlm=False,
        pad_to_multiple_of=8
    )

    # Set trainable parameters
    for name, param in model2.named_parameters():
        if ("routers" in name) or ("lora_params" in name):
            param.requires_grad = True
        else:
            param.requires_grad = False
    print_trainable_parameters(model2)

    # Backup first-stage expert parameters for L2 penalty
    old_expert_params = {}
    for name, param in model2.named_parameters():
        if "lora_params" in name:
            old_expert_params[name] = param.clone().detach()

    # Training arguments for second stage
    training_args2 = TrainingArguments(
        output_dir=os.path.join(config['output_dir'], "stage2"),
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        learning_rate=1e-4,
        warmup_ratio=0.1,
        num_train_epochs=1,
        logging_steps=50,
        evaluation_strategy="no",
        save_strategy="no",
        gradient_accumulation_steps=4,
        fp16=True,
        fp16_full_eval=True,
        report_to="wandb",
        load_best_model_at_end=False,
        max_grad_norm=0.5,
        lr_scheduler_type="cosine",
        save_total_limit=1,
        optim="adamw_torch",
        remove_unused_columns=False,
    )

    # Initialize trainer with L2 regularization
    trainer2 = BalancedLoRATrainer(
        model=model2,
        args=training_args2,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        data_collator=data_collator,
        balance_loss_weight=0.1,
        distill_l2_reg=1e-5,  # Add L2 regularization for experts
        old_expert_params=old_expert_params,
        unconstrained_experts=[]  # Free trainable expert indices
    )

    torch.cuda.empty_cache()
    gc.collect()

    print_trainable_parameters(model2)
    trainer2.train()

    # Save the results of the second stage
    STAGE2_PATH = os.path.join(config["output_dir"], "final_joint_v2")
    save_mamba_moe(model2, tokenizer2, STAGE2_PATH)


if __name__ == "__main__":
    main()
