import os
import time
from util.args_loader import get_args
from util import metrics
import torch
import faiss
import numpy as np
import torchvision.models as models
from typing import Optional
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.covariance import EmpiricalCovariance
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA


def calculate_acc_val(score_log_val, label_log_val):
    num_correct = (score_log_val.argmax(1) == label_log_val).sum().item()
    print(f"Accuracy on validation set: {num_correct/len(label_log_val)}")
    return num_correct / len(label_log_val)


def plot_helper(
    score_in: np.ndarray,
    scores_out_test: np.ndarray,
    in_dataset: str,
    ood_dataset_name: str,
    function_name: str = "ORA",
) -> None:
    plt.hist(score_in, bins=70, alpha=0.5, label="in", density=True)
    plt.hist(scores_out_test, bins=70, alpha=0.5, label="out", density=True)
    plt.legend(loc="upper right", fontsize=9)
    if ood_dataset_name == "dtd":
        ood_dataset_name = "Texture"
    print(ood_dataset_name)
    plt.xlabel("ORA(x)", fontsize=16)
    # plt.xlabel(r'$\sin(\theta) / \sin(\alpha)$', fontsize=16)
    # plt.xlabel(r'$\sin(\theta)$', fontsize=16)
    # plt.xlabel(r'$\sin(\alpha)$', fontsize=16)
    # plt.xlabel(r'$\theta$', fontsize=16)
    plt.ylabel("Density", fontsize=16)
    plt.title(f"{in_dataset} vs {ood_dataset_name}", fontsize=16)
    plt.savefig(f"./{in_dataset}_{ood_dataset_name}_ORA.png", dpi=600)
    plt.close()


def react_filter(feats, thold):
    feats = torch.where(feats > thold, thold, feats)
    return feats


def react_thold(id_feats, percentile=90):
    id_feats = id_feats.cpu().numpy()
    tholds = np.percentile(id_feats.reshape(-1), percentile, axis=0)
    return tholds


