"""
Comprehensive Experiment Configurations and Utilities for HPC

This module provides experiment configurations, management utilities,
and result analysis tools for reproducing the experiments in the HPC paper.
"""

import os
import json
import yaml
import argparse
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import numpy as np
import torch
import subprocess
import glob


@dataclass
class HPCExperimentConfig:
    """Configuration for HPC experiments."""
    
    # Dataset configuration
    dataset: str = "cifar10h"
    data_root: str = "./data"
    human_annotations_path: Optional[str] = None
    
    # Model configuration
    model_name: str = "resnet50"
    model_path: Optional[str] = None
    pretrained: bool = True
    
    # HPC configuration
    hpc_method: str = "empirical"  # empirical, clip, dino, adaptive
    alpha: float = 0.3
    temperature: float = 1.0
    use_adaptive_alpha: bool = False
    gating_strategy: str = "confidence_based"
    
    # Training configuration (if training from scratch)
    epochs: int = 100
    batch_size: int = 128
    learning_rate: float = 0.001
    weight_decay: float = 5e-4
    
    # Evaluation configuration
    eval_batch_size: int = 256
    max_eval_batches: Optional[int] = None
    baseline_methods: List[str] = None
    
    # Output configuration
    results_dir: str = "./results"
    save_plots: bool = True
    save_models: bool = False
    
    # Reproducibility
    seed: int = 42
    device: str = "auto"
    
    def __post_init__(self):
        if self.baseline_methods is None:
            self.baseline_methods = [
                "temperature_scaling", "vector_scaling", 
                "histogram_binning", "isotonic_regression"
            ]


# Predefined experiment configurations from the paper
EXPERIMENT_CONFIGS = {
    "cifar10h_main": HPCExperimentConfig(
        dataset="cifar10h",
        hpc_method="empirical",
        alpha=0.3,
        temperature=1.0,
        results_dir="./results/cifar10h_main"
    ),
    
    "cifar10h_clip": HPCExperimentConfig(
        dataset="cifar10h", 
        hpc_method="clip",
        alpha=0.25,
        temperature=1.0,
        results_dir="./results/cifar10h_clip"
    ),
    
    "cifar10h_adaptive": HPCExperimentConfig(
        dataset="cifar10h",
        hpc_method="adaptive",
        use_adaptive_alpha=True,
        gating_strategy="uncertainty_based",
        results_dir="./results/cifar10h_adaptive"
    ),
    
    "cifar100_proxy": HPCExperimentConfig(
        dataset="cifar100",
        hpc_method="clip",
        alpha=0.2,
        temperature=1.5,
        results_dir="./results/cifar100_proxy"
    ),
    
    "imagenet_scalability": HPCExperimentConfig(
        dataset="imagenet",
        hpc_method="clip",
        alpha=0.15,
        temperature=2.0,
        eval_batch_size=64,
        max_eval_batches=100,  # For computational efficiency
        results_dir="./results/imagenet_scalability"
    ),
    
    "ablation_alpha": HPCExperimentConfig(
        dataset="cifar10h",
        hpc_method="empirical",
        alpha=0.5,  # Will be varied in script
        results_dir="./results/ablation_alpha"
    ),
    
    "robustness_corruption": HPCExperimentConfig(
        dataset="cifar10h",
        hpc_method="empirical",
        alpha=0.3,
        results_dir="./results/robustness_corruption"
    )
}


