import os
import time
from util.args_loader import get_args
from util import metrics
import torch
import faiss
import numpy as np
import torchvision.models as models
from typing import Optional
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.covariance import EmpiricalCovariance
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA


def calculate_acc_val(score_log_val, label_log_val):
    num_correct = (score_log_val.argmax(1) == label_log_val).sum().item()
    print(f"Accuracy on validation set: {num_correct/len(label_log_val)}")
    return num_correct / len(label_log_val)


def plot_helper(
    score_in: np.ndarray,
    scores_out_test: np.ndarray,
    in_dataset: str,
    ood_dataset_name: str,
    function_name: str = "ORA",
) -> None:
    plt.hist(score_in, bins=70, alpha=0.5, label="in", density=True)
    plt.hist(scores_out_test, bins=70, alpha=0.5, label="out", density=True)
    plt.legend(loc="upper right", fontsize=9)
    if ood_dataset_name == "dtd":
        ood_dataset_name = "Texture"
    print(ood_dataset_name)
    plt.xlabel("ORA(x)", fontsize=16)
    # plt.xlabel(r'$\sin(\theta) / \sin(\alpha)$', fontsize=16)
    # plt.xlabel(r'$\sin(\theta)$', fontsize=16)
    # plt.xlabel(r'$\sin(\alpha)$', fontsize=16)
    # plt.xlabel(r'$\theta$', fontsize=16)
    plt.ylabel("Density", fontsize=16)
    plt.title(f"{in_dataset} vs {ood_dataset_name}", fontsize=16)
    plt.savefig(f"./{in_dataset}_{ood_dataset_name}_ORA.png", dpi=600)
    plt.close()


def react_filter(feats, thold):
    feats = torch.where(feats > thold, thold, feats)
    return feats


def react_thold(id_feats, percentile=90):
    id_feats = id_feats.cpu().numpy()
    tholds = np.percentile(id_feats.reshape(-1), percentile, axis=0)
    return tholds


