
'''
Must be run after running `latent_space_classification.py` for the same model

Must have node installed to run javascript files.
'''

import os
import sys
import shutil
from subprocess import check_output

import numpy as np
import argparse
import json

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV

def latent_space_model_retrieval_shrec17_with_precomputed_predictions(predictions, test_ids_N, log_dir, resdir):
    
    predictions_class = np.argmax(predictions, axis=1)

    for i in range(test_ids_N.shape[0]):
        if i % 100 == 0:
            print("{}/{}    ".format(i, test_ids_N.shape[0]), end="\r")
        idfile = os.path.join(resdir, test_ids_N[i])

        retrieved = [(predictions[j, predictions_class[j]], test_ids_N[j]) for j in range(test_ids_N.shape[0]) if predictions_class[j] == predictions_class[i]]
        retrieved = sorted(retrieved, reverse=True)
        retrieved = [i for _, i in retrieved]

        with open(idfile, "w") as f:
            f.write("\n".join(retrieved))
    
    print(check_output(["node", "evaluate.js", os.path.join("..", log_dir) + "/"], cwd="evaluator").decode("utf-8"))
    


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='../runs/shrec17/local_equiv_fibers')
    parser.add_argument('--hash', type=str, required=True)
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model') # lowest_total_loss_with_final_kl_model, lowest_rec_loss
    parser.add_argument('--classifier', type=str, default='KNN')
    parser.add_argument('--split', type=str, default='test')
    args = parser.parse_args()

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'
    
    if args.classifier == 'LC':
        classifier_str = ''
    elif args.classifier == 'KNN':
        classifier_str = 'KNN_'
    
    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)

    ## SHREC17 standard evaluation metrics
    ## Following the same procedure as Spherical CNNs to generate list of retrieved models

    arrays_train = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train.npz' % (model_type_str)))
    train_invariants_ND = arrays_train['invariants_ND']
    train_labels_N = arrays_train['labels_N']

    arrays_test = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=test.npz' % (model_type_str)))
    test_invariants_ND = arrays_test['invariants_ND']
    test_ids_N = arrays_test['ids_N']

    print(test_ids_N)

    log_dir = os.path.join(args.model_dir, args.hash, args.model_type)
    if os.path.isdir(log_dir):
        shutil.rmtree(log_dir)
    os.mkdir(log_dir)
    
    if hparams['perturbed']:
        resdir = os.path.join(args.model_dir, args.hash, args.model_type, args.split + '_perturbed')
    else:
        resdir = os.path.join(args.model_dir, args.hash, args.model_type, args.split + '_normal')
    if os.path.isdir(resdir):
        shutil.rmtree(resdir)
    os.mkdir(resdir)

#     latent_space_model_retrieval_shrec17(train_invariants_ND, train_labels_N, test_invariants_ND, test_ids_N, log_dir, resdir, classifier='LR', optimize_hyps=False)

    predictions = np.load(os.path.join(args.model_dir, args.hash, 'latent_space_classification/__%sclassificaton_on_latent_space_default_classes%s.npy' % (classifier_str, model_type_str)))

    latent_space_model_retrieval_shrec17_with_precomputed_predictions(predictions, test_ids_N, log_dir, resdir)

