
import os, sys
import argparse
sys.path.append('..')
from utils.argparse_utils import *

TRAINING_ARGS = ['is_vae', 'model', 'w3j_filepath', 'input_type', 'net_lmax', 'latent_dim', 'n_cg_blocks', 'lmax_list', 'ch_size_list', 'ls_nonlin_rule_list',
                      'ch_nonlin_rule_list', 'do_initial_linear_projection', 'ch_initial_linear_projection', 'filter_symmetric', 'use_batch_norm',
                      'linearity_first', 'norm_type', 'normalization', 'norm_balanced', 'norm_affine', 'norm_nonlinearity', 'norm_location', 'use_additive_skip_connections',
                      'weight_decay', 'x_rec_loss_fn', 'batch_size', 'learn_frame', 'lr', 'lr_schedule', 'n_epochs', 'lambdas', 'lambdas_schedule', 'no_kl_epochs', 'warmup_kl_epochs',
                      'seed', 'hash', 'experiments_dir', 'experiments_suffix', 'use_wandb', 'use_tensorboard']

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    ## training arguments
    parser.add_argument('--is_vae', type=str_to_bool, default=True)
    parser.add_argument('--model', type=str, default='cgvae_symmetric_simple_flexible')
    parser.add_argument('--input_type', type=str, default='NRR-avg_sqrt_power')
    parser.add_argument('--w3j_filepath', type=str, default='../cg_coefficients/w3j_matrices-lmax=14-version=0.5.0.pkl')
    
    parser.add_argument('--net_lmax', type=int, default=10)
    parser.add_argument('--latent_dim', type=int, default=16)
    parser.add_argument('--n_cg_blocks', type=int, default=6)
    parser.add_argument('--ch_size_list', type=str, default='16,16,16,16,16,16')
    parser.add_argument('--lmax_list', type=str, default='10,10,8,4,2,1')
    parser.add_argument('--ls_nonlin_rule_list', type=str, default='efficient,efficient,efficient,efficient,efficient,efficient')
    parser.add_argument('--ch_nonlin_rule_list', type=str, default='elementwise,elementwise,elementwise,elementwise,elementwise,elementwise')
    parser.add_argument('--do_initial_linear_projection', type=str_to_bool, default=False)
    parser.add_argument('--ch_initial_linear_projection', type=int, default=0)

    parser.add_argument('--filter_symmetric', type=str_to_bool, default=True)
    parser.add_argument('--linearity_first', type=str_to_bool, default=False)

    parser.add_argument('--use_batch_norm', type=str_to_bool, default=True)
    parser.add_argument('--norm_type', type=str, default='signal') # None, layer, signal, layer_and_signal
    parser.add_argument('--normalization', type=str, default='norm') # norm, component -> only considered if norm_type is not none
    parser.add_argument('--norm_balanced', type=str_to_bool_or_float, default=False) 
    parser.add_argument('--norm_affine', type=str, default='per_l') # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
    parser.add_argument('--norm_nonlinearity', type=str, default=None) # identity, relu, swish, sigmoid -> only for layer_norm
    parser.add_argument('--norm_location', type=str, default='between') # first, between, last

    parser.add_argument('--use_additive_skip_connections', type=str_to_bool, default=True)
    parser.add_argument('--weight_decay', type=str_to_bool, default=False)
    parser.add_argument('--x_rec_loss_fn', type=str, default='mse')
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--learn_frame', type=str_to_bool, default=True)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--lr_schedule', type=str, default='log_decrease_until_end_of_warmup', choices=['constant', 'log_decrease_until_end_of_warmup', 'log_decrease_until_end_by_1_OM', 'log_decrease_until_end_by_2_OM', 'log_decrease_until_end_by_3_OM', 'linear_decrease_until_end_of_warmup', 'decrease_below_threshold', 'decrease_after_warmup', 'decrease_at_half'])
    parser.add_argument('--n_epochs', type=int, default=80)
    parser.add_argument('--lambdas', type=str, default='50.0,0.2')
    parser.add_argument('--lambdas_schedule', type=str, default='linear_up_anneal_kl', choices=['constant', 'drop_kl_at_half', 'linear_up_anneal_kl'])
    parser.add_argument('--no_kl_epochs', type=int, default=25)
    parser.add_argument('--warmup_kl_epochs', type=int, default=35)

    parser.add_argument('--seed', type=int, default=420420420)

    parser.add_argument('--hash', type=str, required=True, help='Unique identifier for the run. Usually a hash of the hyperparameters.')
    parser.add_argument('--experiments_dir', type=str, default='../runs/mnist')
    parser.add_argument('--experiments_suffix', type=str, default='equiv_fibers')
    parser.add_argument('--use_wandb', type=str_to_bool, default=False)
    parser.add_argument('--use_tensorboard', type=str_to_bool, default=False)


    ## evaluation pipeline arguments
    parser.add_argument('--model_class', type=str, default='fibers')
    parser.add_argument('--model_dir', type=str, default='../runs/mnist/local_equiv_fibers')
    parser.add_argument('--splits', type=str, default='train,valid,test')
    parser.add_argument('--model_types', type=comma_sep_str_list, default='lowest_total_loss_with_final_kl_model')
    parser.add_argument('--n_frames', type=int, default=4)
    parser.add_argument('--n_samples', type=int, default=5)
    parser.add_argument('--seed_eval', type=int, default=1000005)
    parser.add_argument('--do_inference', type=str_to_bool, default=True)

    parser.add_argument('--do_training', type=str_to_bool, default=True)

    args = parser.parse_args()

    ## launch training
    if args.do_training:
        command = 'python train_vae_mnist_with_fibers_simple_flexible.py'
        for arg in TRAINING_ARGS:
            command += ' --%s %s' % (arg, eval('args.%s' % arg)) # everything should format fine enough as a string
        os.system(command)

    ## call evaluation pipeline on each requested model type
    for model_type in args.model_types:
        ## call evaluation pipeline on each split separately
        command = 'python evaluation_pipeline.py'
        command += ' --model_class %s' % (args.model_class)
        command += ' --w3j_filepath %s' % (args.w3j_filepath)
        command += ' --input_type %s' % (args.input_type)
        command += ' --model_dir %s' % (args.model_dir)
        command += ' --splits %s' % (args.splits)
        command += ' --model_type %s' % (model_type)
        command += ' --hash %s' % (args.hash)
        command += ' --n_samples %d' % (args.n_samples)
        command += ' --seed %d' % (args.seed_eval)
        command += ' --do_inference %s' % (args.do_inference)
        os.system(command)
    
