#!/usr/bin/env python3
"""
Usage:
    python wmdp_experiments.py --model zephyr --domain bio
    python wmdp_experiments.py --model zephyr --domain cyber --method ofmu
"""

import os
import sys
import argparse
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
import pandas as pd

# Add src to path
sys.path.append(str(Path(__file__).parent.parent / "src"))


class WMDPDataset:
    """WMDP dataset loader and processor."""
    
    def __init__(self, domain: str = "bio"):
        self.domain = domain
        self.load_dataset()
        
    def load_dataset(self):
        """Load WMDP dataset for specified domain."""
        try:
            # Load from HuggingFace datasets
            if self.domain == "bio":
                self.dataset = load_dataset("cais/wmdp-bio")
            elif self.domain == "cyber":  
                self.dataset = load_dataset("cais/wmdp-cyber")
            elif self.domain == "chem":
                self.dataset = load_dataset("cais/wmdp-chem")
            else:
                raise ValueError(f"Unknown domain: {self.domain}")
                
            print(f"Loaded WMDP-{self.domain} dataset: {len(self.dataset['test'])} examples")
            
        except Exception as e:
            print(f"Error loading dataset: {e}")
            # Fallback to local data if available
            self.load_local_dataset()
    
    def load_local_dataset(self):
        """Load dataset from local files if HuggingFace fails."""
        data_path = Path(__file__).parent.parent / "data" / "wmdp"
        
        if self.domain == "bio":
            file_path = data_path / "WMDP_BIO_TEST.json"
        elif self.domain == "cyber":
            file_path = data_path / "WMDP_CYBER_TEST.json"
        elif self.domain == "chem":
            file_path = data_path / "WMDP_CHEM_TEST.json"
        else:
            raise ValueError(f"Unknown domain: {self.domain}")
            
        if file_path.exists():
            with open(file_path, 'r') as f:
                data = json.load(f)
            self.dataset = {"test": data}
            print(f"Loaded local WMDP-{self.domain} dataset: {len(data)} examples")
        else:
            raise FileNotFoundError(f"Dataset file not found: {file_path}")
    
    def format_question(self, example: Dict) -> str:
        """Format question with multiple choice options."""
        question = example["question"]
        choices = example["choices"]
        
        formatted = f"Question: {question}\n\nOptions:\n"
        for i, choice in enumerate(choices):
            formatted += f"{chr(65 + i)}. {choice}\n"
        formatted += "\nAnswer:"
        
        return formatted
    
    def get_correct_answer(self, example: Dict) -> str:
        """Get the correct answer letter (A, B, C, D)."""
        return chr(65 + example["answer"])


class WMDPEvaluator:
    """Evaluator for WMDP benchmark."""
    
    def __init__(self, model, tokenizer, dataset: WMDPDataset):
        self.model = model
        self.tokenizer = tokenizer
        self.dataset = dataset
        self.device = next(model.parameters()).device
        
    def evaluate_model(self, subset_size: Optional[int] = None) -> Dict:
        """Evaluate model on WMDP dataset."""
        test_data = self.dataset.dataset["test"]
        
        if subset_size:
            test_data = test_data[:subset_size]
        
        correct = 0
        total = 0
        detailed_results = []
        
        print(f"Evaluating on {len(test_data)} examples...")
        
        for i, example in enumerate(test_data):
            if i % 50 == 0:
                print(f"Progress: {i}/{len(test_data)}")
                
            # Format question
            formatted_question = self.dataset.format_question(example)
            correct_answer = self.dataset.get_correct_answer(example)
            
            # Get model prediction
            predicted_answer = self.predict_answer(formatted_question)
            
            # Check if correct
            is_correct = predicted_answer.strip().upper() == correct_answer
            if is_correct:
                correct += 1
            total += 1
            
            detailed_results.append({
                "question_id": i,
                "question": example["question"],
                "correct_answer": correct_answer,
                "predicted_answer": predicted_answer.strip().upper(),
                "is_correct": is_correct
            })
        
        accuracy = correct / total if total > 0 else 0
        
        results = {
            "accuracy": accuracy,
            "correct": correct,
            "total": total,
            "domain": self.dataset.domain,
            "detailed_results": detailed_results
        }
        
        print(f"Accuracy: {accuracy:.3f} ({correct}/{total})")
        return results
    
    def predict_answer(self, question: str) -> str:
        """Get model prediction for a question."""
        try:
            # Tokenize input
            inputs = self.tokenizer(question, return_tensors="pt", truncation=True, max_length=1024)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # Generate response
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=2,
                    do_sample=False,
                    temperature=0.0,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # Decode response
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = generated_text[len(question):].strip()
            
            # Extract answer (should be A, B, C, or D)
            if len(response) > 0:
                first_char = response[0].upper()
                if first_char in ['A', 'B', 'C', 'D']:
                    return first_char
            
            # Fallback: try to find A, B, C, D in response
            for char in ['A', 'B', 'C', 'D']:
                if char in response.upper():
                    return char
                    
            return "A"  # Default fallback
            
        except Exception as e:
            print(f"Error in prediction: {e}")
            return "A"