def React_energy(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    all_scores = []
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            if thold is not None:
                feats_batch_initial = react_filter(feats_batch_initial, thold).float()
                logits_batch_initial = model.fc(feats_batch_initial)
            energy_score = torch.logsumexp(logits_batch_initial, dim=1)
            all_scores.append(energy_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def fdbd_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    all_scores = []
    class_idx = np.arange(num_classes)
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            trajectory_list = torch.zeros(
                feats_batch_initial.size(0), num_classes, device=device
            )
            for class_id in class_idx:
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                # weight_diff = model.classifier.weight[preds_initial] - model.classifier.weight[class_id]

                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)
                feats_batch_db = (
                    feats_batch_initial
                    - torch.divide(logit_diff, weight_diff_norm**2).view(-1, 1)
                    * weight_diff
                )
                # # fdbd original
                distance_to_db = torch.linalg.norm(
                    feats_batch_initial - feats_batch_db, dim=1
                )
                fdbd_score = distance_to_db / torch.linalg.norm(
                    feats_batch_initial - torch.mean(class_means, dim=0), dim=1
                )

                # # fdbd our derivation
                centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)
                norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)
                cos_sim_origin_perspective = torch.sum(
                    norm_centered_feats * norm_centered_feats_db, dim=1
                )
                angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi

                feats_centered_db = feats_batch_initial - feats_batch_db
                mean_centered_db = torch.mean(class_means, dim=0) - feats_batch_db
                cos_sim = F.cosine_similarity(
                    feats_centered_db, mean_centered_db, dim=1
                )
                angles_db = torch.arccos(cos_sim) / torch.pi
                our_derivation = torch.sin(angles_origin * torch.pi) / torch.sin(
                    angles_db * torch.pi
                )
                # our_derivation = torch.sin(angles_origin * torch.pi)
                # our_derivation = torch.sin(angles_db * torch.pi)

                # check our derivation is same as fdbd
                # fdbd_score[torch.isnan(fdbd_score)] = 0
                # our_derivation[torch.isnan(our_derivation)] = 0
                # # print(torch.allclose(fdbd_score, our_derivation))

                trajectory_list[:, class_id] = our_derivation
            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.mean(trajectory_list, dim=1)
            # ood_score = torch.max(trajectory_list, dim=1).values
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def ORA_score3(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            # if thold is not None:
            #     feats_batch_initial = react_filter(feats_batch_initial, thold).float()
            #     logits_batch_initial = model.fc(feats_batch_initial)
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(
                feats_batch_initial.size(0), num_classes, device=device
            )
            for stat in range(3):
                for class_id in class_idx:
                    logit_diff = max_logits - logits_batch_initial[:, class_id]
                    weight_diff = (
                        model.fc.weight[preds_initial] - model.fc.weight[class_id]
                    )
                    # weight_diff = model.classifier.weight[preds_initial] - model.classifier.weight[class_id]
                    weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)

                    feats_batch_db = (
                        feats_batch_initial
                        - torch.divide(logit_diff, weight_diff_norm**2).view(-1, 1)
                        * weight_diff
                    )
                    if stat == 0:
                        centered_feats = (
                            feats_batch_initial - torch.max(class_means, dim=0).values
                        )
                        centered_feats_db = (
                            feats_batch_db - torch.max(class_means, dim=0).values
                        )

                    elif stat == 1:
                        centered_feats = (
                            feats_batch_initial - torch.min(class_means, dim=0).values
                        )
                        centered_feats_db = (
                            feats_batch_db - torch.min(class_means, dim=0).values
                        )
                    elif stat == 2:
                        centered_feats = (
                            feats_batch_initial
                            - torch.median(class_means, dim=0).values
                        )
                        centered_feats_db = (
                            feats_batch_db - torch.median(class_means, dim=0).values
                        )

                    # centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                    # centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)
                    # centered_feats = feats_batch_initial - torch.median(class_means, dim=0).values
                    # centered_feats_db = feats_batch_db - torch.median(class_means, dim=0).values
                    # centered_feats = feats_batch_initial - class_means[0]
                    # centered_feats_db = feats_batch_db - class_means[0]

                    norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                    norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                    cos_sim_origin_perspective = torch.sum(
                        norm_centered_feats * norm_centered_feats_db, dim=1
                    )
                    angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi
                    trajectory_list[:, class_id] += angles_origin

            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, largest=False, dim=1).values[:, 1]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def ORA_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            # if thold is not None:
            #     feats_batch_initial = react_filter(feats_batch_initial, thold).float()
            #     logits_batch_initial = model.fc(feats_batch_initial)
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(
                feats_batch_initial.size(0), num_classes, device=device
            )
            for class_id in class_idx:
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                # weight_diff = model.classifier.weight[preds_initial] - model.classifier.weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)

                feats_batch_db = (
                    feats_batch_initial
                    - torch.divide(logit_diff, weight_diff_norm**2).view(-1, 1)
                    * weight_diff
                )

                centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)
                # centered_feats = feats_batch_initial - torch.median(class_means, dim=0).values
                # centered_feats_db = feats_batch_db - torch.median(class_means, dim=0).values
                # centered_feats = feats_batch_initial - class_means[0]
                # centered_feats_db = feats_batch_db - class_means[0]

                norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                cos_sim_origin_perspective = torch.sum(
                    norm_centered_feats * norm_centered_feats_db, dim=1
                )
                angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi
                trajectory_list[:, class_id] = angles_origin

            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, largest=False, dim=1).values[
            #     :, 1
            # ]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def ORA_score2(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            if thold is not None:
                feats_batch_initial = react_filter(feats_batch_initial, thold).float()
                logits_batch_initial = model.fc(feats_batch_initial)
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(
                feats_batch_initial.size(0), num_classes, device=device
            )
            for class_id2 in [1, 250, 500, 750, 999]:
                for class_id in class_idx:
                    logit_diff = max_logits - logits_batch_initial[:, class_id]
                    weight_diff = (
                        model.fc.weight[preds_initial] - model.fc.weight[class_id]
                    )
                    weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)

                    feats_batch_db = (
                        feats_batch_initial
                        - torch.divide(logit_diff, weight_diff_norm**2).view(-1, 1)
                        * weight_diff
                    )

                    # centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                    # centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)
                    centered_feats = feats_batch_initial - class_means[class_id2]
                    centered_feats_db = feats_batch_db - class_means[class_id2]

                    norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                    norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                    cos_sim_origin_perspective = torch.sum(
                        norm_centered_feats * norm_centered_feats_db, dim=1
                    )
                    angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi
                    trajectory_list[:, class_id] += angles_origin

            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, largest=False, dim=1).values[:, 1]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def energy_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    all_scores = []
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            energies = torch.logsumexp(logits_batch_initial, dim=1)
            all_scores.append(energies)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def msp_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    all_scores = []
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            probs = F.softmax(logits_batch_initial, dim=1)
            max_probs = torch.max(probs, dim=1).values
            all_scores.append(max_probs)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def maxlogit_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    all_scores = []
    with torch.inference_mode():
        for _, logits_batch_initial in test_loader:
            logits_batch_initial = logits_batch_initial.to(device)
            max_logits = torch.max(logits_batch_initial, dim=1).values
            all_scores.append(max_logits)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def knn_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    train_feats: torch.Tensor,
    k: Optional[torch.Tensor] = 1000,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    train_feats = F.normalize(train_feats, p=2, dim=1)
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            feats_batch_initial = F.normalize(feats_batch_initial, p=2, dim=1)
            # calculate the kth nearest distance to the train_feats
            d = torch.cdist(
                feats_batch_initial,
                train_feats,
                p=2,
                compute_mode="donot_use_mm_for_euclid_dist",
            )
            d = torch.topk(d, k, largest=False, sorted=True).values[:, -1]
            all_scores.append(-d)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def rcos_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    tholds: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(
                feats_batch_initial.size(0), num_classes, device=device
            )
            for class_id in class_idx:
                class_mean_id = class_means[class_id]
                # compute the cosine similarity between the feature and the class mean
                cos_sim = F.cosine_similarity(feats_batch_initial, class_mean_id, dim=1)
                trajectory_list[:, class_id] = cos_sim
            # apply softmax on dim1
            trajectory_list = F.softmax(trajectory_list, dim=1)
            ood_score = torch.max(trajectory_list, dim=1).values
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def plaincos_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    tholds: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(
                feats_batch_initial.size(0), num_classes, device=device
            )
            for class_id in class_idx:
                class_mean_id = class_means[class_id]
                # compute the cosine similarity between the feature and the class mean
                cos_sim = F.cosine_similarity(feats_batch_initial, class_mean_id, dim=1)
                trajectory_list[:, class_id] = cos_sim
            # apply softmax on dim1
            ood_score = torch.max(trajectory_list, dim=1).values
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def rel_mahalonobis_distance_score(
    model: torch.nn.Module,
    id_feats: torch.Tensor,
    in_labels: torch.Tensor,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    tholds: Optional[torch.Tensor] = None,
) -> np.ndarray:

    # -------------------------------------------------------------------------
    # Move model to device and set eval mode
    # -------------------------------------------------------------------------
    model = model.to(device)
    model.eval()

    # -------------------------------------------------------------------------
    # 1. Compute class-wise precision matrix (prec) based on ID features
    #    - For each sample in id_feats, subtract the corresponding class mean
    #    - Fit EmpiricalCovariance on the centered features
    # -------------------------------------------------------------------------
    # Collect all centered features for class-wise covariance
    # (in_labels and class_means are assumed to be consistent)
    train_feat_centered = []
    with torch.no_grad():
        # If these are already on CPU, you can skip .cpu()
        # but we do so to ensure numpy usage below.
        id_feats_cpu = id_feats.cpu().numpy()
        in_labels_cpu = in_labels.cpu().numpy()
        class_means_cpu = class_means.cpu().numpy()

        for c in range(num_classes):
            # Extract features belonging to class c
            class_mask = in_labels_cpu == c
            if not np.any(class_mask):
                # If for some reason this class doesn't exist, skip
                continue
            feats_c = id_feats_cpu[class_mask]
            mean_c = class_means_cpu[c]
            train_feat_centered.append(feats_c - mean_c)

    # Concatenate all centered features
    if len(train_feat_centered) == 0:
        raise ValueError("No features found for computing class-wise covariance.")

    train_feat_centered = np.concatenate(train_feat_centered, axis=0)

    # Fit Empirical Covariance to get class-wise precision
    ec_classwise = EmpiricalCovariance(assume_centered=True)
    ec_classwise.fit(train_feat_centered.astype(np.float64))
    prec_classwise = ec_classwise.precision_  # shape (D, D)

    # Convert to torch on device
    prec_classwise_t = torch.from_numpy(prec_classwise).to(
        device=device, dtype=torch.double
    )

    # -------------------------------------------------------------------------
    # 2. Compute global precision matrix (prec_global)
    #    - Single global mean from all id_feats
    #    - Fit EmpiricalCovariance on id_feats - global_mean
    # -------------------------------------------------------------------------
    global_mean = id_feats_cpu.mean(axis=0)
    train_feat_centered_global = id_feats_cpu - global_mean

    ec_global = EmpiricalCovariance(assume_centered=True)
    ec_global.fit(train_feat_centered_global.astype(np.float64))
    prec_global = ec_global.precision_

    # Convert to torch on device
    global_mean_t = torch.from_numpy(global_mean).to(device=device, dtype=torch.double)
    prec_global_t = torch.from_numpy(prec_global).to(device=device, dtype=torch.double)

    # -------------------------------------------------------------------------
    # 3. For each batch in test_loader, compute the relative Mahalanobis score
    #    If test_loader already provides features, we skip model(...)
    #    If test_loader provides inputs, you'd do feats = model(inputs).
    #
    #    Score steps:
    #      - classwise_score(x) = - min_c [ (x - mean_c)^T @ prec_classwise @ (x - mean_c) ]
    #      - global_score(x)    = - [ (x - global_mean)^T @ prec_global @ (x - global_mean) ]
    #      - final_score(x)     = classwise_score(x) - global_score(x)
    #
    # -------------------------------------------------------------------------
    all_scores = []

    with torch.no_grad():
        for feats_batch, _ in test_loader:
            # If you need to compute features from model, uncomment:
            # feats_batch = model(feats_batch.to(device))
            # Otherwise, assume feats_batch is already the features:
            feats_batch = feats_batch.to(device, dtype=torch.double)

            # Expand feats to (B, 1, D) and class_means to (1, C, D) for vectorized distance
            B = feats_batch.size(0)
            feats_expanded = feats_batch.unsqueeze(1)  # (B, 1, D)
            means_expanded = class_means.to(device, dtype=torch.double).unsqueeze(
                0
            )  # (1, C, D)
            diff_classwise = feats_expanded - means_expanded  # (B, C, D)

            # (x - mean_c) @ prec_classwise -> shape (B, C, D)
            # Then elementwise * diff_classwise and sum over D => (B, C)
            # We'll do a manual matmul: (diff_classwise @ prec_classwise) => (B, C, D)
            temp = torch.matmul(diff_classwise, prec_classwise_t)  # (B, C, D)
            # Then multiply elementwise by diff_classwise and sum along D
            # This is the Mahalanobis distance for each sample to each class
            mahalanobis_classwise = (temp * diff_classwise).sum(dim=-1)  # (B, C)

            # We take the minimum across classes => shape (B,)
            min_maha_classwise = mahalanobis_classwise.min(dim=1).values  # (B,)

            # Negative sign to keep consistent with "score = - Mdist(...)"
            classwise_score = -min_maha_classwise

            # Now compute global Mahalanobis
            # (x - mean_global) -> (B, D)
            diff_global = feats_batch - global_mean_t
            # matmul => (B, D)
            temp_global = torch.matmul(diff_global, prec_global_t)  # (B, D)
            mahalanobis_global = (temp_global * diff_global).sum(dim=-1)  # (B,)
            global_score = -mahalanobis_global

            # Relative score = classwise_score - global_score
            rel_scores = classwise_score - global_score

            all_scores.append(rel_scores.float().cpu())

    # Concatenate all scores and return as np.float32
    final_scores = torch.cat(all_scores).numpy().astype(np.float32)
    return final_scores


