#!/usr/bin/env python3
"""
Usage Examples:
    
    # Run specific experiment
    python run_experiments.py --experiment tofu --model llama2-7b --forget_ratio 0.05
    
    # Run with custom config
    python run_experiments.py --config my_config.yaml
"""

import os
import sys
import argparse
import yaml
import logging
from pathlib import Path
from typing import Dict, List, Optional, Any
import subprocess
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import json

# Add src to path for imports
sys.path.append(str(Path(__file__).parent / "src"))
sys.path.append(str(Path(__file__).parent))


class ExperimentManager:
    """Manages and orchestrates OFMU experiments."""
    
    def __init__(self, config_path: Optional[str] = None):
        self.setup_logging()
        self.load_config(config_path)
        self.results_dir = Path("results")
        self.results_dir.mkdir(exist_ok=True)
    
    def setup_logging(self):
        """Setup logging configuration."""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler('experiments.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
    
    def load_config(self, config_path: Optional[str] = None):
        """Load experiment configuration."""
        default_config = {
            "models": {
                "llama2-7b": {
                    "name": "xyz",
                    "max_length": 2048,
                    "batch_size": 4                }
                # list all other models
            },
            "experiments": {
                "tofu": {
                    "script": "experiments/tofu_experiments.py",
                    "forget_ratios": [0.01, 0.05, 0.10],
                    "methods": ["ofmu", "grad_ascent", "grad_diff", "npo", "rmu"],
                    "eval_metrics": ["forget_quality", "model_utility", "truth_ratio"]
                },
                "cifar": {
                    "script": "experiments/cifar_experiments.py", 
                    "datasets": ["cifar10", "cifar100"],
                    "forget_classes": [1, 5, 10],
                    "methods": ["ofmu", "finetune", "retrain"],
                    "eval_metrics": ["unlearn_acc", "retain_acc", "test_acc", "mia_auc"]
                },
                "wmdp": {
                    "script": "experiments/wmdp_experiments.py",
                    "domains": ["bio", "cyber", "chem"],
                    "methods": ["ofmu", "finetune"],
                    "eval_metrics": ["qa_accuracy", "retain_performance", "safety_metrics"]
                }
            },
            "training": {
                "num_epochs": 10,
                "learning_rate": 1e-5,
                "warmup_steps": 100,
                "eval_steps": 500,
                "save_steps": 1000,
                "gradient_accumulation_steps": 4
            },
            "ofmu": {
                "beta": 0.1,
                "penalty_coeff": 1.0,
                "inner_steps": 5,
                "outer_steps": 3,
                "similarity_threshold": 0.7
            },
            "evaluation": {
                "batch_size": 16,
                "max_eval_samples": 1000,
                "statistical_tests": True
            },
            "hardware": {
                "device": "auto",
                "num_gpus": 1,
                "mixed_precision": True,
                "gradient_checkpointing": True
            }
        }
        
        if config_path and Path(config_path).exists():
            with open(config_path, 'r') as f:
                user_config = yaml.safe_load(f)
            # Merge configs (user config overrides default)
            self.config = self.merge_configs(default_config, user_config)
        else:
            self.config = default_config
    
    def merge_configs(self, default: Dict, user: Dict) -> Dict:
        """Recursively merge user config with default config."""
        merged = default.copy()
        for key, value in user.items():
            if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
                merged[key] = self.merge_configs(merged[key], value)
            else:
                merged[key] = value
        return merged
    
    def run_single_experiment(self, experiment_type: str, **kwargs) -> Dict[str, Any]:
        """Run a single experiment with given parameters."""
        if experiment_type not in self.config["experiments"]:
            raise ValueError(f"Unknown experiment type: {experiment_type}")
        
        exp_config = self.config["experiments"][experiment_type]
        script_path = Path(exp_config["script"])
        
        if not script_path.exists():
            raise FileNotFoundError(f"Experiment script not found: {script_path}")
        
        # Build command
        cmd = [sys.executable, str(script_path)]
        
        # Add experiment-specific arguments
        if experiment_type == "tofu":
            cmd.extend([
                "--model_name", kwargs.get("model", "llama2-7b"),
                "--forget_ratio", str(kwargs.get("forget_ratio", 0.05)),
                "--method", kwargs.get("method", "ofmu"),
                "--num_epochs", str(kwargs.get("num_epochs", self.config["training"]["num_epochs"])),
                "--beta", str(kwargs.get("beta", self.config["ofmu"]["beta"]))
            ])
        elif experiment_type == "cifar":
            cmd.extend([
                "--dataset", kwargs.get("dataset", "cifar10"),
                "--forget_classes", str(kwargs.get("forget_classes", 1)),
                "--method", kwargs.get("method", "ofmu"),
                "--num_epochs", str(kwargs.get("num_epochs", self.config["training"]["num_epochs"])),
                "--beta", str(kwargs.get("beta", self.config["ofmu"]["beta"]))
            ])
        elif experiment_type == "wmdp":
            cmd.extend([
                "--domain", kwargs.get("domain", "bio"),
                "--method", kwargs.get("method", "ofmu"),
                "--model_name", kwargs.get("model", "llama2-7b"),
                "--num_epochs", str(kwargs.get("num_epochs", self.config["training"]["num_epochs"]))
            ])
        
        # Add common arguments
        if kwargs.get("quick_test", False):
            cmd.extend(["--quick_test"])
        
        if kwargs.get("output_dir"):
            cmd.extend(["--output_dir", kwargs["output_dir"]])
        
        # Run experiment
        self.logger.info(f"Running command: {' '.join(cmd)}")
        
        start_time = time.time()
        try:
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=kwargs.get("timeout", 3600)  # 1 hour default timeout
            )
            
            execution_time = time.time() - start_time
            
            return {
                "experiment": experiment_type,
                "parameters": kwargs,
                "command": " ".join(cmd),
                "returncode": result.returncode,
                "stdout": result.stdout,
                "stderr": result.stderr,
                "execution_time": execution_time,
                "success": result.returncode == 0
            }
            
        except subprocess.TimeoutExpired:
            execution_time = time.time() - start_time
            return {
                "experiment": experiment_type,
                "parameters": kwargs,
                "command": " ".join(cmd),
                "returncode": -1,
                "stdout": "",
                "stderr": "Experiment timed out",
                "execution_time": execution_time,
                "success": False
            }
        except Exception as e:
            execution_time = time.time() - start_time
            return {
                "experiment": experiment_type,
                "parameters": kwargs,
                "command": " ".join(cmd),
                "returncode": -1,
                "stdout": "",
                "stderr": str(e),
                "execution_time": execution_time,
                "success": False
            }
    
    def run_experiment_sweep(self, experiment_type: str, parallel: bool = False, max_workers: int = 4) -> List[Dict[str, Any]]:
        """Run a sweep of experiments with different parameters."""
        if experiment_type not in self.config["experiments"]:
            raise ValueError(f"Unknown experiment type: {experiment_type}")
        
        exp_config = self.config["experiments"][experiment_type]
        experiments_to_run = []
        
        # Generate parameter combinations
        if experiment_type == "tofu":
            for model in self.config["models"].keys():
                for forget_ratio in exp_config["forget_ratios"]:
                    for method in exp_config["methods"]:
                        experiments_to_run.append({
                            "experiment_type": experiment_type,
                            "model": model,
                            "forget_ratio": forget_ratio,
                            "method": method,
                            "output_dir": f"results/{experiment_type}_{model}_{forget_ratio}_{method}"
                        })
        
        elif experiment_type == "cifar":
            for dataset in exp_config["datasets"]:
                for forget_classes in exp_config["forget_classes"]:
                    for method in exp_config["methods"]:
                        experiments_to_run.append({
                            "experiment_type": experiment_type,
                            "dataset": dataset,
                            "forget_classes": forget_classes,
                            "method": method,
                            "output_dir": f"results/{experiment_type}_{dataset}_{forget_classes}_{method}"
                        })
        
        elif experiment_type == "wmdp":
            for domain in exp_config["domains"]:
                for method in exp_config["methods"]:
                    for model in self.config["models"].keys():
                        experiments_to_run.append({
                            "experiment_type": experiment_type,
                            "domain": domain,
                            "method": method,
                            "model": model,
                            "output_dir": f"results/{experiment_type}_{domain}_{method}_{model}"
                        })
        
        self.logger.info(f"Running {len(experiments_to_run)} experiments for {experiment_type}")
        
        results = []
        
        if parallel and len(experiments_to_run) > 1:
            # Run experiments in parallel
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                future_to_exp = {
                    executor.submit(self.run_single_experiment, **exp): exp 
                    for exp in experiments_to_run
                }
                
                for future in as_completed(future_to_exp):
                    exp = future_to_exp[future]
                    try:
                        result = future.result()
                        results.append(result)
                        
                        if result["success"]:
                            self.logger.info(f"✓ Completed {exp}")
                        else:
                            self.logger.error(f"✗ Failed {exp}: {result['stderr']}")
                    
                    except Exception as e:
                        self.logger.error(f"✗ Exception in {exp}: {e}")
                        results.append({
                            "experiment": experiment_type,
                            "parameters": exp,
                            "success": False,
                            "error": str(e)
                        })
        else:
            # Run experiments sequentially
            for i, exp in enumerate(experiments_to_run):
                self.logger.info(f"Running experiment {i+1}/{len(experiments_to_run)}: {exp}")
                result = self.run_single_experiment(**exp)
                results.append(result)
                
                if result["success"]:
                    self.logger.info(f"✓ Completed experiment {i+1}")
                else:
                    self.logger.error(f"✗ Failed experiment {i+1}: {result['stderr']}")
        
        return results
    
    def save_results(self, results: Dict[str, Any], output_file: str = None):
        """Save experiment results to file."""
        if output_file is None:
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            output_file = f"results/experiment_results_{timestamp}.json"
        
        output_path = Path(output_file)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        with open(output_path, 'w') as f:
            json.dump(results, f, indent=2, default=str)
        
        self.logger.info(f"Results saved to {output_path}")
    
    def print_summary(self, results: Dict[str, List[Dict[str, Any]]]):
        """Print summary of experiment results."""
        print("\n" + "="*80)
        print("EXPERIMENT RESULTS SUMMARY")
        print("="*80)
        
        total_experiments = 0
        total_success = 0
        total_time = 0
        
        for exp_type, exp_results in results.items():
            print(f"\n{exp_type.upper()} Experiments:")
            print("-" * 40)
            
            success_count = sum(1 for r in exp_results if r.get("success", False))
            total_exp_time = sum(r.get("execution_time", 0) for r in exp_results)
            
            print(f"  Total: {len(exp_results)}")
            print(f"  Success: {success_count}")
            print(f"  Failed: {len(exp_results) - success_count}")
            print(f"  Total Time: {total_exp_time:.1f}s ({total_exp_time/60:.1f}m)")
            
            # Show failed experiments
            failed_experiments = [r for r in exp_results if not r.get("success", False)]
            if failed_experiments:
                print(f"  Failed experiments:")
                for result in failed_experiments:
                    params = result.get("parameters", {})
                    error = result.get("stderr", "Unknown error")[:100]
                    print(f"    - {params}: {error}")
            
            total_experiments += len(exp_results)
            total_success += success_count
            total_time += total_exp_time
        
        print(f"\nOVERALL SUMMARY:")
        print(f"  Total Experiments: {total_experiments}")
        print(f"  Overall Success Rate: {total_success/total_experiments*100:.1f}%")
        print(f"  Total Execution Time: {total_time:.1f}s ({total_time/60:.1f}m)")
        print("="*80)


