import os
import time
from util.args_loader import get_args
from util import metrics
import torch
import faiss
import numpy as np
import torchvision.models as models
from typing import Optional
import torch.nn.functional as F


def react_filter(feats, thold):
    feats = torch.where(feats > thold, thold, feats)
    return feats

def react_thold(id_feats, percentile=90):
    id_feats = id_feats.cpu().numpy()
    tholds = np.percentile(id_feats.reshape(-1), percentile, axis=0)
    return tholds

def React_energy(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    all_scores = []
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            if thold is not None:
                feats_batch_initial = react_filter(feats_batch_initial, thold).float()
                logits_batch_initial = model.fc(feats_batch_initial)
            energy_score = torch.logsumexp(logits_batch_initial, dim=1)
            all_scores.append(energy_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def fdbd_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    all_scores = []
    class_idx = np.arange(num_classes)
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device=device)
            for class_id in class_idx:
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                # weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                weight_diff = model.classifier[1].weight[preds_initial] - model.classifier[1].weight[class_id]

                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)
                feats_batch_db = (
                    feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1, 1) * weight_diff
                )
                # # fdbd original
                distance_to_db = torch.linalg.norm(feats_batch_initial - feats_batch_db, dim=1)
                fdbd_score = distance_to_db / torch.linalg.norm(feats_batch_initial - torch.mean(class_means, dim=0), dim=1)

                # # fdbd our derivation
                centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)
                norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)
                cos_sim_origin_perspective = torch.sum(norm_centered_feats * norm_centered_feats_db, dim=1)
                angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi

                feats_centered_db = feats_batch_initial - feats_batch_db
                mean_centered_db = torch.mean(class_means, dim=0) - feats_batch_db
                cos_sim = F.cosine_similarity(feats_centered_db, mean_centered_db, dim=1)
                angles_db = torch.arccos(cos_sim) / torch.pi
                our_derivation = torch.sin(angles_origin * torch.pi) / torch.sin(angles_db * torch.pi)
                # our_derivation = torch.sin(angles_origin * torch.pi)

                # check our derivation is same as fdbd
                # fdbd_score[torch.isnan(fdbd_score)] = 0
                # our_derivation[torch.isnan(our_derivation)] = 0
                # # print(torch.allclose(fdbd_score, our_derivation))

                trajectory_list[:, class_id] = our_derivation
            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores

def ORA_score3(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            # if thold is not None:
            #     feats_batch_initial = react_filter(feats_batch_initial, thold).float()
            #     logits_batch_initial = model.fc(feats_batch_initial)
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device=device)
            for stat in range(3):
                for class_id in class_idx:
                    logit_diff = max_logits - logits_batch_initial[:, class_id]
                    weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                    # weight_diff = model.classifier.weight[preds_initial] - model.classifier.weight[class_id]
                    weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)

                    feats_batch_db = (
                        feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1, 1) * weight_diff
                    )
                    if stat == 0:
                        centered_feats = feats_batch_initial - torch.max(class_means, dim=0).values
                        centered_feats_db = feats_batch_db - torch.max(class_means, dim=0).values

                    elif stat == 1:
                        centered_feats = feats_batch_initial - torch.min(class_means, dim=0).values
                        centered_feats_db = feats_batch_db - torch.min(class_means, dim=0).values
                    elif stat == 2:
                        centered_feats = feats_batch_initial - torch.median(class_means, dim=0).values
                        centered_feats_db = feats_batch_db - torch.median(class_means, dim=0).values

                    # centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                    # centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)
                    # centered_feats = feats_batch_initial - torch.median(class_means, dim=0).values
                    # centered_feats_db = feats_batch_db - torch.median(class_means, dim=0).values
                    # centered_feats = feats_batch_initial - class_means[0]
                    # centered_feats_db = feats_batch_db - class_means[0]

                    norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                    norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                    cos_sim_origin_perspective = torch.sum(norm_centered_feats * norm_centered_feats_db, dim=1)
                    angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi
                    trajectory_list[:, class_id] += angles_origin

            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, largest=False, dim=1).values[:, 1]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores

