"""
Evaluation metrics for machine unlearning experiments.
Implements metrics for both classification and LLM unlearning tasks.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from typing import Dict, List, Tuple, Optional, Any
import logging

logger = logging.getLogger(__name__)


class UnlearningMetrics:
    """Comprehensive metrics for evaluating unlearning performance"""
    
    def __init__(self, model: nn.Module, device: str = 'cuda'):
        self.model = model
        self.device = device
    
    def evaluate_classification(self,
                              forget_loader: torch.utils.data.DataLoader,
                              retain_loader: torch.utils.data.DataLoader,
                              test_loader: torch.utils.data.DataLoader,
                              original_model: Optional[nn.Module] = None) -> Dict[str, float]:
        """
        Evaluate classification unlearning performance.
        
        Returns:
            Dictionary with metrics: Acc_rt, Acc_ft, MIA, RUD
        """
        metrics = {}
        
        # Retain Accuracy (Acc_rt)
        metrics['acc_rt'] = self.compute_accuracy(retain_loader)
        
        # Forget Accuracy (Acc_ft) - should be low
        metrics['acc_ft'] = self.compute_accuracy(forget_loader)
        
        # Membership Inference Attack (MIA) - should be low
        metrics['mia'] = self.compute_mia(forget_loader, retain_loader)
        
        # Retain Utility Deviation (RUD) - should be low
        if original_model is not None:
            metrics['rud'] = self.compute_rud(test_loader, original_model)
        
        return metrics
    
    def evaluate_llm(self,
                    forget_data: List[Dict],
                    retain_data: List[Dict],
                    retrain_model: Optional[Any] = None) -> Dict[str, float]:
        """
        Evaluate LLM unlearning performance on TOFU benchmark.
        
        Returns:
            Dictionary with metrics: forget_quality, model_utility
        """
        metrics = {}
        
        # Forget Quality - distributional similarity with retrain model
        if retrain_model is not None:
            metrics['forget_quality'] = self.compute_forget_quality(forget_data, retrain_model)
        
        # Model Utility - performance on retain set and general knowledge
        metrics['model_utility'] = self.compute_model_utility(retain_data)
        
        return metrics
    
    def compute_accuracy(self, data_loader: torch.utils.data.DataLoader) -> float:
        """Compute classification accuracy"""
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in data_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)
        
        return correct / total if total > 0 else 0.0
    
    def compute_mia(self,
                   forget_loader: torch.utils.data.DataLoader,
                   retain_loader: torch.utils.data.DataLoader,
                   num_shadow_models: int = 1) -> float:
        """
        Compute Membership Inference Attack success rate.
        
        Uses a simple threshold-based attack on prediction confidence.
        """
        self.model.eval()
        
        # Get prediction confidences
        forget_confidences = []
        retain_confidences = []
        
        with torch.no_grad():
            # Forget set confidences
            for data, target in forget_loader:
                data = data.to(self.device)
                output = self.model(data)
                probs = F.softmax(output, dim=1)
                max_probs = torch.max(probs, dim=1)[0]
                forget_confidences.extend(max_probs.cpu().numpy())
            
            # Retain set confidences
            for data, target in retain_loader:
                data = data.to(self.device)
                output = self.model(data)
                probs = F.softmax(output, dim=1)
                max_probs = torch.max(probs, dim=1)[0]
                retain_confidences.extend(max_probs.cpu().numpy())
        
        # Prepare data for MIA classifier
        X = np.array(forget_confidences + retain_confidences).reshape(-1, 1)
        y = np.array([1] * len(forget_confidences) + [0] * len(retain_confidences))
        
        if len(np.unique(y)) < 2:
            return 0.0
        
        # Train MIA classifier
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
        
        try:
            mia_classifier = LogisticRegression(random_state=42)
            mia_classifier.fit(X_train, y_train)
            
            # Evaluate attack success rate
            y_pred = mia_classifier.predict(X_test)
            attack_accuracy = accuracy_score(y_test, y_pred)
            
            return attack_accuracy
        except Exception as e:
            logger.warning(f"MIA computation failed: {e}")
            return 0.0
    
    def compute_rud(self,
                   test_loader: torch.utils.data.DataLoader,
                   original_model: nn.Module,
                   k_neighbors: int = 5) -> float:
        """
        Compute Retain Utility Deviation (RUD).
        
        Measures how much the model's predictions change on the k-nearest
        neighbors of the forget samples.
        """
        self.model.eval()
        original_model.eval()
        
        total_deviation = 0.0
        total_samples = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data = data.to(self.device)
                
                # Get predictions from both models
                current_output = self.model(data)
                original_output = original_model(data)
                
                # Compute KL divergence
                current_probs = F.softmax(current_output, dim=1)
                original_probs = F.softmax(original_output, dim=1)
                
                # KL divergence: KL(P || Q) = sum(P * log(P/Q))
                kl_div = F.kl_div(current_probs.log(), original_probs, reduction='none')
                kl_div = kl_div.sum(dim=1)
                
                total_deviation += kl_div.sum().item()
                total_samples += data.size(0)
        
        return total_deviation / total_samples if total_samples > 0 else 0.0
    
    def compute_forget_quality(self,
                             forget_data: List[Dict],
                             retrain_model: Any) -> float:
        """
        Compute forget quality for LLM unlearning.
        
        Measures distributional similarity between unlearned and retrained model
        outputs on the forget set.
        """
        if not hasattr(self.model, 'generate'):
            logger.warning("Model does not support generation, returning default forget quality")
            return 0.5
        
        similarities = []
        
        for sample in forget_data:
            prompt = sample['prompt']
            
            try:
                # Generate responses from both models
                unlearned_response = self._generate_response(self.model, prompt)
                retrain_response = self._generate_response(retrain_model, prompt)
                
                # Compute similarity (simple token overlap for now)
                similarity = self._compute_text_similarity(unlearned_response, retrain_response)
                similarities.append(similarity)
                
            except Exception as e:
                logger.warning(f"Failed to compute forget quality for sample: {e}")
                continue
        
        return np.mean(similarities) if similarities else 0.5
    
    def compute_model_utility(self, retain_data: List[Dict]) -> float:
        """
        Compute model utility for LLM unlearning.
        
        Harmonic mean of performance on retain set and general knowledge.
        """
        if not hasattr(self.model, 'generate'):
            logger.warning("Model does not support generation, returning default utility")
            return 0.5
        
        retain_scores = []
        
        for sample in retain_data:
            prompt = sample['prompt']
            expected_answer = sample.get('answer', '')
            
            try:
                response = self._generate_response(self.model, prompt)
                
                # Simple scoring based on keyword presence
                score = self._score_response(response, expected_answer)
                retain_scores.append(score)
                
            except Exception as e:
                logger.warning(f"Failed to compute utility for sample: {e}")
                continue
        
        retain_performance = np.mean(retain_scores) if retain_scores else 0.5
        
        # For now, assume general knowledge performance is the same as retain performance
        # In practice, this would be evaluated on a separate general knowledge dataset
        general_knowledge_performance = retain_performance
        
        # Harmonic mean
        if retain_performance + general_knowledge_performance == 0:
            return 0.0
        
        utility = 2 * (retain_performance * general_knowledge_performance) / (
            retain_performance + general_knowledge_performance
        )
        
        return utility
    
    def _generate_response(self, model: Any, prompt: str, max_length: int = 100) -> str:
        """Generate response from model"""
        # This would depend on the specific model implementation
        # For now, return a placeholder
        return f"Generated response for: {prompt[:50]}..."
    
    def _compute_text_similarity(self, text1: str, text2: str) -> float:
        """Compute similarity between two text strings"""
        # Simple token-based similarity
        tokens1 = set(text1.lower().split())
        tokens2 = set(text2.lower().split())
        
        if not tokens1 and not tokens2:
            return 1.0
        if not tokens1 or not tokens2:
            return 0.0
        
        intersection = len(tokens1.intersection(tokens2))
        union = len(tokens1.union(tokens2))
        
        return intersection / union if union > 0 else 0.0
    
    def _score_response(self, response: str, expected: str) -> float:
        """Score response against expected answer"""
        # Simple keyword matching
        response_lower = response.lower()
        expected_lower = expected.lower()
        
        if expected_lower in response_lower:
            return 1.0
        
        # Partial credit based on word overlap
        response_words = set(response_lower.split())
        expected_words = set(expected_lower.split())
        
        if not expected_words:
            return 1.0  # If no expected answer, give full credit
        
        overlap = len(response_words.intersection(expected_words))
        return overlap / len(expected_words)


class ResultLogger:
    """Logger for experimental results"""
    
    def __init__(self, log_dir: str = "results"):
        self.log_dir = log_dir
        import os
        os.makedirs(log_dir, exist_ok=True)
    
    def log_results(self,
                   experiment_name: str,
                   method_name: str,
                   dataset: str,
                   metrics: Dict[str, float],
                   hyperparams: Dict[str, Any] = None):
        """Log experimental results to file"""
        import json
        import datetime
        
        result_entry = {
            'timestamp': datetime.datetime.now().isoformat(),
            'experiment': experiment_name,
            'method': method_name,
            'dataset': dataset,
            'metrics': metrics,
            'hyperparameters': hyperparams or {}
        }
        
        # Log to JSON file
        json_file = f"{self.log_dir}/{experiment_name}_{dataset}_results.json"
        
        # Read existing results
        results = []
        try:
            with open(json_file, 'r') as f:
                results = json.load(f)
        except FileNotFoundError:
            pass
        
        # Add new result
        results.append(result_entry)
        
        # Write back
        with open(json_file, 'w') as f:
            json.dump(results, f, indent=2)
        
        logger.info(f"Results logged to {json_file}")
    
    def log_to_csv(self,
                   experiment_name: str,
                   results_data: List[Dict[str, Any]]):
        """Log results to CSV format for easy analysis"""
        import pandas as pd
        
        df = pd.DataFrame(results_data)
        csv_file = f"{self.log_dir}/{experiment_name}_summary.csv"
        df.to_csv(csv_file, index=False)
        
        logger.info(f"Summary results saved to {csv_file}")