def main():
    parser = argparse.ArgumentParser(description="Run OFMU experiments")
    
    # Experiment selection
    parser.add_argument("--experiment", choices=["tofu", "cifar", "wmdp"], 
                       help="Run specific experiment type")
    parser.add_argument("--all", action="store_true", 
                       help="Run all experiments")
    
    # Configuration
    parser.add_argument("--config", type=str, 
                       help="Path to custom config file")
    
    # Experiment parameters
    parser.add_argument("--model", type=str, default="llama2-7b",
                       help="Model to use")
    parser.add_argument("--method", type=str, default="ofmu",
                       help="Unlearning method")
    parser.add_argument("--forget_ratio", type=float, default=0.05,
                       help="Forget ratio for TOFU experiments")
    parser.add_argument("--dataset", type=str, default="cifar10",
                       help="Dataset for CIFAR experiments")
    parser.add_argument("--domain", type=str, default="bio",
                       help="Domain for WMDP experiments")
    parser.add_argument("--forget_classes", type=int, default=1,
                       help="Number of classes to forget for CIFAR")
    
    # Execution options
    parser.add_argument("--parallel", action="store_true",
                       help="Run experiments in parallel")
    parser.add_argument("--max_workers", type=int, default=4,
                       help="Maximum parallel workers")
    parser.add_argument("--quick_test", action="store_true",
                       help="Run quick test version")
    parser.add_argument("--output_dir", type=str,
                       help="Output directory for results")
    
    # Analysis
    parser.add_argument("--analyze_results", type=str,
                       help="Analyze results from given JSON file")
    
    args = parser.parse_args()
    
    # Initialize manager
    manager = ExperimentManager(args.config)
    
    # Analyze existing results
    if args.analyze_results:
        if not Path(args.analyze_results).exists():
            print(f"Results file not found: {args.analyze_results}")
            return 1
        
        with open(args.analyze_results, 'r') as f:
            results = json.load(f)
        
        manager.print_summary(results)
        return 0
    
    # Run experiments
    try:
        if args.all:
            # Run all experiments
            results = manager.run_all_experiments(args.parallel, args.max_workers)
            
        elif args.experiment:
            # Run specific experiment type
            if args.experiment == "tofu":
                kwargs = {
                    "model": args.model,
                    "forget_ratio": args.forget_ratio,
                    "method": args.method,
                    "quick_test": args.quick_test,
                    "output_dir": args.output_dir
                }
            elif args.experiment == "cifar":
                kwargs = {
                    "dataset": args.dataset,
                    "forget_classes": args.forget_classes,
                    "method": args.method,
                    "quick_test": args.quick_test,
                    "output_dir": args.output_dir
                }
            elif args.experiment == "wmdp":
                kwargs = {
                    "domain": args.domain,
                    "method": args.method,
                    "model": args.model,
                    "quick_test": args.quick_test,
                    "output_dir": args.output_dir
                }
            
            # Run single experiment or sweep
            if any([args.model != "llama2-7b", args.method != "ofmu", 
                   args.forget_ratio != 0.05, args.dataset != "cifar10",
                   args.domain != "bio", args.forget_classes != 1]):
                # Single experiment with custom parameters
                result = manager.run_single_experiment(args.experiment, **kwargs)
                results = {args.experiment: [result]}
            else:
                # Full sweep for this experiment type
                sweep_results = manager.run_experiment_sweep(args.experiment, args.parallel, args.max_workers)
                results = {args.experiment: sweep_results}
        
        else:
            print("Please specify --all, --experiment, or --analyze_results")
            return 1
        
        # Save and summarize results
        manager.save_results(results)
        manager.print_summary(results)
        
        # Check if any experiments failed
        all_success = all(
            all(r.get("success", False) for r in exp_results)
            for exp_results in results.values()
        )
        
        return 0 if all_success else 1
    
    except KeyboardInterrupt:
        print("\nExperiments interrupted by user")
        return 1
    except Exception as e:
        print(f"Experiment execution failed: {e}")
        return 1


if __name__ == "__main__":
    exit(main())