"""
Main experiment runner for Multi-Scale Attention U-Net
This is the entrypoint script for reproducing all results
"""

import torch
import torch.nn as nn
import numpy as np
import json
import os
import time
from datetime import datetime
from typing import Dict, List, Tuple
import argparse
import os
import random
import numpy as np

try:
    import torch
except Exception:
    torch = None


def set_reproducibility(seed: int = 42) -> None:
    """Set seeds and determinism for reproducible runs."""
    random.seed(seed)
    np.random.seed(seed)
    if torch is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
import random

# Set random seeds for reproducibility
def set_seed(seed: int = 42):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Import our modules
from dataset import create_dataloaders, DatasetConfig
from model import MSAUNet, BaselineUNet, count_parameters
from trainer import ModelTrainer, create_trainer_config
from metrics import compute_model_efficiency
from visualization import SegmentationVisualizer, create_visualization_report
from losses import create_loss_function

class ExperimentRunner:
    """Main experiment runner class"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.results = {}
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.output_dir = f"results_{self.timestamp}"
        
        # Create output directory
        os.makedirs(self.output_dir, exist_ok=True)
        os.makedirs(os.path.join(self.output_dir, 'figures'), exist_ok=True)
        os.makedirs(os.path.join(self.output_dir, 'checkpoints'), exist_ok=True)
        
        print(f"Experiment started at {datetime.now()}")
        print(f"Device: {self.device}")
        print(f"Output directory: {self.output_dir}")
        
    def run_baseline_experiments(self) -> Dict[str, Dict[str, float]]:
        """Run baseline model experiments"""
        print("\n" + "="*60)
        print("RUNNING BASELINE EXPERIMENTS")
        print("="*60)
        
        baseline_results = {}
        
        # Create dataset
        dataset_config = DatasetConfig()
        train_loader, val_loader, test_loader = create_dataloaders(
            dataset_config, 
            batch_size=self.config.get('batch_size', 16)
        )
        
        # Test different baseline models
        baseline_models = {
            'U-Net': BaselineUNet(in_channels=3, num_classes=5),
            'MSA-UNet': MSAUNet(in_channels=3, num_classes=5, num_heads=4)
        }
        
        for model_name, model in baseline_models.items():
            print(f"\nTraining {model_name}...")
            
            # Create trainer config
            trainer_config = create_trainer_config()
            trainer_config.update({
                'num_epochs': self.config.get('num_epochs', 50),
                'learning_rate': self.config.get('learning_rate', 0.001),
                'batch_size': self.config.get('batch_size', 16)
            })
            
            # Create trainer
            trainer = ModelTrainer(
                model, train_loader, val_loader, test_loader, 
                self.device, trainer_config
            )
            
            # Train model
            history = trainer.train(
                num_epochs=trainer_config['num_epochs'],
                save_dir=os.path.join(self.output_dir, 'checkpoints', model_name.lower().replace('-', '_'))
            )
            
            # Evaluate model
            test_metrics = trainer.evaluate_model()
            baseline_results[model_name] = test_metrics
            
            # Save training history
            with open(os.path.join(self.output_dir, f'{model_name.lower().replace("-", "_")}_history.json'), 'w') as f:
                json.dump(history, f, indent=2)
            
            # Plot training curves
            trainer.plot_training_history(
                save_path=os.path.join(self.output_dir, 'figures', f'{model_name.lower().replace("-", "_")}_training_curves.png')
            )
            
            print(f"{model_name} Results:")
            print(f"  Dice Score: {test_metrics['mean_dice']:.4f}")
            print(f"  IoU Score: {test_metrics['mean_iou']:.4f}")
            print(f"  Hausdorff Distance: {test_metrics['mean_hausdorff']:.4f}")
            print(f"  Boundary F1: {test_metrics['mean_boundary_f1']:.4f}")
            print(f"  Parameters: {count_parameters(model):,}")
        
        return baseline_results
    
    def run_ablation_studies(self) -> Dict[str, Dict[str, float]]:
        """Run ablation studies"""
        print("\n" + "="*60)
        print("RUNNING ABLATION STUDIES")
        print("="*60)
        
        ablation_results = {}
        
        # Create dataset
        dataset_config = DatasetConfig()
        train_loader, val_loader, test_loader = create_dataloaders(
            dataset_config, 
            batch_size=self.config.get('batch_size', 16)
        )
        
        # Test different configurations
        ablation_configs = {
            'MSA-UNet-1Head': {'num_heads': 1},
            'MSA-UNet-2Heads': {'num_heads': 2},
            'MSA-UNet-4Heads': {'num_heads': 4},
            'MSA-UNet-8Heads': {'num_heads': 8}
        }
        
        for config_name, model_config in ablation_configs.items():
            print(f"\nTesting {config_name}...")
            
            # Create model
            model = MSAUNet(
                in_channels=3, 
                num_classes=5, 
                num_heads=model_config['num_heads']
            )
            
            # Create trainer config
            trainer_config = create_trainer_config()
            trainer_config.update({
                'num_epochs': self.config.get('ablation_epochs', 30),
                'learning_rate': self.config.get('learning_rate', 0.001),
                'batch_size': self.config.get('batch_size', 16)
            })
            
            # Create trainer
            trainer = ModelTrainer(
                model, train_loader, val_loader, test_loader, 
                self.device, trainer_config
            )
            
            # Train model
            history = trainer.train(
                num_epochs=trainer_config['num_epochs'],
                save_dir=os.path.join(self.output_dir, 'checkpoints', config_name.lower().replace('-', '_'))
            )
            
            # Evaluate model
            test_metrics = trainer.evaluate_model()
            ablation_results[config_name] = test_metrics
            
            print(f"{config_name} Results:")
            print(f"  Dice Score: {test_metrics['mean_dice']:.4f}")
            print(f"  Parameters: {count_parameters(model):,}")
        
        return ablation_results
    
    def run_efficiency_analysis(self) -> Dict[str, Dict[str, float]]:
        """Run model efficiency analysis"""
        print("\n" + "="*60)
        print("RUNNING EFFICIENCY ANALYSIS")
        print("="*60)
        
        efficiency_results = {}
        
        # Test different models
        models = {
            'U-Net': BaselineUNet(in_channels=3, num_classes=5),
            'MSA-UNet': MSAUNet(in_channels=3, num_classes=5, num_heads=4)
        }
        
        for model_name, model in models.items():
            print(f"\nAnalyzing {model_name} efficiency...")
            
            # Compute efficiency metrics
            efficiency = compute_model_efficiency(
                model, (3, 512, 512), self.device, num_runs=100
            )
            
            efficiency_results[model_name] = efficiency
            
            print(f"{model_name} Efficiency:")
            print(f"  Inference Time: {efficiency['inference_time_ms']:.2f} ms")
            print(f"  Memory Usage: {efficiency['memory_usage_mb']:.2f} MB")
            print(f"  Parameters: {efficiency['num_parameters']:,}")
            print(f"  FLOPs: {efficiency['flops']:,}")
            print(f"  FPS: {efficiency['fps']:.2f}")
        
        return efficiency_results
    
    def run_comprehensive_evaluation(self) -> Dict[str, float]:
        """Run comprehensive evaluation on test set"""
        print("\n" + "="*60)
        print("RUNNING COMPREHENSIVE EVALUATION")
        print("="*60)
        
        # Create dataset
        dataset_config = DatasetConfig()
        train_loader, val_loader, test_loader = create_dataloaders(
            dataset_config, 
            batch_size=self.config.get('batch_size', 16)
        )
        
        # Load best model
        best_model_path = os.path.join(self.output_dir, 'checkpoints', 'msa_unet', 'best_model.pth')
        if os.path.exists(best_model_path):
            model = MSAUNet(in_channels=3, num_classes=5, num_heads=4)
            checkpoint = torch.load(best_model_path, map_location=self.device)
            model.load_state_dict(checkpoint['model_state_dict'])
            model = model.to(self.device)
            
            # Create evaluator
            from metrics import ModelEvaluator
            evaluator = ModelEvaluator(model, self.device, 5)
            
            # Evaluate on test set
            test_metrics = evaluator.evaluate(test_loader)
            
            print("Comprehensive Test Results:")
            print(f"  Mean Dice Score: {test_metrics['mean_dice']:.4f}")
            print(f"  Mean IoU Score: {test_metrics['mean_iou']:.4f}")
            print(f"  Mean Hausdorff Distance: {test_metrics['mean_hausdorff']:.4f}")
            print(f"  Mean Boundary F1: {test_metrics['mean_boundary_f1']:.4f}")
            print(f"  Pixel Accuracy: {test_metrics['pixel_accuracy']:.4f}")
            
            return test_metrics
        else:
            print("Best model not found. Please run baseline experiments first.")
            return {}
    
    def generate_visualizations(self, results: Dict[str, Dict[str, float]]):
        """Generate comprehensive visualizations"""
        print("\n" + "="*60)
        print("GENERATING VISUALIZATIONS")
        print("="*60)
        
        # Create visualizer
        class_names = ['Heart', 'Liver', 'Kidney', 'Lung', 'Brain']
        visualizer = SegmentationVisualizer(5, class_names)
        
        # Plot metrics comparison
        if results:
            comparison_fig = visualizer.plot_metrics_comparison(results)
            comparison_fig.savefig(
                os.path.join(self.output_dir, 'figures', 'metrics_comparison.png'),
                dpi=300, bbox_inches='tight'
            )
            plt.close(comparison_fig)
        
        # Create results summary
        if 'MSA-UNet' in results:
            summary_fig = visualizer.create_results_summary(results['MSA-UNet'])
            summary_fig.savefig(
                os.path.join(self.output_dir, 'figures', 'results_summary.png'),
                dpi=300, bbox_inches='tight'
            )
            plt.close(summary_fig)
        
        print("Visualizations generated successfully!")
    
    def save_results(self, results: Dict):
        """Save all results to JSON"""
        results_file = os.path.join(self.output_dir, 'metrics.json')
        with open(results_file, 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"Results saved to {results_file}")
    
    def run_all_experiments(self) -> Dict:
        """Run all experiments"""
        print("Starting comprehensive experiment suite...")
        
        # Set random seed
        set_seed(42)
        
        # Run baseline experiments
        baseline_results = self.run_baseline_experiments()
        
        # Run ablation studies
        ablation_results = self.run_ablation_studies()
        
        # Run efficiency analysis
        efficiency_results = self.run_efficiency_analysis()
        
        # Run comprehensive evaluation
        test_results = self.run_comprehensive_evaluation()
        
        # Combine all results
        all_results = {
            'baseline_results': baseline_results,
            'ablation_results': ablation_results,
            'efficiency_results': efficiency_results,
            'test_results': test_results,
            'experiment_config': self.config,
            'timestamp': self.timestamp
        }
        
        # Save results
        self.save_results(all_results)
        
        # Generate visualizations
        self.generate_visualizations(baseline_results)
        
        # Print summary
        self.print_summary(all_results)
        
        return all_results
    
    def print_summary(self, results: Dict):
        """Print experiment summary"""
        print("\n" + "="*60)
        print("EXPERIMENT SUMMARY")
        print("="*60)
        
        if 'baseline_results' in results:
            baseline = results['baseline_results']
            print("\nBaseline Results:")
            for model_name, metrics in baseline.items():
                print(f"  {model_name}:")
                print(f"    Dice Score: {metrics['mean_dice']:.4f}")
                print(f"    IoU Score: {metrics['mean_iou']:.4f}")
                print(f"    Hausdorff Distance: {metrics['mean_hausdorff']:.4f}")
        
        if 'efficiency_results' in results:
            efficiency = results['efficiency_results']
            print("\nEfficiency Results:")
            for model_name, metrics in efficiency.items():
                print(f"  {model_name}:")
                print(f"    Inference Time: {metrics['inference_time_ms']:.2f} ms")
                print(f"    Memory Usage: {metrics['memory_usage_mb']:.2f} MB")
                print(f"    Parameters: {metrics['num_parameters']:,}")
        
        print(f"\nAll results saved to: {self.output_dir}")
        print("Experiment completed successfully!")

def main():
    set_reproducibility(42)
    """Main function"""
    parser = argparse.ArgumentParser(description='Run MSA-UNet experiments')
    parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--ablation_epochs', type=int, default=30, help='Epochs for ablation studies')
    parser.add_argument('--quick', action='store_true', help='Run quick experiments (fewer epochs)')
    
    args = parser.parse_args()
    
    # Create experiment config
    config = {
        'num_epochs': args.epochs,
        'batch_size': args.batch_size,
        'learning_rate': args.learning_rate,
        'ablation_epochs': args.ablation_epochs
    }
    
    if args.quick:
        config['num_epochs'] = 10
        config['ablation_epochs'] = 5
        print("Running quick experiments...")
    
    # Create and run experiments
    runner = ExperimentRunner(config)
    results = runner.run_all_experiments()
    
    return results

if __name__ == "__main__":
    # Import matplotlib here to avoid issues
    import matplotlib.pyplot as plt
    
    # Run experiments
    results = main()
    
    print("\nExperiment completed successfully!")
    print("Check the results directory for all outputs.")