def React_energy(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    all_scores = []
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            if thold is not None:
                feats_batch_initial = react_filter(feats_batch_initial, thold).float()
                logits_batch_initial = model.fc(feats_batch_initial)
            energy_score = torch.logsumexp(logits_batch_initial, dim=1)
            all_scores.append(energy_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def fdbd_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    all_scores = []
    class_idx = np.arange(num_classes)
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            trajectory_list = torch.zeros(
                feats_batch_initial.size(0), num_classes, device=device
            )
            for class_id in class_idx:
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                # weight_diff = model.classifier.weight[preds_initial] - model.classifier.weight[class_id]

                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)
                feats_batch_db = (
                    feats_batch_initial
                    - torch.divide(logit_diff, weight_diff_norm**2).view(-1, 1)
                    * weight_diff
                )
                # # fdbd original
                distance_to_db = torch.linalg.norm(
                    feats_batch_initial - feats_batch_db, dim=1
                )
                fdbd_score = distance_to_db / torch.linalg.norm(
                    feats_batch_initial - torch.mean(class_means, dim=0), dim=1
                )

                # # fdbd our derivation
                centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)
                norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)
                cos_sim_origin_perspective = torch.sum(
                    norm_centered_feats * norm_centered_feats_db, dim=1
                )
                angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi

                feats_centered_db = feats_batch_initial - feats_batch_db
                mean_centered_db = torch.mean(class_means, dim=0) - feats_batch_db
                cos_sim = F.cosine_similarity(
                    feats_centered_db, mean_centered_db, dim=1
                )
                angles_db = torch.arccos(cos_sim) / torch.pi
                our_derivation = torch.sin(angles_origin * torch.pi) / torch.sin(
                    angles_db * torch.pi
                )
                # our_derivation = torch.sin(angles_origin * torch.pi)
                # our_derivation = torch.sin(angles_db * torch.pi)

                # check our derivation is same as fdbd
                # fdbd_score[torch.isnan(fdbd_score)] = 0
                # our_derivation[torch.isnan(our_derivation)] = 0
                # # print(torch.allclose(fdbd_score, our_derivation))

                trajectory_list[:, class_id] = our_derivation
            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.mean(trajectory_list, dim=1)
            # ood_score = torch.max(trajectory_list, dim=1).values
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def ORA_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            # if thold is not None:
            #     feats_batch_initial = react_filter(feats_batch_initial, thold).float()
            #     logits_batch_initial = model.fc(feats_batch_initial)
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(
                feats_batch_initial.size(0), num_classes, device=device
            )
            for class_id in class_idx:
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                # weight_diff = model.classifier.weight[preds_initial] - model.classifier.weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)

                feats_batch_db = (
                    feats_batch_initial
                    - torch.divide(logit_diff, weight_diff_norm**2).view(-1, 1)
                    * weight_diff
                )

                centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)
                # centered_feats = feats_batch_initial - torch.median(class_means, dim=0).values
                # centered_feats_db = feats_batch_db - torch.median(class_means, dim=0).values
                # centered_feats = feats_batch_initial - class_means[0]
                # centered_feats_db = feats_batch_db - class_means[0]

                norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                cos_sim_origin_perspective = torch.sum(
                    norm_centered_feats * norm_centered_feats_db, dim=1
                )
                angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi
                trajectory_list[:, class_id] = angles_origin

            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, largest=False, dim=1).values[
            #     :, 1
            # ]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def ORA_class_idx_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    cid: int,
    dataset_name: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            # if thold is not None:
            #     feats_batch_initial = react_filter(feats_batch_initial, thold).float()
            #     logits_batch_initial = model.fc(feats_batch_initial)
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(
                feats_batch_initial.size(0), num_classes, device=device
            )
            for class_id in class_idx:
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                # weight_diff = model.classifier.weight[preds_initial] - model.classifier.weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)

                feats_batch_db = (
                    feats_batch_initial
                    - torch.divide(logit_diff, weight_diff_norm**2).view(-1, 1)
                    * weight_diff
                )
                centered_feats = feats_batch_initial - class_means[cid]
                centered_feats_db = feats_batch_db - class_means[cid]

                norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                cos_sim_origin_perspective = torch.sum(
                    norm_centered_feats * norm_centered_feats_db, dim=1
                )
                angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi
                trajectory_list[:, class_id] = angles_origin

            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, largest=False, dim=1).values[
            #     :, 1
            # ]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    # save
    np.save(f"./swinpre/class_scores/{dataset_name}_class_{cid}.npy", scores)
    return scores


def ORA_classwise_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    dataset_name: str,
    score_log: torch.Tensor,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    preds = score_log.argmax(1)
    all_scores = torch.zeros(len(test_loader.dataset), device=device)
    total_size = 0

    class_idx = np.arange(num_classes)
    all_class_scores = []
    for class_id in class_idx:
        class_scores = torch.from_numpy(
            np.load(
                f"./swinpre/class_scores/{dataset_name}_class_{class_id}.npy",
            )
        ).to(device)
        all_class_scores.append(class_scores)
        # mask = preds == class_id
        # masked ones are plus and rest is minus
        # all_scores += (
        #     999 * torch.multiply(mask, class_scores)
        #     - torch.multiply(~mask, class_scores)
        # ).float()
        # all_scores += class_scores
    all_class_scores = torch.stack(all_class_scores).to(device)
    # sum the largest k element each column
    k = 1
    all_scores = torch.topk(all_class_scores, k, largest=True, dim=0).values.mean(dim=0)
    # with torch.inference_mode():
    #     for feats_batch_initial, logits_batch_initial in test_loader:
    #         feats_batch_initial = feats_batch_initial.to(device)
    #         logits_batch_initial = logits_batch_initial.to(device)
    #         # if thold is not None:
    #         #     feats_batch_initial = react_filter(feats_batch_initial, thold).float()
    #         #     logits_batch_initial = model.fc(feats_batch_initial)
    #         preds_initial = logits_batch_initial.argmax(1)
    #         max_logits = logits_batch_initial.max(dim=1).values
    #         total_size += feats_batch_initial.size(0)
    #         trajectory_list = torch.zeros(
    #             feats_batch_initial.size(0), num_classes, device=device
    #         )
    #         for class_id in class_idx:
    #             class_scores = torch.from_numpy(
    #                 np.load(
    #                     f"./swinpre/class_scores/{dataset_name}_class_{class_id}.npy",
    #                     scores,
    #                 )
    #             ).to(device)
    #         ood_score = torch.max(trajectory_list, dim=1).values
    #         # ood_score = torch.topk(trajectory_list, 2, largest=False, dim=1).values[
    #         #     :, 1
    #         # ]
    #         # ood_score = torch.mean(trajectory_list, dim=1)
    #         all_scores.append(ood_score)
    # scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    scores = np.asarray(all_scores.detach().cpu().numpy(), dtype=np.float32)
    return scores


args = get_args()

seed = args.seed
print(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

class_num = 1000
id_train_size = 1281167
id_val_size = 50000

cache_dir = f"cache/{args.in_dataset}_train_{args.name}_in"
feat_log = torch.from_numpy(
    np.memmap(
        f"{cache_dir}/feat.mmap", dtype=float, mode="r", shape=(id_train_size, 1024)
    )
).to(device)
score_log = torch.from_numpy(
    np.memmap(
        f"{cache_dir}/score.mmap",
        dtype=float,
        mode="r",
        shape=(id_train_size, class_num),
    )
).to(device)
label_log = torch.from_numpy(
    np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode="r", shape=(id_train_size,))
).to(device)


cache_dir = f"cache/{args.in_dataset}_val_{args.name}_in"
feat_log_val = torch.from_numpy(
    np.memmap(
        f"{cache_dir}/feat.mmap", dtype=float, mode="r", shape=(id_val_size, 1024)
    )
).to(device)
score_log_val = torch.from_numpy(
    np.memmap(
        f"{cache_dir}/score.mmap", dtype=float, mode="r", shape=(id_val_size, class_num)
    )
).to(device)
label_log_val = torch.from_numpy(
    np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode="r", shape=(id_val_size,))
).to(device)