def rel_mahalonobis_distance_score_2(
    model: torch.nn.Module,
    id_feats: torch.Tensor,
    in_labels: torch.Tensor,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    tholds: Optional[torch.Tensor] = None,
) -> np.ndarray:
    # 1) Move model to device, set eval mode.
    model = model.to(device)
    model.eval()

    # 2) Ensure class_means and id_feats are on the same device.
    id_feats = id_feats.to(device, dtype=torch.float32)
    in_labels = in_labels.to(device)
    class_means = class_means.to(device, dtype=torch.float32)

    # -------------------------------------------------------------------------
    # 3) Build classwise-centered features on GPU
    #    We gather (z_i - mu_{y_i}) for each sample, for a single shared cov
    # -------------------------------------------------------------------------
    with torch.no_grad():
        # This is a list of GPU tensors, each from one class.
        centered_feats_gpu = []
        for c in range(num_classes):
            class_mask = in_labels == c
            feats_c = id_feats[class_mask]  # shape (Nc, D)
            mean_c = class_means[c]  # shape (D,)
            centered_feats_gpu.append(feats_c - mean_c)

        all_centered_feats_gpu = torch.cat(centered_feats_gpu, dim=0)

    # -------------------------------------------------------------------------
    # 4) Compute EmpiricalCovariance (CPU / scikit-learn)
    #    So we do one big transfer to CPU
    # -------------------------------------------------------------------------
    centered_feats_cpu = all_centered_feats_gpu.cpu().numpy()  # shape (N, D)
    ec_classwise = EmpiricalCovariance(assume_centered=True)
    ec_classwise.fit(centered_feats_cpu.astype(np.float64))

    prec_classwise = ec_classwise.precision_  # (D, D)
    prec_classwise_t = torch.from_numpy(prec_classwise).to(device, dtype=torch.double)

    # -------------------------------------------------------------------------
    # 5) Compute global mean and global covariance
    # -------------------------------------------------------------------------
    with torch.no_grad():
        global_mean_gpu = id_feats.mean(dim=0)  # (D,) on GPU

        # Center globally on GPU
        all_centered_global_gpu = id_feats - global_mean_gpu  # (N, D)

    # Move global-centered feats to CPU for EmpiricalCovariance
    centered_feats_global_cpu = all_centered_global_gpu.cpu().numpy()
    ec_global = EmpiricalCovariance(assume_centered=True)
    ec_global.fit(centered_feats_global_cpu.astype(np.float64))
    prec_global = ec_global.precision_

    # Move global mean & prec back to GPU
    global_mean_t = global_mean_gpu.to(device, dtype=torch.double)
    prec_global_t = torch.from_numpy(prec_global).to(device, dtype=torch.double)

    # -------------------------------------------------------------------------
    # 6) For each batch in test_loader, compute relative Mahalanobis distance
    # -------------------------------------------------------------------------
    all_scores = []
    with torch.no_grad():
        for feats_batch, _ in test_loader:
            # If test_loader returns inputs, you would do:
            # feats_batch = model(feats_batch.to(device, dtype=torch.float32))

            feats_batch = feats_batch.to(device, dtype=torch.double)

            # Expand feats to (B, 1, D) and class_means to (1, C, D)
            B = feats_batch.size(0)
            feats_expanded = feats_batch.unsqueeze(1)  # (B, 1, D)
            means_expanded = class_means.to(device, dtype=torch.double).unsqueeze(0)
            # means_expanded => (1, C, D)
            diff_classwise = feats_expanded - means_expanded  # (B, C, D)

            # 6a) Classwise Mahalanobis: (x - mu_c)^T @ prec_classwise @ (x - mu_c)
            temp = torch.matmul(diff_classwise, prec_classwise_t)  # (B, C, D)
            mahalanobis_classwise = (temp * diff_classwise).sum(dim=-1)  # (B, C)

            min_maha_classwise = mahalanobis_classwise.min(dim=1).values  # (B,)
            classwise_score = -min_maha_classwise

            # 6b) Global Mahalanobis: (x - mu_global)^T @ prec_global @ (x - mu_global)
            diff_global = feats_batch - global_mean_t  # (B, D)
            temp_global = torch.matmul(diff_global, prec_global_t)  # (B, D)
            mahalanobis_global = (temp_global * diff_global).sum(dim=-1)  # (B,)
            global_score = -mahalanobis_global

            # 6c) Relative score = classwise_score - global_score
            rel_scores = classwise_score - global_score
            all_scores.append(rel_scores.float().cpu())

    # Concatenate and return
    final_scores = torch.cat(all_scores).numpy().astype(np.float32)
    return final_scores


