
import os, sys
import argparse

sys.path.append('..')
from utils.argparse_utils import *


MODEL_NAME = 'cgvae_symm_simp_flex-%s-z=%d-x_lambda=400-data=%d-bs=%d-kl_lambda=0.025_v%d' # -kl_lambda=0.025

DATA_QUANTITY_TO_BATCH_SIZE_DICT = {
    0: 0,
    400: 4,
    1000: 10,
    2000: 20,
    5000: 50,
    20000: 200
}

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='../runs/toy_aminoacids/local_equiv_fibers')
    parser.add_argument('--is_vae', type=str_to_bool, default=True)
    parser.add_argument('--z', type=int, default=2)
    parser.add_argument('--data_quantities', type=comma_sep_int_list, default='400,1000,2000,5000,20000')
    parser.add_argument('--repetitions', type=int, default=3)
    parser.add_argument('--classifier', type=str, default='KNN')
    parser.add_argument('--n_folds', type=int, default=5)
    parser.add_argument('--perc_valid', type=float, default=10.0)
    parser.add_argument('--seed', type=int, default=12345678) # DO NOT CHANGE THIS!!!
    
    args = parser.parse_args()


    # create list of all hashes and a parallel list of model types

    if args.is_vae:
        is_vae_str = 'VAE_min_loss_with_final_kl'
        model_type = 'lowest_total_loss_with_final_kl_model'
    else:
        is_vae_str = 'AE'
        model_type = 'lowest_rec_loss'


    hashes, model_types = [], []
    for data_quantity in args.data_quantities:
        batch_size = DATA_QUANTITY_TO_BATCH_SIZE_DICT[data_quantity]
        for rep in range(args.repetitions):
            hashes.append(MODEL_NAME % (is_vae_str, args.z, data_quantity, batch_size, rep+1))
            if data_quantity == 0:
                model_types.append('no_training')
            else:
                model_types.append(model_type)
    

    # launch latent space classification script for each hash
    for i, (hash, model_type) in enumerate(zip(hashes, model_types)):
        print('%d/%d: working on %s...' % (i+1, len(hashes), hash))
        sys.stdout.flush()

        command = 'python latent_space_classification_cross_val_on_test.py'
        command += ' --model_dir %s' %(args.model_dir)
        command += ' --model_type %s' % (model_type)
        command += ' --hash %s' % (hash)
        command += ' --n_folds %d' % (args.n_folds)
        command += ' --classifier %s' % (args.classifier)
        command += ' --perc_valid %.9f' % (args.perc_valid) # ridiculously high floating point precision because why not, even though this is usually going to be a integer
        command += ' --seed %d' % (args.seed)
        os.system(command)