def ORA_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            # if thold is not None:
            #     feats_batch_initial = react_filter(feats_batch_initial, thold).float()
            #     logits_batch_initial = model.fc(feats_batch_initial)
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device=device)
            for class_id in class_idx:
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                # weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                weight_diff = model.classifier[1].weight[preds_initial] - model.classifier[1].weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)

                feats_batch_db = (
                    feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1, 1) * weight_diff
                )

                centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)
                # centered_feats = feats_batch_initial - torch.median(class_means, dim=0).values
                # centered_feats_db = feats_batch_db - torch.median(class_means, dim=0).values
                # centered_feats = feats_batch_initial - class_means[0]
                # centered_feats_db = feats_batch_db - class_means[0]

                norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                cos_sim_origin_perspective = torch.sum(norm_centered_feats * norm_centered_feats_db, dim=1)
                angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi
                trajectory_list[:, class_id] = angles_origin

            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, largest=False, dim=1).values[:, 1]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores

def ORA_score2(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            if thold is not None:
                feats_batch_initial = react_filter(feats_batch_initial, thold).float()
                logits_batch_initial = model.fc(feats_batch_initial)
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device=device)
            for class_id2 in [1, 250, 500, 750, 999]:
                for class_id in class_idx:
                    logit_diff = max_logits - logits_batch_initial[:, class_id]
                    weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                    weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)

                    feats_batch_db = (
                        feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1, 1) * weight_diff
                    )

                    # centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                    # centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)
                    centered_feats = feats_batch_initial - class_means[class_id2]
                    centered_feats_db = feats_batch_db - class_means[class_id2]

                    norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                    norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                    cos_sim_origin_perspective = torch.sum(norm_centered_feats * norm_centered_feats_db, dim=1)
                    angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi
                    trajectory_list[:, class_id] += angles_origin

            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, largest=False, dim=1).values[:, 1]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def energy_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    all_scores = []
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            energies = torch.logsumexp(logits_batch_initial, dim=1)
            all_scores.append(energies)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def msp_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    thold: Optional[torch.Tensor] = None,
) -> np.ndarray:
    model = model.to(device)
    model.eval()
    all_scores = []
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            logits_batch_initial = logits_batch_initial.to(device)
            probs = F.softmax(logits_batch_initial, dim=1)
            max_probs = torch.max(probs, dim=1).values
            all_scores.append(max_probs)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def knn_score(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    device: str,
    train_feats: torch.Tensor,
    k: Optional[torch.Tensor] = 1000,
) -> np.ndarray:
    model = model.to(device)
    model.eval()

    all_scores = []
    train_feats = F.normalize(train_feats, p=2, dim=1)
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.to(device)
            feats_batch_initial = F.normalize(feats_batch_initial, p=2, dim=1)
            # calculate the kth nearest distance to the train_feats
            d = torch.cdist(feats_batch_initial, train_feats, p=2, compute_mode='donot_use_mm_for_euclid_dist')
            d = torch.topk(d, k, largest=False, sorted=True).values[:, -1]
            all_scores.append(-d)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


args = get_args()

seed = args.seed
print(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

class_num = 1000
id_train_size = 1281167
id_val_size = 50000

cache_dir = f"cache/{args.in_dataset}_train_{args.name}_in"
feat_log = torch.from_numpy(np.memmap(f"{cache_dir}/feat.mmap", dtype=float, mode='r', shape=(id_train_size, 1280))).to(device)
score_log = torch.from_numpy(np.memmap(f"{cache_dir}/score.mmap", dtype=float, mode='r', shape=(id_train_size, class_num))).to(device)
label_log = torch.from_numpy(np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode='r', shape=(id_train_size,))).to(device)


cache_dir = f"cache/{args.in_dataset}_val_{args.name}_in"
feat_log_val = torch.from_numpy(np.memmap(f"{cache_dir}/feat.mmap", dtype=float, mode='r', shape=(id_val_size, 1280))).to(device)
score_log_val = torch.from_numpy(np.memmap(f"{cache_dir}/score.mmap", dtype=float, mode='r', shape=(id_val_size, class_num))).to(device)
label_log_val = torch.from_numpy(np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode='r', shape=(id_val_size,))).to(device)
breakpoint()

ood_feat_score_log = {}
ood_dataset_size = {
    'inat':10000,
    'sun50': 10000,
    'places50': 10000,
    'dtd': 5640
}