def mds_score(
    model: torch.nn.Module,
    id_feats: torch.Tensor,
    in_labels: torch.Tensor,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    tholds: Optional[torch.Tensor] = None,
) -> np.ndarray:
    # -------------------------------------------------------------------------
    # 1) Move model to device and set it to eval
    #    (Only needed if you plan to pass raw inputs through model)
    # -------------------------------------------------------------------------
    model = model.to(device)
    model.eval()

    # -------------------------------------------------------------------------
    # 2) Compute the shared covariance matrix from classwise-centered features
    #    Because scikit-learn EmpiricalCovariance requires CPU+NumPy, we do:
    #       - Center each ID feature by its class mean on GPU
    #       - Concatenate
    #       - Move once to CPU, fit EmpiricalCovariance
    # -------------------------------------------------------------------------
    id_feats = id_feats.to(device, dtype=torch.float32)  # (N, D)
    in_labels = in_labels.to(device)
    class_means = class_means.to(device, dtype=torch.float32)

    with torch.no_grad():
        centered_gpu_list = []
        for c in range(num_classes):
            class_mask = in_labels == c
            if not torch.any(class_mask):
                # If a class doesn't appear in your ID set, skip it
                continue
            feats_c = id_feats[class_mask]  # (Nc, D)
            mean_c = class_means[c]  # (D,)
            centered_gpu_list.append(feats_c - mean_c)  # center on GPU

        if len(centered_gpu_list) == 0:
            raise ValueError("No features found to compute classwise covariance.")

        # Concatenate all classwise-centered features => (N, D)
        all_centered_feats = torch.cat(centered_gpu_list, dim=0)

    # Move to CPU for EmpiricalCovariance
    all_centered_feats_cpu = all_centered_feats.cpu().numpy()

    # Fit EmpiricalCovariance
    ec = EmpiricalCovariance(assume_centered=True)
    ec.fit(all_centered_feats_cpu.astype(np.float64))
    prec = ec.precision_  # (D, D)

    # Convert precision matrix to torch on device
    prec_t = torch.from_numpy(prec).to(device=device, dtype=torch.double)

    # -------------------------------------------------------------------------
    # 3) For each test sample, compute min_{c} Mahalanobis distance, then *-1
    #    (We treat the negative distance as a "score"—the more negative,
    #     the further from ID.)
    # -------------------------------------------------------------------------
    all_scores = []

    with torch.no_grad():
        for feats_batch, _ in test_loader:
            # If test_loader gives raw inputs, do: feats_batch = model(feats_batch.to(device))
            feats_batch = feats_batch.to(device, dtype=torch.double)

            # Expand feats => (B, 1, D), class_means => (1, C, D)
            B = feats_batch.size(0)
            feats_expanded = feats_batch.unsqueeze(1)  # (B, 1, D)
            means_expanded = class_means.to(device, dtype=torch.double).unsqueeze(
                0
            )  # (1, C, D)
            diff_classwise = feats_expanded - means_expanded  # (B, C, D)

            # Compute Mdist => (diff @ prec) * diff, sum over D => (B, C)
            temp = torch.matmul(diff_classwise, prec_t)  # (B, C, D)
            maha_dists = (temp * diff_classwise).sum(dim=-1)  # (B, C)

            # For each sample, take the minimum distance across classes
            min_maha = maha_dists.min(dim=1).values  # (B,)

            # Our Mahalanobis score is -min_maha
            maha_scores = -min_maha
            all_scores.append(maha_scores.float().cpu())

    final_scores = torch.cat(all_scores).numpy().astype(np.float32)
    return final_scores


