import os
import math
import json
import pandas as pd
from datasets import Dataset, concatenate_datasets
from transformers import (
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from transformers import DataCollatorForSeq2Seq
from transformers import T5Tokenizer, T5ForConditionalGeneration
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
import gc


class RSL(nn.Module):
    def __init__(self, in_features, num_experts, alpha=0.1):
        super().__init__()
        self.alpha = alpha
        self.num_experts = num_experts
        self.mlp = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.GELU(),
            nn.Linear(256, num_experts),
        )
        # **Modify initialization method**
        nn.init.xavier_normal_(self.mlp[0].weight, gain=math.sqrt(2))  # Suitable for GELU
        nn.init.zeros_(self.mlp[0].bias)
        nn.init.xavier_normal_(self.mlp[2].weight, gain=1.0)  # Also can be set to 1.0

    def forward(self, hidden_states):
        B, seq_len, hidden_dim = hidden_states.shape
        logits = self.mlp(hidden_states).float()
        scores = F.softmax(logits, dim=-1)

        # top3
        top3_weights, top3_indices = torch.topk(scores, k=1, dim=-1)
        top3_weights = F.normalize(top3_weights, p=1, dim=-1)

        aux_loss = None
        if self.training and self.alpha > 0.0:
            Pi = scores.mean(dim=0)  # Mean over tokens
            fi = F.one_hot(top3_indices.view(-1), num_classes=self.num_experts).float().mean(dim=0)

            balance = (Pi * fi).sum()
            entropy = - (scores * scores.clamp(min=1e-9).log()).sum(dim=-1).mean()

            gamma = 0.1  # Control selectivity preference, suggested 0.01 ~ 0.1
            aux_loss = self.alpha * balance - gamma * entropy

        router_weights = scores.view(B, seq_len, self.num_experts).float()
        return router_weights, aux_loss, top3_indices, top3_weights


#########################################
# Data Loading Functions
#########################################   

def load_model_and_tokenizer(model_name: str):
    """Load Flan-T5 model and tokenizer"""
    tokenizer = T5Tokenizer.from_pretrained("/root/flant5-large")
    model = T5ForConditionalGeneration.from_pretrained("/root/flant5-large", device_map="auto") 
    return model, tokenizer


def load_glue_dataset(path, task_name):
    df = pd.read_parquet(path)

    # ========= Task Sampling Ratios =========
    task_sample_ratios = {

    }

    sample_ratio = task_sample_ratios.get(task_name, 1.0)
    sample_size = int(len(df) * sample_ratio)
    df = df.sample(n=sample_size, random_state=42)

    # ========= Handle Each Task =========
    if task_name == "sst2":
        df = df[["sentence", "label"]].rename(columns={"sentence": "text"})

    elif task_name == "cola":
        df = df[["sentence", "label"]].rename(columns={"sentence": "text"})

    elif task_name == "qqp":
        if "question1" not in df.columns:
            raise KeyError("qqp missing question1 / question2 fields")
        df = df[["question1", "question2", "label"]]
        df["text"] = df.apply(lambda x: f"Question1: {x['question1']} , Question: {x['question2']}", axis=1)

    elif task_name == "mrpc":
        if "sentence1" not in df.columns:
            raise KeyError("mrpc missing sentence1 / sentence2 fields")
        df = df[["sentence1", "sentence2", "label"]]
        df["text"] = df.apply(lambda x: f"Sentence 1: {x['sentence1']} , Sentence 2: {x['sentence2']}", axis=1)

    elif task_name == "rte":
        df = df[["sentence1", "sentence2", "label"]]
        df["text"] = df.apply(lambda x: f"Premise: {x['sentence1']} , Hypothesis: {x['sentence2']}", axis=1)

    else:
        raise ValueError(f"Unsupported task: {task_name}")

    # Final retain text and label columns
    df = df[["text", "label"]]

    # Convert to HF Dataset and add domain info
    dataset = Dataset.from_pandas(df)
    domain_id = {
        "sst2": 0, "cola": 1, "qqp": 2, "mrpc": 3, "rte": 4
    }[task_name]

    return dataset.map(lambda x: {"domain": task_name, "domain_id": domain_id})


