import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, Any, Tuple, Optional
from torch_geometric.data import Data


def accuracy(pred, truth):
    return (pred == truth).float().mean().item()


def sp(pred, s):
    pos_rate_0 = (pred[s == 0] == 1).float().mean()
    pos_rate_1 = (pred[s == 1] == 1).float().mean()
    return (pos_rate_0 - pos_rate_1).abs().item()


def eo(pred, truth, s):
    pos_samples_0 = (truth == 1) & (s == 0)
    if pos_samples_0.sum() == 0:
        tpr_0 = 0.0
    else:
        tpr_0 = (pred[pos_samples_0] == 1).float().mean()
    
    pos_samples_1 = (truth == 1) & (s == 1)
    if pos_samples_1.sum() == 0:
        tpr_1 = 0.0
    else:
        tpr_1 = (pred[pos_samples_1] == 1).float().mean()
    
    return abs(tpr_0 - tpr_1)


def auc_score(pred_probs, truth):
    try:
        if len(pred_probs.shape) > 1 and pred_probs.shape[1] == 2:
            pred_probs = pred_probs[:, 1]
        
        truth_np = truth.cpu().numpy()
        pred_np = pred_probs.cpu().numpy()
        
        pos_samples = pred_np[truth_np == 1]
        neg_samples = pred_np[truth_np == 0]
        
        if len(pos_samples) == 0 or len(neg_samples) == 0:
            return 0.5
        
        total_pairs = len(pos_samples) * len(neg_samples)
        correct_pairs = 0
        
        for pos_score in pos_samples:
            correct_pairs += (pos_score > neg_samples).sum()
        
        return correct_pairs / total_pairs
    except:
        return 0.0


def f1_score(pred, truth):
    try:
        pred_np = pred.cpu().numpy()
        truth_np = truth.cpu().numpy()
        
        unique_classes = np.unique(truth_np)
        f1_scores = []
        
        for cls in unique_classes:
            tp = ((pred_np == cls) & (truth_np == cls)).sum()
            fp = ((pred_np == cls) & (truth_np != cls)).sum()
            fn = ((pred_np != cls) & (truth_np == cls)).sum()
            
            if tp + fp == 0:
                precision = 0.0
            else:
                precision = tp / (tp + fp)
            
            if tp + fn == 0:
                recall = 0.0
            else:
                recall = tp / (tp + fn)
            
            if precision + recall == 0:
                f1 = 0.0
            else:
                f1 = 2 * (precision * recall) / (precision + recall)
            
            f1_scores.append(f1)
        
        return np.mean(f1_scores)
    except:
        return 0.0


def demographic_parity(pred, s):
    pos_rate_0 = (pred[s == 0] == 1).float().mean()
    pos_rate_1 = (pred[s == 1] == 1).float().mean()
    return torch.abs(pos_rate_0 - pos_rate_1).item()


def equality_of_opportunity(pred, truth, s):
    pos_samples_0 = (truth == 1) & (s == 0)
    pos_samples_1 = (truth == 1) & (s == 1)
    
    if pos_samples_0.sum() == 0:
        tpr_0 = 0.0
    else:
        tpr_0 = (pred[pos_samples_0] == 1).float().mean()
    
    if pos_samples_1.sum() == 0:
        tpr_1 = 0.0
    else:
        tpr_1 = (pred[pos_samples_1] == 1).float().mean()
    
    return torch.abs(tpr_0 - tpr_1).item()


def equalized_odds(pred, truth, s):
    eop = equality_of_opportunity(pred, truth, s)
    
    neg_samples_0 = (truth == 0) & (s == 0)
    neg_samples_1 = (truth == 0) & (s == 1)
    
    if neg_samples_0.sum() == 0 or neg_samples_1.sum() == 0:
        return eop
    
    fpr_0 = (pred[neg_samples_0] == 1).float().mean()
    fpr_1 = (pred[neg_samples_1] == 1).float().mean()
    fpr_diff = torch.abs(fpr_0 - fpr_1).item()
    
    return max(eop, fpr_diff)