class ExperimentManager:
    """
    Manages experiment execution, logging, and result collection.
    """
    
    def __init__(self, base_results_dir: str = "./experiments"):
        self.base_results_dir = base_results_dir
        os.makedirs(base_results_dir, exist_ok=True)
        
    def create_experiment_dir(self, experiment_name: str) -> str:
        """Create timestamped experiment directory."""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        exp_dir = os.path.join(self.base_results_dir, f"{experiment_name}_{timestamp}")
        os.makedirs(exp_dir, exist_ok=True)
        return exp_dir
    
    def save_config(self, config: HPCExperimentConfig, exp_dir: str):
        """Save experiment configuration."""
        config_path = os.path.join(exp_dir, "config.json")
        with open(config_path, 'w') as f:
            json.dump(asdict(config), f, indent=2)
        
        # Also save as YAML for human readability
        yaml_path = os.path.join(exp_dir, "config.yaml")
        with open(yaml_path, 'w') as f:
            yaml.dump(asdict(config), f, default_flow_style=False)
    
    def run_experiment(
        self, 
        experiment_name: str,
        config: Optional[HPCExperimentConfig] = None,
        custom_args: Optional[List[str]] = None
    ) -> str:
        """
        Run a single experiment.
        
        Args:
            experiment_name: Name of experiment
            config: Optional custom configuration
            custom_args: Additional command line arguments
            
        Returns:
            Path to experiment results directory
        """
        # Get or create config
        if config is None:
            if experiment_name in EXPERIMENT_CONFIGS:
                config = EXPERIMENT_CONFIGS[experiment_name]
            else:
                config = HPCExperimentConfig()
        
        # Create experiment directory
        exp_dir = self.create_experiment_dir(experiment_name)
        config.results_dir = exp_dir
        
        # Save configuration
        self.save_config(config, exp_dir)
        
        # Build command
        cmd = [
            "python", "evaluate_hpc.py",
            "--dataset", config.dataset,
            "--batch_size", str(config.eval_batch_size),
            "--results_dir", exp_dir,
            "--device", config.device
        ]
        
        # Add optional arguments
        if config.model_path:
            cmd.extend(["--model_path", config.model_path])
        if config.human_annotations_path:
            cmd.extend(["--human_annotations", config.human_annotations_path])
        if config.max_eval_batches:
            cmd.extend(["--max_batches", str(config.max_eval_batches)])
        
        # Add custom arguments
        if custom_args:
            cmd.extend(custom_args)
        
        print(f"Running experiment: {experiment_name}")
        print(f"Command: {' '.join(cmd)}")
        print(f"Results will be saved to: {exp_dir}")
        
        # Execute experiment
        try:
            result = subprocess.run(cmd, capture_output=True, text=True, cwd=".")
            
            # Save logs
            with open(os.path.join(exp_dir, "stdout.log"), 'w') as f:
                f.write(result.stdout)
            with open(os.path.join(exp_dir, "stderr.log"), 'w') as f:
                f.write(result.stderr)
            
            if result.returncode == 0:
                print(f"✓ Experiment {experiment_name} completed successfully")
            else:
                print(f"✗ Experiment {experiment_name} failed with return code {result.returncode}")
                print(f"Error: {result.stderr}")
            
        except Exception as e:
            print(f"✗ Failed to run experiment {experiment_name}: {e}")
        
        return exp_dir
    
    def run_experiment_suite(
        self, 
        suite_name: str = "main_experiments",
        experiments: Optional[List[str]] = None
    ) -> Dict[str, str]:
        """
        Run a suite of experiments.
        
        Args:
            suite_name: Name for the experiment suite
            experiments: List of experiment names to run
            
        Returns:
            Dictionary mapping experiment names to result directories
        """
        if experiments is None:
            # Default main experiments from paper
            experiments = [
                "cifar10h_main", "cifar10h_clip", "cifar10h_adaptive",
                "cifar100_proxy"
            ]
        
        print(f"Running experiment suite: {suite_name}")
        print(f"Experiments: {experiments}")
        
        results = {}
        
        for exp_name in experiments:
            exp_dir = self.run_experiment(exp_name)
            results[exp_name] = exp_dir
            print(f"Completed {exp_name}")
        
        # Create suite summary
        suite_dir = self.create_experiment_dir(f"suite_{suite_name}")
        summary_path = os.path.join(suite_dir, "suite_summary.json")
        
        suite_summary = {
            "suite_name": suite_name,
            "timestamp": datetime.now().isoformat(),
            "experiments": results
        }
        
        with open(summary_path, 'w') as f:
            json.dump(suite_summary, f, indent=2)
        
        print(f"✓ Experiment suite {suite_name} completed")
        print(f"Summary saved to: {summary_path}")
        
        return results


