import numpy as np
import pandas as pd
from typing import Optional
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve

def load_adjacency_matrix(path: str) -> np.ndarray:

    if path.lower().endswith(".npy"):
        A = np.load(path)
    else:
        df = pd.read_csv(path, index_col=0, header=0)
        A = df.select_dtypes(include=[np.number]).to_numpy()
    A = (A > 0).astype(int)
    return A

def _offdiag_flatten(y_true_mat: np.ndarray,
                     scores_mat: np.ndarray,
                     mask_diagonal: bool = True):

    if y_true_mat.shape != scores_mat.shape:
        raise ValueError(f"Shape mismatch: y_true {y_true_mat.shape} vs scores {scores_mat.shape}")

    y_true = y_true_mat.copy().astype(int)
    scores = scores_mat.copy().astype(np.float64)

    if mask_diagonal:
        np.fill_diagonal(y_true, 0)
        np.fill_diagonal(scores, np.nan)

    mask = ~np.isnan(scores)
    y = y_true[mask].ravel().astype(int)
    s = scores[mask].ravel().astype(np.float64)
    return y, s

def auroc_against_ground_truth(S: np.ndarray, A_gt: np.ndarray, mask_diagonal: bool = True) -> float:

    y, s = _offdiag_flatten(A_gt, S, mask_diagonal=mask_diagonal)
    if len(np.unique(y)) < 2:
        return np.nan
    return float(roc_auc_score(y, s))

def pr_auc_against_ground_truth(S: np.ndarray, A_gt: np.ndarray, mask_diagonal: bool = True) -> float:

    y, s = _offdiag_flatten(A_gt, S, mask_diagonal=mask_diagonal)
    if len(np.unique(y)) < 2:
        return np.nan
    return float(average_precision_score(y, s))

def best_f1_precision_recall_threshold(S: np.ndarray, A_gt: np.ndarray, mask_diagonal: bool = True):

    y, s = _offdiag_flatten(A_gt, S, mask_diagonal=mask_diagonal)
    if len(np.unique(y)) < 2:
        return {
            "best_f1": np.nan,
            "precision": np.nan,
            "recall": np.nan,
            "threshold": np.nan
        }

    precision, recall, thresholds = precision_recall_curve(y, s)
    p = precision[:-1]
    r = recall[:-1]
    t = thresholds

    denom = (p + r)
    f1 = np.where(denom > 0, 2 * p * r / denom, 0.0)

    if f1.size == 0:
        return {
            "best_f1": np.nan,
            "precision": np.nan,
            "recall": np.nan,
            "threshold": np.nan
        }

    idx = int(np.nanargmax(f1))
    return {
        "best_f1": float(f1[idx]),
        "precision": float(p[idx]),
        "recall": float(r[idx]),
        "threshold": float(t[idx]),
    }

def shd_from_threshold(S: np.ndarray, A_gt: np.ndarray, threshold: float, mask_diagonal: bool = True) -> float:

    A = A_gt.copy().astype(int)
    scores = S.copy().astype(np.float64)

    if mask_diagonal:
        np.fill_diagonal(A, 0)
        np.fill_diagonal(scores, np.nan)

    pred = (scores >= threshold).astype(int)
    mask = ~np.isnan(scores)
    diff = (pred[mask] != A[mask]).astype(int)
    return float(diff.sum())

def best_shd_threshold(
    S: np.ndarray,
    A_gt: np.ndarray,
    mask_diagonal: bool = True,
    thresholds: Optional[np.ndarray] = None
):
    A = A_gt.copy().astype(int)
    scores = S.copy().astype(np.float64)

    if mask_diagonal:
        np.fill_diagonal(A, 0)
        np.fill_diagonal(scores, np.nan)

    mask = ~np.isnan(scores)
    y = A[mask].ravel().astype(int)
    s = scores[mask].ravel().astype(np.float64)

    if thresholds is None:
        uniq = np.unique(s)
        thresholds = np.concatenate(([s.max() + 1.0], uniq, [s.min() - 1.0])).astype(np.float64)

    best_shd = None
    best_thr = None
    for thr in thresholds:
        pred = (s >= thr).astype(int)
        shd = int((pred != y).sum())
        if best_shd is None or shd < best_shd:
            best_shd = shd
            best_thr = float(thr)

    return {"best_shd": float(best_shd), "best_threshold": float(best_thr)}

def normalized_shd(shd: float, n: int) -> float:

    denom = n * n - n
    return float(shd) / float(denom) if denom > 0 else np.nan

def evaluation_causal(gt_path, S_mat, col_order):
    A_gt = load_adjacency_matrix(gt_path)
    print("gt_path:", gt_path)

    n = len(col_order)
    if A_gt.shape != (n, n):
        raise ValueError(
            f"Ground-truth adjacency shape {A_gt.shape} does not match number of variables {n}."
        )

    print("ground_truth:\n", A_gt)

    auroc = auroc_against_ground_truth(S_mat, A_gt, mask_diagonal=True)
    auc_pr = pr_auc_against_ground_truth(S_mat, A_gt, mask_diagonal=True)

    best = best_f1_precision_recall_threshold(S_mat, A_gt, mask_diagonal=True)
    best_f1 = best["best_f1"]
    best_p = best["precision"]
    best_r = best["recall"]
    best_thr_f1 = best["threshold"]

    best_shd = best_shd_threshold(S_mat, A_gt, mask_diagonal=True)
    best_shd_val = best_shd["best_shd"]
    best_thr_shd = best_shd["best_threshold"]
    nshd = normalized_shd(best_shd_val, n)

    metrics = {
        "AUROC": auroc,
        "AUC_PR": auc_pr,
        "BestF1": best_f1,
        "Precision": best_p,
        "Recall": best_r,
        "Best_F1_Threshold": best_thr_f1,
        "BestSHD": best_shd_val,
        "Best_SHD_Threshold": best_thr_shd,
        "nSHD": nshd,
    }
    return metrics
