import torch
import torch.nn as nn
from transformers import (
    LlamaForCausalLM, 
    LlamaTokenizer, 
    LlamaConfig, 
    Trainer, 
    TrainingArguments,
    DataCollatorForLanguageModeling,
    AutoTokenizer,
    AutoModelForCausalLM,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import List, Optional, Any, Dict
from pathlib import Path
from tqdm import tqdm
import json
import torch.nn.functional as F

class BaselineMoELayer(nn.Module):
    def __init__(self, num_experts: int, d_model: int, d_ff: int):
        super().__init__()
        self.num_experts = num_experts
        self.d_model = d_model
        self.d_ff = d_ff
        
        # Create experts
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.GELU(),
                nn.Linear(d_ff, d_model)
            ) for _ in range(num_experts)
        ])
        
        # Router
        self.router = nn.Linear(d_model, num_experts)
        
        # Register buffer for statistics
        self.register_buffer('_router_stats', torch.zeros(num_experts))
        self.register_buffer('_total_tokens', torch.tensor(0))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute routing scores
        router_logits = self.router(x)
        router_probs = torch.softmax(router_logits, dim=-1)
        
        # Select top-1 expert
        expert_weights, expert_indices = torch.max(router_probs, dim=-1)
        
        # Apply experts
        output = torch.zeros_like(x)
        for i in range(self.num_experts):
            mask = (expert_indices == i)
            if mask.any():
                output[mask] = self.experts[i](x[mask])
                if self.training:
                    self._router_stats[i] += mask.sum()
                    self._total_tokens += mask.sum()
        
        return output

    def get_routing_stats(self) -> Dict[str, float]:
        """Get routing statistics"""
        if self._total_tokens == 0:
            return None
        
        stats = {
            'expert_utilization': (self._router_stats / self._total_tokens).tolist(),
            'total_tokens': self._total_tokens.item()
        }
        
        # Reset statistics
        self._router_stats.zero_()
        self._total_tokens.zero_()
        
        return stats