#########################################
# Prompt Formatting Functions
#########################################
def create_prompt_formats(sample: dict):
    domain = sample['domain']
    label_map = {
        "sst2": {0: "negative", 1: "positive"},
        "cola": {0: "unacceptable", 1: "acceptable"},
        "qqp": {0: "not duplicate", 1: "duplicate"},
        "mrpc": {0: "not paraphrase", 1: "paraphrase"},
        "rte": {0: "not entailment", 1: "entailment"},
    }

    label_text = label_map[domain][int(sample["label"])]

    if domain == 'sst2':
        system_prompt = "Classify the sentiment of the following sentence as negative or positive."
        formatted_prompt = (
            f"{system_prompt}\n"
            f"Sentence: {sample['text']}\n"
            f"Label: {label_text}"
        )

    elif domain == 'cola':
        system_prompt = "Classify the following sentence as unacceptable or acceptable in terms of grammaticality."
        formatted_prompt = (
            f"{system_prompt}\n"
            f"Sentence: {sample['text']}\n"
            f"Label: {label_text}"
        )

    elif domain == 'qqp':
        system_prompt = "Determine if the following two questions are duplicate or not duplicate."
        formatted_prompt = (
            f"{system_prompt}\n"
            f"{sample['text']}\n"
            f"Label: {label_text}"
        )

    elif domain == 'mrpc':
        system_prompt = "Determine if the following two sentences are paraphrases or not paraphrases."
        formatted_prompt = (
            f"{system_prompt}\n"
            f"{sample['text']}\n"
            f"Label: {label_text}"
        )

    elif domain == 'rte':
        system_prompt = "Determine if the hypothesis entails the premise or not."
        formatted_prompt = (
            f"{system_prompt}\n"
            f"{sample['text']}\n"
            f"Label: {label_text}"
        )

    else:
        raise ValueError(f"Unknown domain: {domain}")

    return {
        "text": formatted_prompt,
        "domain": domain,
        "domain_id": sample['domain_id']
    }


def create_model_inputs(example, tokenizer, max_length=50):
    prompt_text = example["text"]
    # Split prompt and label
    if "Label:" in prompt_text:
        prompt, answer = prompt_text.split("Label:")
        prompt = prompt.strip()
        answer = answer.strip()
    else:
        raise ValueError("Missing 'Label' in prompt.")

    # Tokenize input (prompt)
    input_encoding = tokenizer(
        prompt,
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
    )
    
    # Tokenize output (answer) as label
    label_encoding = tokenizer(
        answer,
        max_length=10,
        truncation=True,
        return_tensors="pt",
    )

    labels = label_encoding["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    return {
        "input_ids": input_encoding["input_ids"].squeeze(0).to("cuda"),
        "attention_mask": input_encoding["attention_mask"].squeeze(0).to("cuda"),
        "labels": labels.squeeze(0).to("cuda"),
    }


def preprocess_function(examples, tokenizer, max_length=128):
    if isinstance(examples, dict):
        examples_dict = examples
    else:
        examples_dict = {k: examples[k] for k in examples.keys()}
    texts = [f"{text}" for domain, text in zip(examples['domain'], examples['text'])]

    tokenized = tokenizer(
        texts,
        max_length=max_length,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )
    tokenized['domain_ids'] = torch.tensor(examples_dict['domain_id'], dtype=torch.long)

    for key in tokenized:
        if isinstance(tokenized[key], torch.Tensor):
            tokenized[key] = tokenized[key].to("cuda")

    return tokenized


def prepare_datasets(config, tokenizer):
    tasks = ["sst2", "cola", "qqp", "mrpc"]
    task_datasets = []

    # Load each task separately
    for task in tasks:
        path_key = f"{task}_path"
        task_dataset = load_glue_dataset(config[path_key], task)
        task_datasets.append(task_dataset)

    # Concatenate & shuffle
    mixed_ds = concatenate_datasets(task_datasets).shuffle(seed=42)

    # Format Prompt (including instruction+input+label)
    mixed_ds = mixed_ds.map(create_prompt_formats)

    mixed_ds = mixed_ds.map(
        lambda x: create_model_inputs(x, tokenizer),
        remove_columns=['label','text','domain']  
    )

    mixed_ds = mixed_ds.train_test_split(test_size=0.1)
    return mixed_ds['train'], mixed_ds['test']