def compute_fairness_metrics(preds, labels, sens_attr, mask=None):
    if mask is not None:
        preds = preds[mask]
        labels = labels[mask]
        sens_attr = sens_attr[mask]
    
    if len(preds.shape) > 1:
        if preds.shape[1] == 2:
            pred_labels = preds.argmax(dim=1)
            pred_probs = F.softmax(preds, dim=1)[:, 1]
        else:
            pred_labels = preds.argmax(dim=1)
            pred_probs = F.softmax(preds, dim=1).max(dim=1)[0]
    else:
        pred_labels = preds
        pred_probs = preds.float()
    
    metrics = {
        'accuracy': accuracy(pred_labels, labels),
        'f1': f1_score(pred_labels, labels),
        'auc': auc_score(pred_probs, labels),
        'demographic_parity': demographic_parity(pred_labels, sens_attr),
        'equality_of_opportunity': equality_of_opportunity(pred_labels, labels, sens_attr),
        'equalized_odds': equalized_odds(pred_labels, labels, sens_attr),
        'statistical_parity': sp(pred_labels, sens_attr),
        'equal_opportunity': eo(pred_labels, labels, sens_attr)
    }
    
    return metrics


def evaluate_robustness(model, clean_data, perturbed_data, metrics=['accuracy', 'fairness']):
    model.eval()
    results = {}
    
    with torch.no_grad():
        clean_logits = model(clean_data.edge_index, clean_data.x)
        clean_preds = clean_logits.argmax(dim=1)
        
        perturbed_logits = model(perturbed_data.edge_index, perturbed_data.x)
        perturbed_preds = perturbed_logits.argmax(dim=1)
        
        test_mask = clean_data.test_mask
        
        if 'accuracy' in metrics:
            clean_acc = accuracy(clean_preds[test_mask], clean_data.y[test_mask])
            perturbed_acc = accuracy(perturbed_preds[test_mask], perturbed_data.y[test_mask])
            
            results['clean_accuracy'] = clean_acc
            results['perturbed_accuracy'] = perturbed_acc
            results['accuracy_drop'] = clean_acc - perturbed_acc
            results['relative_accuracy_drop'] = (clean_acc - perturbed_acc) / (clean_acc + 1e-8)
        
        if 'fairness' in metrics and hasattr(clean_data, 's'):
            clean_fairness = compute_fairness_metrics(
                clean_preds[test_mask], clean_data.y[test_mask], clean_data.s[test_mask]
            )
            
            perturbed_fairness = compute_fairness_metrics(
                perturbed_preds[test_mask], perturbed_data.y[test_mask], perturbed_data.s[test_mask]
            )
            
            results['clean_fairness'] = clean_fairness
            results['perturbed_fairness'] = perturbed_fairness
            
            fairness_changes = {}
            for key in clean_fairness:
                if key in perturbed_fairness:
                    fairness_changes[f'{key}_change'] = abs(clean_fairness[key] - perturbed_fairness[key])
            results['fairness_changes'] = fairness_changes
    
    return results


def compute_comprehensive_metrics(model, data, split='test'):
    model.eval()
    
    if split == 'train':
        mask = data.train_mask
    elif split == 'val':
        mask = data.val_mask
    else:
        mask = data.test_mask
    
    with torch.no_grad():
        logits = model(data.edge_index, data.x)
        preds = F.softmax(logits, dim=1)
        pred_labels = preds.argmax(dim=1)
        
        metrics = {
            'accuracy': accuracy(pred_labels[mask], data.y[mask]),
            'f1': f1_score(pred_labels[mask], data.y[mask]),
            'auc': auc_score(preds[mask], data.y[mask])
        }
        
        if hasattr(data, 's'):
            fairness_metrics = compute_fairness_metrics(
                pred_labels[mask], data.y[mask], data.s[mask]
            )
            metrics.update({f'fair_{k}': v for k, v in fairness_metrics.items()})
    
    return metrics