import os
import time
from util.args_loader import get_args
from util.model_loader import get_model
from util import metrics
import torch
import faiss
import numpy as np
import torchvision.models as models
import pdb
import torch.nn.functional as F
from typing import Optional


torch.set_default_tensor_type(torch.DoubleTensor)



def LAFO(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    percentile: int,
    thold: Optional[torch.Tensor]=None,
) -> np.ndarray:
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)
    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            
            feats_batch_initial = feats_batch_initial.cuda()
            feats_batch_initial = scale(feats_batch_initial, percentile)
            # feats_batch_initial = ash_s(feats_batch_initial, percentile)
            # compare feats_batch_initial with thold, whenever feats_batch_initial is larger than thold, use thold
            if thold is not None:
                feats_batch_initial[thold < feats_batch_initial] = thold
            
            
            # logits_batch_initial = logits_batch_initial.cuda()
            logits_batch_initial = model.fc(feats_batch_initial.double())
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device='cuda')
            for class_id in class_idx:
                
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)
                
                feats_batch_db = feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1,1) * weight_diff
                

                mean_feats = torch.mean(class_means, dim=0)
                # apply thold to mean feats
                if thold is not None:
                    mean_feats[thold < mean_feats] = thold

                # centered_feats = feats_batch_initial - mean_feats
                # centered_feats_db = feats_batch_db - mean_feats
                centered_feats = feats_batch_initial - scale(torch.mean(class_means, dim=0))
                centered_feats_db = feats_batch_db - scale(torch.mean(class_means, dim=0))
                # centered_feats = feats_batch_initial - ash_s(torch.mean(class_means, dim=0), percentile)
                # centered_feats_db = feats_batch_db - ash_s(torch.mean(class_means, dim=0), percentile)

                norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                cos_sim_origin_perspective = torch.sum(norm_centered_feats * norm_centered_feats_db, dim=1)
                angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi

                trajectory_list[:, class_id] = angles_origin
            
            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, dim=1).values[:, 1]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores



def fDBD(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
    percentile: int,
    thold: Optional[torch.Tensor]=None,
) -> np.ndarray:
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)
    if thold is not None:
        train_mean = torch.mean(class_means, dim=0)
        train_mean[thold < train_mean] = thold

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            
            feats_batch_initial = feats_batch_initial.cuda()
            # feats_batch_initial = scale(feats_batch_initial, percentile)
            # feats_batch_initial = ash_s(feats_batch_initial, percentile)
            if thold is not None:
                feats_batch_initial[thold < feats_batch_initial] = thold

            logits_batch_initial = model.fc(feats_batch_initial.double())
            # logits_batch_initial = logits_batch_initial.cuda()
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device='cuda')
            for class_id in class_idx:
                
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)
                
                feats_batch_db = feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1,1) * weight_diff
                distance_to_db = torch.linalg.norm(feats_batch_initial - feats_batch_db, dim=1)
                # fdbd
                # trajectory_list[:, class_id] = distance_to_db / torch.linalg.norm(feats_batch_initial - scale(torch.mean(class_means, dim=0)), dim=1)
                trajectory_list[:, class_id] = distance_to_db / torch.linalg.norm(feats_batch_initial - train_mean, dim=1)
            
            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores



args = get_args()