class BaselineMoEModel(nn.Module):
    def __init__(self, model_path: str, num_experts: int, device: str = "cuda"):
        super().__init__()
        self.device = device
        self.num_experts = num_experts
        
        # Initialize base model
        self.model = AutoModelForCausalLM.from_pretrained(model_path)
        self.d_model = self.model.config.hidden_size
        
        # Initialize expert layers
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.d_model, 4 * self.d_model),
                nn.GELU(),
                nn.Linear(4 * self.d_model, self.d_model)
            ).to(device) for _ in range(num_experts)
        ])
        
        # Initialize router
        self.router = nn.Linear(self.d_model, num_experts).to(device)
        
    def forward(self, inputs):
        """
        Forward pass
        Args:
            inputs: dict type, contains keys like 'input_ids', 'attention_mask', etc.
        """
        # Get inputs
        if isinstance(inputs, dict):
            input_ids = inputs.get('input_ids')
            attention_mask = inputs.get('attention_mask')
            labels = inputs.get('labels')
        else:
            input_ids = inputs
            attention_mask = None
            labels = None
        
        # Get base model outputs
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True
        )
        
        hidden_states = outputs.hidden_states[-1]  # Use last layer hidden states
        
        # Apply MoE layer
        batch_size, seq_len, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)  # Flatten for processing
        
        # Compute routing scores
        router_logits = self.router(hidden_states)
        router_probs = F.softmax(router_logits, dim=-1)
        
        # Compute expert outputs
        expert_outputs = torch.zeros_like(hidden_states)
        for i, expert in enumerate(self.experts):
            expert_output = expert(hidden_states)
            expert_outputs += expert_output * router_probs[:, i:i+1]
        
        # Reshape back to original dimensions
        expert_outputs = expert_outputs.view(batch_size, seq_len, hidden_dim)
        
        # If in training mode and labels are provided, compute loss
        if labels is not None:
            # Use language model head to compute logits
            lm_head = self.model.get_output_embeddings()
            logits = lm_head(expert_outputs)
            
            # Compute loss
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
            
            return CausalLMOutputWithPast(
                loss=loss,
                logits=logits,
                past_key_values=outputs.past_key_values,
                hidden_states=expert_outputs,
            )
        
        # If in inference mode, only return logits
        lm_head = self.model.get_output_embeddings()
        logits = lm_head(expert_outputs)
        
        return CausalLMOutputWithPast(
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=expert_outputs,
        )

    def save_model(self, save_path: str):
        """Save MOE model"""
        save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)
        
        # Save config
        self.model.config.save_pretrained(save_path)
        
        # Save weights
        torch.save(self.model.state_dict(), save_path / "pytorch_model.bin")
        
        # Save tokenizer
        self.tokenizer.save_pretrained(save_path)
    
    @torch.no_grad()
    def generate(self, 
                prompts: List[str], 
                max_length: int = 100,
                temperature: float = 0.7,
                batch_size: int = 4) -> List[str]:
        """Generate text"""
        self.model.eval()
        all_outputs = []
        
        # Batch generation
        for i in range(0, len(prompts), batch_size):
            batch_prompts = prompts[i:i + batch_size]
            inputs = self.tokenizer(batch_prompts, 
                                  return_tensors="pt", 
                                  padding=True).to(self.device)
            
            outputs = self.model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=max_length,
                temperature=temperature,
                do_sample=True
            )
            
            decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
            all_outputs.extend(decoded)
        
        return all_outputs

    def continue_pretraining(self,
                           train_data: Any,
                           output_dir: str,
                           num_epochs: int = 3,
                           batch_size: int = 4,
                           learning_rate: float = 2e-5,
                           save_steps: int = 1000,
                           **kwargs):
        """Continue pretraining"""
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            learning_rate=learning_rate,
            save_steps=save_steps,
            **kwargs
        )
        
        # Create data collator
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False  # Use causal language modeling mode
        )
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_data,
            data_collator=data_collator,
        )
        
        trainer.train()
        
    def supervised_finetuning(self,
                            train_data: Any,
                            output_dir: str,
                            eval_data: Optional[Any] = None,
                            num_epochs: int = 3,
                            batch_size: int = 4,
                            learning_rate: float = 2e-5,
                            save_steps: int = 1000,
                            **kwargs):
        """Supervised fine-tuning"""
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            learning_rate=learning_rate,
            save_steps=save_steps,
            evaluation_strategy="steps" if eval_data else "no",
            **kwargs
        )
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_data,
            eval_dataset=eval_data,
            tokenizer=self.tokenizer
        )
        
        trainer.train()
        
    def get_model_stats(self):
        """Get model statistics"""
        return {
            'routing_stats': [{
                'expert_utilization': self.router.weight.softmax(dim=-1).mean(dim=0).tolist(),
                'expert_capacity': 1.0,  # No capacity limit in base version
                'load_balancing': None,  # No load balancing in base version
            }]
        }

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """Make model callable"""
        return self.forward(x)

# Usage example
if __name__ == "__main__":
    # Initialize model
    model = BaselineMoEModel(
        model_path="/path/to/model",
        num_experts=8
    )
    
    # Example: continue pretraining
    # train_data = ...  # Prepare pretraining data
    # model.continue_pretraining(
    #     train_data=train_data,
    #     output_dir="pretrain_output",
    #     num_epochs=3
    # )
    
    # Example: supervised fine-tuning
    # train_data = ...  # Prepare fine-tuning data
    # eval_data = ...   # Prepare evaluation data
    # model.supervised_finetuning(
    #     train_data=train_data,
    #     eval_data=eval_data,
    #     output_dir="finetune_output",
    #     num_epochs=3
    # )
    
    # Save model
    model.save_model("saved_moe_model")
    
    # Generate text
    prompts = ["Hello, how are", "The weather is"]
    outputs = model.generate(prompts)
    print(outputs)
    
    # Get statistics
    stats = model.get_model_stats()
    print(json.dumps(stats, indent=2))