def nnguide_score(
    model: torch.nn.Module,
    train_loader: torch.utils.data.DataLoader,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    train_feats: torch.Tensor,
    k: Optional[torch.Tensor] = 10,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    energy_scores = torch.from_numpy(
        energy_score(model, train_loader, num_classes, class_means, device)
    ).to(device)
    all_scores = []
    train_feats = F.normalize(train_feats, p=2, dim=1)
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)  # B x 2048
            logits_batch_initial = logits_batch_initial.to(device)
            energy_score_batch = torch.logsumexp(logits_batch_initial, dim=1)
            feats_batch_initial = F.normalize(
                feats_batch_initial, p=2, dim=1
            )  # B x 2048
            # calculate the kth nearest distance to the train_feats
            d = torch.matmul(
                feats_batch_initial,  # B x 2048
                (train_feats * energy_scores.view(-1, 1)).T,  # N x 2048
            )  # B x N
            d = torch.topk(d, k, largest=True, sorted=True).values.mean(dim=1)
            all_scores.append(d * energy_score_batch)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def neco_score(
    model: torch.nn.Module,
    id_feats: torch.Tensor,
    in_labels: torch.Tensor,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    tholds: Optional[torch.Tensor] = None,
    model_architecture_type: str = "resnet",
    neco_dim: int = 10,
) -> np.ndarray:
    # -------------------------------------------------------------------------
    # 1) Move the model to device (if you need to pass raw inputs -> features)
    # -------------------------------------------------------------------------
    model = model.to(device)
    model.eval()

    # -------------------------------------------------------------------------
    # 2) Convert the ID feats to CPU+NumPy for StandardScaler and PCA
    # -------------------------------------------------------------------------
    id_feats_np = id_feats.cpu().numpy()  # shape (N, D)

    # 2a) Fit StandardScaler on training features
    ss = StandardScaler()
    id_feats_scaled = ss.fit_transform(id_feats_np)  # (N, D)

    # 2b) Fit PCA on scaled training feats
    pca_estimator = PCA(n_components=256)
    # take a random subset of 50000 samples from id_feats_scaled
    random_inds = np.random.choice(id_feats_scaled.shape[0], size=100000, replace=False)
    subset_id_feats_scaled = id_feats_scaled[random_inds]
    pca_estimator.fit(
        subset_id_feats_scaled
    )  # learns the full PCA, though we only keep a part

    # -------------------------------------------------------------------------
    # 3) Process each batch from test_loader
    # -------------------------------------------------------------------------
    all_scores = []

    with torch.no_grad():
        for feats_batch, logits_batch in test_loader:
            # If your test_loader returns raw inputs, you must do:
            # feats_batch = model(feats_batch.to(device))
            # For now, assume they are already features:
            feats_batch = feats_batch.to(device)  # (B, D)
            logits_batch = logits_batch.to(device)  # (B, num_classes)

            # (A) Convert feats to CPU => scale => PCA => keep first neco_dim
            feats_np = feats_batch.cpu().numpy()  # shape (B, D)
            feats_scaled = ss.transform(feats_np)  # shape (B, D)
            feats_pca = pca_estimator.transform(feats_scaled)  # shape (B, D)
            feats_pca_dim = feats_pca[:, :neco_dim]  # shape (B, neco_dim)

            # (B) Compute the ratio sc_final = ||feats_pca_dim|| / ||feats_np||
            #     We'll do per sample in a vectorized way
            #     np.linalg.norm(..., axis=1) => shape (B,)
            norm_pca = np.linalg.norm(feats_pca_dim, axis=1)  # shape (B,)
            norm_orig = np.linalg.norm(feats_np, axis=1)  # shape (B,)
            ratio = norm_pca / (norm_orig + 1e-12)  # avoid division by zero

            # (C) If model_architecture_type != 'resnet', multiply by max logit
            if model_architecture_type != "resnet":
                # max logit => shape (B,)
                max_logit = logits_batch.max(dim=1).values
                # multiply ratio by that
                ratio = ratio * max_logit.cpu().numpy()

            all_scores.append(ratio)

    # -------------------------------------------------------------------------
    # 4) Return NECO scores as float32 array
    # -------------------------------------------------------------------------
    all_scores_np = np.concatenate(all_scores, axis=0)
    return all_scores_np.astype(np.float32)