print("ID ACC:", calculate_acc_val(score_log_val, label_log_val))

ood_feat_score_log = {}
ood_dataset_size = {
    "inat": 10000,
    "sun50": 10000,
    "places50": 10000,
    "dtd": 5640,
    "ssbhard": 49000,
    "ninco": 5878,
}

for ood_dataset in args.out_datasets:
    ood_feat_log = torch.from_numpy(
        np.memmap(
            f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/feat.mmap",
            dtype=float,
            mode="r",
            shape=(ood_dataset_size[ood_dataset], 1024),
        )
    ).to(device)
    ood_score_log = torch.from_numpy(
        np.memmap(
            f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/score.mmap",
            dtype=float,
            mode="r",
            shape=(ood_dataset_size[ood_dataset], class_num),
        )
    ).to(device)
    ood_feat_score_log[ood_dataset] = ood_feat_log, ood_score_log


######## get w, b; precompute demoninator matrix, training feature mean  #################

if args.name == "resnet50":
    net = models.resnet50(pretrained=True)
    for i, param in enumerate(net.fc.parameters()):
        if i == 0:
            w = param.data
        else:
            b = param.data

elif args.name == "resnet50-supcon":
    checkpoint = torch.load("ckpt/ImageNet_resnet50_supcon_linear.pth")
    w = checkpoint["model"]["fc.weight"]
    b = checkpoint["model"]["fc.bias"]

train_mean = torch.mean(feat_log, dim=0).to(device)

denominator_matrix = torch.zeros((1000, 1000)).to(device)
# for p in range(1000):
#   w_p = w - w[p,:]
#   denominator = torch.norm(w_p, dim=1)
#   denominator[p] = 1
#   denominator_matrix[p, :] = denominator

#################### fDBD score OOD detection #################

all_results = []
all_score_out = []

values, nn_idx = score_log_val.max(1)
logits_sub = torch.abs(score_log_val - values.repeat(1000, 1).T)
# pdb.set_trace()
# score_in = torch.sum(logits_sub/denominator_matrix[nn_idx], axis=1)/torch.norm(feat_log_val - train_mean , dim = 1)
# score_in = score_in.float().cpu().numpy()

from util.model_loader import get_model

model = get_model(args, 1000, load_ckpt=True)
# model.fc = model.head.fc
in_dataset = torch.utils.data.TensorDataset(feat_log_val.cpu(), score_log_val.cpu())
in_loader = torch.utils.data.DataLoader(
    in_dataset, batch_size=128, shuffle=False, num_workers=2
)

train_dataset = torch.utils.data.TensorDataset(feat_log.cpu(), score_log.cpu())
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=False, num_workers=2
)