class WMDPUnlearner:
    """Simple unlearning methods for WMDP evaluation."""
    
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        
    def fine_tune_unlearn(self, forget_data: List[Dict], num_epochs: int = 3) -> None:
        """Fine-tune model to forget specific knowledge."""
        print("Fine-tuning for unlearning...")
        
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
        criterion = torch.nn.CrossEntropyLoss()
        
        for epoch in range(num_epochs):
            total_loss = 0
            
            for example in forget_data:
                # Create refusal training data
                question = f"Question: {example['question']}\nAnswer:"
                refusal_response = "I cannot provide information about this topic."
                
                # Tokenize
                inputs = self.tokenizer(question, return_tensors="pt", truncation=True, max_length=512)
                targets = self.tokenizer(question + " " + refusal_response, return_tensors="pt", truncation=True, max_length=512)
                
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                targets = {k: v.to(self.device) for k, v in targets.items()}
                
                # Forward pass
                outputs = self.model(**inputs, labels=targets["input_ids"])
                loss = outputs.loss
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            avg_loss = total_loss / len(forget_data)
            print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
    
    def gradient_ascent_unlearn(self, forget_data: List[Dict], num_steps: int = 100) -> None:
        """Use gradient ascent to unlearn knowledge."""
        print("Applying gradient ascent unlearning...")
        
        optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-6)
        criterion = torch.nn.CrossEntropyLoss()
        
        for step in range(num_steps):
            total_loss = 0
            
            for example in forget_data:
                question = f"Question: {example['question']}\nAnswer:"
                correct_answer = chr(65 + example["answer"])
                
                # Tokenize
                inputs = self.tokenizer(question, return_tensors="pt", truncation=True, max_length=512)
                targets = self.tokenizer(question + " " + correct_answer, return_tensors="pt", truncation=True, max_length=512)
                
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                targets = {k: v.to(self.device) for k, v in targets.items()}
                
                # Forward pass
                outputs = self.model(**inputs, labels=targets["input_ids"])
                loss = -outputs.loss  # Negative for gradient ascent
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            if step % 20 == 0:
                avg_loss = total_loss / len(forget_data)
                print(f"Step {step}/{num_steps}, Average Loss: {avg_loss:.4f}")