class AblationStudyRunner:
    """
    Specialized runner for ablation studies.
    """
    
    def __init__(self, manager: ExperimentManager):
        self.manager = manager
    
    def run_alpha_ablation(
        self,
        alpha_values: List[float] = None,
        dataset: str = "cifar10h"
    ) -> str:
        """Run ablation study over alpha values."""
        if alpha_values is None:
            alpha_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
        
        print(f"Running alpha ablation study with values: {alpha_values}")
        
        results = {}
        base_config = EXPERIMENT_CONFIGS.get("ablation_alpha", HPCExperimentConfig())
        base_config.dataset = dataset
        
        for alpha in alpha_values:
            exp_name = f"alpha_{alpha:.1f}"
            config = HPCExperimentConfig(**asdict(base_config))
            config.alpha = alpha
            
            exp_dir = self.manager.run_experiment(exp_name, config)
            results[exp_name] = exp_dir
        
        # Create ablation summary
        ablation_dir = self.manager.create_experiment_dir("alpha_ablation_summary")
        summary = {
            "ablation_type": "alpha",
            "dataset": dataset,
            "alpha_values": alpha_values,
            "results": results,
            "timestamp": datetime.now().isoformat()
        }
        
        with open(os.path.join(ablation_dir, "ablation_summary.json"), 'w') as f:
            json.dump(summary, f, indent=2)
        
        return ablation_dir
    
    def run_method_comparison(
        self,
        methods: List[str] = None,
        dataset: str = "cifar10h"
    ) -> str:
        """Compare different HPC methods."""
        if methods is None:
            methods = ["empirical", "clip", "dino", "adaptive"]
        
        print(f"Running method comparison with: {methods}")
        
        results = {}
        base_config = HPCExperimentConfig(dataset=dataset)
        
        for method in methods:
            exp_name = f"method_{method}"
            config = HPCExperimentConfig(**asdict(base_config))
            config.hpc_method = method
            
            if method == "adaptive":
                config.use_adaptive_alpha = True
                config.gating_strategy = "uncertainty_based"
            
            exp_dir = self.manager.run_experiment(exp_name, config)
            results[exp_name] = exp_dir
        
        # Create comparison summary
        comparison_dir = self.manager.create_experiment_dir("method_comparison_summary")
        summary = {
            "comparison_type": "methods",
            "dataset": dataset,
            "methods": methods,
            "results": results,
            "timestamp": datetime.now().isoformat()
        }
        
        with open(os.path.join(comparison_dir, "comparison_summary.json"), 'w') as f:
            json.dump(summary, f, indent=2)
        
        return comparison_dir


