"""
Evaluation Framework for Cross-Modal Adversarial Training (CMAT)
Generated by AI Research Agent for Agents4Science 2025
"""

import torch
import torch.nn as nn
import numpy as np
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import pandas as pd

class Evaluator:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.model.eval()
        
    def evaluate_clean(self, test_loader, preprocessor):
        """Evaluate model on clean test data"""
        print("Evaluating on clean data...")
        
        correct = 0
        total = 0
        all_predictions = []
        all_labels = []
        
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Clean Evaluation"):
                face = batch['face'].to(self.device)
                voice = batch['voice'].to(self.device)
                behavioral = batch['behavioral'].to(self.device)
                labels = batch['label'].to(self.device)
                
                logits, _ = self.model(face, voice, behavioral)
                _, predicted = torch.max(logits.data, 1)
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                all_predictions.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        accuracy = 100 * correct / total
        print(f"Clean Accuracy: {accuracy:.2f}%")
        
        return accuracy
    
    def evaluate_adversarial(self, test_loader, preprocessor, attack_type='pgd', epsilon=0.03, num_steps=10):
        """Evaluate model on adversarial test data"""
        print(f"Evaluating on {attack_type.upper()} adversarial data...")
        
        correct = 0
        total = 0
        attack_success = 0
        
        for batch in tqdm(test_loader, desc="Adversarial Evaluation"):
            face = batch['face'].to(self.device)
            voice = batch['voice'].to(self.device)
            behavioral = batch['behavioral'].to(self.device)
            labels = batch['label'].to(self.device)
            
            # Generate adversarial examples
            if attack_type == 'pgd':
                face_adv, voice_adv, behavioral_adv = self._generate_pgd_attack(
                    face, voice, behavioral, labels, epsilon, num_steps
                )
            elif attack_type == 'fgsm':
                face_adv, voice_adv, behavioral_adv = self._generate_fgsm_attack(
                    face, voice, behavioral, labels, epsilon
                )
            else:
                raise ValueError(f"Unknown attack type: {attack_type}")
            
            # Evaluate on adversarial examples
            with torch.no_grad():
                logits_clean, _ = self.model(face, voice, behavioral)
                logits_adv, _ = self.model(face_adv, voice_adv, behavioral_adv)
                
                _, predicted_clean = torch.max(logits_clean.data, 1)
                _, predicted_adv = torch.max(logits_adv.data, 1)
                
                total += labels.size(0)
                correct += (predicted_adv == labels).sum().item()
                attack_success += (predicted_clean != predicted_adv).sum().item()
        
        accuracy = 100 * correct / total
        success_rate = 100 * attack_success / total
        
        print(f"Adversarial Accuracy: {accuracy:.2f}%")
        print(f"Attack Success Rate: {success_rate:.2f}%")
        
        return accuracy, success_rate
    
    def evaluate_cross_modal_adversarial(self, test_loader, preprocessor, epsilon=0.03):
        """Evaluate model on cross-modal adversarial attacks"""
        print("Evaluating on cross-modal adversarial attacks...")
        
        correct = 0
        total = 0
        
        for batch in tqdm(test_loader, desc="Cross-Modal Adversarial Evaluation"):
            face = batch['face'].to(self.device)
            voice = batch['voice'].to(self.device)
            behavioral = batch['behavioral'].to(self.device)
            labels = batch['label'].to(self.device)
            
            # Generate cross-modal adversarial examples
            face_adv, voice_adv, behavioral_adv = self._generate_cross_modal_attack(
                face, voice, behavioral, labels, epsilon
            )
            
            # Evaluate
            with torch.no_grad():
                logits, _ = self.model(face_adv, voice_adv, behavioral_adv)
                _, predicted = torch.max(logits.data, 1)
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        print(f"Cross-Modal Adversarial Accuracy: {accuracy:.2f}%")
        
        return accuracy
    
    def measure_latency(self, test_loader, preprocessor, num_samples=100):
        """Measure inference latency"""
        print("Measuring inference latency...")
        
        latencies = []
        
        with torch.no_grad():
            for i, batch in enumerate(test_loader):
                if i >= num_samples:
                    break
                    
                face = batch['face'].to(self.device)
                voice = batch['voice'].to(self.device)
                behavioral = batch['behavioral'].to(self.device)
                
                # Warm up
                if i == 0:
                    for _ in range(10):
                        _ = self.model(face, voice, behavioral)
                
                # Measure latency
                torch.cuda.synchronize() if torch.cuda.is_available() else None
                start_time = time.time()
                
                _ = self.model(face, voice, behavioral)
                
                torch.cuda.synchronize() if torch.cuda.is_available() else None
                end_time = time.time()
                
                latency = (end_time - start_time) * 1000  # Convert to milliseconds
                latencies.append(latency)
        
        avg_latency = np.mean(latencies)
        std_latency = np.std(latencies)
        
        print(f"Average Latency: {avg_latency:.2f} ± {std_latency:.2f} ms")
        
        return avg_latency
    
    def evaluate_transferability(self, test_loader, preprocessor, epsilon=0.03):
        """Evaluate transferability of adversarial examples across modalities"""
        print("Evaluating adversarial transferability...")
        
        transfer_rates = {'face_to_voice': 0, 'face_to_behavioral': 0, 
                         'voice_to_face': 0, 'voice_to_behavioral': 0,
                         'behavioral_to_face': 0, 'behavioral_to_voice': 0}
        
        total_samples = 0
        
        for batch in tqdm(test_loader, desc="Transferability Evaluation"):
            face = batch['face'].to(self.device)
            voice = batch['voice'].to(self.device)
            behavioral = batch['behavioral'].to(self.device)
            labels = batch['label'].to(self.device)
            
            batch_size = face.size(0)
            total_samples += batch_size
            
            # Generate adversarial examples for each modality
            face_adv = self._generate_single_modal_attack(face, voice, behavioral, labels, 'face', epsilon)
            voice_adv = self._generate_single_modal_attack(face, voice, behavioral, labels, 'voice', epsilon)
            behavioral_adv = self._generate_single_modal_attack(face, voice, behavioral, labels, 'behavioral', epsilon)
            
            # Test transferability
            with torch.no_grad():
                # Original predictions
                logits_orig, _ = self.model(face, voice, behavioral)
                _, pred_orig = torch.max(logits_orig, 1)
                
                # Face adversarial -> other modalities
                logits_f2v, _ = self.model(face_adv, voice, behavioral)
                _, pred_f2v = torch.max(logits_f2v, 1)
                transfer_rates['face_to_voice'] += (pred_orig != pred_f2v).sum().item()
                
                logits_f2b, _ = self.model(face_adv, voice, behavioral)
                _, pred_f2b = torch.max(logits_f2b, 1)
                transfer_rates['face_to_behavioral'] += (pred_orig != pred_f2b).sum().item()
                
                # Voice adversarial -> other modalities
                logits_v2f, _ = self.model(face, voice_adv, behavioral)
                _, pred_v2f = torch.max(logits_v2f, 1)
                transfer_rates['voice_to_face'] += (pred_orig != pred_v2f).sum().item()
                
                logits_v2b, _ = self.model(face, voice_adv, behavioral)
                _, pred_v2b = torch.max(logits_v2b, 1)
                transfer_rates['voice_to_behavioral'] += (pred_orig != pred_v2b).sum().item()
                
                # Behavioral adversarial -> other modalities
                logits_b2f, _ = self.model(face, voice, behavioral_adv)
                _, pred_b2f = torch.max(logits_b2f, 1)
                transfer_rates['behavioral_to_face'] += (pred_orig != pred_b2f).sum().item()
                
                logits_b2v, _ = self.model(face, voice, behavioral_adv)
                _, pred_b2v = torch.max(logits_b2v, 1)
                transfer_rates['behavioral_to_voice'] += (pred_orig != pred_b2v).sum().item()
        
        # Convert to percentages
        for key in transfer_rates:
            transfer_rates[key] = 100 * transfer_rates[key] / total_samples
        
        print("Transferability Rates:")
        for key, rate in transfer_rates.items():
            print(f"  {key}: {rate:.2f}%")
        
        return transfer_rates
    
    def generate_comparison_table(self, results_dict):
        """Generate comparison table for different methods"""
        print("Generating comparison table...")
        
        # Create DataFrame
        df = pd.DataFrame(results_dict).T
        
        # Round numerical values
        numeric_columns = ['Clean Acc', 'Adv Acc', 'Cross-Modal Acc', 'Latency (ms)']
        for col in numeric_columns:
            if col in df.columns:
                df[col] = df[col].round(2)
        
        # Save to CSV
        df.to_csv('comparison_results.csv')
        
        # Print table
        print("\nComparison Table:")
        print(df.to_string())
        
        return df
    
    def plot_confusion_matrix(self, test_loader, preprocessor, num_classes=10):
        """Plot confusion matrix"""
        print("Generating confusion matrix...")
        
        all_predictions = []
        all_labels = []
        
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Confusion Matrix"):
                face = batch['face'].to(self.device)
                voice = batch['voice'].to(self.device)
                behavioral = batch['behavioral'].to(self.device)
                labels = batch['label'].to(self.device)
                
                logits, _ = self.model(face, voice, behavioral)
                _, predicted = torch.max(logits.data, 1)
                
                all_predictions.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        # Create confusion matrix
        cm = confusion_matrix(all_labels, all_predictions)
        
        # Plot
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
        plt.close()
        
        return cm
    
    def _generate_pgd_attack(self, face, voice, behavioral, labels, epsilon, num_steps):
        """Generate PGD adversarial examples"""
        alpha = epsilon / num_steps
        
        face_adv = face.clone().detach()
        voice_adv = voice.clone().detach()
        behavioral_adv = behavioral.clone().detach()
        
        for _ in range(num_steps):
            face_adv.requires_grad_(True)
            voice_adv.requires_grad_(True)
            behavioral_adv.requires_grad_(True)
            
            logits, _ = self.model(face_adv, voice_adv, behavioral_adv)
            loss = nn.functional.cross_entropy(logits, labels)
            
            grad_face = torch.autograd.grad(loss, face_adv, retain_graph=True)[0]
            grad_voice = torch.autograd.grad(loss, voice_adv, retain_graph=True)[0]
            grad_behavioral = torch.autograd.grad(loss, behavioral_adv, retain_graph=True)[0]
            
            face_adv = face_adv + alpha * grad_face.sign()
            voice_adv = voice_adv + alpha * grad_voice.sign()
            behavioral_adv = behavioral_adv + alpha * grad_behavioral.sign()
            
            # Project to epsilon ball
            face_adv = torch.clamp(face_adv, face - epsilon, face + epsilon)
            voice_adv = torch.clamp(voice_adv, voice - epsilon, voice + epsilon)
            behavioral_adv = torch.clamp(behavioral_adv, behavioral - epsilon, behavioral + epsilon)
            
            face_adv = face_adv.detach()
            voice_adv = voice_adv.detach()
            behavioral_adv = behavioral_adv.detach()
        
        return face_adv, voice_adv, behavioral_adv
    
    def _generate_fgsm_attack(self, face, voice, behavioral, labels, epsilon):
        """Generate FGSM adversarial examples"""
        face_adv = face.clone().detach()
        voice_adv = voice.clone().detach()
        behavioral_adv = behavioral.clone().detach()
        
        face_adv.requires_grad_(True)
        voice_adv.requires_grad_(True)
        behavioral_adv.requires_grad_(True)
        
        logits, _ = self.model(face_adv, voice_adv, behavioral_adv)
        loss = nn.functional.cross_entropy(logits, labels)
        
        grad_face = torch.autograd.grad(loss, face_adv, retain_graph=True)[0]
        grad_voice = torch.autograd.grad(loss, voice_adv, retain_graph=True)[0]
        grad_behavioral = torch.autograd.grad(loss, behavioral_adv, retain_graph=True)[0]
        
        face_adv = face_adv + epsilon * grad_face.sign()
        voice_adv = voice_adv + epsilon * grad_voice.sign()
        behavioral_adv = behavioral_adv + epsilon * grad_behavioral.sign()
        
        return face_adv.detach(), voice_adv.detach(), behavioral_adv.detach()
    
    def _generate_cross_modal_attack(self, face, voice, behavioral, labels, epsilon):
        """Generate cross-modal adversarial examples"""
        # Generate adversarial examples for all modalities simultaneously
        face_adv, voice_adv, behavioral_adv = self._generate_pgd_attack(
            face, voice, behavioral, labels, epsilon, num_steps=10
        )
        
        return face_adv, voice_adv, behavioral_adv
    
    def _generate_single_modal_attack(self, face, voice, behavioral, labels, modality, epsilon):
        """Generate adversarial examples for a single modality"""
        if modality == 'face':
            face_adv = face.clone().detach()
            face_adv.requires_grad_(True)
            
            logits, _ = self.model(face_adv, voice, behavioral)
            loss = nn.functional.cross_entropy(logits, labels)
            
            grad = torch.autograd.grad(loss, face_adv)[0]
            face_adv = face_adv + epsilon * grad.sign()
            face_adv = torch.clamp(face_adv, face - epsilon, face + epsilon)
            
            return face_adv.detach()
        
        elif modality == 'voice':
            voice_adv = voice.clone().detach()
            voice_adv.requires_grad_(True)
            
            logits, _ = self.model(face, voice_adv, behavioral)
            loss = nn.functional.cross_entropy(logits, labels)
            
            grad = torch.autograd.grad(loss, voice_adv)[0]
            voice_adv = voice_adv + epsilon * grad.sign()
            voice_adv = torch.clamp(voice_adv, voice - epsilon, voice + epsilon)
            
            return voice_adv.detach()
        
        elif modality == 'behavioral':
            behavioral_adv = behavioral.clone().detach()
            behavioral_adv.requires_grad_(True)
            
            logits, _ = self.model(face, voice, behavioral_adv)
            loss = nn.functional.cross_entropy(logits, labels)
            
            grad = torch.autograd.grad(loss, behavioral_adv)[0]
            behavioral_adv = behavioral_adv + epsilon * grad.sign()
            behavioral_adv = torch.clamp(behavioral_adv, behavioral - epsilon, behavioral + epsilon)
            
            return behavioral_adv.detach()

