from enum import Enum
from typing import Dict, Any, Optional
import torch
import torch.nn as nn
from pathlib import Path
import json
import os
from transformers import AutoTokenizer
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

from baseline_moe import BaselineMoEModel
from advanced_moe import AdvancedMoEModel
from sliced_moe import SlicedMoEModel
class MoEType(Enum):
    BASELINE = "baseline"
    ADVANCED = "advanced"
    SLICED = "sliced"

class MoEFactory:
    @staticmethod
    def create_moe(
        moe_type: str,
        model_path: str,
        config: Dict[str, Any],
        device: str = "cuda"
    ) -> nn.Module:
        """
        Create a specified type of MOE model
        
        Args:
            moe_type: Type of MOE ("baseline", "advanced", or "sliced")
            model_path: Base model path
            config: MOE configuration parameters
            device: Device to run on
        
        Returns:
            Created MOE model
        """
        moe_type = MoEType(moe_type.lower())
        
        # Set tokenizer's padding token
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
            # Add special token
            special_tokens_dict = {'pad_token': '[PAD]'}
            tokenizer.add_special_tokens(special_tokens_dict)
        
        # Create the corresponding type of model
        if moe_type == MoEType.BASELINE:
            model = BaselineMoEModel(
                model_path=model_path,
                num_experts=config['num_experts'],
                device=device
            )
        elif moe_type == MoEType.ADVANCED:
            model = AdvancedMoEModel(
                model_path=model_path,
                device=device,
                **config
            )
        elif moe_type == MoEType.SLICED:
            model = SlicedMoEModel(
                model_path=model_path,
                num_experts=config['num_experts'],
                use_load_balancing=config.get('use_load_balancing', False),
                device=device
            )
        else:
            raise ValueError(f"Unknown MOE type: {moe_type}")
        
        # Ensure the model is on the correct device
        model = model.to(device)
        model.tokenizer = tokenizer
        return model

    @staticmethod
    def save_moe(model: nn.Module, save_path: str):
        """Save MoE model"""
        try:
            save_path = Path(save_path)
            save_path.mkdir(parents=True, exist_ok=True)
            
            # 1. Save the base pretrained model to a subdirectory
            base_model_path = save_path / "base_model"
            base_model_path.mkdir(exist_ok=True)
            model.model.save_pretrained(base_model_path)
            model.tokenizer.save_pretrained(base_model_path)
            
            # 2. Save MoE-specific components
            moe_state = {
                'router': model.router.state_dict(),
                'experts': [expert.state_dict() for expert in model.experts],
                'num_experts': model.num_experts,
                'd_model': model.d_model,
                'model_type': model.__class__.__name__,
                'base_model_path': str(base_model_path)
            }
            torch.save(moe_state, save_path / "moe_state.pt")
            
        except Exception as e:
            print(f"Error saving model: {str(e)}")
            raise

    @staticmethod
    def load_moe(load_path: str, device: str = "cuda") -> nn.Module:
        """Load MoE model"""
        try:
            load_path = Path(load_path)
            
            # 1. Load MoE state
            moe_state = torch.load(load_path / "moe_state.pt", map_location=device)
            
            # 2. Create the corresponding type of model
            base_model_path = moe_state['base_model_path']
            model = MoEFactory.create_moe(
                moe_type="baseline" if moe_state['model_type'] == "BaselineMoEModel" else "advanced",
                model_path=base_model_path,
                config={'num_experts': moe_state['num_experts']},
                device=device
            )
            
            # 3. Load the state of MoE components
            model.router.load_state_dict(moe_state['router'])
            for expert, expert_state in zip(model.experts, moe_state['experts']):
                expert.load_state_dict(expert_state)
            
            return model
        
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            raise

    @staticmethod
    def continue_pretraining(
        model: nn.Module,
        train_data: Any,
        output_dir: str,
        **training_args
    ):
        """
        Continue pretraining the MOE model
        
        Args:
            model: MOE model
            train_data: Training data
            output_dir: Output directory
            **training_args: Training parameters
        """
        if isinstance(model, (BaselineMoEModel, AdvancedMoEModel)):
            model.continue_pretraining(
                train_data=train_data,
                output_dir=output_dir,
                **training_args
            )
        else:
            raise ValueError("Unsupported model type")

    @staticmethod
    def supervised_finetuning(
        model,
        train_dataset,  # Ensure this is the first required parameter
        output_dir: str,
        num_epochs: int = 3,
        batch_size: int = 4,
        learning_rate: float = 2e-5,
        warmup_steps: int = 100,
        logging_steps: int = 10,
        save_strategy: str = "steps",
        **kwargs
    ):
        """Perform supervised fine-tuning"""
        from transformers import TrainingArguments, Trainer
        
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            learning_rate=learning_rate,
            warmup_steps=warmup_steps,
            logging_steps=logging_steps,
            save_strategy=save_strategy,
            **kwargs
        )
        
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
        )
        
        trainer.train()

# Usage example
if __name__ == "__main__":
    # Create sliced MOE
    sliced_config = {
        'num_experts': 8,
        'use_load_balancing': True
    }
    sliced_model = MoEFactory.create_moe(
        moe_type='sliced',
        model_path='llama_base',
        config=sliced_config
    )
    
    # Create baseline MOE
    baseline_config = {
        'num_experts': 8
    }
    baseline_model = MoEFactory.create_moe(
        moe_type='baseline',
        model_path='llama_base',
        config=baseline_config
    )
    
    # Create advanced MOE
    advanced_config = {
        'num_experts': 8,
        'top_k': 2,
        'capacity_factor': 1.2,
        'use_load_balancing': True,
        'use_z_loss': True,
        'z_loss_coef': 1e-3
    }
    advanced_model = MoEFactory.create_moe(
        moe_type='advanced',
        model_path='llama_base',
        config=advanced_config
    )
    
    # Save models
    MoEFactory.save_moe(baseline_model, "saved_baseline_moe")
    MoEFactory.save_moe(advanced_model, "saved_advanced_moe")
    MoEFactory.save_moe(sliced_model, "saved_sliced_moe")
    
    # Load model
    loaded_model = MoEFactory.load_moe("saved_advanced_moe")
    