import torch
from moe_factory import MoEFactory
from pathlib import Path
import json
from typing import Dict, Any
import tempfile
import shutil
from datasets import Dataset
import numpy as np
from transformers import DataCollatorForLanguageModeling
import os

class MOEFrameworkTester:
    def __init__(self, base_model_path: str):
        """
        Initialize the tester
        
        Args:
            base_model_path: Base model path (e.g., model path)
        """
        self.base_model_path = Path(base_model_path)
        if not self.base_model_path.exists():
            raise ValueError(f"Model path does not exist: {base_model_path}")
        
        # Check for required model files
        required_files = ['tokenizer.model', 'config.json']
        missing_files = [f for f in required_files if not (self.base_model_path / f).exists()]
        if missing_files:
            raise ValueError(f"Model directory is missing required files: {', '.join(missing_files)}")
        
        self.temp_dir = None  # Initialize as None
        self.test_configs = self._get_test_configs()
        
    def _get_test_configs(self) -> Dict[str, Dict[str, Any]]:
        """Get test configurations"""
        return {
            'baseline': {
                'num_experts': 2,
                'd_model': 512,
                'd_ff': 2048
            },
            'advanced': {
                'num_experts': 2,
                'd_model': 512,
                'd_ff': 2048,
                'top_k': 2,
                'capacity_factor': 1.2,
                'use_load_balancing': True,
                'use_z_loss': True,
                'z_loss_coef': 1e-3
            }
        }
    
    def _create_dummy_data(self, size: int = 100) -> Dataset:
        """Create test data"""
        # Create simple test data
        data = {
            'text': [f"This is test sentence {i}" for i in range(size)],
            'input_ids': [np.random.randint(0, 1000, 50).tolist() for _ in range(size)],
            'attention_mask': [[1] * 50 for _ in range(size)]
        }
        return Dataset.from_dict(data)
    
    def test_model_creation(self) -> Dict[str, Any]:
        """Test model creation"""
        print("\n1. Testing model creation...")
        results = {}
        
        for moe_type, config in self.test_configs.items():
            try:
                print(f"\nTesting creation of {moe_type} model...")
                model = AnonymizedFactory.create_model(
                    model_type=moe_type,
                    model_path=self.base_model_path,
                    config=config
                )
                results[moe_type] = {
                    'status': 'success',
                    'model': model
                }
                print(f"✅ {moe_type} model creation successful")
            except Exception as e:
                results[moe_type] = {
                    'status': 'failed',
                    'error': str(e)
                }
                print(f"❌ {moe_type} model creation failed: {str(e)}")
        
        return results
    
    def test_forward_pass(self, models: Dict[str, Any]):
        """Test forward pass"""
        print("\n2. Testing forward pass...")
        results = {}
        
        batch_size, seq_len = 4, 16
        for moe_type, model_info in models.items():
            if model_info['status'] == 'success':
                try:
                    print(f"\nTesting forward pass for {moe_type} model...")
                    model = model_info['model']
                    device = model.device
                    
                    # Create input data
                    input_ids = torch.randint(0, 32000, (batch_size, seq_len)).long().to(device)
                    attention_mask = torch.ones(batch_size, seq_len).long().to(device)
                    
                    # Build inputs
                    inputs = {
                        'input_ids': input_ids,
                        'attention_mask': attention_mask
                    }
                    
                    with torch.no_grad():
                        outputs = model(inputs)
                    
                    # Check output
                    if hasattr(outputs, 'last_hidden_state'):
                        output_shape = outputs.last_hidden_state.shape
                    elif hasattr(outputs, 'logits'):
                        output_shape = outputs.logits.shape
                    else:
                        raise ValueError("Model output format not recognized")
                    
                    results[moe_type] = {
                        'status': 'success',
                        'output_shape': list(output_shape)
                    }
                    print(f"✅ {moe_type} model forward pass successful")
                    print(f"Output shape: {output_shape}")
                    
                except Exception as e:
                    results[moe_type] = {
                        'status': 'failed',
                        'error': str(e)
                    }
                    print(f"❌ {moe_type} model forward pass failed: {str(e)}")
        
        return results
    
    def _setup_temp_dir(self):
        """Set up temporary directory"""
        if self.temp_dir is None:
            self.temp_dir = tempfile.mkdtemp()
        return self.temp_dir
    
    def test_save_load(self, models: Dict[str, Any]):
        """Test model save and load"""
        print("\n3. Testing save and load...")
        results = {}
        
        # Check available space and select appropriate temporary directory
        def get_free_space(path):
            stats = shutil.disk_usage(path)
            return stats.free / (1024 * 1024 * 1024)  # GB
        
        # Try different temporary directory locations
        temp_dirs = ['/anonymized-tmp/']
        selected_dir = None
        
        for dir_path in temp_dirs:
            if os.path.exists(dir_path):
                free_space = get_free_space(dir_path)
                if free_space > 20:
                    selected_dir = dir_path
                    break
        
        if not selected_dir:
            print("Warning: Insufficient storage space, skipping save/load test")
            for moe_type in models:
                results[moe_type] = {
                    'status': 'skipped',
                    'error': 'Insufficient storage space'
                }
            return results
        
        for moe_type, model_info in models.items():
            if model_info['status'] == 'success':
                try:
                    print(f"\nTesting save and load for {moe_type} model...")
                    model = model_info['model']
                    
                    # Use selected temporary directory
                    with tempfile.TemporaryDirectory(dir=selected_dir) as temp_dir:
                        save_path = os.path.join(temp_dir, f"test_save_{moe_type}")
                        
                        # Save model
                        AnonymizedFactory.save_model(model, save_path)
                        
                        # Release original model memory
                        del model
                        torch.cuda.empty_cache()
                        
                        # Load model
                        loaded_model = AnonymizedFactory.load_model(save_path)
                        
                        # Simple validation
                        input_ids = torch.randint(0, 32000, (2, 8)).to(loaded_model.device)
                        attention_mask = torch.ones(2, 8).to(loaded_model.device)
                        inputs = {
                            'input_ids': input_ids,
                            'attention_mask': attention_mask
                        }
                        
                        with torch.no_grad():
                            outputs = loaded_model(inputs)
                        
                        # Validate output format
                        if not (hasattr(outputs, 'last_hidden_state') or hasattr(outputs, 'logits')):
                            raise ValueError("Model output format not recognized")
                        
                        results[moe_type] = {
                            'status': 'success'
                        }
                        print(f"✅ {moe_type} model save/load test passed")
                        
                        # Clean up loaded model
                        del loaded_model
                        torch.cuda.empty_cache()
                        
                except Exception as e:
                    results[moe_type] = {
                        'status': 'failed',
                        'error': str(e)
                    }
                    print(f"❌ {moe_type} model save/load test failed: {str(e)}")
                    
                    # Ensure memory is cleaned up on error
                    if 'model' in locals():
                        del model
                    if 'loaded_model' in locals():
                        del loaded_model
                    torch.cuda.empty_cache()
        
        return results
    
    def test_training(self, models: Dict[str, Any]):
        """Test training functionality"""
        print("\n4. Testing training functionality...")
        results = {}
        
        # Create simple test dataset
        train_texts = [
            "This is a test sentence.",
            "Another test sentence.",
            "Third test sentence."
        ]
        
        for moe_type, model_info in models.items():
            if model_info['status'] == 'success':
                try:
                    print(f"\nTesting training for {moe_type} model...")
                    model = model_info['model']
                    
                    # Use tokenizer to process data
                    tokenizer = model.tokenizer
                    encoded = tokenizer(
                        train_texts,
                        padding=True,
                        truncation=True,
                        max_length=16,
                        return_tensors="pt"
                    )
                    
                    # Create simple dataset
                    from torch.utils.data import Dataset
                    
                    class SimpleDataset(Dataset):
                        def __init__(self, encodings):
                            self.encodings = encodings
                        
                        def __getitem__(self, idx):
                            item = {key: val[idx] for key, val in self.encodings.items()}
                            item['labels'] = item['input_ids'].clone()
                            return item
                        
                        def __len__(self):
                            return len(self.encodings['input_ids'])

                    train_dataset = SimpleDataset(encoded)
                    
                    # Use temporary directory
                    with tempfile.TemporaryDirectory() as temp_dir:
                        output_dir = os.path.join(temp_dir, f"test_output_{moe_type}")
                        
                        # Correct: Call supervised_finetuning with correct parameters
                        AnonymizedFactory.supervised_finetuning(
                            model=model,
                            train_dataset=train_dataset,  # Ensure parameter name is correct
                            output_dir=output_dir,
                            num_epochs=1,
                            batch_size=2,
                            learning_rate=1e-5,
                            logging_steps=1,
                            save_strategy="no"
                        )
                        
                        results[moe_type] = {'status': 'success'}
                        print(f"✅ {moe_type} model training test successful")
                        
                except Exception as e:
                    results[moe_type] = {
                        'status': 'failed',
                        'error': str(e)
                    }
                    print(f"❌ {moe_type} model training test failed: {str(e)}")
                    import traceback
                    print(traceback.format_exc())
        
        return results
    
    def test_model_parameters(self, models: Dict[str, Any]):
        """Test model parameter configuration"""
        print("\n5. Testing model parameter configuration...")
        results = {}
        
        for moe_type, model_info in models.items():
            if model_info['status'] == 'success':
                try:
                    model = model_info['model']
                    config = self.test_configs[moe_type]
                    
                    # Validate key parameters
                    actual_experts = model.num_experts
                    actual_d_model = model.d_model
                    
                    results[moe_type] = {
                        'status': 'success',
                        'parameter_check': {
                            'num_experts': actual_experts == config['num_experts'],
                            'd_model': actual_d_model == config['d_model']
                        }
                    }
                    print(f"✅ {moe_type} model parameter validation passed")
                except Exception as e:
                    results[moe_type] = {
                        'status': 'failed',
                        'error': str(e)
                    }
                    print(f"❌ {moe_type} model parameter validation failed: {str(e)}")
        return results
    
    def check_dependencies(self):
        """Check if necessary dependencies are installed"""
        try:
            import torch
            import transformers
            import sentencepiece
            import datasets
            import numpy
            import tqdm
            
            print("✅ All necessary dependencies are installed")
            return True
        except ImportError as e:
            print(f"❌ Missing necessary dependency: {str(e)}")
            print("\nPlease run the following command to install dependencies:")
            print("pip install -r requirements.txt")
            return False

    def run_all_tests(self):
        """Run all tests"""
        print("Starting comprehensive testing of the framework...")
        
        # First check dependencies
        if not self.check_dependencies():
            return
        
        # 1. Test model creation
        model_results = self.test_model_creation()
        
        # 2. Test forward pass
        forward_results = self.test_forward_pass(model_results)
        
        # 3. Test save and load
        save_load_results = self.test_save_load(model_results)
        
        # 4. Test training functionality
        training_results = self.test_training(model_results)
        
        # 5. Test model parameter configuration
        parameter_results = self.test_model_parameters(model_results)
        
        # 6. Test memory usage
        memory_results = self.test_memory_usage(model_results)
        
        # 7. Test expert usage balance
        expert_balance_results = self.test_expert_balance(model_results)
        
        # 8. Test model performance
        performance_results = self.test_performance(model_results)
        
        # Summarize results
        final_results = {
            'model_creation': {
                k: {
                    'status': v['status'],
                    'error': str(v.get('error', ''))
                } for k, v in model_results.items()
            },
            'forward_pass': {
                k: {
                    'status': v['status'],
                    'error': str(v.get('error', ''))
                } for k, v in forward_results.items()
            },
            'save_load': {
                k: {
                    'status': v['status'],
                    'error': str(v.get('error', ''))
                } for k, v in save_load_results.items()
            },
            'training': {
                k: {
                    'status': v['status'],
                    'error': str(v.get('error', ''))
                } for k, v in training_results.items()
            },
            'parameter_check': {
                k: {
                    'status': v['status'],
                    'error': str(v.get('error', ''))
                } for k, v in parameter_results.items()
            },
            'memory_usage': {
                k: {
                    'status': v['status'],
                    'error': str(v.get('error', ''))
                } for k, v in memory_results.items()
            },
            'expert_balance': {
                k: {
                    'status': v['status'],
                    'error': str(v.get('error', ''))
                } for k, v in expert_balance_results.items()
            },
            'performance': {
                k: {
                    'status': v['status'],
                    'error': str(v.get('error', ''))
                } for k, v in performance_results.items()
            }
        }
        
        # Print summary report
        print("\n=== Test Report ===")
        for test_name, results in final_results.items():
            print(f"\n{test_name}:")
            for moe_type, result in results.items():
                status = "✅ Passed" if result['status'] == 'success' else f"❌ Failed: {result.get('error', '')}"
                print(f"{moe_type}: {status}")
        
        # Save results
        with open("test_results.json", "w") as f:
            json.dump(final_results, f, indent=2, default=str)
        
        return final_results

    def test_memory_usage(self, models: Dict[str, Any]):
        """Test memory usage"""
        print("\n6. Testing memory usage...")
        results = {}
        
        batch_size, seq_len = 4, 16
        for moe_type, model_info in models.items():
            if model_info['status'] == 'success':
                try:
                    model = model_info['model']
                    device = model.device
                    
                    # Create input data
                    input_ids = torch.randint(0, 32000, (batch_size, seq_len)).to(device)
                    attention_mask = torch.ones(batch_size, seq_len).to(device)
                    
                    # Record initial memory
                    torch.cuda.empty_cache()
                    start_mem = torch.cuda.memory_allocated()
                    
                    # Run model
                    with torch.no_grad():
                        _ = model({'input_ids': input_ids, 'attention_mask': attention_mask})
                    
                    # Record peak memory
                    peak_mem = torch.cuda.max_memory_allocated()
                    
                    results[moe_type] = {
                        'status': 'success',
                        'memory_usage': {
                            'start_mem': start_mem,
                            'peak_mem': peak_mem,
                            'diff_mem': peak_mem - start_mem
                        }
                    }
                    print(f"✅ {moe_type} model memory usage test passed")
                    print(f"Memory usage: {(peak_mem - start_mem) / 1024 / 1024:.2f}MB")
                    
                except Exception as e:
                    results[moe_type] = {
                        'status': 'failed',
                        'error': str(e)
                    }
                    print(f"❌ {moe_type} model memory usage test failed: {str(e)}")
        
        return results

    def test_expert_balance(self, models: Dict[str, Any]):
        """Test expert usage balance"""
        print("\n7. Testing expert balance...")
        results = {}
        
        batch_size, seq_len = 4, 16
        for moe_type, model_info in models.items():
            if model_info['status'] == 'success':
                try:
                    model = model_info['model']
                    device = model.device
                    
                    # Create input data
                    input_ids = torch.randint(0, 32000, (batch_size, seq_len)).to(device)
                    attention_mask = torch.ones(batch_size, seq_len).to(device)
                    
                    # Run multiple forward passes
                    with torch.no_grad():
                        for _ in range(10):
                            _ = model({'input_ids': input_ids, 'attention_mask': attention_mask})
                    
                    # Collect expert usage statistics
                    stats = model.get_model_stats()
                    if 'routing_stats' in stats and stats['routing_stats']:
                        # Calculate usage variance
                        expert_usage = torch.tensor([
                            layer_stats['expert_utilization'] 
                            for layer_stats in stats['routing_stats']
                            if 'expert_utilization' in layer_stats
                        ])
                        usage_variance = expert_usage.var().item() if expert_usage.numel() > 0 else 0
                        
                        results[moe_type] = {
                            'status': 'success',
                            'usage_variance': usage_variance,
                            'is_balanced': usage_variance < 0.1  # Set threshold
                        }
                        print(f"✅ {moe_type} model expert balance test passed")
                        print(f"Usage variance: {usage_variance:.4f}")
                    else:
                        results[moe_type] = {
                            'status': 'failed',
                            'error': 'No routing statistics available'
                        }
                        print(f"❌ {moe_type} model expert balance test failed: No routing statistics available")
                
                except Exception as e:
                    results[moe_type] = {
                        'status': 'failed',
                        'error': str(e)
                    }
                    print(f"❌ {moe_type} model expert balance test failed: {str(e)}")
        
        return results

    def test_performance(self, models: Dict[str, Any]):
        """Test model performance"""
        print("\n8. Testing model performance...")
        results = {}
        
        batch_size, seq_len = 4, 16
        num_iterations = 100
        
        for moe_type, model_info in models.items():
            if model_info['status'] == 'success':
                try:
                    model = model_info['model']
                    device = model.device
                    
                    # Create input data
                    input_ids = torch.randint(0, 32000, (batch_size, seq_len)).to(device)
                    attention_mask = torch.ones(batch_size, seq_len).to(device)
                    inputs = {'input_ids': input_ids, 'attention_mask': attention_mask}
                    
                    # Warm-up
                    with torch.no_grad():
                        for _ in range(10):
                            _ = model(inputs)
                    
                    # Timing
                    start_time = torch.cuda.Event(enable_timing=True)
                    end_time = torch.cuda.Event(enable_timing=True)
                    
                    start_time.record()
                    with torch.no_grad():
                        for _ in range(num_iterations):
                            _ = model(inputs)
                    end_time.record()
                    
                    torch.cuda.synchronize()
                    elapsed_time = start_time.elapsed_time(end_time)
                    avg_time = elapsed_time / num_iterations
                    
                    results[moe_type] = {
                        'status': 'success',
                        'avg_inference_time': avg_time,
                        'throughput': batch_size * num_iterations / (elapsed_time / 1000)  # tokens/sec
                    }
                    print(f"✅ {moe_type} model performance test passed")
                    print(f"Average inference time: {avg_time:.2f}ms")
                    print(f"Throughput: {results[moe_type]['throughput']:.2f} tokens/sec")
                    
                except Exception as e:
                    results[moe_type] = {
                        'status': 'failed',
                        'error': str(e)
                    }
                    print(f"❌ {moe_type} model performance test failed: {str(e)}")
        
        return results

    def __del__(self):
        """Clean up temporary files"""
        try:
            if hasattr(self, 'temp_dir') and self.temp_dir is not None:
                shutil.rmtree(self.temp_dir)
                self.temp_dir = None
        except Exception as e:
            print(f"Error cleaning up temporary files: {str(e)}")

# Usage example
if __name__ == "__main__":
    try:
        # Initialize tester
        tester = AnonymizedFrameworkTester(base_model_path="/anonymized-tmp/model")
        
        # Run all tests
        results = tester.run_all_tests()
        
        # Save detailed test results
        with open("test_results.json", "w") as f:
            json.dump(results, f, indent=2)
    except Exception as e:
        print(f"❌ Error during testing: {str(e)}")