# imports
import sys
import os
import time

import numpy as np

from evaluation_configs import get_configs
from evaluation_modules import get_knn, compute_mrr
from refuse.utils.datasets import AssemblageFunctionsDataset

start = time.time()

# more configurations
configs = get_configs()

# load test data set if normalizing labels
if configs['normalize_labels']:
    assert configs['dataset'] == 'assemblage', "label normalization is only supported with the assemblage dataset"
    dataset = AssemblageFunctionsDataset(**configs['dataset_configs'])
else:
    assert configs['how_normalize'] is None, "'how_normalize' must be None if 'normalize_labels' is False"
    dataset = None
    

# load embeddings
embeddings = np.memmap(configs['embeddings_file'], dtype='float32', mode='r+')
embeddings = np.reshape(embeddings, (-1, configs['embd_size']))
labels = np.memmap(configs['labels_file'], dtype='int64', mode='r+')

# build NN search index using faiss
nearest_neighbors, distances = get_knn(embeddings, configs['dist_fn'], configs['K'], configs['M'])

# compute mrr
mrr_upper, mrr_lower = compute_mrr(labels, nearest_neighbors,
                                   normalize=configs['how_normalize'], dataset=dataset)

end = time.time()

with open(configs['output_file'], 'a+') as out:
    print("Results for normalize = ", configs['how_normalize'], file=out)
    print("The lower and upper bounds on the mean reciprocal rank over the test set are: ", mrr_lower, mrr_upper, file=out)
    print("The time to evaluate was {} seconds".format(end-start), file=out)
    print(file=out)