for ood_dataset in args.out_datasets:
    ood_feat_log = torch.from_numpy(np.memmap(f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/feat.mmap", dtype=float, mode='r', shape=(ood_dataset_size[ood_dataset], 1280))).to(device)
    ood_score_log = torch.from_numpy(np.memmap(f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/score.mmap", dtype=float, mode='r', shape=(ood_dataset_size[ood_dataset], class_num))).to(device)
    ood_feat_score_log[ood_dataset] = ood_feat_log, ood_score_log 


######## get w, b; precompute demoninator matrix, training feature mean  #################

if args.name == 'resnet50':
    net = models.resnet50(pretrained=True)
    for i, param in enumerate(net.fc.parameters()):
      if i == 0:
        w = param.data
      else:
        b = param.data

elif args.name == 'resnet50-supcon':
    checkpoint = torch.load('ckpt/ImageNet_resnet50_supcon_linear.pth')
    w = checkpoint['model']['fc.weight']
    b = checkpoint['model']['fc.bias']

train_mean = torch.mean(feat_log, dim= 0).to(device)

denominator_matrix = torch.zeros((1000,1000)).to(device)
# for p in range(1000):
#   w_p = w - w[p,:]
#   denominator = torch.norm(w_p, dim=1)
#   denominator[p] = 1
#   denominator_matrix[p, :] = denominator

#################### fDBD score OOD detection #################

all_results = []
all_score_out = []

values, nn_idx = score_log_val.max(1)
logits_sub = torch.abs(score_log_val - values.repeat(1000, 1).T)
#pdb.set_trace()
# score_in = torch.sum(logits_sub/denominator_matrix[nn_idx], axis=1)/torch.norm(feat_log_val - train_mean , dim = 1)
# score_in = score_in.float().cpu().numpy()

# from util.model_loader import get_model
# model = get_model(args, 1000, load_ckpt=True)
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
model.eval()
model.to(device)

in_dataset = torch.utils.data.TensorDataset(feat_log_val.cpu(), score_log_val.cpu())
in_loader = torch.utils.data.DataLoader(in_dataset, batch_size=128, shuffle=False, num_workers=2)

class_means = []
for i in range(1000):
    class_means.append(feat_log_val[label_log_val == i].mean(0))
class_means = torch.stack(class_means).to(device)

# tholds = react_thold(feat_log, percentile=85)
tholds = None
# tholds = torch.from_numpy(tholds).to(device)
# score_in = ORA_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = energy_score(model, in_loader, 1000, class_means, device, tholds)
score_in = msp_score(model, in_loader, 1000, class_means, device, tholds)
# score_in = fdbd_score(model, in_loader, 1000, class_means, device)
# score_in = React_energy(model, in_loader, 1000, class_means, device, tholds)
# breakpoint()
train_feats = feat_log[:]
# score_in = knn_score(model, in_loader, 1000, class_means, device, train_feats)



for ood_dataset, (feat_log, score_log) in ood_feat_score_log.items():
    print(ood_dataset)
    values, nn_idx = score_log.max(1)
    logits_sub = torch.abs(score_log - values.repeat(1000, 1).T)
    # scores_out_test = torch.sum(logits_sub/denominator_matrix[nn_idx], axis=1)/torch.norm(feat_log - train_mean , dim = 1)
    # scores_out_test = scores_out_test.float().cpu().numpy()
    out_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(feat_log.cpu(), score_log.cpu()), batch_size=128, shuffle=False, num_workers=2)
    # scores_out_test = ORA_score(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = energy_score(model, out_loader, 1000, class_means, device, tholds)
    scores_out_test = msp_score(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = fdbd_score(model, out_loader, 1000, class_means, device)
    # scores_out_test = React_energy(model, out_loader, 1000, class_means, device, tholds)
    # scores_out_test = knn_score(model, out_loader, 1000, class_means, device, train_feats)

    # scores_out_test = torch.sum(logits_sub/denominator_matrix[nn_idx], axis=1)/torch.norm(feat_log - train_mean , dim = 1)
    # scores_out_test = scores_out_test.float().cpu().numpy()
    all_score_out.extend(scores_out_test)
    results = metrics.cal_metric(score_in, scores_out_test)
    all_results.append(results)

metrics.print_all_results(all_results, args.out_datasets, 'ORA')
print()