#!/usr/bin/env python3
"""
Usage:
    from evaluation.evaluator import TOFUEvaluator, CIFAREvaluator, WMDPEvaluator
    from evaluation.metrics import compute_forget_quality, compute_mia_efficacy
    from evaluation.analysis import generate_comparison_plots
"""

import numpy as np
import torch
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional, Union
from abc import ABC, abstractmethod
import json
from pathlib import Path


class BaseEvaluator(ABC):
    """Abstract base class for all evaluators."""
    
    def __init__(self, model, tokenizer=None, dataset=None):
        self.model = model
        self.tokenizer = tokenizer
        self.dataset = dataset
        self.device = next(model.parameters()).device if hasattr(model, 'parameters') else 'cpu'
    
    @abstractmethod
    def evaluate(self) -> Dict:
        """Run evaluation and return metrics."""
        pass


class TOFUEvaluator(BaseEvaluator):
    """Evaluator for TOFU benchmark with FQ, MU, and FTR metrics."""
    
    def __init__(self, model, tokenizer, dataset, max_new_tokens=50):
        super().__init__(model, tokenizer, dataset)
        self.max_new_tokens = max_new_tokens
    
    def compute_forget_quality(self, forget_set: List[Dict]) -> float:
        """
        Compute Forget Quality (FQ) - measures forgetting efficacy.
        """
        total_loss = 0.0
        total_samples = 0
        
        self.model.eval()
        with torch.no_grad():
            for example in forget_set:
                question = example["question"]
                answer = example["answer"]
                
                # Format input
                if "chat" in self.tokenizer.name_or_path.lower():
                    input_text = f"User: {question}\nAssistant: {answer}"
                else:
                    input_text = f"Question: {question}\nAnswer: {answer}"
                
                # Tokenize
                inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                
                # Compute loss
                outputs = self.model(**inputs, labels=inputs["input_ids"])
                loss = outputs.loss
                
                total_loss += loss.item()
                total_samples += 1
        
        # Return negative log-likelihood (higher = better forgetting)
        avg_loss = total_loss / total_samples if total_samples > 0 else float('inf')
        forget_quality = np.exp(-avg_loss)  # Convert to probability-like metric
        
        return forget_quality
    
    def compute_model_utility(self, retain_set: List[Dict]) -> float:
        correct = 0
        total = 0
        
        self.model.eval()
        with torch.no_grad():
            for example in retain_set:
                question = example["question"]
                correct_answer = example["answer"]
                
                # Generate response
                if "chat" in self.tokenizer.name_or_path.lower():
                    prompt = f"User: {question}\nAssistant:"
                else:
                    prompt = f"Question: {question}\nAnswer:"
                
                inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                
                # Generate
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.max_new_tokens,
                    do_sample=False,
                    temperature=0.0,
                    pad_token_id=self.tokenizer.eos_token_id
                )
                
                # Decode response
                generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                response = generated_text[len(prompt):].strip()
                
                # Check if correct (simple string matching)
                if self._check_answer_correctness(response, correct_answer):
                    correct += 1
                total += 1
        
        utility = correct / total if total > 0 else 0.0
        return utility
    
    def compute_forget_truth_ratio(self, forget_set: List[Dict]) -> float:
        truthful_responses = 0
        total_samples = 0
        
        self.model.eval()
        with torch.no_grad():
            for example in forget_set:
                question = example["question"]
                correct_answer = example["answer"]
                
                # Generate response
                if "chat" in self.tokenizer.name_or_path.lower():
                    prompt = f"User: {question}\nAssistant:"
                else:
                    prompt = f"Question: {question}\nAnswer:"
                
                inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                
                # Generate
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.max_new_tokens,
                    do_sample=False,
                    temperature=0.0,
                    pad_token_id=self.tokenizer.eos_token_id
                )
                
                # Decode response
                generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                response = generated_text[len(prompt):].strip()
                
                # Check if response contains truthful information
                if self._check_answer_correctness(response, correct_answer):
                    truthful_responses += 1
                total_samples += 1
        
        ftr = truthful_responses / total_samples if total_samples > 0 else 0.0
        return ftr
    
    def _check_answer_correctness(self, response: str, correct_answer: str) -> bool:
        response_lower = response.lower().strip()
        answer_lower = correct_answer.lower().strip()
        
        # Simple substring matching
        if answer_lower in response_lower:
            return True
        
        # Check for key phrases
        answer_words = set(answer_lower.split())
        response_words = set(response_lower.split())
        
        # If most answer words are present, consider it correct
        if len(answer_words) > 0:
            overlap = len(answer_words.intersection(response_words))
            if overlap / len(answer_words) > 0.7:  # 70% word overlap threshold
                return True
        
        return False
    
    def evaluate(self) -> Dict:
        """Run complete TOFU evaluation."""
        if not hasattr(self.dataset, 'forget_set') or not hasattr(self.dataset, 'retain_set'):
            raise ValueError("Dataset must have 'forget_set' and 'retain_set' attributes")
        
        print("Computing Forget Quality...")
        forget_quality = self.compute_forget_quality(self.dataset.forget_set)
        
        print("Computing Model Utility...")
        model_utility = self.compute_model_utility(self.dataset.retain_set)
        
        print("Computing Forget Truth Ratio...")
        forget_truth_ratio = self.compute_forget_truth_ratio(self.dataset.forget_set)
        
        results = {
            "forget_quality": forget_quality,
            "model_utility": model_utility,
            "forget_truth_ratio": forget_truth_ratio,
            "num_forget_samples": len(self.dataset.forget_set),
            "num_retain_samples": len(self.dataset.retain_set)
        }
        
        print(f"Evaluation Results:")
        print(f"  Forget Quality: {forget_quality:.4f}")
        print(f"  Model Utility: {model_utility:.4f}") 
        print(f"  Forget Truth Ratio: {forget_truth_ratio:.4f}")
        
        return results


