import os
import time
from util.args_loader import get_args
from util import metrics
import torch
# import faiss
from sklearn.neighbors import NearestNeighbors
from annoy import AnnoyIndex
import numpy as np

torch.manual_seed(1)
torch.cuda.manual_seed(1)
np.random.seed(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = get_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

cache_name = f"cache/{args.in_dataset}_train_{args.name}_in_alllayers.npy"
feat_log, score_log, label_log = np.load(cache_name, allow_pickle=True)
feat_log, score_log = feat_log.T.astype(np.float32), score_log.T.astype(np.float32)
class_num = score_log.shape[1]

cache_name = f"cache/{args.in_dataset}_val_{args.name}_in_alllayers.npy"
feat_log_val, score_log_val, label_log_val = np.load(cache_name, allow_pickle=True)
feat_log_val, score_log_val = feat_log_val.T.astype(np.float32), score_log_val.T.astype(np.float32)

ood_feat_log_all = {}
for ood_dataset in args.out_datasets:
    cache_name = f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out_alllayers.npy"
    ood_feat_log, ood_score_log = np.load(cache_name, allow_pickle=True)
    ood_feat_log, ood_score_log = ood_feat_log.T.astype(np.float32), ood_score_log.T.astype(np.float32)
    ood_feat_log_all[ood_dataset] = ood_feat_log

normalizer = lambda x: x / (np.linalg.norm(x, ord=2, axis=-1, keepdims=True) + 1e-10)

prepos_feat = lambda x: np.ascontiguousarray(normalizer(x[:, range(448, min(x.shape[1], 960))]))# Last Layer only

ftrain = prepos_feat(feat_log)
ftest = prepos_feat(feat_log_val)
food_all = {}
for ood_dataset in args.out_datasets:
    food_all[ood_dataset] = prepos_feat(ood_feat_log_all[ood_dataset])


#################### KNN score OOD detection #################

index = NearestNeighbors(n_neighbors=50, algorithm='brute', metric='euclidean')
index.fit(ftrain)

# Query nearest neighbors for test data
distances, _ = index.kneighbors(ftest, n_neighbors=50)
scores_in = -distances[:, -1]

all_results = []
all_score_ood = []

for ood_dataset, food in food_all.items():
    # Query nearest neighbors for out-of-distribution data
    distances, _ = index.kneighbors(food, n_neighbors=50)
    scores_ood_test = -distances[:, -1]
    all_score_ood.extend(scores_ood_test)

    # Save scores to files
    np.savetxt(f'./output/knn_cifar/in_scores.txt', scores_in, delimiter='\n')
    if ood_dataset == 'dtd': ood_dataset_name = 'textures'
    elif ood_dataset == 'LSUN': ood_dataset_name = 'lsuncrop'
    elif ood_dataset == 'LSUN_resize': ood_dataset_name = 'lsunresize'
    else: ood_dataset_name = ood_dataset.lower()
    np.savetxt(f'./output/knn_cifar/{ood_dataset_name}.txt', scores_ood_test, delimiter='\n')

    # Calculate metrics
    results = metrics.cal_metric(scores_in, scores_ood_test)
    all_results.append(results)

metrics.print_all_results(all_results, list(food_all.keys()), 'KNN k=50')


