
'''
Perform AAs classification via 5-fold cross validation on test data (10,000 datapoints).
Test data was unseen by all models regardless of training data quantity, so this is an extremely fair comparison.
Seed the splits of course so they're the same across runs.
Do not bother with splitting at the AAs level, the sample size is large enough that it should be fine.
Either way, the cross validation makes it so that each model sees every datapoint at training time and at test time.
Use the MultiClassLinearClassifier with a 10% validation data (seeded of course).
'''


import os, sys
import gzip, pickle
import json
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import argparse

import sklearn
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier

sys.path.append('..')
from classifiers import MultiClassLinearClassifier

def sanity_check(train_idxs, valid_idxs, test_idxs):
    train_idxs_set = set(list(train_idxs))
    valid_idxs_set = set(list(valid_idxs))
    test_idxs_set = set(list(test_idxs))
    assert len(train_idxs_set.intersection(valid_idxs_set)) == 0
    assert len(train_idxs_set.intersection(test_idxs_set)) == 0
    assert len(valid_idxs_set.intersection(test_idxs_set)) == 0



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='../runs/toy_aminoacids/local_equiv_fibers')
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model') # lowest_total_loss_with_final_kl_model, lowest_rec_loss
    parser.add_argument('--hash', type=str, required=True)
    parser.add_argument('--n_folds', type=int, default=5)
    parser.add_argument('--classifier', type=str, default='KNN')
    parser.add_argument('--perc_valid', type=float, default=10.0)
    parser.add_argument('--seed', type=int, default=12345678) # DO NOT CHANGE THIS!!!

    args = parser.parse_args()

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'no_training':
        model_type_str = '-no_training'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'
    
    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)
    
    rng = np.random.default_rng(args.seed)

    # assumes `inference_fibers.py` has already been run
    # NB: this will load the test data normalized by the training data size specific of the current model.
    # This means that the test data is at different scales across models, but it's fine, it's standard machine learning practice.
    # Also it's not going to matter much at all in this case as the training datasets come from the same distribution so they are not biased,
    # though there might be some variance with using the smaller datasets.
    try:
        arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=test.npz' % (model_type_str)))
    except Exception as E:
        if args.model_type == 'lowest_rec_loss':
            model_type_str = '-lowest_total_loss_with_final_kl_model'
            arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=test.npz' % (model_type_str)))
            print('Warning, using model_type_str: %s' % (model_type_str), file=sys.stderr)
        else:
            raise E

    invariants_ND = arrays['invariants_ND']
    labels_N = arrays['labels_N']
    N = labels_N.shape[0]
    assert N == invariants_ND.shape[0]
    n_features = invariants_ND.shape[1]
    n_classes = len(set(list(labels_N)))
    assert set(list(labels_N)) == set(list(np.arange(n_classes))) # assume that labels are given as indices from 0 to (n_classes - 1)

    ## split data in folds using the seed. Just create a list of array where the k^th array represents the indices of the k^th fold.
    import math
    fold_size = math.floor(N / args.n_folds)
    indices = np.arange(N)
    rng.shuffle(indices)
    folds_idxs = [indices[(k)*fold_size : (k+1)*fold_size] for k in range(args.n_folds - 1)]
    folds_idxs.append(indices[(args.n_folds - 1)*fold_size : ]) # last fold, which may have a few more examples, but not more than (args.n_folds - 1)

    ## iterate over the folds
    ## for each iteration, save 10% for validation (seeded!) and train the classifier, save scores
    ## compute accuracy only with the scores stacked from the 5 folds
    y_true = []
    y_pred = []
    for k in tqdm(range(args.n_folds)):
        train_and_valid_idxs = np.hstack([folds_idxs[i] for i in range(args.n_folds) if i != k])
        N_train_and_valid = train_and_valid_idxs.shape[0]
        
        train_idxs = train_and_valid_idxs[int((args.perc_valid / 100)*N_train_and_valid) :]
        valid_idxs = train_and_valid_idxs[: int((args.perc_valid / 100)*N_train_and_valid)]
        test_idxs = folds_idxs[k]
        # print(train_idxs.shape)
        # print(valid_idxs.shape)
        # print(test_idxs.shape)
        sanity_check(train_idxs, valid_idxs, test_idxs)

        y_true.append(labels_N[test_idxs])

        if args.classifier == 'LC':
            model = MultiClassLinearClassifier(n_features, n_classes, verbose=False)
            model = model.fit(invariants_ND[train_idxs], labels_N[train_idxs], x_valid_MF=invariants_ND[valid_idxs], y_valid_M=labels_N[valid_idxs])
            classifier_str = ''
        elif args.classifier == 'KNN':
            model = KNeighborsClassifier()
            model.fit(invariants_ND[train_and_valid_idxs], labels_N[train_and_valid_idxs])
            classifier_str = 'KNN_'

        predictions_one_hot = model.predict_proba(invariants_ND[test_idxs])
        y_pred.append(np.argmax(predictions_one_hot, axis=1))
    
    y_true = np.hstack(y_true)
    y_pred = np.hstack(y_pred)

    report = classification_report(y_true, y_pred, output_dict=True)
    # print(report)
    pd.DataFrame(report).to_csv(os.path.join(args.model_dir, args.hash, 'latent_space_classification/__residue_type_cross_val_on_test_%sclassificaton_on_latent_space%s.csv' % (classifier_str, model_type_str)))




