import os
import math
import json
import pandas as pd
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
import gc

#########################################
# LoRA_Mixer Class
#########################################
class LoRA_Mixer(nn.Module):
    """
    This class integrates DeepSeek's gating auxiliary loss and supports two-stage training:
      - freeze_router=True forces routing to experts corresponding to domain_id, no top3 routing
      - freeze_router=False uses DeepSeek top3 routing
    """
    def __init__(
        self,
        base_model,
        expert_paths,
        num_layers=32,
        proj_names=["q_proj", "k_proj", "v_proj", "o_proj"],
        alpha=0.1,
        freeze_router=False
    ):
        super().__init__()

        self.num_experts = len(expert_paths)
        self.expert_activation = {}  # Track expert activations by layer
        for layer_idx in range(num_layers):
            for proj_name in proj_names:
                key = f"layer{layer_idx}_{proj_name}"
                self.expert_activation[key] = [0] * self.num_experts

        self.base_model = base_model
        self.num_layers = num_layers
        self.proj_names = proj_names
        self.register_buffer("expert_count", torch.zeros(self.num_experts))
        self.router_logits_cache = {}
        self.aux_loss = 0.0
        self.alpha = alpha
        self.freeze_router = freeze_router

        # Router network & LoRA
        self.routers = nn.ModuleDict()
        self.lora_params = nn.ParameterDict()

        # Store expert assignments
        self.expert_assignments = []
        self.expert_usage_list = []

        # Freeze base model parameters
        for param in base_model.parameters():
            param.requires_grad = False

        # Initialize routers and LoRA
        self._init_routers_and_lora(expert_paths)

        # Decide whether routers are trainable based on freeze_router
        if freeze_router:
            for name, param in self.routers.named_parameters():
                param.requires_grad = False

        # Monkey patch linear layers
        self._monkey_patch_linear_layers()

    def _init_routers_and_lora(self, expert_paths):
        print("\n[INFO] Initializing router network and LoRA parameters...")
        experts = [load_file(path) for path in expert_paths]
        for layer_idx in range(self.num_layers):
            for proj_name in self.proj_names:
                linear_layer = getattr(self.base_model.model.layers[layer_idx].self_attn, proj_name)
                if not isinstance(linear_layer, nn.Linear):
                    raise TypeError(f"Expected nn.Linear, but {proj_name} is {type(linear_layer)}")

                # RSL
                router = RSL(
                    in_features=linear_layer.in_features,
                    num_experts=self.num_experts,
                    alpha=self.alpha
                )
                self.routers[f"layer{layer_idx}_{proj_name}"] = router

                # Load LoRA parameters
                lora_As, lora_Bs = [], []
                for expert in experts:
                    key_A = f"base_model.model.model.layers.{layer_idx}.self_attn.{proj_name}.lora_A.weight"
                    key_B = f"base_model.model.model.layers.{layer_idx}.self_attn.{proj_name}.lora_B.weight"
                    lora_As.append(nn.Parameter(expert[key_A]))
                    lora_Bs.append(nn.Parameter(expert[key_B]))
                self.lora_params[f"layer{layer_idx}_{proj_name}_A"] = nn.ParameterList(lora_As)
                self.lora_params[f"layer{layer_idx}_{proj_name}_B"] = nn.ParameterList(lora_Bs)

    def _monkey_patch_linear_layers(self):
        print("\n[INFO] Replacing forward methods of linear layers (DeepSeek gating or domain-forced) ...")
        for layer_idx in range(self.num_layers):
            for proj_name in self.proj_names:
                key = f"layer{layer_idx}_{proj_name}"
                linear_layer = getattr(self.base_model.model.layers[layer_idx].self_attn, proj_name)
                if not isinstance(linear_layer, nn.Linear):
                    raise TypeError(f"Expected nn.Linear, but {proj_name} is {type(linear_layer)}")

                orig_forward = linear_layer.forward

                def make_forward(router_key, orig):
                    def new_forward(x):
                        B, seq_len, _ = x.shape
                        router = self.routers[router_key]
                        domain_ids = getattr(self, '_temp_domain_ids', None)

                        # Base model output
                        base_out = orig(x)

                        lora_As = self.lora_params[f"{router_key}_A"]
                        lora_Bs = self.lora_params[f"{router_key}_B"]

                        if self.freeze_router:
                            if domain_ids is None:
                                raise ValueError("freeze_router=True but domain_ids not provided")

                            expert_outputs = []
                            for i in range(B):
                                expert_id = domain_ids[i].item()
                                lora_A = lora_As[expert_id]
                                lora_B = lora_Bs[expert_id]

                                delta = F.linear(x[i], lora_A)
                                delta = F.linear(delta, lora_B)
                                expert_outputs.append(delta.unsqueeze(0))

                            lora_out = torch.cat(expert_outputs, dim=0)

                            domain_ids_2d = domain_ids.unsqueeze(1).expand(B, seq_len)
                            domain_ids_flat = domain_ids_2d.reshape(-1).cpu().numpy()
                            for idx in domain_ids_flat:
                                self.expert_count[idx] += 1
                                self.expert_activation[router_key][idx] += 1
                        else:
                            router_weights, aux_loss, top3_indices, _ = router(x)

                            if self.training and aux_loss is not None:
                                self.aux_loss += aux_loss

                            expert_outputs = []
                            for expert_idx in range(self.num_experts):
                                lora_A = lora_As[expert_idx]
                                lora_B = lora_Bs[expert_idx]
                                delta = F.linear(x, lora_A)
                                delta = F.linear(delta, lora_B)
                                expert_outputs.append(delta)

                            expert_outputs = torch.stack(expert_outputs, dim=-1)
                            lora_out = torch.einsum('bsen,bsn->bse', expert_outputs, router_weights)

                            top_indices_flat = top3_indices.view(-1).cpu().numpy()
                            for idx in top_indices_flat:
                                self.expert_count[idx] += 1
                                self.expert_activation[router_key][idx] += 1

                        assert base_out.shape == lora_out.shape, f"[{router_key}] shape mismatch: base={base_out.shape}, lora={lora_out.shape}"
                        final_out = base_out + lora_out
                        return final_out

                    return new_forward

                linear_layer.forward = make_forward(router_key=key, orig=orig_forward)

    def forward(self, input_ids, attention_mask=None, labels=None, domain_ids=None):
        """
        Add domain_ids to pass when freeze_router=True to force routing
        """
        self.aux_loss = 0.0
        self._temp_domain_ids = domain_ids if domain_ids is not None else None

        self.temp_domain_ids = domain_ids  # Temporarily save

        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )

        if self.training and self.alpha > 0:
            total_router_layers = self.num_layers * len(self.proj_names)
            self.aux_loss = self.aux_loss / total_router_layers

        expert_usage = self.get_expert_usage()
        self.expert_usage_list.append(expert_usage)

        return {
            'loss': outputs.loss,
            'logits': outputs.logits,
        }

    def generate(self, *args, **kwargs):
        """Let LoRA_Mixer directly call the base_model's generate method"""
        return self.base_model.generate(*args, **kwargs)

    def reset_expert_activation(self):
        """Reset expert activation counts for all layers"""
        for key in self.expert_activation:
            self.expert_activation[key] = [0] * self.num_experts

    def get_expert_activation_stats(self):
        """Get detailed expert activation stats"""
        total_stats = defaultdict(int)
        layer_stats = {}

        for key, counts in self.expert_activation.items():
            total = sum(counts)
            if total == 0:
                continue

            percentages = [c / total for c in counts]
            layer_stats[key] = {
                "total": total,
                "counts": counts.copy(),
                "percentages": percentages
            }

            for expert_idx, count in enumerate(counts):
                total_stats[expert_idx] += count

        global_total = sum(total_stats.values())
        global_percent = {}
        if global_total > 0:
            for expert in total_stats:
                global_percent[expert] = total_stats[expert] / global_total

        return {
            "layer_stats": layer_stats,
            "global_stats": {
                "total": global_total,
                "counts": dict(total_stats),
                "percentages": global_percent
            }
        }

    def get_expert_usage(self):
        total = self.expert_count.sum()
        usage_dict = {}
        if total > 0:
            for i, cnt in enumerate(self.expert_count):
                usage_dict[f"expert_{i}"] = float(cnt / total)
        else:
            for i, _ in enumerate(self.expert_count):
                usage_dict[f"expert_{i}"] = 0.0
        return usage_dict

    def save_moe_data(self, output_dir):
        """Save expert usage & token-level expert assignments"""
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, "expert_usage.json"), "w") as f:
            json.dump(self.expert_usage_list, f, indent=4)

        torch.save(self.expert_assignments, os.path.join(output_dir, "expert_assignments.pt"))
        print(f"✅ Expert data saved to {output_dir}")