args = get_args()

seed = args.seed
print(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

class_num = 1000
id_train_size = 1281167
id_val_size = 50000

cache_dir = f"cache/{args.in_dataset}_train_{args.name}_in"
feat_log = torch.from_numpy(
    np.memmap(
        f"{cache_dir}/feat.mmap", dtype=float, mode="r", shape=(id_train_size, 768)
    )
).to(device)
score_log = torch.from_numpy(
    np.memmap(
        f"{cache_dir}/score.mmap",
        dtype=float,
        mode="r",
        shape=(id_train_size, class_num),
    )
).to(device)
label_log = torch.from_numpy(
    np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode="r", shape=(id_train_size,))
).to(device)


cache_dir = f"cache/{args.in_dataset}_val_{args.name}_in"
feat_log_val = torch.from_numpy(
    np.memmap(f"{cache_dir}/feat.mmap", dtype=float, mode="r", shape=(id_val_size, 768))
).to(device)
score_log_val = torch.from_numpy(
    np.memmap(
        f"{cache_dir}/score.mmap", dtype=float, mode="r", shape=(id_val_size, class_num)
    )
).to(device)
label_log_val = torch.from_numpy(
    np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode="r", shape=(id_val_size,))
).to(device)

print("ID ACC:", calculate_acc_val(score_log_val, label_log_val))

