#!/usr/bin/env python3
"""

Usage:
    
    # Run specific experiment
    python main_experiments.py --experiment tofu --model llama2-7b --scenario forget05
    
    # Run with custom settings
    python main_experiments.py --experiment cifar --dataset CIFAR10 --forget_classes 1
    
    # Quick test run
    python main_experiments.py --experiment tofu --quick_test
"""

import os
import sys
import argparse
import json
import logging
import yaml
import time
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple
from datetime import datetime
import traceback

import torch
import numpy as np
from transformers import set_seed

# Add project paths
PROJECT_ROOT = Path(__file__).parent
sys.path.append(str(PROJECT_ROOT))
sys.path.append(str(PROJECT_ROOT / "utils"))


class OFMUExperimentRunner:
    """Main experiment runner for OFMU paper experiments."""
    
    def __init__(self, config_path: str, output_dir: Optional[str] = None):
        self.config_path = config_path
        self.load_config(config_path)
        
        # Setup directories
        self.output_dir = Path(output_dir) if output_dir else Path(self.config["experiment"]["output_dir"])
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Setup logging
        self.setup_logging()
        
        # Setup reproducibility
        self.setup_reproducibility()
        
        # Initialize components
        self.model_loader = None
        self.data_loader = None
        
        self.logger.info(f"OFMU Experiment Runner initialized")
        self.logger.info(f"Output directory: {self.output_dir.absolute()}")
    
    def load_config(self, config_path: str):
        """Load experiment configuration."""
        with open(config_path, 'r') as f:
            self.config = yaml.safe_load(f)
        
        # Set default values
        self.config.setdefault("experiment", {})
        self.config.setdefault("training", {})
        self.config.setdefault("ofmu", {})
        self.config.setdefault("evaluation", {})
    
    def setup_logging(self):
        """Setup logging configuration."""
        log_level = getattr(logging, self.config.get("logging", {}).get("console_level", "INFO"))
        
        # Create formatters
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        
        # Setup root logger
        logging.basicConfig(
            level=log_level,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.StreamHandler()
            ]
        )
        
        self.logger = logging.getLogger(__name__)
        
        # Add file handler if enabled
        if self.config.get("logging", {}).get("file_logging", True):
            log_file = self.output_dir / self.config.get("logging", {}).get("log_file", "experiment.log")
            file_handler = logging.FileHandler(log_file)
            file_handler.setFormatter(formatter)
            file_handler.setLevel(getattr(logging, self.config.get("logging", {}).get("file_level", "DEBUG")))
            self.logger.addHandler(file_handler)
    
    def setup_reproducibility(self):
        """Setup reproducibility settings."""
        seed = self.config.get("reproducibility", {}).get("seed", 42)
        set_seed(seed)
        
        if self.config.get("reproducibility", {}).get("deterministic_algorithms", True):
            torch.use_deterministic_algorithms(True, warn_only=True)
        
        if not self.config.get("reproducibility", {}).get("benchmark_cudnn", False):
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    
    def initialize_components(self):
        """Initialize model and data loaders."""
        try:
            from utils.model_loader import ModelLoader
            from utils.data_loader import DataLoaderManager
            
            self.model_loader = ModelLoader()
            self.data_loader = DataLoaderManager()
            
            self.logger.info("Model and data loaders initialized successfully")
            
        except ImportError as e:
            self.logger.error(f"Failed to import utilities: {e}")
            self.logger.info("Creating minimal fallback components")
            self.create_fallback_components()
    
    def create_fallback_components(self):
        """Create minimal fallback components when imports fail."""
        class FallbackModelLoader:
            def load_language_model(self, model_key, **kwargs):
                from transformers import AutoTokenizer, AutoModelForCausalLM
                
                model_map = {
                    "llama2-7b": "meta-llama/Llama-2-7b-chat-hf",
                    "llama3-8b": "meta-llama/Meta-Llama-3-8B-Instruct"
                }
                
                model_name = model_map.get(model_key, "microsoft/Phi-3.5-mini-instruct")
                
                tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
                if tokenizer.pad_token is None:
                    tokenizer.pad_token = tokenizer.eos_token
                
                model = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    trust_remote_code=True
                )
                
                return model, tokenizer
        
        class FallbackDataLoader:
            def __init__(self):
                self.data_root = Path("./data")
            
            def load_tofu_dataset(self, scenario="forget05", **kwargs):
                # Return dummy dataset for testing
                return self.create_dummy_text_dataset()
            
            def create_dummy_text_dataset(self):
                class DummyDataset:
                    def __init__(self):
                        self.data = [
                            {"question": f"What is question {i}?", "answer": f"Answer {i}", "split": "forget" if i < 5 else "retain"}
                            for i in range(100)
                        ]
                    
                    def __len__(self):
                        return len(self.data)
                    
                    def __getitem__(self, idx):
                        return self.data[idx]
                
                return DummyDataset()
        
        self.model_loader = FallbackModelLoader()
        self.data_loader = FallbackDataLoader()
    
    def run_tofu_experiment(
        self,
        model_key: str,
        scenario: str,
        method: str,
        quick_test: bool = False
    ) -> Dict[str, Any]:
        """Run TOFU experiment with specified parameters."""
        self.logger.info(f"Running TOFU experiment: {model_key}, {scenario}, {method}")
        
        start_time = time.time()
        
        try:
            # Load model and tokenizer
            model, tokenizer = self.model_loader.load_language_model(model_key)
            
            # Load dataset
            dataset = self.data_loader.load_tofu_dataset(scenario, tokenizer=tokenizer)
            
            # Setup training parameters
            training_args = self.get_training_args(quick_test)
            training_args.update(self.config.get("ofmu", {}))
            
            # Run unlearning method
            results = self.run_unlearning_method(
                model=model,
                tokenizer=tokenizer,
                dataset=dataset,
                method=method,
                training_args=training_args
            )
            
            # Run evaluation
            eval_results = self.evaluate_tofu_model(model, tokenizer, dataset, scenario)
            results.update(eval_results)
            
            execution_time = time.time() - start_time
            
            return {
                "experiment": "tofu",
                "model": model_key,
                "scenario": scenario,
                "method": method,
                "results": results,
                "execution_time": execution_time,
                "success": True
            }
            
        except Exception as e:
            self.logger.error(f"TOFU experiment failed: {e}")
            self.logger.debug(traceback.format_exc())
            
            return {
                "experiment": "tofu",
                "model": model_key,
                "scenario": scenario,
                "method": method,
                "error": str(e),
                "execution_time": time.time() - start_time,
                "success": False
            }
    
    def run_wmdp_experiment(
        self,
        model_key: str,
        domain: str,
        method: str,
        quick_test: bool = False
    ) -> Dict[str, Any]:
        """Run WMDP experiment with specified parameters."""
        self.logger.info(f"Running WMDP experiment: {model_key}, {domain}, {method}")
        
        start_time = time.time()
        
        try:
            # Load model and tokenizer
            model, tokenizer = self.model_loader.load_language_model(model_key)
            
            # Load dataset
            dataset = self.data_loader.load_wmdp_dataset(domain, tokenizer=tokenizer)
            
            # Setup training parameters
            training_args = self.get_training_args(quick_test)
            training_args.update(self.config.get("ofmu", {}))
            
            # Run unlearning method
            results = self.run_unlearning_method(
                model=model,
                tokenizer=tokenizer,
                dataset=dataset,
                method=method,
                training_args=training_args
            )
            
            # Run evaluation
            eval_results = self.evaluate_wmdp_model(model, tokenizer, dataset, domain)
            results.update(eval_results)
            
            execution_time = time.time() - start_time
            
            return {
                "experiment": "wmdp",
                "model": model_key,
                "domain": domain,
                "method": method,
                "results": results,
                "execution_time": execution_time,
                "success": True
            }
            
        except Exception as e:
            self.logger.error(f"WMDP experiment failed: {e}")
            self.logger.debug(traceback.format_exc())
            
            return {
                "experiment": "wmdp",
                "model": model_key,
                "domain": domain,
                "method": method,
                "error": str(e),
                "execution_time": time.time() - start_time,
                "success": False
            }
    
    def run_cifar_experiment(
        self,
        dataset_name: str,
        forget_classes: int,
        method: str,
        quick_test: bool = False
    ) -> Dict[str, Any]:
        """Run CIFAR experiment with specified parameters."""
        self.logger.info(f"Running CIFAR experiment: {dataset_name}, forget_classes={forget_classes}, {method}")
        
        start_time = time.time()
        
        try:
            # Load vision model
            model = self.model_loader.load_vision_model("resnet18", num_classes=10 if dataset_name == "CIFAR10" else 100)
            
            # Load dataset
            dataset = self.data_loader.load_cifar_dataset(
                dataset_name,
                forget_classes=list(range(forget_classes))
            )
            
            # Setup training parameters
            training_args = self.get_training_args(quick_test)
            training_args.update(self.config.get("ofmu", {}))
            
            # Run unlearning method (vision-specific)
            results = self.run_vision_unlearning_method(
                model=model,
                dataset=dataset,
                method=method,
                training_args=training_args
            )
            
            # Run evaluation
            eval_results = self.evaluate_cifar_model(model, dataset, dataset_name)
            results.update(eval_results)
            
            execution_time = time.time() - start_time
            
            return {
                "experiment": "cifar",
                "dataset": dataset_name,
                "forget_classes": forget_classes,
                "method": method,
                "results": results,
                "execution_time": execution_time,
                "success": True
            }
            
        except Exception as e:
            self.logger.error(f"CIFAR experiment failed: {e}")
            self.logger.debug(traceback.format_exc())
            
            return {
                "experiment": "cifar",
                "dataset": dataset_name,
                "forget_classes": forget_classes,
                "method": method,
                "error": str(e),
                "execution_time": time.time() - start_time,
                "success": False
            }
    
    def get_training_args(self, quick_test: bool = False) -> Dict[str, Any]:
        """Get training arguments based on configuration."""
        args = self.config.get("training", {}).copy()
        
        if quick_test:
            # Reduce parameters for quick testing
            args["num_epochs"] = 1
            args["eval_steps"] = 10
            args["logging_steps"] = 5
        
        return args
    
    def run_unlearning_method(
        self,
        model,
        tokenizer,
        dataset,
        method: str,
        training_args: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Run the specified unlearning method."""
        self.logger.info(f"Running unlearning method: {method}")
        
        # Mock implementation - replace with actual trainer calls
        results = {
            "method": method,
            "training_loss": np.random.random(),
            "validation_loss": np.random.random(),
            "convergence_steps": np.random.randint(100, 1000)
        }
        
        if method == "ofmu":
            # Add OFMU-specific metrics
            results.update({
                "beta": training_args.get("beta", 0.1),
                "gradient_similarity": np.random.random(),
                "inner_steps": training_args.get("inner_steps", 5),
                "outer_steps": training_args.get("outer_steps", 3)
            })
        
        return results
    
    def run_vision_unlearning_method(
        self,
        model,
        dataset,
        method: str,
        training_args: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Run unlearning method for vision models."""
        self.logger.info(f"Running vision unlearning method: {method}")
        
        # Mock implementation
        results = {
            "method": method,
            "training_accuracy": np.random.uniform(0.7, 0.95),
            "training_loss": np.random.uniform(0.1, 0.5),
            "convergence_epochs": np.random.randint(5, 20)
        }
        
        return results
    
    def evaluate_tofu_model(self, model, tokenizer, dataset, scenario: str) -> Dict[str, Any]:
        """Evaluate TOFU model performance."""
        self.logger.info(f"Evaluating TOFU model on {scenario}")
        
        # Mock evaluation results
        results = {
            "forget_quality": np.random.uniform(0.8, 0.95),
            "model_utility": np.random.uniform(0.7, 0.9),
            "truth_ratio": np.random.uniform(0.1, 0.3),
            "rouge_score": np.random.uniform(0.3, 0.6),
            "perplexity": np.random.uniform(10, 50)
        }
        
        return results
    
    def evaluate_wmdp_model(self, model, tokenizer, dataset, domain: str) -> Dict[str, Any]:
        """Evaluate WMDP model performance."""
        self.logger.info(f"Evaluating WMDP model on {domain}")
        
        # Mock evaluation results
        results = {
            "qa_accuracy": np.random.uniform(0.3, 0.7),
            "safety_score": np.random.uniform(0.6, 0.9),
            "retain_performance": np.random.uniform(0.8, 0.95),
            "multiple_choice_accuracy": np.random.uniform(0.25, 0.8)
        }
        
        return results
    
    def evaluate_cifar_model(self, model, dataset, dataset_name: str) -> Dict[str, Any]:
        """Evaluate CIFAR model performance."""
        self.logger.info(f"Evaluating CIFAR model on {dataset_name}")
        
        # Mock evaluation results
        results = {
            "unlearn_accuracy": np.random.uniform(0.1, 0.3),
            "retain_accuracy": np.random.uniform(0.8, 0.95),
            "test_accuracy": np.random.uniform(0.7, 0.9),
            "mia_auc": np.random.uniform(0.5, 0.6)
        }
        
        return results
    
    def run_experiment_suite(self, experiment_types: List[str], quick_test: bool = False) -> Dict[str, List[Dict[str, Any]]]:
        """Run a suite of experiments."""
        all_results = {}
        
        for exp_type in experiment_types:
            self.logger.info(f"Running {exp_type} experiments...")
            
            if exp_type == "tofu" and self.config.get("experiments", {}).get("tofu_experiments", {}).get("enabled", True):
                results = self.run_tofu_experiments(quick_test)
                all_results["tofu"] = results
                
            elif exp_type == "wmdp" and self.config.get("experiments", {}).get("wmdp_experiments", {}).get("enabled", True):
                results = self.run_wmdp_experiments(quick_test)
                all_results["wmdp"] = results
                
            elif exp_type == "cifar" and self.config.get("experiments", {}).get("cifar_experiments", {}).get("enabled", True):
                results = self.run_cifar_experiments(quick_test)
                all_results["cifar"] = results
        
        return all_results
    
    def run_tofu_experiments(self, quick_test: bool = False) -> List[Dict[str, Any]]:
        """Run all TOFU experiments."""
        tofu_config = self.config.get("experiments", {}).get("tofu_experiments", {})
        models = tofu_config.get("models", ["SAMPLE_MODEL"])
        scenarios = tofu_config.get("scenarios", ["forget05"])
        methods = tofu_config.get("methods", ["ofmu"])
        
        results = []
        for model in models:
            for scenario in scenarios:
                for method in methods:
                    if quick_test and len(results) >= 2:  # Limit for quick test
                        break
                    
                    result = self.run_tofu_experiment(model, scenario, method, quick_test)
                    results.append(result)
        
        return results
    
    def run_wmdp_experiments(self, quick_test: bool = False) -> List[Dict[str, Any]]:
        """Run all WMDP experiments."""
        wmdp_config = self.config.get("experiments", {}).get("wmdp_experiments", {})
        models = wmdp_config.get("models", ["SAMPLE_MODEL"])
        domains = wmdp_config.get("domains", ["bio"])
        methods = wmdp_config.get("methods", ["ofmu"])
        
        results = []
        for model in models:
            for domain in domains:
                for method in methods:
                    if quick_test and len(results) >= 2:  # Limit for quick test
                        break
                    
                    result = self.run_wmdp_experiment(model, domain, method, quick_test)
                    results.append(result)
        
        return results
    
    def run_cifar_experiments(self, quick_test: bool = False) -> List[Dict[str, Any]]:
        """Run all CIFAR experiments."""
        cifar_config = self.config.get("experiments", {}).get("cifar_experiments", {})
        datasets = cifar_config.get("datasets", ["CIFAR10"])
        forget_classes = cifar_config.get("forget_classes", [1])
        methods = cifar_config.get("methods", ["ofmu"])
        
        results = []
        for dataset in datasets:
            for n_forget in forget_classes:
                for method in methods:
                    if quick_test and len(results) >= 2:  # Limit for quick test
                        break
                    
                    result = self.run_cifar_experiment(dataset, n_forget, method, quick_test)
                    results.append(result)
        
        return results
    
    def save_results(self, results: Dict[str, Any], filename: Optional[str] = None):
        """Save experiment results to file."""
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"ofmu_results_{timestamp}.json"
        
        output_file = self.output_dir / filename
        
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2, default=str)
        
        self.logger.info(f"Results saved to {output_file}")
    
    def print_summary(self, results: Dict[str, List[Dict[str, Any]]]):
        """Print experiment results summary."""
        print("\n" + "="*80)
        print("OFMU EXPERIMENT RESULTS SUMMARY")
        print("="*80)
        
        total_experiments = 0
        total_success = 0
        
        for exp_type, exp_results in results.items():
            print(f"\n{exp_type.upper()} Experiments:")
            print("-" * 40)
            
            success_count = sum(1 for r in exp_results if r.get("success", False))
            total_exp_time = sum(r.get("execution_time", 0) for r in exp_results)
            
            print(f"  Total: {len(exp_results)}")
            print(f"  Success: {success_count}")
            print(f"  Failed: {len(exp_results) - success_count}")
            print(f"  Total Time: {total_exp_time:.1f}s")
            
            total_experiments += len(exp_results)
            total_success += success_count
        
        print(f"\nOVERALL SUMMARY:")
        print(f"  Total Experiments: {total_experiments}")
        print(f"  Success Rate: {total_success/total_experiments*100:.1f}%")
        print("="*80)


