"""
Main script to run all DataOpt experiments.
Provides unified interface to execute individual experiments or complete suite.
"""

import os
import sys
import argparse
import subprocess
import json
import logging
from datetime import datetime
from typing import List, Dict, Any

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def run_experiment(experiment_name: str, args: List[str] = None) -> int:
    """Run a specific experiment"""
    
    experiment_files = {
        'exp1': 'experiments/exp1_sota_enhancement.py',
        'exp2': 'experiments/exp2_llm_unlearning.py',
        'exp3': 'experiments/exp3_retain_composition.py',
        'exp4': 'experiments/exp4_controllability.py',
        'exp5': 'experiments/exp5_delete_comparison.py'
    }
    
    if experiment_name not in experiment_files:
        logger.error(f"Unknown experiment: {experiment_name}")
        return 1
    
    script_path = experiment_files[experiment_name]
    
    if not os.path.exists(script_path):
        logger.error(f"Experiment script not found: {script_path}")
        return 1
    
    # Build command
    cmd = [sys.executable, script_path]
    if args:
        cmd.extend(args)
    
    logger.info(f"Running experiment: {experiment_name}")
    logger.info(f"Command: {' '.join(cmd)}")
    
    # Run experiment
    try:
        result = subprocess.run(cmd, capture_output=False, text=True)
        return result.returncode
    except Exception as e:
        logger.error(f"Error running experiment {experiment_name}: {e}")
        return 1


def run_all_experiments(output_dir: str = "results", device: str = "cuda") -> Dict[str, int]:
    """Run all experiments in sequence"""
    
    logger.info("Starting complete experimental suite...")
    
    results = {}
    
    # Experiment 1: SOTA Enhancement
    logger.info("=" * 60)
    logger.info("EXPERIMENT 1: SOTA Enhancement")
    logger.info("=" * 60)
    
    exp1_args = [
        '--dataset', 'both',
        '--baselines', 'NEGGRAD', 'SCRUB', 'BadTeacher', 'SalUn',
        '--device', device,
        '--output_dir', output_dir
    ]
    
    results['exp1'] = run_experiment('exp1', exp1_args)
    
    # Experiment 2: LLM Unlearning
    logger.info("=" * 60)
    logger.info("EXPERIMENT 2: LLM Unlearning")
    logger.info("=" * 60)
    
    exp2_args = [
        '--models', 'llama-3-8b', 'phi-3',
        '--forget_ratios', '0.01', '0.05', '0.10',
        '--baselines', 'GA', 'NPO', 'ICU', 'DataOpt',
        '--device', device,
        '--output_dir', output_dir
    ]
    
    results['exp2'] = run_experiment('exp2', exp2_args)
    
    # Experiment 3: Retain Set Composition
    logger.info("=" * 60)
    logger.info("EXPERIMENT 3: Retain Set Composition")
    logger.info("=" * 60)
    
    exp3_args = [
        '--forget_class', '0',
        '--num_retain_samples', '200',
        '--strategies', 'Random', 'Neighborhood', 'Boundary', 'DataOpt',
        '--device', device,
        '--output_dir', output_dir
    ]
    
    results['exp3'] = run_experiment('exp3', exp3_args)
    
    # Experiment 4: Controllability
    logger.info("=" * 60)
    logger.info("EXPERIMENT 4: Unlearning Controllability")
    logger.info("=" * 60)
    
    exp4_args = [
        '--forget_class', '0',
        '--k_values', '1', '3', '5', '7', '9',
        '--device', device,
        '--output_dir', output_dir,
        '--runs', '3'
    ]
    
    results['exp4'] = run_experiment('exp4', exp4_args)
    
    # Experiment 5: DELETE Comparison
    logger.info("=" * 60)
    logger.info("EXPERIMENT 5: DELETE Comparison")
    logger.info("=" * 60)
    
    exp5_args = [
        '--forget_class', '0',
        '--device', device,
        '--output_dir', output_dir,
        '--runs', '5'
    ]
    
    results['exp5'] = run_experiment('exp5', exp5_args)
    
    return results