class CIFAREvaluator(BaseEvaluator):
    """Evaluator for CIFAR experiments with UA, RA, TA, and MIA metrics."""
    
    def __init__(self, model, forget_dataset, retain_dataset, test_dataset):
        super().__init__(model)
        self.forget_dataset = forget_dataset
        self.retain_dataset = retain_dataset
        self.test_dataset = test_dataset
    
    def compute_accuracy(self, dataset, dataset_name: str) -> Tuple[float, int, int]:
        """Compute accuracy on a given dataset."""
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in dataset:
                if not isinstance(inputs, torch.Tensor):
                    inputs = inputs[0] if isinstance(inputs, (list, tuple)) else inputs
                if not isinstance(targets, torch.Tensor):
                    targets = targets[1] if isinstance(targets, (list, tuple)) else targets
                
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        accuracy = 100. * correct / total if total > 0 else 0.0
        print(f"{dataset_name} Accuracy: {accuracy:.2f}% ({correct}/{total})")
        
        return accuracy, correct, total
    
    def compute_mia_efficacy(self) -> float:
        """
        Compute Membership Inference Attack (MIA) efficacy.
        """
        # Simple confidence-based MIA
        self.model.eval()
        
        # Get confidence scores for forget set
        forget_confidences = []
        with torch.no_grad():
            for inputs, targets in self.forget_dataset:
                if not isinstance(inputs, torch.Tensor):
                    inputs = inputs[0] if isinstance(inputs, (list, tuple)) else inputs
                if not isinstance(targets, torch.Tensor):
                    targets = targets[1] if isinstance(targets, (list, tuple)) else targets
                    
                inputs = inputs.to(self.device)
                outputs = self.model(inputs)
                probs = F.softmax(outputs, dim=1)
                max_probs = probs.max(dim=1)[0]
                forget_confidences.extend(max_probs.cpu().numpy())
        
        # Get confidence scores for a sample of retain set (same size as forget set)
        retain_confidences = []
        retain_loader_iter = iter(self.retain_dataset)
        samples_needed = len(forget_confidences)
        samples_collected = 0
        
        with torch.no_grad():
            while samples_collected < samples_needed:
                try:
                    inputs, targets = next(retain_loader_iter)
                    if not isinstance(inputs, torch.Tensor):
                        inputs = inputs[0] if isinstance(inputs, (list, tuple)) else inputs
                    
                    inputs = inputs.to(self.device)
                    outputs = self.model(inputs)
                    probs = F.softmax(outputs, dim=1)
                    max_probs = probs.max(dim=1)[0]
                    
                    batch_confidences = max_probs.cpu().numpy()
                    needed_from_batch = min(len(batch_confidences), samples_needed - samples_collected)
                    retain_confidences.extend(batch_confidences[:needed_from_batch])
                    samples_collected = len(retain_confidences)
                    
                except StopIteration:
                    # Restart iterator if we run out of data
                    retain_loader_iter = iter(self.retain_dataset)
        
        # Threshold-based MIA attack
        # Assume lower confidence indicates successful unlearning
        threshold = np.median(retain_confidences)
        
        # Count how many forget samples have confidence below threshold
        unlearned_count = sum(1 for conf in forget_confidences if conf < threshold)
        
        # MIA efficacy: fraction of forget samples that appear "unlearned"
        mia_efficacy = unlearned_count / len(forget_confidences) * 100
        
        return mia_efficacy
    
    def evaluate(self) -> Dict:
        """Run complete CIFAR evaluation."""
        print("Evaluating CIFAR model...")
        
        # Compute accuracies
        forget_acc, forget_correct, forget_total = self.compute_accuracy(self.forget_dataset, "Forget")
        retain_acc, retain_correct, retain_total = self.compute_accuracy(self.retain_dataset, "Retain")
        test_acc, test_correct, test_total = self.compute_accuracy(self.test_dataset, "Test")
        
        # Compute MIA efficacy
        print("Computing MIA efficacy...")
        mia_efficacy = self.compute_mia_efficacy()
        
        # Unlearning accuracy (inverse of forget accuracy)
        unlearning_accuracy = 100 - forget_acc
        
        results = {
            "forget_accuracy": forget_acc,
            "retain_accuracy": retain_acc,
            "test_accuracy": test_acc,
            "unlearning_accuracy": unlearning_accuracy,
            "mia_efficacy": mia_efficacy,
            "forget_correct": forget_correct,
            "forget_total": forget_total,
            "retain_correct": retain_correct,
            "retain_total": retain_total,
            "test_correct": test_correct,
            "test_total": test_total
        }
        
        print(f"Evaluation Results:")
        print(f"  Unlearning Accuracy: {unlearning_accuracy:.2f}%")
        print(f"  Retain Accuracy: {retain_acc:.2f}%")
        print(f"  Test Accuracy: {test_acc:.2f}%")
        print(f"  MIA Efficacy: {mia_efficacy:.2f}%")
        
        return results