class ResultAnalyzer:
    """
    Analyzes and aggregates experiment results.
    """
    
    def __init__(self):
        pass
    
    def collect_results(self, experiment_dirs: List[str]) -> Dict[str, Dict]:
        """Collect results from multiple experiment directories."""
        all_results = {}
        
        for exp_dir in experiment_dirs:
            exp_name = os.path.basename(exp_dir)
            
            # Look for results JSON files
            result_files = glob.glob(os.path.join(exp_dir, "results_*.json"))
            
            if result_files:
                with open(result_files[0], 'r') as f:
                    results = json.load(f)
                all_results[exp_name] = results
            else:
                print(f"No results found in {exp_dir}")
        
        return all_results
    
    def create_comparison_table(
        self,
        results: Dict[str, Dict],
        metrics: List[str] = None,
        output_path: Optional[str] = None
    ) -> str:
        """Create comparison table from results."""
        if metrics is None:
            metrics = ['accuracy', 'ece', 'nll_true', 'nll_human', 'aurc']
        
        # Extract HPC and baseline results
        table_data = {}
        
        for exp_name, exp_results in results.items():
            table_data[exp_name] = {}
            
            # Get HPC results
            if 'hpc_methods' in exp_results:
                for method, method_results in exp_results['hpc_methods'].items():
                    full_name = f"{exp_name}_{method}"
                    table_data[full_name] = {m: method_results.get(m, 'N/A') for m in metrics}
            
            # Get baseline results
            if 'baseline_methods' in exp_results:
                for method, method_results in exp_results['baseline_methods'].items():
                    full_name = f"{exp_name}_{method}"
                    table_data[full_name] = {m: method_results.get(m, 'N/A') for m in metrics}
        
        # Create formatted table
        header = "Method".ljust(30) + "".join([m.upper().ljust(12) for m in metrics])
        table = header + "\n" + "="*100 + "\n"
        
        for method_name, method_results in table_data.items():
            row = method_name.ljust(30)
            for metric in metrics:
                value = method_results[metric]
                if isinstance(value, (int, float)):
                    row += f"{value:.4f}".ljust(12)
                else:
                    row += str(value).ljust(12)
            table += row + "\n"
        
        if output_path:
            with open(output_path, 'w') as f:
                f.write(table)
        
        return table
    
    def plot_ablation_results(
        self,
        ablation_results: Dict[str, Dict],
        parameter_name: str,
        output_dir: str
    ):
        """Plot ablation study results."""
        try:
            import matplotlib.pyplot as plt
            
            # Extract parameter values and metrics
            param_values = []
            metrics_data = {}
            
            for exp_name, results in ablation_results.items():
                # Extract parameter value from experiment name
                if parameter_name in exp_name:
                    param_val = float(exp_name.split(f"{parameter_name}_")[1])
                    param_values.append(param_val)
                    
                    # Extract metrics from HPC results
                    if 'hpc_methods' in results:
                        for method, method_results in results['hpc_methods'].items():
                            method_key = f"HPC_{method}"
                            if method_key not in metrics_data:
                                metrics_data[method_key] = {}
                            
                            for metric, value in method_results.items():
                                if isinstance(value, (int, float)):
                                    if metric not in metrics_data[method_key]:
                                        metrics_data[method_key][metric] = []
                                    metrics_data[method_key][metric].append((param_val, value))
            
            # Sort by parameter value
            for method in metrics_data:
                for metric in metrics_data[method]:
                    metrics_data[method][metric].sort(key=lambda x: x[0])
            
            # Create plots
            key_metrics = ['ece', 'nll_human', 'accuracy']
            
            for metric in key_metrics:
                plt.figure(figsize=(10, 6))
                
                for method in metrics_data:
                    if metric in metrics_data[method]:
                        values = metrics_data[method][metric]
                        x_vals = [v[0] for v in values]
                        y_vals = [v[1] for v in values]
                        plt.plot(x_vals, y_vals, 'o-', label=method, linewidth=2, markersize=6)
                
                plt.xlabel(parameter_name.capitalize())
                plt.ylabel(metric.upper())
                plt.title(f'{metric.upper()} vs {parameter_name.capitalize()}')
                plt.legend()
                plt.grid(True, alpha=0.3)
                
                plot_path = os.path.join(output_dir, f'ablation_{parameter_name}_{metric}.png')
                plt.savefig(plot_path, dpi=300, bbox_inches='tight')
                plt.close()
                
                print(f"Saved plot: {plot_path}")
                
        except ImportError:
            print("Matplotlib not available. Skipping plots.")


def create_experiment_script(experiment_name: str, output_path: str = "run_experiment.py"):
    """Create a standalone script for running specific experiments."""
    
    script_content = f'''#!/usr/bin/env python3
"""
Auto-generated experiment script for {experiment_name}
Generated on: {datetime.now().isoformat()}
"""

import sys
import os
sys.path.append(os.path.dirname(__file__))

from experiment_configs import ExperimentManager, EXPERIMENT_CONFIGS

def main():
    manager = ExperimentManager()
    
    # Run {experiment_name} experiment
    if "{experiment_name}" in EXPERIMENT_CONFIGS:
        result_dir = manager.run_experiment("{experiment_name}")
        print(f"Results saved to: {{result_dir}}")
    else:
        print(f"Unknown experiment: {experiment_name}")
        print(f"Available experiments: {{list(EXPERIMENT_CONFIGS.keys())}}")

if __name__ == "__main__":
    main()
'''
    
    with open(output_path, 'w') as f:
        f.write(script_content)
    
    # Make executable
    os.chmod(output_path, 0o755)
    print(f"Created experiment script: {output_path}")


