import os
import time
import torch
# import faiss
from sklearn.neighbors import NearestNeighbors
import numpy as np

def get_knn_score(config):
    cache_name = f"./knn/cache/{config['id_dataset']}_train_in_alllayers.npy"
    feat_log, score_log, label_log = np.load(cache_name, allow_pickle=True)
    feat_log, score_log = feat_log.T.astype(np.float32), score_log.T.astype(np.float32)

    cache_name = f"./knn/cache/{config['id_dataset']}_val_in_alllayers.npy"
    feat_log_val, score_log_val, label_log_val = np.load(cache_name, allow_pickle=True)
    feat_log_val, score_log_val = feat_log_val.T.astype(np.float32), score_log_val.T.astype(np.float32)

    ood_feat_log_all = {}
    for ood_dataset in config['ood_datasets']:
        cache_name = f"./knn/cache/{ood_dataset}vs{config['id_dataset']}_out_alllayers.npy"
        ood_feat_log, ood_score_log = np.load(cache_name, allow_pickle=True) 
        ood_feat_log, ood_score_log = ood_feat_log.T.astype(np.float32), ood_score_log.T.astype(np.float32)
        ood_feat_log_all[ood_dataset] = ood_feat_log

    normalizer = lambda x: x / (np.linalg.norm(x, ord=2, axis=-1, keepdims=True) + 1e-10)

    prepos_feat = lambda x: np.ascontiguousarray(normalizer(x[:, range(448, min(x.shape[1], 960))]))# Last Layer only

    ftrain = prepos_feat(feat_log)
    ftest = prepos_feat(feat_log_val)
    food_all = {}
    for ood_dataset in config['ood_datasets']:
        food_all[ood_dataset] = prepos_feat(ood_feat_log_all[ood_dataset])

    index = NearestNeighbors(n_neighbors=50, algorithm='brute', metric='euclidean')
    index.fit(ftrain)

    # Query nearest neighbors for test data
    distances, _ = index.kneighbors(ftest, n_neighbors=50)
    scores_in = -distances[:, -1]
    scores_in = np.exp(scores_in)
    all_score_ood = []

    for ood_dataset, food in food_all.items():
        # Query nearest neighbors for out-of-distribution data
        distances, _ = index.kneighbors(food, n_neighbors=50)
        scores_ood_test = -distances[:, -1]
        scores_ood_test = np.exp(scores_ood_test)
        all_score_ood.extend(scores_ood_test)
    return scores_in, all_score_ood