class MambaLoRAMoE(nn.Module):
    """
    Integrates DeepSeek's gating auxiliary loss and supports two-stage training:
      - When freeze_router=True, forces data to be routed to the expert corresponding to domain_id without top3 routing
      - When freeze_router=False, uses DeepSeek top3 routing
    """
    def __init__(
        self,
        base_model,
        expert_paths,
        encoder_layers=24,
        decoder_layers=24,
        proj_names=["q", "v"],
        alpha=0.1,
        freeze_router=False
    ):
        super().__init__()
        self.base_model = base_model
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        self.proj_names = proj_names
        self.num_experts = len(expert_paths)
        self.alpha = alpha
        self.aux_loss = 0
        self.routers = nn.ModuleDict()
        self.lora_params = nn.ParameterDict()
        self._init_routers_and_lora(expert_paths)
        self.w_base = nn.Parameter(torch.tensor(1.0))
        self.w_lora = nn.Parameter(torch.tensor(1.0))

        # Freeze base model parameters
        for param in base_model.parameters():
            param.requires_grad = False
            
        # Monkey patch linear layers
        self._monkey_patch_linear_layers()

    def _init_routers_and_lora(self, expert_paths):
        print("\n[INFO] Initializing router network and LoRA parameters...")
        experts = [torch.load(path) for path in expert_paths]  # Assume using PyTorch bin files
        
        # Process Encoder layers
        for layer_idx in range(self.encoder_layers):
            for proj_name in self.proj_names:
                self._init_layer_params(
                    layer_type="encoder",
                    layer_idx=layer_idx,
                    proj_name=proj_name,
                    experts=experts
                )
        
        # Process Decoder layers (Self Attention and Cross Attention)
        for layer_idx in range(self.decoder_layers):
            # Self Attention
            for proj_name in self.proj_names:
                self._init_layer_params(
                    layer_type="decoder_self",
                    layer_idx=layer_idx,
                    proj_name=proj_name,
                    experts=experts
                )
            
            # Cross Attention (only process v_proj)
            self._init_layer_params(
                layer_type="decoder_cross",
                layer_idx=layer_idx,
                proj_name="v",
                experts=experts
            )
    
    def _init_layer_params(self, layer_type, layer_idx, proj_name, experts):
        # Access the attention object in T5
        if layer_type == "encoder":
            attention = self.base_model.encoder.block[layer_idx].layer[0].SelfAttention
            attn_key_prefix = f"base_model.model.encoder.block.{layer_idx}.layer.0.SelfAttention.{proj_name}"
        elif layer_type == "decoder_self":
            attention = self.base_model.decoder.block[layer_idx].layer[0].SelfAttention
            attn_key_prefix = f"base_model.model.decoder.block.{layer_idx}.layer.0.SelfAttention.{proj_name}"
        elif layer_type == "decoder_cross":
            attention = self.base_model.decoder.block[layer_idx].layer[1].EncDecAttention
            attn_key_prefix = f"base_model.model.decoder.block.{layer_idx}.layer.1.EncDecAttention.{proj_name}"
        else:
            raise ValueError(f"Unknown layer_type: {layer_type}")
        
        if not hasattr(attention, proj_name):
            print(f"[WARNING] Missing {proj_name} in {layer_type}.layer.{layer_idx}")
            return

        linear_layer = getattr(attention, proj_name)

        # Create Router
        router_key = f"{layer_type}__{layer_idx}__{proj_name}"
        self.routers[router_key] = RSL(
            in_features=linear_layer.in_features,
            num_experts=self.num_experts,
            alpha=self.alpha
        )

        # Load LoRA parameters (from expert dict)
        lora_As, lora_Bs = [], []
        for expert in experts:
            key_A = f"{attn_key_prefix}.lora_A.weight"
            key_B = f"{attn_key_prefix}.lora_B.weight"
            try:
                lora_As.append(nn.Parameter(expert[key_A]))
                lora_Bs.append(nn.Parameter(expert[key_B]))
            except KeyError as e:
                print(f"❗[ERROR] LoRA key missing: {e.args[0]}")
                raise

        self.lora_params[f"{router_key}_A"] = nn.ParameterList(lora_As)
        self.lora_params[f"{router_key}_B"] = nn.ParameterList(lora_Bs)

    def _monkey_patch_linear_layers(self):
        print("\n[INFO] Replacing forward methods of linear layers...")
        import torch.nn.functional as F

        def create_new_forward(router_key, orig_layer, lora_As, lora_Bs, topk=3):
            def new_forward(x):
                B, seq_len, _ = x.shape
                router = self.routers[router_key]

                base_out = orig_layer.__class__.forward(orig_layer, x)

                router_weights, aux_loss, topk_indices, _ = router(x)
                if self.training and aux_loss is not None:
                    self.aux_loss += aux_loss  # Accumulate aux_loss

                expert_outputs = []
                for expert_idx in range(self.num_experts):
                    lora_A = lora_As[expert_idx]
                    lora_B = lora_Bs[expert_idx]
                    delta = F.linear(F.linear(x, lora_A), lora_B)
                    expert_outputs.append(delta)
                expert_outputs = torch.stack(expert_outputs, dim=-1)  # [B, seq_len, hidden_dim, num_experts]

                # Create a mask that only keeps topk weights
                sparse_router_weights = torch.zeros_like(router_weights)  # [B, seq_len, num_experts]
                for b in range(B):
                    for s in range(seq_len):
                        topk_indices_current = topk_indices[b, s, :]  # Get topk expert indices for the current token
                        sparse_router_weights[b, s, topk_indices_current] = router_weights[b, s, topk_indices_current]

                # Use sparse weights for weighted sum
                lora_out = torch.einsum('bsen,bsn->bse', expert_outputs, sparse_router_weights)  # [B, seq_len, hidden_dim]

                # Stack -> softmax -> broadcast weights
                weights_raw = torch.stack([self.w_base, self.w_lora])
                weights = torch.softmax(weights_raw, dim=0)

                combined_out = weights[0] * base_out + weights[1] * lora_out

                return combined_out
            return new_forward


        for router_key in self.routers:
            try:
                layer_type, layer_idx, proj_name = router_key.split("__")
                layer_idx = int(layer_idx)
            except Exception as e:
                print(f"[⚠️WARNING] Cannot parse router_key: {router_key}, error: {e}")
                continue

            if layer_type == "encoder":
                attention = self.base_model.encoder.block[layer_idx].layer[0].SelfAttention
            elif layer_type == "decoder_self":
                attention = self.base_model.decoder.block[layer_idx].layer[0].SelfAttention
            elif layer_type == "decoder_cross":
                attention = self.base_model.decoder.block[layer_idx].layer[1].EncDecAttention
            else:
                continue

            orig_layer = getattr(attention, proj_name)
            lora_As = self.lora_params[f"{router_key}_A"]
            lora_Bs = self.lora_params[f"{router_key}_B"]
            orig_layer.forward = create_new_forward(router_key, orig_layer, lora_As, lora_Bs, 3)

    def forward(self, input_ids, attention_mask=None, labels=None, domain_ids=None):
        """
        Add domain_ids to pass when freeze_router=True to force routing
        """
        # Reset aux_loss before each forward pass
        self.aux_loss = 0.0
        self._temp_domain_ids = domain_ids if domain_ids is not None else None

        self.temp_domain_ids = domain_ids  # Temporarily save
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )

        # Average aux_loss to prevent large values when having many layers
        if self.training and self.alpha > 0:
            total_router_layers = (
                self.encoder_layers + 2 * self.decoder_layers  # decoder has self and cross attention
            ) * len(self.proj_names)
            self.aux_loss = self.aux_loss / total_router_layers

        return {
            'loss': outputs.loss,
            'logits': outputs.logits,
        }

    def generate(self, *args, **kwargs):
        """Let MambaLoRAMoE directly call the base_model's generate method"""
        return self.base_model.generate(*args, **kwargs)

    def reset_expert_activation(self):
        """Reset expert activation counts for all layers"""
        for key in self.expert_activation:
            self.expert_activation[key] = [0] * self.num_experts


