import pandas as pd
import numpy as np
import torch
from sklearn.metrics import (
    roc_curve, precision_recall_curve,
    roc_auc_score, average_precision_score, matthews_corrcoef,
    f1_score, precision_score, recall_score, accuracy_score,
    brier_score_loss, log_loss
)


def calc_metrics(preds, probs, targets):
    results = {}
    results["IoU"] = ((preds == 1) & (targets == 1)).sum() / max(1, ((preds == 1) | (targets == 1)).sum())
    results["Precision"] = precision_score(targets, preds, zero_division=0)
    results["Recall"] = recall_score(targets, preds, zero_division=0)
    results["F1"] = f1_score(targets, preds, zero_division=0)
    results["MCC"] = matthews_corrcoef(targets, preds)
    results["Accuracy"] = accuracy_score(targets, preds)
    try:
        results["AUC"] = roc_auc_score(targets, probs)
    except:
        results["AUC"] = float("nan")
    try:
        results["PR-AUC"] = average_precision_score(targets, probs)
    except:
        results["PR-AUC"] = float("nan")
    try:
        results["PCC"] = np.corrcoef(probs, targets)[0, 1]
    except:
        results["PCC"] = float("nan")
    try:
        results["Brier"] = brier_score_loss(targets, probs)
    except:
        results["Brier"] = float("nan")
    try:
        results["BCE"] = log_loss(targets, probs, labels=[0, 1])
    except:
        results["BCE"] = float("nan")
    return results


@torch.no_grad()
def evaluate(data_loader, model, device,
             thr_H: float = 0.20, thr_L: float = 0.13, thr_Ag: float = 0.30,
             return_curves: bool = False):
    """
    Evaluation loop for three-chain Conformer (H, L, Ag).
    """

    model.eval()

    all_H_probs, all_H_preds, all_H_targets = [], [], []
    all_L_probs, all_L_preds, all_L_targets = [], [], []
    all_Ag_probs, all_Ag_preds, all_Ag_targets = [], [], []

    for batch in data_loader:
        H_embed = batch['H_embedding'].to(device).transpose(1, 2)
        L_embed = batch['L_embedding'].to(device).transpose(1, 2)
        Ag_embed = batch['Ag_embedding'].to(device).transpose(1, 2)

        H_labels = batch['H_labels'].to(device)
        L_labels = batch['L_labels'].to(device)
        Ag_labels = batch['Ag_labels'].to(device)

        H_mask = batch['H_mask'].to(device).bool()
        L_mask = batch['L_mask'].to(device).bool()
        Ag_mask = batch['Ag_mask'].to(device).bool()

        with torch.cuda.amp.autocast():
            logits_H, logits_L, logits_Ag = model(H_embed, L_embed, Ag_embed)

            H_probs = torch.softmax(logits_H, dim=-1)[..., 1]
            L_probs = torch.softmax(logits_L, dim=-1)[..., 1]
            Ag_probs = torch.softmax(logits_Ag, dim=-1)[..., 1]

            H_preds = (H_probs > thr_H).long() * H_mask.long()
            L_preds = (L_probs > thr_L).long() * L_mask.long()
            Ag_preds = (Ag_probs > thr_Ag).long() * Ag_mask.long()

            all_H_probs.append(H_probs[H_mask].cpu())
            all_H_preds.append(H_preds[H_mask].cpu())
            all_H_targets.append(H_labels[H_mask].cpu())

            all_L_probs.append(L_probs[L_mask].cpu())
            all_L_preds.append(L_preds[L_mask].cpu())
            all_L_targets.append(L_labels[L_mask].cpu())

            all_Ag_probs.append(Ag_probs[Ag_mask].cpu())
            all_Ag_preds.append(Ag_preds[Ag_mask].cpu())
            all_Ag_targets.append(Ag_labels[Ag_mask].cpu())

    # Concat
    all_H_probs = torch.cat(all_H_probs).numpy()
    all_H_preds = torch.cat(all_H_preds).numpy()
    all_H_targets = torch.cat(all_H_targets).numpy()

    all_L_probs = torch.cat(all_L_probs).numpy()
    all_L_preds = torch.cat(all_L_preds).numpy()
    all_L_targets = torch.cat(all_L_targets).numpy()

    all_Ag_probs = torch.cat(all_Ag_probs).numpy()
    all_Ag_preds = torch.cat(all_Ag_preds).numpy()
    all_Ag_targets = torch.cat(all_Ag_targets).numpy()

    # Combine H and L into Ab
    all_Ab_probs = np.concatenate([all_H_probs, all_L_probs])
    all_Ab_preds = np.concatenate([all_H_preds, all_L_preds])
    all_Ab_targets = np.concatenate([all_H_targets, all_L_targets])

    H_metrics = calc_metrics(all_H_preds, all_H_probs, all_H_targets)
    L_metrics = calc_metrics(all_L_preds, all_L_probs, all_L_targets)
    Ag_metrics = calc_metrics(all_Ag_preds, all_Ag_probs, all_Ag_targets)
    Ab_metrics = calc_metrics(all_Ab_preds, all_Ab_probs, all_Ab_targets)

    metrics = {"H": H_metrics, "L": L_metrics, "Ag": Ag_metrics, "Ab": Ab_metrics}

    # Optionally compute ROC/PR data
    if return_curves:
        roc_pr_curves = {}
        for name, probs, targets in [
            ("H", all_H_probs, all_H_targets),
            ("L", all_L_probs, all_L_targets),
            ("Ag", all_Ag_probs, all_Ag_targets),
            ("Ab", all_Ab_probs, all_Ab_targets),
        ]:
            # ROC
            fpr, tpr, _ = roc_curve(targets, probs)
            try:
                auc_val = roc_auc_score(targets, probs)
            except:
                auc_val = float("nan")

            # PR
            precision, recall, _ = precision_recall_curve(targets, probs)
            try:
                pr_auc = average_precision_score(targets, probs)
            except:
                pr_auc = float("nan")

            roc_pr_curves[name] = {
                "roc": (fpr, tpr, auc_val),
                "pr": (precision, recall, pr_auc),
            }

        return metrics, roc_pr_curves
    else:
        return metrics