"""
Anonymized Mixture of Experts (MoE) Implementation

This module implements a framework for converting traditional Dense layers into MoE structures.
Main features:
1. Distributes the hidden dimension of FFN layers (d_ff) among multiple experts
2. Implements a dynamic routing mechanism to select the most suitable expert for each input token
3. Supports load balancing to avoid uneven expert usage
4. Provides detailed expert usage statistics

Main components:
- AnonymizedMoELayer: Single-layer MoE implementation with a router and multiple experts
- AnonymizedMoEModel: Multi-layer MoE model that can stack multiple MoE layers

Usage:
1. Initialize model:
   model = AnonymizedMoEModel(
       num_experts=8,      # Number of experts
       d_model=512,       # Input dimension
       d_ff=2048,        # Total FFN hidden dimension
       num_layers=1,     # Number of MoE layers
       use_load_balancing=True  # Whether to use load balancing
   )

2. Forward pass:
   output = model(input_tensor)

3. Get statistics:
   stats = model.get_model_stats()

"""

import torch
import torch.nn as nn
from transformers import (
    LlamaForCausalLM, 
    LlamaTokenizer, 
    LlamaConfig,
    PreTrainedModel,
    AutoModel,
)
from typing import List, Optional, Any, Dict, Tuple
from pathlib import Path
import json
import os
from torch.utils.checkpoint import checkpoint