class WMDPExperiment:
    """WMDP experiment runner."""
    
    def __init__(self, config: Dict):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.setup_logging()
        self.setup_model()
        self.setup_dataset()
        
    def setup_logging(self):
        """Setup logging."""
        log_file = f"wmdp_experiment_{self.config['model']}_{self.config['domain']}.log"
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
        
    def setup_model(self):
        """Setup model and tokenizer."""
        model_configs = {
            "zephyr": "SAMPLE_MODEL_LINK",
        }
        
        model_name = model_configs[self.config["model"]]
        self.logger.info(f"Loading model: {model_name}")
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        
    def setup_dataset(self):
        """Setup WMDP dataset."""
        self.dataset = WMDPDataset(self.config["domain"])
        
    def run_baseline_evaluation(self) -> Dict:
        """Run baseline evaluation without unlearning."""
        self.logger.info("Running baseline evaluation...")
        
        evaluator = WMDPEvaluator(self.model, self.tokenizer, self.dataset)
        results = evaluator.evaluate_model()
        results["method"] = "baseline"
        
        return results
    
    def run_unlearning_experiment(self, method: str) -> Dict:
        """Run unlearning experiment."""
        self.logger.info(f"Running unlearning with method: {method}")
        
        # Create a subset of data to "forget" (e.g., 10% of test data)
        test_data = self.dataset.dataset["test"]
        forget_size = min(100, len(test_data) // 10)  # Forget up to 100 examples
        forget_data = test_data[:forget_size]
        
        self.logger.info(f"Forgetting {len(forget_data)} examples")
        
        # Apply unlearning method
        unlearner = WMDPUnlearner(self.model, self.tokenizer, self.device)
        
        if method == "fine_tune":
            unlearner.fine_tune_unlearn(forget_data)
        elif method == "gradient_ascent":
            unlearner.gradient_ascent_unlearn(forget_data)
        else:
            raise ValueError(f"Unknown unlearning method: {method}")
        
        # Evaluate after unlearning
        evaluator = WMDPEvaluator(self.model, self.tokenizer, self.dataset)
        results = evaluator.evaluate_model()
        results["method"] = method
        results["forget_size"] = len(forget_data)
        
        return results
    
    def run_experiment(self) -> Dict:
        """Run complete WMDP experiment."""
        # Baseline evaluation
        baseline_results = self.run_baseline_evaluation()
        
        # Unlearning experiments
        methods = ["fine_tune", "gradient_ascent"]
        unlearning_results = {}
        
        for method in methods:
            try:
                # Reload model for each method to start fresh
                self.setup_model()
                results = self.run_unlearning_experiment(method)
                unlearning_results[method] = results
            except Exception as e:
                self.logger.error(f"Error with method {method}: {e}")
                unlearning_results[method] = {"error": str(e)}
        
        # Combine results
        all_results = {
            "baseline": baseline_results,
            **unlearning_results,
            "config": self.config
        }
        
        # Save results
        results_path = Path(f"results/wmdp_{self.config['model']}_{self.config['domain']}.json")
        results_path.parent.mkdir(exist_ok=True)
        
        with open(results_path, 'w') as f:
            json.dump(all_results, f, indent=2)
            
        self.logger.info(f"Results saved to {results_path}")
        return all_results


def main():
    parser = argparse.ArgumentParser(description="WMDP Experiments for OFMU")
    parser.add_argument("--model", choices=["zephyr", "llama2", "llama3"], default="zephyr",
                       help="Model to use")
    parser.add_argument("--domain", choices=["bio", "cyber", "chem"], default="bio",
                       help="WMDP domain")
    parser.add_argument("--method", choices=["baseline", "fine_tune", "gradient_ascent", "all"], 
                       default="all", help="Method to run")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    args = parser.parse_args()
    
    # Set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    # Create configuration
    config = {
        "model": args.model,
        "domain": args.domain,
        "seed": args.seed
    }
    
    # Run experiment
    experiment = WMDPExperiment(config)
    
    if args.method == "all":
        results = experiment.run_experiment()
        
        # Print summary
        print("\n" + "="*60)
        print(f"WMDP EXPERIMENT RESULTS - {args.domain.upper()}")
        print("="*60)
        print(f"Model: {args.model}")
        print("-"*60)
        print(f"{'Method':<15} {'Accuracy':<10} {'Status':<10}")
        print("-"*60)
        
        for method, result in results.items():
            if method == "config":
                continue
            if "error" in result:
                print(f"{method:<15} {'--':<10} {'ERROR':<10}")
            else:
                acc = result.get('accuracy', 0)
                print(f"{method:<15} {acc:<10.3f} {'OK':<10}")
        
        print("="*60)
        
    else:
        if args.method == "baseline":
            results = experiment.run_baseline_evaluation()
        else:
            experiment.setup_model()  # Fresh model
            results = experiment.run_unlearning_experiment(args.method)
            
        print(f"\nResults for {args.method}:")
        print(f"  Accuracy: {results['accuracy']:.3f}")
        print(f"  Correct: {results['correct']}/{results['total']}")


if __name__ == "__main__":
    main()