#!/usr/bin/env python
# metrics.py - Metric computation utilities
# --------------------------------------------------------------------
import numpy as np
import torch
from sklearn.metrics import cohen_kappa_score

def compute_onc_metrics_from_net(net, loss_layer, X, y):

    net.eval()
    device = next(net.parameters()).device

    with torch.no_grad():
        feats = net.f(torch.tensor(X, device=device))  # (N, d)
    feats_np = feats.cpu().numpy()

    if not isinstance(y, torch.Tensor):
        y = torch.tensor(y, device=device)
    else:
        y = y.to(device)
    y_np = y.cpu().numpy()

    w = net.h.weight.detach().squeeze().to(device)
    if hasattr(loss_layer, '_b'):
        raw_b = loss_layer._b().detach().to(device)
    else:
        raw = loss_layer.b.detach().to(device)
        raw_b = raw[1:-1]
    b = raw_b.cpu().numpy()         # shape (Q-1,)
    Q = b.shape[0] + 1

    centroids = []
    z_q = []
    class_feats = {}
    for q in range(Q):
        mask = (y_np == q)
        if mask.sum() > 0:
            fq = feats_np[mask]
            class_feats[q] = fq
            c_q = fq.mean(axis=0)
            centroids.append(c_q)
            z_q.append(float(c_q.dot(w.cpu().numpy())))
        else:
            centroids.append(np.zeros(feats_np.shape[1]))
            z_q.append(0.0)
    centroids = np.stack(centroids, axis=0)
    z_q = np.array(z_q)

    ONC1 = np.mean([np.linalg.norm(fq - centroids[q], axis=1).mean()
                    for q, fq in class_feats.items()])
    global_mean = feats_np.mean(axis=0)
    R = np.linalg.norm(feats_np - global_mean, axis=1).mean()
    ONC1_norm = ONC1 / (R + 1e-12)

    C = centroids - centroids.mean(axis=0)
    U, S, Vt = np.linalg.svd(C, full_matrices=False)
    u = Vt[0]
    resid = C - np.outer(C.dot(u), u)
    ONC21 = resid ** 2
    ONC21 = ONC21.sum() / (np.square(C).sum() + 1e-12)

    w_np = w.cpu().numpy()
    cos = abs(w_np.dot(u) / (np.linalg.norm(w_np) * np.linalg.norm(u) + 1e-12))
    ONC22 = 1.0 - cos

    mid = (z_q[:-1] + z_q[1:]) / 2
    ONC4 = np.abs(b - mid).mean()
    delta = np.diff(b)
    barDelta = delta.mean()
    ONC4_norm = ONC4 / (barDelta + 1e-12)

    return {
        'ONC1':   ONC1_norm,
        'ONC2-1': ONC21,
        'ONC2-2': ONC22,
        'ONC3':   ONC4_norm
    }


def compute_additional_metrics(y_true, y_pred):
    """Compute additional metrics: accuracy, within-1 accuracy, min sensitivity, QWK"""
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Accuracy (exact match)
    accuracy = np.mean(y_true == y_pred)

    # Within-1 accuracy
    within_1_acc = np.mean(np.abs(y_true - y_pred) <= 1)

    # Minimum sensitivity (per-class recall)
    classes = np.unique(y_true)
    sensitivities = []
    for cls in classes:
        mask = (y_true == cls)
        if mask.sum() > 0:
            sensitivity = np.mean(y_pred[mask] == cls)
            sensitivities.append(sensitivity)

    min_sensitivity = min(sensitivities) if sensitivities else 0.0

    # Quadratic Weighted Kappa
    qwk = cohen_kappa_score(y_true, y_pred, weights='quadratic')

    return {
        'accuracy': accuracy,
        'within_1_acc': within_1_acc,
        'min_sensitivity': min_sensitivity,
        'qwk': qwk
    }