class_means = []
for i in range(1000):
    class_means.append(feat_log_val[label_log_val == i].mean(0))
class_means = torch.stack(class_means).to(device)

# tholds = react_thold(feat_log, percentile=85)
tholds = None
# tholds = torch.from_numpy(tholds).to(device)
# score_in = ORA_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = rcos_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = plaincos_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = rel_mahalonobis_distance_score_2(
#     model, feat_log, label_log, in_loader, 1000, class_means, device, tholds
# )
# for i in range(1000):
#     print(i)
#     score_in = ORA_class_idx_score(
#         model, in_loader, 1000, class_means, device, cid=i, dataset_name="imagenet"
#     )
score_in = ORA_classwise_score(
    model, in_loader, 1000, class_means, device, "imagenet", score_log=score_log_val
)
# score_in = neco_score(
#     model, feat_log, label_log, in_loader, 1000, class_means, device, tholds
# )
# score_in = mds_score(
#     model, feat_log, label_log, in_loader, 1000, class_means, device, tholds
# )
# score_in = energy_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = msp_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = maxlogit_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = fdbd_score(model, in_loader, 1000, class_means, device)
# score_in = React_energy(model, in_loader, 1000, class_means, device, tholds)
# breakpoint()
train_feats = feat_log[:]
train_labels = label_log[:]
# score_in = knn_score(model, in_loader, 1000, class_means, device, train_feats)
# score_in = nnguide_score(
#     model, train_loader, in_loader, 1000, class_means, device, train_feats
# )


for ood_dataset, (feat_log, score_log) in ood_feat_score_log.items():
    print(ood_dataset)
    values, nn_idx = score_log.max(1)
    logits_sub = torch.abs(score_log - values.repeat(1000, 1).T)
    # scores_out_test = torch.sum(logits_sub/denominator_matrix[nn_idx], axis=1)/torch.norm(feat_log - train_mean , dim = 1)
    # scores_out_test = scores_out_test.float().cpu().numpy()
    out_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(feat_log.cpu(), score_log.cpu()),
        batch_size=128,
        shuffle=False,
        num_workers=2,
    )

    # scores_out_test = ORA_score(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = rcos_score(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = plaincos_score(
    #     model, out_loader, 1000, class_means, device, tholds
    # )
    # scores_out_test = rel_mahalonobis_distance_score_2(
    #     model, train_feats, train_labels, out_loader, 1000, class_means, device, tholds
    # )
    # for k in range(1000):
    #     print(k)
    #     scores_out_test = ORA_class_idx_score(
    #         model,
    #         out_loader,
    #         1000,
    #         class_means,
    #         device,
    #         cid=k,
    #         dataset_name=ood_dataset,
    #     )
    scores_out_test = ORA_classwise_score(
        model,
        out_loader,
        1000,
        class_means,
        device,
        score_log=score_log,
        dataset_name=ood_dataset,
    )
    # scores_out_test = neco_score(
    #     model, train_feats, train_labels, out_loader, 1000, class_means, device, tholds
    # )
    # scores_out_test = mds_score(
    #     model, train_feats, train_labels, out_loader, 1000, class_means, device, tholds
    # )
    # scores_out_test = energy_score(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = msp_score(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = maxlogit_score(
    #     model, out_loader, 1000, class_means, device, tholds
    # )
    # scores_out_test = fdbd_score(model, out_loader, 1000, class_means, device)
    # scores_out_test = React_energy(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = knn_score(
    #     model, out_loader, 1000, class_means, device, train_feats
    # )
    # scores_out_test = nnguide_score(
    #     model, train_loader, out_loader, 1000, class_means, device, train_feats
    # )

    # plot histograms
    # import matplotlib.pyplot as plt
    # plot_helper(score_in, scores_out_test, args.in_dataset, ood_dataset, function_name="ORA")

    # scores_out_test = torch.sum(logits_sub/denominator_matrix[nn_idx], axis=1)/torch.norm(feat_log - train_mean , dim = 1)
    # scores_out_test = scores_out_test.float().cpu().numpy()
    all_score_out.extend(scores_out_test)
    results = metrics.cal_metric(score_in, scores_out_test)
    all_results.append(results)

metrics.print_all_results(all_results, args.out_datasets, "ORA")
print()