class WMDPEvaluator(BaseEvaluator):
    """Evaluator for WMDP benchmark with QA accuracy metrics."""
    
    def __init__(self, model, tokenizer, dataset, max_new_tokens=10):
        super().__init__(model, tokenizer, dataset)
        self.max_new_tokens = max_new_tokens
    
    def evaluate_qa_accuracy(self, questions: List[Dict]) -> Dict:
        """Evaluate QA accuracy on WMDP questions."""
        correct = 0
        total = 0
        detailed_results = []
        
        self.model.eval()
        with torch.no_grad():
            for i, example in enumerate(questions):
                if i % 50 == 0:
                    print(f"Progress: {i}/{len(questions)}")
                
                question = example["question"]
                choices = example["choices"]
                correct_answer_idx = example["answer"]
                correct_answer = chr(65 + correct_answer_idx)  # Convert to A, B, C, D
                
                # Format question with choices
                formatted_question = f"Question: {question}\n\nOptions:\n"
                for j, choice in enumerate(choices):
                    formatted_question += f"{chr(65 + j)}. {choice}\n"
                formatted_question += "\nAnswer:"
                
                # Get model prediction
                predicted_answer = self._predict_answer(formatted_question)
                
                # Check if correct
                is_correct = predicted_answer.strip().upper() == correct_answer
                if is_correct:
                    correct += 1
                total += 1
                
                detailed_results.append({
                    "question_id": i,
                    "question": question,
                    "choices": choices,
                    "correct_answer": correct_answer,
                    "predicted_answer": predicted_answer.strip().upper(),
                    "is_correct": is_correct
                })
        
        accuracy = correct / total if total > 0 else 0.0
        
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total,
            "detailed_results": detailed_results
        }
    
    def _predict_answer(self, question: str) -> str:
        """Get model prediction for a multiple choice question."""
        try:
            # Tokenize input
            inputs = self.tokenizer(question, return_tensors="pt", truncation=True, max_length=1024)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # Generate response
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                do_sample=False,
                temperature=0.0,
                pad_token_id=self.tokenizer.eos_token_id
            )
            
            # Decode response
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = generated_text[len(question):].strip()
            
            # Extract answer (should be A, B, C, or D)
            if len(response) > 0:
                first_char = response[0].upper()
                if first_char in ['A', 'B', 'C', 'D']:
                    return first_char
            
            # Fallback: try to find A, B, C, D in response
            for char in ['A', 'B', 'C', 'D']:
                if char in response.upper():
                    return char
                    
            return "A"  # Default fallback
            
        except Exception as e:
            print(f"Error in prediction: {e}")
            return "A"
    
    def evaluate(self) -> Dict:
        """Run complete WMDP evaluation."""
        if not hasattr(self.dataset, 'dataset') or 'test' not in self.dataset.dataset:
            raise ValueError("Dataset must have 'dataset' attribute with 'test' key")
        
        test_questions = self.dataset.dataset["test"]
        print(f"Evaluating on {len(test_questions)} WMDP questions...")
        
        results = self.evaluate_qa_accuracy(test_questions)
        results["domain"] = getattr(self.dataset, 'domain', 'unknown')
        
        print(f"WMDP Evaluation Results:")
        print(f"  Domain: {results['domain']}")
        print(f"  Accuracy: {results['accuracy']:.3f}")
        print(f"  Correct: {results['correct']}/{results['total']}")
        
        return results