# Command line interface
def main():
    parser = argparse.ArgumentParser(description='HPC Experiment Management')
    subparsers = parser.add_subparsers(dest='command', help='Available commands')
    
    # List experiments
    list_parser = subparsers.add_parser('list', help='List available experiments')
    
    # Run single experiment
    run_parser = subparsers.add_parser('run', help='Run single experiment')
    run_parser.add_argument('experiment', type=str, help='Experiment name')
    run_parser.add_argument('--results_dir', type=str, default='./experiments',
                           help='Base results directory')
    
    # Run experiment suite
    suite_parser = subparsers.add_parser('suite', help='Run experiment suite')
    suite_parser.add_argument('--name', type=str, default='main_experiments',
                             help='Suite name')
    suite_parser.add_argument('--experiments', nargs='+', 
                             help='List of experiments to run')
    
    # Run ablation study
    ablation_parser = subparsers.add_parser('ablation', help='Run ablation study')
    ablation_parser.add_argument('--type', type=str, choices=['alpha', 'methods'],
                                default='alpha', help='Ablation type')
    ablation_parser.add_argument('--dataset', type=str, default='cifar10h',
                                help='Dataset to use')
    
    # Analyze results
    analyze_parser = subparsers.add_parser('analyze', help='Analyze results')
    analyze_parser.add_argument('--input_dirs', nargs='+', required=True,
                               help='Experiment directories to analyze')
    analyze_parser.add_argument('--output_dir', type=str, default='./analysis',
                               help='Output directory for analysis')
    
    # Create script
    script_parser = subparsers.add_parser('create-script', help='Create experiment script')
    script_parser.add_argument('experiment', type=str, help='Experiment name')
    script_parser.add_argument('--output', type=str, default='run_experiment.py',
                              help='Output script path')
    
    args = parser.parse_args()
    
    if args.command == 'list':
        print("Available experiments:")
        for name, config in EXPERIMENT_CONFIGS.items():
            print(f"  {name}: {config.dataset} dataset, {config.hpc_method} method")
    
    elif args.command == 'run':
        manager = ExperimentManager(args.results_dir)
        result_dir = manager.run_experiment(args.experiment)
        print(f"Results: {result_dir}")
    
    elif args.command == 'suite':
        manager = ExperimentManager()
        results = manager.run_experiment_suite(args.name, args.experiments)
        print(f"Suite results: {results}")
    
    elif args.command == 'ablation':
        manager = ExperimentManager()
        ablation_runner = AblationStudyRunner(manager)
        
        if args.type == 'alpha':
            result_dir = ablation_runner.run_alpha_ablation(dataset=args.dataset)
        elif args.type == 'methods':
            result_dir = ablation_runner.run_method_comparison(dataset=args.dataset)
        
        print(f"Ablation results: {result_dir}")
    
    elif args.command == 'analyze':
        analyzer = ResultAnalyzer()
        results = analyzer.collect_results(args.input_dirs)
        
        os.makedirs(args.output_dir, exist_ok=True)
        
        # Create comparison table
        table = analyzer.create_comparison_table(
            results, output_path=os.path.join(args.output_dir, 'comparison_table.txt')
        )
        print("Comparison table:")
        print(table)
        
        # Plot ablation results if applicable
        analyzer.plot_ablation_results(results, 'alpha', args.output_dir)
    
    elif args.command == 'create-script':
        create_experiment_script(args.experiment, args.output)
    
    else:
        parser.print_help()


if __name__ == "__main__":
    main()