ood_feat_score_log = {}
ood_dataset_size = {
    "inat": 10000,
    "sun50": 10000,
    "places50": 10000,
    "dtd": 5640,
    "ssbhard": 49000,
    "ninco": 5878,
}

for ood_dataset in args.out_datasets:
    ood_feat_log = torch.from_numpy(
        np.memmap(
            f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/feat.mmap",
            dtype=float,
            mode="r",
            shape=(ood_dataset_size[ood_dataset], 768),
        )
    ).to(device)
    ood_score_log = torch.from_numpy(
        np.memmap(
            f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/score.mmap",
            dtype=float,
            mode="r",
            shape=(ood_dataset_size[ood_dataset], class_num),
        )
    ).to(device)
    ood_feat_score_log[ood_dataset] = ood_feat_log, ood_score_log


######## get w, b; precompute demoninator matrix, training feature mean  #################

if args.name == "resnet50":
    net = models.resnet50(pretrained=True)
    for i, param in enumerate(net.fc.parameters()):
        if i == 0:
            w = param.data
        else:
            b = param.data

elif args.name == "resnet50-supcon":
    checkpoint = torch.load("ckpt/ImageNet_resnet50_supcon_linear.pth")
    w = checkpoint["model"]["fc.weight"]
    b = checkpoint["model"]["fc.bias"]

train_mean = torch.mean(feat_log, dim=0).to(device)

denominator_matrix = torch.zeros((1000, 1000)).to(device)
# for p in range(1000):
#   w_p = w - w[p,:]
#   denominator = torch.norm(w_p, dim=1)
#   denominator[p] = 1
#   denominator_matrix[p, :] = denominator