def create_experiment_summary(results: Dict[str, int], output_dir: str):
    """Create summary of all experiment results"""
    
    summary = {
        'timestamp': datetime.now().isoformat(),
        'experiments': {},
        'overall_status': 'success' if all(code == 0 for code in results.values()) else 'partial_failure'
    }
    
    experiment_descriptions = {
        'exp1': 'SOTA Enhancement (CIFAR-100, Tiny-ImageNet)',
        'exp2': 'LLM Unlearning (TOFU benchmark)',
        'exp3': 'Retain Set Composition Analysis (CIFAR-10)',
        'exp4': 'Unlearning Controllability Analysis (CIFAR-10)',
        'exp5': 'DELETE Framework Comparison (CIFAR-10)'
    }
    
    for exp_name, return_code in results.items():
        summary['experiments'][exp_name] = {
            'description': experiment_descriptions.get(exp_name, 'Unknown experiment'),
            'status': 'success' if return_code == 0 else 'failed',
            'return_code': return_code
        }
    
    # Save summary
    summary_file = os.path.join(output_dir, 'experiment_suite_summary.json')
    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2)
    
    logger.info(f"Experiment suite summary saved to {summary_file}")
    
    # Print summary
    print("\n" + "="*70)
    print("DATAOPT EXPERIMENT SUITE SUMMARY")
    print("="*70)
    print(f"Timestamp: {summary['timestamp']}")
    print(f"Overall Status: {summary['overall_status'].upper()}")
    print("-" * 70)
    
    for exp_name, info in summary['experiments'].items():
        status_symbol = "✓" if info['status'] == 'success' else "✗"
        print(f"{status_symbol} {exp_name.upper()}: {info['description']}")
        if info['status'] == 'failed':
            print(f"   Return code: {info['return_code']}")
    
    print("\nResults saved to:", output_dir)
    
    return summary


def main():
    parser = argparse.ArgumentParser(
        description='DataOpt Experiment Runner',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Run all experiments
  python run_experiments.py --all
  
  # Run specific experiment
  python run_experiments.py --experiment exp1
  
  # Run experiment with custom args
  python run_experiments.py --experiment exp1 --args "--dataset cifar100 --device cpu"
  
  # Run multiple specific experiments
  python run_experiments.py --experiments exp1 exp3 exp5
        """
    )
    
    parser.add_argument('--all', action='store_true',
                       help='Run all experiments in sequence')
    parser.add_argument('--experiment', choices=['exp1', 'exp2', 'exp3', 'exp4', 'exp5'],
                       help='Run a specific experiment')
    parser.add_argument('--experiments', nargs='+', 
                       choices=['exp1', 'exp2', 'exp3', 'exp4', 'exp5'],
                       help='Run multiple specific experiments')
    parser.add_argument('--args', type=str, default='',
                       help='Additional arguments to pass to experiment (for single experiment)')
    parser.add_argument('--device', default='cuda',
                       help='Device to use for all experiments')
    parser.add_argument('--output_dir', default='results',
                       help='Output directory for all results')
    parser.add_argument('--list', action='store_true',
                       help='List available experiments and exit')
    
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # List experiments
    if args.list:
        print("Available experiments:")
        print("  exp1 - SOTA Enhancement (CIFAR-100, Tiny-ImageNet)")
        print("  exp2 - LLM Unlearning (TOFU benchmark)")
        print("  exp3 - Retain Set Composition Analysis (CIFAR-10)")
        print("  exp4 - Unlearning Controllability Analysis (CIFAR-10)")
        print("  exp5 - DELETE Framework Comparison (CIFAR-10)")
        return
    
    results = {}
    
    if args.all:
        # Run all experiments
        results = run_all_experiments(args.output_dir, args.device)
        
    elif args.experiment:
        # Run single experiment
        exp_args = args.args.split() if args.args else []
        exp_args.extend(['--device', args.device, '--output_dir', args.output_dir])
        results[args.experiment] = run_experiment(args.experiment, exp_args)
        
    elif args.experiments:
        # Run multiple experiments
        for exp in args.experiments:
            exp_args = ['--device', args.device, '--output_dir', args.output_dir]
            results[exp] = run_experiment(exp, exp_args)
            
    else:
        parser.print_help()
        return
    
    # Create summary
    if results:
        create_experiment_summary(results, args.output_dir)


if __name__ == "__main__":
    main()