def compute_statistical_significance(results1: Dict, results2: Dict, metric: str, 
                                   test_type: str = "ttest") -> Dict:
    """
    Compute statistical significance between two sets of results.
    """
    try:
        from scipy import stats
        
        # Extract values
        val1 = results1.get(metric, 0)
        val2 = results2.get(metric, 0)
        
        # For single values, we can't compute statistical significance
        # In practice, you'd need multiple runs or bootstrapping
        if isinstance(val1, (int, float)) and isinstance(val2, (int, float)):
            # Simple comparison
            difference = abs(val1 - val2)
            relative_diff = difference / max(val1, val2, 1e-8)
            
            return {
                "metric": metric,
                "value1": val1,
                "value2": val2,
                "difference": difference,
                "relative_difference": relative_diff,
                "is_significant": relative_diff > 0.05,  # 5% threshold
                "note": "Single-value comparison, no statistical test performed"
            }
        
        # For detailed results with multiple samples
        if "detailed_results" in results1 and "detailed_results" in results2:
            # Extract binary outcomes for significance testing
            outcomes1 = [r.get("is_correct", False) for r in results1["detailed_results"]]
            outcomes2 = [r.get("is_correct", False) for r in results2["detailed_results"]]
            
            if test_type == "ttest":
                stat, p_value = stats.ttest_ind(outcomes1, outcomes2)
            elif test_type == "wilcoxon":
                stat, p_value = stats.ranksums(outcomes1, outcomes2)
            else:
                # Default to proportion test
                count1, n1 = sum(outcomes1), len(outcomes1)
                count2, n2 = sum(outcomes2), len(outcomes2)
                stat, p_value = stats.proportions_ztest([count1, count2], [n1, n2])
            
            return {
                "metric": metric,
                "test_type": test_type,
                "statistic": stat,
                "p_value": p_value,
                "is_significant": p_value < 0.05,
                "alpha": 0.05
            }
    
    except ImportError:
        return {"error": "scipy not available for statistical testing"}
    except Exception as e:
        return {"error": f"Statistical test failed: {e}"}
    
    return {"error": "Unable to perform statistical test"}


def save_evaluation_results(results: Dict, output_path: Union[str, Path]):
    """Save evaluation results to JSON file."""
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"Evaluation results saved to: {output_path}")


def load_evaluation_results(input_path: Union[str, Path]) -> Dict:
    """Load evaluation results from JSON file."""
    input_path = Path(input_path)
    
    if not input_path.exists():
        raise FileNotFoundError(f"Results file not found: {input_path}")
    
    with open(input_path, 'r') as f:
        results = json.load(f)
    
    return results