#################### fDBD score OOD detection #################

all_results = []
all_score_out = []

values, nn_idx = score_log_val.max(1)
logits_sub = torch.abs(score_log_val - values.repeat(1000, 1).T)
# pdb.set_trace()
# score_in = torch.sum(logits_sub/denominator_matrix[nn_idx], axis=1)/torch.norm(feat_log_val - train_mean , dim = 1)
# score_in = score_in.float().cpu().numpy()

from util.model_loader import get_model

model = get_model(args, 1000, load_ckpt=True)
model.fc = model.head
in_dataset = torch.utils.data.TensorDataset(feat_log_val.cpu(), score_log_val.cpu())
in_loader = torch.utils.data.DataLoader(
    in_dataset, batch_size=128, shuffle=False, num_workers=2
)

train_dataset = torch.utils.data.TensorDataset(feat_log.cpu(), score_log.cpu())
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=False, num_workers=2
)

class_means = []
for i in range(1000):
    class_means.append(feat_log_val[label_log_val == i].mean(0))
class_means = torch.stack(class_means).to(device)

# tholds = react_thold(feat_log, percentile=85)
tholds = None
# tholds = torch.from_numpy(tholds).to(device)
# score_in = ORA_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = rcos_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = plaincos_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = rel_mahalonobis_distance_score_2(
#     model, feat_log, label_log, in_loader, 1000, class_means, device, tholds
# )

# score_in = mds_score(
#     model, feat_log, label_log, in_loader, 1000, class_means, device, tholds
# )
# score_in = energy_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = msp_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = maxlogit_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = fdbd_score(model, in_loader, 1000, class_means, device)
# score_in = React_energy(model, in_loader, 1000, class_means, device, tholds)
# breakpoint()
train_feats = feat_log[:]
train_labels = label_log[:]
score_in = knn_score(model, in_loader, 1000, class_means, device, train_feats)
# score_in = nnguide_score(
#     model, train_loader, in_loader, 1000, class_means, device, train_feats
# )
# score_in = neco_score(
#     model, train_feats, train_labels, in_loader, 1000, class_means, device
# )


for ood_dataset, (feat_log, score_log) in ood_feat_score_log.items():
    print(ood_dataset)
    values, nn_idx = score_log.max(1)
    logits_sub = torch.abs(score_log - values.repeat(1000, 1).T)
    # scores_out_test = torch.sum(logits_sub/denominator_matrix[nn_idx], axis=1)/torch.norm(feat_log - train_mean , dim = 1)
    # scores_out_test = scores_out_test.float().cpu().numpy()
    out_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(feat_log.cpu(), score_log.cpu()),
        batch_size=128,
        shuffle=False,
        num_workers=2,
    )

    # scores_out_test = ORA_score(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = rcos_score(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = plaincos_score(
    #     model, out_loader, 1000, class_means, device, tholds
    # )
    # scores_out_test = rel_mahalonobis_distance_score_2(
    #     model, train_feats, train_labels, out_loader, 1000, class_means, device, tholds
    # )

    # scores_out_test = mds_score(
    #     model, train_feats, train_labels, out_loader, 1000, class_means, device, tholds
    # )
    # scores_out_test = energy_score(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = msp_score(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = maxlogit_score(
    #     model, out_loader, 1000, class_means, device, tholds
    # )
    # scores_out_test = fdbd_score(model, out_loader, 1000, class_means, device)
    # scores_out_test = React_energy(model, out_loader, 1000, class_means, device, tholds)
    scores_out_test = knn_score(
        model, out_loader, 1000, class_means, device, train_feats
    )
    # scores_out_test = nnguide_score(
    #     model, train_loader, out_loader, 1000, class_means, device, train_feats
    # )
    # scores_out_test = neco_score(
    #     model, train_feats, train_labels, out_loader, 1000, class_means, device
    # )

    # plot histograms
    # import matplotlib.pyplot as plt
    # plot_helper(score_in, scores_out_test, args.in_dataset, ood_dataset, function_name="ORA")

    # scores_out_test = torch.sum(logits_sub/denominator_matrix[nn_idx], axis=1)/torch.norm(feat_log - train_mean , dim = 1)
    # scores_out_test = scores_out_test.float().cpu().numpy()
    all_score_out.extend(scores_out_test)
    results = metrics.cal_metric(score_in, scores_out_test)
    all_results.append(results)

metrics.print_all_results(all_results, args.out_datasets, "ORA")
print()
