
import os, sys
import gzip, pickle
import json
import numpy as np
import pandas as pd
import torch
import argparse

import sklearn
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier

sys.path.append('..')
from classifiers import MultiClassLinearClassifier

def latent_space_prediction(train_invariants_ND, train_labels_N, eval_invariants_ND, eval_labels_N, valid_invariants_ND=None, valid_labels_N=None, classifier='linear', optimize_hyps=False):
    n_features = train_invariants_ND.shape[1]
    n_classes = len(set(list(train_labels_N)))

    if classifier == 'LC':
        estimator = MultiClassLinearClassifier(n_features, n_classes, verbose=True)
        hyperparams = {'lr': [0.1, 0.01, 0.001]}
    elif classifier == 'KNN':
        estimator = KNeighborsClassifier()
        hyperparams = {'n_neighbors': [5, 10, 20]}
    
    if optimize_hyps:
        model = GridSearchCV(estimator, hyperparams)
    else:
        model = estimator
    
    if classifier == 'KNN':
        model = model.fit(train_invariants_ND, train_labels_N)
    else:
        model = model.fit(train_invariants_ND, train_labels_N, x_valid_MF=valid_invariants_ND, y_valid_M=valid_labels_N)
    
    predictions = model.predict_proba(eval_invariants_ND)
    onehot_predictions = np.argmax(predictions, axis=1)
    
    return classification_report(eval_labels_N, onehot_predictions, output_dict=True), predictions


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='../runs/shrec17/local_equiv_fibers')
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model')
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--hash', type=str, required=True)
    parser.add_argument('--classifier', type=str, default='KNN')
    parser.add_argument('--seed', type=int, default=10000000)

    args = parser.parse_args()

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'no_training':
        model_type_str = '-no_training'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'
    
    if args.classifier == 'LC':
        classifier_str = ''
    elif args.classifier == 'KNN':
        classifier_str = 'KNN_'
    
    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)

    # assumes `inference_fibers.py` has already been run
    try:
        train_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train-input_type=%s.npz' % (model_type_str, hparams['input_type'])))
    except:
        train_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train.npz' % (model_type_str)))
    train_invariants_ND = train_arrays['invariants_ND']
    train_labels_N = train_arrays['labels_N']

    try:
        valid_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=valid-input_type=%s.npz' % (model_type_str, hparams['input_type'])))
    except:
        valid_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=valid.npz' % (model_type_str)))
    valid_invariants_ND = valid_arrays['invariants_ND']
    valid_labels_N = valid_arrays['labels_N']

    try:
        eval_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s-input_type=%s.npz' % (model_type_str, args.split, hparams['input_type'])))
    except:
        eval_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s.npz' % (model_type_str, args.split)))
    eval_invariants_ND = eval_arrays['invariants_ND']
    eval_labels_N = eval_arrays['labels_N']

    report, predictions_N = latent_space_prediction(train_invariants_ND, train_labels_N, eval_invariants_ND, eval_labels_N, valid_invariants_ND=valid_invariants_ND, valid_labels_N=valid_labels_N, classifier=args.classifier, optimize_hyps=False)

    pd.DataFrame(report).to_csv(os.path.join(args.model_dir, args.hash, 'latent_space_classification/__%sclassificaton_on_latent_space_default_classes%s.csv' % (classifier_str, model_type_str)))

    print(report)

    np.save(os.path.join(args.model_dir, args.hash, 'latent_space_classification/__%sclassificaton_on_latent_space_default_classes%s.npy' % (classifier_str, model_type_str)), predictions_N)



