
'''
Latent space classification via LC or KNN
'''

import os, sys
import gzip, pickle
import json
import numpy as np
import pandas as pd
import torch
import argparse

import sklearn
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier

sys.path.append('..')
from classifiers import MultiClassLinearClassifier
from utils.argparse_utils import *

def latent_space_prediction(train_invariants_ND, train_labels_N, eval_invariants_ND, eval_labels_N, valid_invariants_ND=None, valid_labels_N=None, classifier='linear', optimize_hyps=False):
    n_features = train_invariants_ND.shape[1]
    n_classes = len(set(list(train_labels_N)))

    if classifier == 'LC':
        estimator = MultiClassLinearClassifier(n_features, n_classes, verbose=True)
        hyperparams = {'lr': [0.1, 0.01, 0.001]}
    elif classifier == 'KNN':
        estimator = KNeighborsClassifier()
        hyperparams = {'n_neighbors': [5, 10, 20]}
    else:
        raise NotImplementedError
    
    if optimize_hyps:
        model = GridSearchCV(estimator, hyperparams)
    else:
        model = estimator
    
    if classifier == 'KNN':
        model = model.fit(train_invariants_ND, train_labels_N)
    else:
        model = model.fit(train_invariants_ND, train_labels_N, x_valid_MF=valid_invariants_ND, y_valid_M=valid_labels_N)
    
    predictions = model.predict_proba(eval_invariants_ND)
    onehot_predictions = np.argmax(predictions, axis=1)
    
    return classification_report(eval_labels_N, onehot_predictions, output_dict=True), onehot_predictions

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='../runs/mnist/local_equiv_fibers')
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model') # lowest_total_loss_with_final_kl_moder, lowest_rec_loss
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--hash', type=str, required=True)
    parser.add_argument('--classifier', type=str, default='KNN')
    parser.add_argument('--use_validation_to_train', type=str_to_bool, default=False)
    parser.add_argument('--seed', type=int, default=10000000)

    args = parser.parse_args()

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'no_training':
        model_type_str = '-no_training'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'

    if args.classifier == 'LC':
        classifier_str = ''
    elif args.classifier == 'KNN':
        classifier_str = 'KNN_'
    
    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)

    # assumes `inference_fibers.py` has already been run
    try:
        train_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train-input_type=%s.npz' % (model_type_str, hparams['input_type'])))
    except:
        train_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train.npz' % (model_type_str)))
    train_invariants_ND = train_arrays['invariants_ND']
    train_labels_N = train_arrays['labels_N']

    try:
        valid_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=valid-input_type=%s.npz' % (model_type_str, hparams['input_type'])))
    except:
        valid_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=valid.npz' % (model_type_str)))
    valid_invariants_ND = valid_arrays['invariants_ND']
    valid_labels_N = valid_arrays['labels_N']

    try:
        eval_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s-input_type=%s.npz' % (model_type_str, args.split, hparams['input_type'])))
    except:
        eval_arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s.npz' % (model_type_str, args.split)))
    eval_invariants_ND = eval_arrays['invariants_ND']
    eval_labels_N = eval_arrays['labels_N']


    if args.use_validation_to_train:
        report, predictions_N = latent_space_prediction(np.vstack([train_invariants_ND, valid_invariants_ND]), np.hstack([train_labels_N, valid_labels_N]), eval_invariants_ND, eval_labels_N, valid_invariants_ND=valid_invariants_ND, valid_labels_N=valid_labels_N, classifier=args.classifier, optimize_hyps=False)

        pd.DataFrame(report).to_csv(os.path.join(args.model_dir, args.hash, 'latent_space_classification/__%sclassificaton_on_latent_space_default_classes_WITH_VALIDATION_DATA%s.csv' % (classifier_str, model_type_str)))

        print(report)

        np.save(os.path.join(args.model_dir, args.hash, 'latent_space_classification/__%sclassificaton_on_latent_space_default_classes_WITH_VALIDATION_DATA%s.npy' % (classifier_str, model_type_str)), predictions_N)

    else:
        report, predictions_N = latent_space_prediction(train_invariants_ND, train_labels_N, eval_invariants_ND, eval_labels_N, valid_invariants_ND=valid_invariants_ND, valid_labels_N=valid_labels_N, classifier=args.classifier, optimize_hyps=False)

        pd.DataFrame(report).to_csv(os.path.join(args.model_dir, args.hash, 'latent_space_classification/__%sclassificaton_on_latent_space_default_classes%s.csv' % (classifier_str, model_type_str)))

        print(report)

        np.save(os.path.join(args.model_dir, args.hash, 'latent_space_classification/__%sclassificaton_on_latent_space_default_classes%s.npy' % (classifier_str, model_type_str)), predictions_N)