class AnonymizedMoELayer(nn.Module):
    def __init__(self, 
                 num_experts: int, 
                 d_model: int, 
                 d_ff: int,
                 use_load_balancing: bool = False):
        super().__init__()
        self.num_experts = num_experts
        self.d_model = d_model
        # Distribute d_ff among experts
        self.d_ff_per_expert = d_ff // num_experts
        self.use_load_balancing = use_load_balancing
        
        # Create experts, each handling a portion of d_ff
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, self.d_ff_per_expert),
                nn.GELU(),
                nn.Linear(self.d_ff_per_expert, d_model)
            ) for _ in range(num_experts)
        ])
        
        # Router
        self.router = nn.Linear(d_model, num_experts)
        
        # Register buffer for statistics
        self.register_buffer('_router_stats', torch.zeros(num_experts))
        self.register_buffer('_total_tokens', torch.tensor(0))
        
    def _compute_routing_weights(self, router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute routing weights"""
        router_probs = torch.softmax(router_logits, dim=-1)
        weights, indices = torch.topk(router_probs, self.top_k, dim=-1)
        weights = weights / weights.sum(dim=-1, keepdim=True)
        
        return weights, indices
    
    def _compute_load_balancing_loss(self, router_probs: torch.Tensor) -> torch.Tensor:
        """Compute load balancing loss"""
        if not self.use_load_balancing or not self.training:
            return torch.tensor(0.0, device=router_probs.device)
            
        # Compute frequency of each expert
        freq = router_probs.mean(dim=0)
        # Compute ideal frequency
        ideal_freq = torch.ones_like(freq) / self.num_experts
        # Use KL divergence as loss
        loss = torch.sum(freq * torch.log(freq / ideal_freq))
        
        return loss
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute routing scores
        router_logits = self.router(x)
        weights, indices = self._compute_routing_weights(router_logits)
        
        # Process all experts in parallel
        expert_outputs = torch.zeros_like(x)
        for i, expert in enumerate(self.experts):
            # Use routing weights as scaling factor for each expert
            expert_output = expert(x)
            expert_outputs += weights[..., i:i+1] * expert_output
            
            # Update statistics
            if self.training:
                self._router_stats[i] += weights[..., i].sum()
                self._total_tokens += weights.shape[0] * weights.shape[1]
        
        # Add load balancing loss
        if self.training and self.use_load_balancing:
            expert_outputs = expert_outputs + self._compute_load_balancing_loss(weights)
            
        return expert_outputs

    def get_routing_stats(self) -> Dict[str, Any]:
        """Get routing statistics"""
        if self._total_tokens == 0:
            return None
        
        stats = {
            'expert_utilization': (self._router_stats / self._total_tokens).tolist(),
            'total_tokens': self._total_tokens.item()
        }
        
        # Reset statistics
        self._router_stats.zero_()
        self._total_tokens.zero_()
        
        return stats

class AnonymizedMoEModel(nn.Module):
    def __init__(self, 
                 num_experts: int,
                 d_model: int,
                 d_ff: int,
                 num_layers: int = 1,
                 use_load_balancing: bool = False):
        super().__init__()
        self.num_experts = num_experts
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_layers = num_layers
        
        # Create multi-layer MoE
        self.layers = nn.ModuleList([
            AnonymizedMoELayer(
                num_experts=num_experts,
                d_model=d_model,
                d_ff=d_ff,
                use_load_balancing=use_load_balancing
            ) for _ in range(num_layers)
        ])
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pass through each layer
        for layer in self.layers:
            x = layer(x)
        return x
    
    def get_model_stats(self) -> Dict[str, Any]:
        """Get model statistics"""
        routing_stats = []
        for i, layer in enumerate(self.layers):
            layer_stats = layer.get_routing_stats()
            if layer_stats:
                routing_stats.append({
                    'layer_index': i,
                    **layer_stats
                })
        
        return {
            'routing_stats': routing_stats
        }

class AdvancedAnonymizedMoEModel(nn.Module):
    def __init__(
        self,
        base_model: PreTrainedModel,
        num_experts: int = 8,
        expert_dim: Optional[int] = None,
        hidden_dim: Optional[int] = None,
        top_k: int = 2,
        capacity_factor: float = 1.2,
        eval_capacity_factor: Optional[float] = None,
        min_capacity: int = 4,
        use_load_balancing: bool = True,
        load_balancing_weight: float = 1.0,
        use_z_loss: bool = True,
        z_loss_coef: float = 1e-3,
        **kwargs
    ):
        super().__init__()
        self.base_model = base_model
        self.num_experts = num_experts
        self.expert_dim = expert_dim or self.base_model.config.hidden_size
        self.hidden_dim = hidden_dim or self.base_model.config.hidden_size
        self.top_k = top_k
        self.capacity_factor = capacity_factor
        self.eval_capacity_factor = eval_capacity_factor or capacity_factor
        self.min_capacity = min_capacity
        self.use_load_balancing = use_load_balancing
        self.load_balancing_weight = load_balancing_weight
        self.use_z_loss = use_z_loss
        self.z_loss_coef = z_loss_coef
        
        # Initialize gradient checkpointing flag
        self.gradient_checkpointing = False
        
        # Initialize expert layer
        self.moe_layer = AdvancedAnonymizedMoELayer(
            input_dim=self.hidden_dim,
            expert_dim=self.expert_dim,
            num_experts=num_experts,
            top_k=top_k,
            capacity_factor=capacity_factor,
            eval_capacity_factor=eval_capacity_factor,
            min_capacity=min_capacity,
            use_load_balancing=use_load_balancing,
            load_balancing_weight=load_balancing_weight,
            use_z_loss=use_z_loss,
            z_loss_coef=z_loss_coef
        )
        
        # Initialize routing statistics
        self.routing_stats = []
    
    def enable_gradient_checkpointing(self):
        """Enable gradient checkpointing"""
        self.gradient_checkpointing = True
        self.base_model.gradient_checkpointing_enable()
    
    def get_model_stats(self):
        """Get model statistics"""
        return {
            'routing_stats': self.routing_stats
        }
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        **kwargs
    ):
        # Get hidden states from base model
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            **kwargs
        )
        
        hidden_states = outputs.hidden_states[-1]
        
        # Apply MoE layer
        if self.gradient_checkpointing and self.training:
            moe_output = torch.utils.checkpoint.checkpoint(
                self.moe_layer,
                hidden_states
            )
        else:
            moe_output = self.moe_layer(hidden_states)
        
        # Get routing statistics
        if hasattr(self.moe_layer, 'get_routing_stats'):
            self.routing_stats.append(self.moe_layer.get_routing_stats())
        
        # Compute loss
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(moe_output.view(-1, self.hidden_dim), labels.view(-1))
            
            # Add z-loss (if enabled)
            if self.use_z_loss and hasattr(self.moe_layer, 'z_loss'):
                loss = loss + self.z_loss_coef * self.moe_layer.z_loss
        
        return MoEModelOutput(
            loss=loss,
            hidden_states=moe_output,
            routing_stats=self.routing_stats[-1] if self.routing_stats else None
        )
    
    def save_pretrained(self, save_directory: str):
        """Save model"""
        os.makedirs(save_directory, exist_ok=True)
        
        # Save base model
        self.base_model.save_pretrained(os.path.join(save_directory, "base_model"))
        
        # Save MoE configuration
        moe_config = {
            "num_experts": self.num_experts,
            "expert_dim": self.expert_dim,
            "hidden_dim": self.hidden_dim,
            "top_k": self.top_k,
            "capacity_factor": self.capacity_factor,
            "eval_capacity_factor": self.eval_capacity_factor,
            "min_capacity": self.min_capacity,
            "use_load_balancing": self.use_load_balancing,
            "load_balancing_weight": self.load_balancing_weight,
            "use_z_loss": self.use_z_loss,
            "z_loss_coef": self.z_loss_coef
        }
        
        with open(os.path.join(save_directory, "moe_config.json"), "w") as f:
            json.dump(moe_config, f)
        
        # Save MoE layer state dict
        torch.save(
            self.moe_layer.state_dict(),
            os.path.join(save_directory, "moe_layer.pt")
        )
    
    @classmethod
    def from_pretrained(
        cls,
        model_path: str,
        device: str = "cuda",
        **kwargs
    ):
        """Load from pretrained model"""
        # Load base model
        base_model = AutoModel.from_pretrained(
            os.path.join(model_path, "base_model")
        ).to(device)
        
        # Load MoE configuration
        with open(os.path.join(model_path, "moe_config.json"), "r") as f:
            moe_config = json.load(f)
        
        # Create model instance
        model = cls(base_model=base_model, **moe_config, **kwargs)
        
        # Load MoE layer state dict
        moe_state_dict = torch.load(
            os.path.join(model_path, "moe_layer.pt"),
            map_location=device
        )
        model.moe_layer.load_state_dict(moe_state_dict)
        
        return model.to(device)