import numpy as np
import sklearn.metrics as skm
from scipy import interpolate
import torch 
from torchmetrics.classification import BinaryAveragePrecision
import sys, os
from pathlib import Path

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
BASE_DIR = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(BASE_DIR))

from torch_uncertainty.metrics import FPR95
from torchmetrics.classification import BinaryAUROC

def get_auroc(ood: np.ndarray, id_: np.ndarray) -> float:
    # labels: 1 for OOD (positive), 0 for ID (negative)
    labels = torch.tensor([1] * len(ood) + [0] * len(id_), dtype=torch.long)
    scores = torch.tensor(np.concatenate((ood, id_)), dtype=torch.float32)

    metric = BinaryAUROC()
    return metric(scores, labels).item()

def get_fpr95(ood: np.ndarray, id_: np.ndarray) -> float:
    scores = torch.tensor(np.concatenate((ood, id_)), dtype=torch.float32)
    labels = torch.tensor([1] * len(ood) + [0] * len(id_), dtype=torch.long)

    metric = FPR95(pos_label=1)  # OOD is positive
    return metric(scores, labels).item()

def get_aupr(ood: np.ndarray, id_: np.ndarray) -> float:
    ood = torch.tensor(ood, dtype=torch.float32)
    id_ = torch.tensor(id_, dtype=torch.float32)

    labels = torch.cat([
        torch.ones_like(ood, dtype=torch.long),
        torch.zeros_like(id_, dtype=torch.long)
    ])
    scores = torch.cat([ood, id_])

    metric = BinaryAveragePrecision()
    return metric(scores, labels).item()