if __name__ == '__main__':
    print("Testing Evaluator...")
    
    # Create dummy model and data
    from model import ProposedModel
    from dataset import ResearchDataset
    from preprocessor import Preprocessor
    from torch.utils.data import DataLoader
    
    # Model
    model = ProposedModel(num_classes=100)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    # Data
    dataset = ResearchDataset(num_subjects=100, samples_per_subject=10)
    preprocessor = Preprocessor(augment=True)
    test_loader = DataLoader(dataset, batch_size=32, shuffle=False)
    
    # Evaluator
    evaluator = Evaluator(model, device)
    
    # Test evaluation
    clean_acc = evaluator.evaluate_clean(test_loader, preprocessor)
    adv_acc, success_rate = evaluator.evaluate_adversarial(test_loader, preprocessor)
    cross_modal_acc = evaluator.evaluate_cross_modal_adversarial(test_loader, preprocessor)
    latency = evaluator.measure_latency(test_loader, preprocessor, num_samples=10)
    
    print(f"Clean Accuracy: {clean_acc:.2f}%")
    print(f"Adversarial Accuracy: {adv_acc:.2f}%")
    print(f"Cross-Modal Accuracy: {cross_modal_acc:.2f}%")
    print(f"Latency: {latency:.2f} ms")
    
    print("Evaluator test complete.")
