#!/usr/bin/env python3
"""
Usage:
    python tofu_experiments.py --model llama2 --forget_scenario forget05
    python tofu_experiments.py --model llama3 --forget_scenario forget01 --method ofmu
"""

import os
import sys
import argparse
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader
import wandb

# Add src to path
sys.path.append(str(Path(__file__).parent.parent / "src"))

from trainer.unlearn.ofmu import OFMUTrainer
from trainer.unlearn.gradient_ascent import GradientAscentTrainer
from trainer.unlearn.gradient_diff import GradientDifferenceTrainer
from trainer.unlearn.npo import NPOTrainer
from trainer.unlearn.simnpo import SimNPOTrainer
from trainer.unlearn.rmu import RMUTrainer
from data.tofu import TOFUDataset
from evals.tofu import TOFUEvaluator


class TOFUExperiment:
    """TOFU experiment runner for unlearning methods comparison."""
    
    def __init__(self, config: Dict):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.setup_logging()
        self.setup_model_and_data()
        
    def setup_logging(self):
        """Setup logging configuration."""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(f'tofu_experiment_{self.config["model"]}_{self.config["forget_scenario"]}.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
        
    def setup_model_and_data(self):
        """Initialize model, tokenizer, and datasets."""
        model_configs = {
            "llama2-7b": "meta-llama/Llama-2-7b-chat-hf",
            "llama3-8b": "meta-llama/Meta-Llama-3-8B-Instruct",
        }
        
        model_name = model_configs[self.config["model"]]
        self.logger.info(f"Loading model: {model_name}")
        
        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        
        # Load TOFU dataset
        dataset_path = Path(__file__).parent.parent / "data" / "tofu"
        self.dataset = TOFUDataset(
            dataset_path / f"{self.config['forget_scenario']}.json",
            tokenizer=self.tokenizer
        )
        
        self.logger.info(f"Dataset loaded: {len(self.dataset.forget_set)} forget, {len(self.dataset.retain_set)} retain")
        
    def get_trainer(self, method: str):
        """Get the appropriate trainer for the specified method."""
        trainer_configs = {
            "ofmu": {
                "class": OFMUTrainer,
                "params": {
                    "beta": 0.1,
                    "rho_init": 0.01,
                    "inner_steps": 5,
                    "inner_lr": 1e-5,
                    "outer_lr": 1e-5,
                    "similarity_metric": "cosine"
                }
            },
            "gradient_ascent": {
                "class": GradientAscentTrainer,
                "params": {
                    "lr": 1e-5,
                    "max_steps": 1000
                }
            },
            "gradient_diff": {
                "class": GradientDifferenceTrainer,
                "params": {
                    "lr": 1e-5,
                    "alpha": 1.0
                }
            },
            "npo": {
                "class": NPOTrainer,
                "params": {
                    "lr": 1e-5,
                    "beta": 0.1
                }
            },
            "simnpo": {
                "class": SimNPOTrainer,
                "params": {
                    "lr": 1e-5,
                    "beta": 0.1,
                    "gamma": 0.85
                }
            },
            "rmu": {
                "class": RMUTrainer,
                "params": {
                    "lr": 1e-5,
                    "alpha": 1.0,
                    "regularization": "l2"
                }
            }
        }
        
        if method not in trainer_configs:
            raise ValueError(f"Unknown method: {method}")
            
        trainer_config = trainer_configs[method]
        return trainer_config["class"](
            model=self.model,
            tokenizer=self.tokenizer,
            **trainer_config["params"]
        )
    
    def run_experiment(self, method: str) -> Dict:
        """Run unlearning experiment for a specific method."""
        self.logger.info(f"Running experiment with method: {method}")
        
        # Initialize trainer
        trainer = self.get_trainer(method)
        
        # Create data loaders
        forget_loader = DataLoader(
            self.dataset.forget_set,
            batch_size=self.config["batch_size"],
            shuffle=True,
            collate_fn=self.dataset.collate_fn
        )
        
        retain_loader = DataLoader(
            self.dataset.retain_set,
            batch_size=self.config["batch_size"],
            shuffle=True,
            collate_fn=self.dataset.collate_fn
        )
        
        # Run unlearning
        self.logger.info("Starting unlearning process...")
        unlearned_model = trainer.unlearn(
            forget_loader=forget_loader,
            retain_loader=retain_loader,
            num_epochs=self.config["num_epochs"],
            save_steps=self.config.get("save_steps", 100)
        )
        
        # Evaluate results
        self.logger.info("Evaluating results...")
        evaluator = TOFUEvaluator(
            model=unlearned_model,
            tokenizer=self.tokenizer,
            dataset=self.dataset
        )
        
        results = evaluator.evaluate()
        results["method"] = method
        results["config"] = self.config
        
        # Save results
        results_path = Path(f"results/{method}_{self.config['model']}_{self.config['forget_scenario']}.json")
        results_path.parent.mkdir(exist_ok=True)
        
        with open(results_path, 'w') as f:
            json.dump(results, f, indent=2)
            
        self.logger.info(f"Results saved to {results_path}")
        return results
    
    def run_all_methods(self) -> Dict[str, Dict]:
        """Run experiments for all methods."""
        methods = ["ofmu", "gradient_ascent", "gradient_diff", "npo", "simnpo", "rmu"]
        all_results = {}
        
        for method in methods:
            try:
                results = self.run_experiment(method)
                all_results[method] = results
                self.logger.info(f"Completed {method}: FQ={results['forget_quality']:.3f}, MU={results['model_utility']:.3f}")
            except Exception as e:
                self.logger.error(f"Failed to run {method}: {e}")
                all_results[method] = {"error": str(e)}
        
        return all_results


def main():
    parser = argparse.ArgumentParser(description="TOFU Experiments for OFMU")
    parser.add_argument("--model", choices=["llama2", "llama3"], default="llama2",
                       help="Model to use for experiments")
    parser.add_argument("--forget_scenario", choices=["forget01", "forget05", "forget10"], 
                       default="forget05", help="Forget scenario")
    parser.add_argument("--method", choices=["ofmu", "gradient_ascent", "gradient_diff", 
                       "npo", "simnpo", "rmu", "all"], default="all",
                       help="Unlearning method to run")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
    parser.add_argument("--num_epochs", type=int, default=5, help="Number of epochs")
    parser.add_argument("--wandb", action="store_true", help="Use wandb logging")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    args = parser.parse_args()
    
    # Set random seeds for reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    # Initialize wandb if requested
    if args.wandb:
        wandb.init(
            project="ofmu-tofu-experiments",
            config=vars(args),
            name=f"{args.method}_{args.model}_{args.forget_scenario}"
        )
    
    # Create experiment configuration
    config = {
        "model": args.model,
        "forget_scenario": args.forget_scenario,
        "batch_size": args.batch_size,
        "num_epochs": args.num_epochs,
        "seed": args.seed
    }
    
    # Run experiment
    experiment = TOFUExperiment(config)
    
    if args.method == "all":
        results = experiment.run_all_methods()
        
        # Print summary table
        print("\n" + "="*80)
        print("TOFU EXPERIMENT RESULTS SUMMARY")
        print("="*80)
        print(f"Model: {args.model}, Scenario: {args.forget_scenario}")
        print("-"*80)
        print(f"{'Method':<15} {'FQ':<8} {'MU':<8} {'FTR':<8} {'Status':<10}")
        print("-"*80)
        
        for method, result in results.items():
            if "error" in result:
                print(f"{method:<15} {'--':<8} {'--':<8} {'--':<8} {'ERROR':<10}")
            else:
                fq = result.get('forget_quality', 0)
                mu = result.get('model_utility', 0)
                ftr = result.get('forget_truth_ratio', 0)
                print(f"{method:<15} {fq:<8.3f} {mu:<8.3f} {ftr:<8.3f} {'OK':<10}")
        
        print("="*80)
        
    else:
        results = experiment.run_experiment(args.method)
        print(f"\nResults for {args.method}:")
        print(f"  Forget Quality: {results['forget_quality']:.3f}")
        print(f"  Model Utility: {results['model_utility']:.3f}")
        print(f"  Forget Truth Ratio: {results['forget_truth_ratio']:.3f}")


if __name__ == "__main__":
    main()