import torch
import torch.nn as nn
from transformers import (
    LlamaForCausalLM, 
    LlamaTokenizer, 
    LlamaConfig, 
    Trainer, 
    TrainingArguments,
    DataCollatorForLanguageModeling,
    AutoTokenizer,
    AutoModelForCausalLM
)
from typing import List, Optional, Any, Dict, Tuple
from pathlib import Path
from tqdm import tqdm
import json
import torch.nn.functional as F

class AdvancedMoELayer(nn.Module):
    def __init__(self, 
                 num_experts: int, 
                 d_model: int, 
                 d_ff: int,
                 top_k: int = 1,
                 capacity_factor: float = 1.0,
                 use_load_balancing: bool = False,
                 use_z_loss: bool = False,
                 z_loss_coef: float = 1e-3):
        super().__init__()
        self.num_experts = num_experts
        self.d_model = d_model
        self.d_ff = d_ff
        self.top_k = top_k
        self.capacity_factor = capacity_factor
        self.use_load_balancing = use_load_balancing
        self.use_z_loss = use_z_loss
        self.z_loss_coef = z_loss_coef
        
        self.capacity = int(capacity_factor * d_model * top_k / num_experts)
        
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.GELU(),
                nn.Linear(d_ff, d_model)
            ) for _ in range(num_experts)
        ])
        
        self.router = nn.Linear(d_model, num_experts)
        
        self.register_buffer('_router_stats', torch.zeros(num_experts))
        self.register_buffer('_total_tokens', torch.tensor(0))
        
    def _compute_routing_weights(self, router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Compute routing weights and expert assignments"""
        router_probs = torch.softmax(router_logits, dim=-1)
        
        weights, indices = torch.topk(router_probs, self.top_k, dim=-1)
        
        weights = weights / weights.sum(dim=-1, keepdim=True)
        
        z_loss = torch.tensor(0.0, device=router_logits.device)
        if self.use_z_loss and self.training:
            z_loss = torch.mean(torch.logsumexp(router_logits, dim=-1) ** 2)
            z_loss = self.z_loss_coef * z_loss
            
        return weights, indices, z_loss
    
    def _compute_load_balancing_loss(self, router_probs: torch.Tensor) -> torch.Tensor:
        """Compute load balancing loss"""
        if not self.use_load_balancing or not self.training:
            return torch.tensor(0.0, device=router_probs.device)
            
        freq = router_probs.mean(dim=0)
        ideal_freq = torch.ones_like(freq) / self.num_experts
        loss = torch.sum(freq * torch.log(freq / ideal_freq))
        return loss
        
    def _process_expert_outputs(self, x: torch.Tensor, weights: torch.Tensor, 
                              indices: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, d_model = x.shape
        device = x.device

        flat_x = x.view(-1, d_model)  # [batch_size * seq_len, d_model]
        flat_weights = weights.view(-1, self.top_k)  # [batch_size * seq_len, top_k]
        flat_indices = indices.view(-1, self.top_k)  # [batch_size * seq_len, top_k]

        output = torch.zeros_like(flat_x)
        
        position_in_expert = torch.zeros(self.num_experts, device=device)
        
        for i in range(flat_x.shape[0]):
            for k in range(self.top_k):
                expert_idx = flat_indices[i, k]
                # Check expert capacity
                if position_in_expert[expert_idx] < self.capacity:
                    # Process sample
                    expert_output = self.experts[expert_idx](flat_x[i:i+1])
                    output[i] += flat_weights[i, k] * expert_output.squeeze(0)
                    position_in_expert[expert_idx] += 1
                    
                    # Update routing statistics
                    if self.training:
                        self._router_stats[expert_idx] += 1
                        self._total_tokens += 1
        
        return output.view(batch_size, seq_len, d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass"""
        # Compute routing scores
        router_logits = self.router(x)
        weights, indices, z_loss = self._compute_routing_weights(router_logits)
        
        # Compute load balancing loss
        balance_loss = self._compute_load_balancing_loss(torch.softmax(router_logits, dim=-1))
        
        # Process all experts in parallel with capacity control and top-k routing
        expert_outputs = self._process_expert_outputs(x, weights, indices)
        
        # Add auxiliary losses if enabled
        if self.training:
            if self.use_load_balancing:
                expert_outputs = expert_outputs + balance_loss
                if self.use_z_loss:
                    expert_outputs = expert_outputs + z_loss
        
        return expert_outputs

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """Make the model callable"""
        return self.forward(x)

    def get_routing_stats(self) -> Dict[str, Any]:
        """Get routing statistics"""
        if self._total_tokens == 0:
            return None
        
        stats = {
            'expert_utilization': (self._router_stats / self._total_tokens).tolist(),
            'total_tokens': self._total_tokens.item()
        }
        
        # Reset statistics
        self._router_stats.zero_()
        self._total_tokens.zero_()
        
        return stats

"""
Implementation:
1. Add Switch Transformer-style load balancing
2. Implement Top-k routing to allow multiple experts to collaborate
3. Add Expert Capacity limitation
4. Include auxiliary losses such as Z-Loss to improve training
"""
class AdvancedMoEModel(nn.Module):
    def __init__(self, model_path: str, device: str = "cuda", **kwargs):
        super().__init__()
        self.device = device
        self.num_experts = kwargs.pop('num_experts')
        self.top_k = kwargs.get('top_k', 1)
        self.capacity_factor = kwargs.get('capacity_factor', 1.0)
        self.use_load_balancing = kwargs.get('use_load_balancing', False)
        self.use_z_loss = kwargs.get('use_z_loss', False)
        self.z_loss_coef = kwargs.get('z_loss_coef', 1e-3)
        
        
        self.model = AutoModelForCausalLM.from_pretrained(model_path)
        self.d_model = self.model.config.hidden_size
        
       
        self.router = nn.Linear(self.d_model, self.num_experts).to(device)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.d_model, 4 * self.d_model),
                nn.GELU(),
                nn.Linear(4 * self.d_model, self.d_model)
            ).to(device) for _ in range(self.num_experts)
        ])
        
        
        self._router_stats = torch.zeros(self.num_experts).to(device)
        self._total_tokens = 0
        
    def forward(self, inputs):
       
        if isinstance(inputs, dict):
            input_ids = inputs.get('input_ids')
            attention_mask = inputs.get('attention_mask')
            labels = inputs.get('labels')
        else:
            input_ids = inputs
            attention_mask = None
            labels = None
        
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True
        )
        
        hidden_states = outputs.hidden_states[-1] 
        
        batch_size, seq_len, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)  
        
        router_logits = self.router(hidden_states)
        router_probs = F.softmax(router_logits, dim=-1)
       
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True) 
        
        expert_outputs = torch.zeros_like(hidden_states)
        for i, expert in enumerate(self.experts):
            mask = (top_k_indices == i).any(dim=-1)
            if mask.any():
                expert_output = expert(hidden_states[mask])
                expert_outputs[mask] += expert_output * router_probs[mask, i].unsqueeze(-1)
           
        expert_outputs = expert_outputs.view(batch_size, seq_len, hidden_dim)
        
        outputs.hidden_states = outputs.hidden_states[:-1] + (expert_outputs,)
        outputs.last_hidden_state = expert_outputs
        
        return outputs

    def _compute_load_balancing_loss(self, router_probs):
        expert_usage = router_probs.mean(dim=0)
        target_usage = torch.ones_like(expert_usage) / self.num_experts
        balance_loss = torch.mean((expert_usage - target_usage) ** 2)
        return balance_loss

    def save_model(self, save_path: str):
        save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)
        
        self.model.config.save_pretrained(save_path)
        
        moe_config = {
            'num_experts': self.num_experts,
            'top_k': self.top_k,
            'capacity_factor': self.capacity_factor,
            'use_load_balancing': self.use_load_balancing,
            'use_z_loss': self.use_z_loss,
            'z_loss_coef': self.z_loss_coef
        }
        with open(save_path / "moe_config.json", 'w') as f:
            json.dump(moe_config, f, indent=2)
        
        torch.save(self.model.state_dict(), save_path / "pytorch_model.bin")
        
        self.tokenizer.save_pretrained(save_path)
    
    @torch.no_grad()
    def generate(self, 
                prompts: List[str], 
                max_length: int = 100,
                temperature: float = 0.7,
                batch_size: int = 4) -> List[str]:
        self.model.eval()
        all_outputs = []

        for i in range(0, len(prompts), batch_size):
            batch_prompts = prompts[i:i + batch_size]
            inputs = self.tokenizer(batch_prompts, 
                                  return_tensors="pt", 
                                  padding=True).to(self.device)
            
            outputs = self.model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=max_length,
                temperature=temperature,
                do_sample=True
            )
            
            decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
            all_outputs.extend(decoded)
        
        return all_outputs

    def continue_pretraining(self,
                           train_data: Any,
                           output_dir: str,
                           num_epochs: int = 3,
                           batch_size: int = 4,
                           learning_rate: float = 2e-5,
                           save_steps: int = 1000,
                           gradient_accumulation_steps: int = 4,
                           warmup_steps: int = 100,
                           **kwargs):
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            learning_rate=learning_rate,
            warmup_steps=warmup_steps,
            save_steps=save_steps,
            **kwargs
        )
        
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False  
        )
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_data,
            data_collator=data_collator,
        )
        
        trainer.train()
        
    def supervised_finetuning(self,
                            train_data: Any,
                            output_dir: str,
                            eval_data: Optional[Any] = None,
                            num_epochs: int = 3,
                            batch_size: int = 4,
                            learning_rate: float = 2e-5,
                            save_steps: int = 1000,
                            gradient_accumulation_steps: int = 4,
                            warmup_steps: int = 100,
                            **kwargs):
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            learning_rate=learning_rate,
            warmup_steps=warmup_steps,
            save_steps=save_steps,
            evaluation_strategy="steps" if eval_data else "no",
            **kwargs
        )
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_data,
            eval_dataset=eval_data,
            tokenizer=self.tokenizer
        )
        
        trainer.train()
        
    def get_model_stats(self):
        with torch.no_grad():
            router_weights = self.router.weight.softmax(dim=-1)
            expert_usage = router_weights.mean(dim=0)
            load_balance = (expert_usage - (1.0 / self.num_experts)) ** 2
            
            return {
                'routing_stats': [{
                    'expert_utilization': expert_usage.tolist(),
                    'expert_capacity': self.capacity_factor,
                    'load_balancing': load_balance.mean().item(),
                }]
            }

if __name__ == "__main__":
    model = AdvancedMoEModel(
        model_path="/path/to/model",
        num_experts=8,
        top_k=2,
        capacity_factor=1.2,
        use_load_balancing=True,
        use_z_loss=True,
        z_loss_coef=1e-3
    )

    model.save_model("saved_moe_model")
    

    prompts = ["Hello, how are", "The weather is"]
    outputs = model.generate(prompts)
    print(outputs)

    stats = model.get_model_stats()
    print(json.dumps(stats, indent=2))