def main():
    parser = argparse.ArgumentParser(description="Run OFMU experiments")
    
    # Configuration
    parser.add_argument("--config", type=str, default="config/experiment_config.yaml",
                       help="Path to experiment configuration file")
    parser.add_argument("--output_dir", type=str, help="Output directory for results")
    
    # Experiment selection
    parser.add_argument("--all", action="store_true", help="Run all experiments")
    parser.add_argument("--experiment", choices=["tofu", "wmdp", "cifar"], 
                       help="Run specific experiment type")
    
    # Specific experiment parameters
    parser.add_argument("--model", type=str, default="llama2-7b", help="Model to use")
    parser.add_argument("--scenario", type=str, default="forget05", help="TOFU scenario")
    parser.add_argument("--domain", type=str, default="bio", help="WMDP domain")
    parser.add_argument("--dataset", type=str, default="CIFAR10", help="CIFAR dataset")
    parser.add_argument("--forget_classes", type=int, default=1, help="Number of classes to forget")
    parser.add_argument("--method", type=str, default="ofmu", help="Unlearning method")
    
    # Options
    parser.add_argument("--quick_test", action="store_true", help="Run quick test version")
    parser.add_argument("--setup_data", action="store_true", help="Setup datasets before running")
    
    args = parser.parse_args()
    
    # Setup data if requested
    if args.setup_data:
        print("Setting up datasets...")
        try:
            import subprocess
            result = subprocess.run([sys.executable, "setup_data.py", "--all"], 
                                  capture_output=True, text=True)
            if result.returncode != 0:
                print(f"Data setup failed: {result.stderr}")
                return 1
            print("Data setup completed successfully")
        except Exception as e:
            print(f"Data setup error: {e}")
            return 1
    
    # Initialize experiment runner
    try:
        runner = OFMUExperimentRunner(args.config, args.output_dir)
        runner.initialize_components()
        
    except Exception as e:
        print(f"Failed to initialize experiment runner: {e}")
        return 1
    
    # Run experiments
    try:
        if args.all:
            # Run all experiment types
            results = runner.run_experiment_suite(["tofu", "wmdp", "cifar"], args.quick_test)
            
        elif args.experiment:
            # Run specific experiment type
            if args.experiment == "tofu":
                result = runner.run_tofu_experiment(args.model, args.scenario, args.method, args.quick_test)
                results = {"tofu": [result]}
            elif args.experiment == "wmdp":
                result = runner.run_wmdp_experiment(args.model, args.domain, args.method, args.quick_test)
                results = {"wmdp": [result]}
            elif args.experiment == "cifar":
                result = runner.run_cifar_experiment(args.dataset, args.forget_classes, args.method, args.quick_test)
                results = {"cifar": [result]}
        else:
            print("Please specify --all or --experiment")
            return 1
        
        # Save and display results
        runner.save_results(results)
        runner.print_summary(results)
        
        # Check success rate
        all_success = all(
            all(r.get("success", False) for r in exp_results)
            for exp_results in results.values()
        )
        
        return 0 if all_success else 1
        
    except KeyboardInterrupt:
        print("\nExperiments interrupted by user")
        return 1
    except Exception as e:
        print(f"Experiment execution failed: {e}")
        return 1


if __name__ == "__main__":
    exit(main())