#############################
# Custom Trainer: Supports DeepSeek aux_loss + L2 Regularization (some experts free training)
#############################
class BalancedLoRATrainer(Trainer):
    def __init__(
        self,
        *args,
        balance_loss_weight=0.1,
        distill_l2_reg=0.0,
        old_expert_params=None,
        unconstrained_experts=[],  # NEW: Unconstrained expert index list
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.balance_loss_weight = balance_loss_weight
        self.distill_l2_reg = distill_l2_reg
        self.old_expert_params = old_expert_params
        self.unconstrained_experts = unconstrained_experts  # NEW

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        domain_ids = inputs.get("domain_ids", None)
        outputs = model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            labels=inputs['labels'],
            domain_ids=domain_ids,
        )
        task_loss = outputs['loss']
        aux_loss = model.aux_loss

        total_loss = task_loss + self.balance_loss_weight * aux_loss

        # NEW: Skip L2 regularization for specified experts
        if self.distill_l2_reg > 0.0 and self.old_expert_params is not None:
            l2_distill_loss = 0.0
            for name, param in model.named_parameters():
                if ("lora_params" in name) and (name in self.old_expert_params):
                    parts = name.split('.')
                    expert_index = int(parts[-1])
                    if expert_index in self.unconstrained_experts:
                        continue  # Skip unconstrained experts
                    old_p = self.old_expert_params[name].to(param.device)
                    l2_distill_loss += (param - old_p).pow(2).sum()
            total_loss += self.distill_l2_reg * l2_distill_loss

        log_dict = {
            "task_loss": task_loss.item(),
            "aux_loss": aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
        }
        log_dict.update({
            "w_base": model.w_base.detach().cpu().item(),
            "w_lora": model.w_lora.detach().cpu().item(),
        })
        if domain_ids is not None:
            unique_ids = domain_ids.unique()
            for uid in unique_ids:
                mask = (domain_ids == uid)
                if mask.sum() > 0:
                    logits = outputs['logits'][mask]      # [B, S, V]
                    labels = inputs['labels'][mask]       # [B, S]

                    logits = logits.view(-1, logits.size(-1))
                    labels = labels.view(-1)

                    valid_mask = labels != -100
                    if valid_mask.sum() > 0:
                        domain_loss = F.cross_entropy(
                            logits[valid_mask],
                            labels[valid_mask],
                            reduction="mean"
                        )
                        log_dict[f"loss_domain_{uid.item()}"] = domain_loss.item()
        self.log(log_dict)

        return (total_loss, outputs) if return_outputs else total_loss