seed = args.seed
print(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu



def scale(x, percentile=90):
    x = x.view(-1, 2048, 1, 1)
    input = x.clone()
    assert x.dim() == 4
    assert 0 <= percentile <= 100
    b, c, h, w = x.shape

    # calculate the sum of the input per sample
    s1 = x.sum(dim=[1, 2, 3])
    n = x.shape[1:].numel()
    k = n - int(np.round(n * percentile / 100.0))
    t = x.view((b, c * h * w))
    v, i = torch.topk(t, k, dim=1)
    t.zero_().scatter_(dim=1, index=i, src=v)

    # calculate new sum of the input per sample after pruning
    s2 = x.sum(dim=[1, 2, 3])

    # apply sharpening
    scale = s1 / s2
    return torch.flatten(input * torch.exp(scale[:, None, None, None]), 1).float()



def ash_s(x, percentile=90):
    x = x.view(-1, 2048, 1, 1)
    assert x.dim() == 4
    assert 0 <= percentile <= 100
    b, c, h, w = x.shape

    # calculate the sum of the input per sample
    s1 = x.sum(dim=[1, 2, 3])
    n = x.shape[1:].numel()
    k = n - int(np.round(n * percentile / 100.0))
    t = x.view((b, c * h * w))
    v, i = torch.topk(t, k, dim=1)
    t.zero_().scatter_(dim=1, index=i, src=v)

    # calculate new sum of the input per sample after pruning
    s2 = x.sum(dim=[1, 2, 3])

    # apply sharpening
    scale = s1 / s2

    x = x * torch.exp(scale[:, None, None, None])

    return torch.flatten(x, 1).float()

def react_thold(train_feats, percentile=90):
    return np.percentile(train_feats.flatten(), percentile, axis=0)


class_num = 1000
id_train_size = 1281167
id_val_size = 50000

cache_dir = f"cache/{args.in_dataset}_train_{args.name}_in"
train_feat_log = torch.from_numpy(np.memmap(f"{cache_dir}/feat.mmap", dtype=float, mode='r', shape=(id_train_size, 2048))).to(device)
train_score_log = torch.from_numpy(np.memmap(f"{cache_dir}/score.mmap", dtype=float, mode='r', shape=(id_train_size, class_num))).to(device)
train_label_log = torch.from_numpy(np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode='r', shape=(id_train_size,))).to(device)


cache_dir = f"cache/{args.in_dataset}_val_{args.name}_in"
feat_log_val = torch.from_numpy(np.memmap(f"{cache_dir}/feat.mmap", dtype=float, mode='r', shape=(id_val_size, 2048))).to(device)
score_log_val = torch.from_numpy(np.memmap(f"{cache_dir}/score.mmap", dtype=float, mode='r', shape=(id_val_size, class_num))).to(device)
label_log_val = torch.from_numpy(np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode='r', shape=(id_val_size,))).to(device)


ood_feat_score_log = {}
ood_dataset_size = {
    'inat':10000,
    'sun50': 10000,
    'places50': 10000,
    'dtd': 5640
}

for ood_dataset in args.out_datasets:
    ood_feat_log = torch.from_numpy(np.memmap(f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/feat.mmap", dtype=float, mode='r', shape=(ood_dataset_size[ood_dataset], 2048))).to(device)
    ood_score_log = torch.from_numpy(np.memmap(f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/score.mmap", dtype=float, mode='r', shape=(ood_dataset_size[ood_dataset], class_num))).to(device)
    ood_feat_score_log[ood_dataset] = ood_feat_log, ood_score_log 


in_dataset = torch.utils.data.TensorDataset(feat_log_val.cpu(), score_log_val.cpu())
in_loader = torch.utils.data.DataLoader(
    in_dataset, batch_size=512, shuffle=False, num_workers=2
)

class_means = torch.zeros(1000, train_feat_log.size(1)).to(device)
for i in range(1000):
    class_means[i] = torch.mean(train_feat_log[train_label_log == i], dim=0).to(device)


num_classes = 1000
model = get_model(
    args, num_classes, load_ckpt=True
)


for percentile in range(95, 0, -5):
    # r_thold = react_thold(train_feat_log.cpu().numpy(), percentile)
    r_thold = None
    score_in = LAFO(model, in_loader, num_classes, class_means, percentile, r_thold)
    # score_in = fDBD(model, in_loader, num_classes, class_means, percentile, r_thold)
    all_results = []
    all_score_out = []

    for ood_dataset_name, (feat_log, score_log) in ood_feat_score_log.items():
        # create a dummy dataset over feat_log
        ood_dataset = torch.utils.data.TensorDataset(feat_log.cpu(), score_log.cpu())
        # dataloader
        ood_loader = torch.utils.data.DataLoader(
            ood_dataset, batch_size=512, shuffle=False, num_workers=2
        )

        scores_out_test = LAFO(model, ood_loader, num_classes, class_means, percentile, r_thold)
        # scores_out_test = fDBD(model, ood_loader, num_classes, class_means, percentile, r_thold)

        all_score_out.extend(scores_out_test)
        results = metrics.cal_metric(score_in, scores_out_test)
        all_results.append(results)
        print(results)
        

    metrics.print_all_results(all_results, args.out_datasets, "LAFA percentile: " + str(percentile))
    print()