
import os, sys
import argparse

sys.path.append('..')
from utils.argparse_utils import *


MODEL_NAME = 'cgvae_symm_simp_flex-VAE_min_loss_with_final_kl-z=%d-x_lambda=400-data=20000-bs=200-kl_lambda=%s_v%d'


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='../runs/toy_aminoacids/local_equiv_fibers')
    parser.add_argument('--z', type=int, default=2)
    parser.add_argument('--kl_lambdas', type=comma_sep_str_list, default='0.05,0.25,0.5')
    parser.add_argument('--repetitions', type=int, default=3)
    parser.add_argument('--classifier', type=str, default='KNN')
    parser.add_argument('--n_folds', type=int, default=5)
    parser.add_argument('--perc_valid', type=float, default=10.0)
    parser.add_argument('--seed', type=int, default=12345678) # DO NOT CHANGE THIS!!!
    
    args = parser.parse_args()


    # create list of all hashes and a parallel list of model types

    model_type = 'lowest_total_loss_with_final_kl_model'

    hashes, model_types = [], []
    for kl_lambda in args.kl_lambdas:
        for rep in range(args.repetitions):
            hashes.append(MODEL_NAME % (args.z, kl_lambda, rep+1))
            model_types.append(model_type)
    

    # launch latent space classification script for each hash
    for i, (hash, model_type) in enumerate(zip(hashes, model_types)):
        print('%d/%d: working on %s...' % (i+1, len(hashes), hash))
        sys.stdout.flush()

        command = 'python latent_space_classification_cross_val_on_test.py'
        command += ' --model_dir %s' %(args.model_dir)
        command += ' --model_type %s' % (model_type)
        command += ' --hash %s' % (hash)
        command += ' --n_folds %d' % (args.n_folds)
        command += ' --classifier %s' % (args.classifier)
        command += ' --perc_valid %.9f' % (args.perc_valid) # ridiculously high floating point precision because why not, even though this is usually going to be a integer
        command += ' --seed %d' % (args.seed)
        os.system(command)