#############################
# Helper Functions
#############################
def print_trainable_parameters(model: nn.Module):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {trainable_params} / {total_params} ({100 * trainable_params / total_params:.2f}%)")


def save_mamba_moe(model, tokenizer, save_path):
    """Save only LoRA and Router parameters"""
    os.makedirs(save_path, exist_ok=True)

    state_dict = model.state_dict()

    moe_state_dict = {
        k: v for k, v in state_dict.items()
        if ("routers" in k or "lora_params" in k)
    }
    torch.save(moe_state_dict, os.path.join(save_path, "mamba_moe_top3-gelu.pth"))

    tokenizer.save_pretrained(save_path)
    print(f"[INFO] LoRA-Mixer saved to {save_path}/mamba_moe_top3-gelu.pth")

    del state_dict
    torch.cuda.empty_cache()


def main():
    config = {
        "base_model": "/root/flant5-large",
        "expert_paths": [
            "/root/flant5-large/sst2lora/adapter_model.bin",
            "/root/flant5-large/colalora/adapter_model.bin",
            "/root/flant5-large/qqplora/adapter_model.bin",
            "/root/flant5-large/mrpclora/adapter_model.bin",
            "/root/flant5-large/rtelora/adapter_model.bin",
        ],
        "sst2_path": "/root/falcon_mamba/glue/sst2/train-00000-of-00001.parquet",
        "cola_path": "/root/falcon_mamba/glue/cola/train-00000-of-00001.parquet",
        "qqp_path": "/root/falcon_mamba/glue/qqp/train-00000-of-00001.parquet",
        "mrpc_path": "/root/falcon_mamba/glue/mrpc/train-00000-of-00001.parquet",
        "rte_path": "/root/falcon_mamba/glue/rte/train-00000-of-00001.parquet",
        "output_dir": "/root/flant5-large/lora-moe/"
    }

    base_model2, tokenizer2 = load_model_and_tokenizer(config["base_model"])
    train_dataset, validation_dataset = prepare_datasets(config, tokenizer2)

    model2 = MambaLoRAMoE(
        base_model=base_model2,
        expert_paths=config["expert_paths"],
        encoder_layers=24,
        decoder_layers=24,
        freeze_router=False
    )

    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer2,
        padding=True,
        pad_to_multiple_of=8,
        return_tensors="pt"
    )

    for name, param in model2.named_parameters():
        if ("routers" in name) or ("lora_params" in name):
            param.requires_grad = True
        elif name in ["w_base", "w_lora"]:
            param.requires_grad = True
        else:
            param.requires_grad = False
    print_trainable_parameters(model2)

    old_expert_params = {}
    for name, param in model2.named_parameters():
        if "lora_params" in name:
            old_expert_params[name] = param.clone().detach()

    training_args2 = TrainingArguments(
        output_dir=os.path.join(config['output_dir'], "stage2"),
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        learning_rate=1e-4,
        warmup_ratio=0.1,
        num_train_epochs=1,
        logging_steps=50,
        evaluation_strategy="no",
        save_strategy="no",
        gradient_accumulation_steps=4,
        report_to="wandb",
        load_best_model_at_end=False,
        max_grad_norm=0.5,
        lr_scheduler_type="cosine",
        save_total_limit=1,
        optim="adamw_torch",
        remove_unused_columns=False,
    )
    trainer2 = BalancedLoRATrainer(
        model=model2,
        args=training_args2,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        data_collator=data_collator,
        balance_loss_weight=0.2,
        distill_l2_reg=0,
        old_expert_params=old_expert_params,
        unconstrained_experts=[]
    )

    torch.cuda.empty_cache()
    gc.collect()

    print_trainable_parameters(model2)
    trainer2.train()

    STAGE2_PATH = os.path.join(config["output_dir"], "final_joint")
    save_mamba_moe(model2, tokenizer2, STAGE2_PATH)


if __name__ == "__main__":
    main()
