
import os, sys
import argparse

sys.path.append('..')
from utils.argparse_utils import *


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_class', type=str, default='fibers')
    parser.add_argument('--w3j_filepath', type=str, default='../cg_coefficients/w3j_matrices-lmax=14-version=0.5.0.pkl')
    parser.add_argument('--input_type', type=str)
    parser.add_argument('--model_dir', type=str, default='../runs/mnist/local_equiv_fibers')
    parser.add_argument('--splits', type=str)
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model')
    parser.add_argument('--hash', type=str)
    parser.add_argument('--n_samples', type=int, default=4)
    parser.add_argument('--seed', type=int, default=1000005)
    parser.add_argument('--do_inference', type=str_to_bool, default=True)

    args = parser.parse_args()

    ## call inference on all splits
    if args.do_inference:
        if args.model_class == 'fibers':
            command = 'python inference_fibers.py'
            command += ' --w3j_filepath %s' % (args.w3j_filepath)
            command += ' --input_type %s' % (args.input_type)
            command += ' --model_dir %s' % (args.model_dir)
            command += ' --splits %s' % (args.splits)
            command += ' --model_type %s' % (args.model_type)
            command += ' --hash %s' % (args.hash)
            command += ' --seed %d' % (args.seed)
            os.system(command)
        
        elif args.model_class == 'e3nn':
            command = 'python inference.py'
            command += ' --input_type %s' % (args.input_type)
            command += ' --model_dir %s' % (args.model_dir)
            command += ' --splits %s' % (args.splits)
            command += ' --model_type %s' % (args.model_type)
            command += ' --hash %s' % (args.hash)
            command += ' --seed %d' % (args.seed)
            os.system(command)


    ## call evaluation on each split separately
    for split in args.splits.split(','):
        command = 'python evaluation.py'
        command += ' --input_type %s' % (args.input_type)
        command += ' --model_dir %s' % (args.model_dir)
        command += ' --split %s' % (split)
        command += ' --model_type %s' % (args.model_type)
        command += ' --hash %s' % (args.hash)
        command += ' --seed %d' % (args.seed)
        os.system(command)


    ## call reconstructions on all splits except training (don't really need them)
    for split in args.splits.split(','):
        if split != 'train':
            command = 'python reconstructions.py'
            command += ' --model_class %s' % (args.model_class)
            command += ' --w3j_filepath %s' % (args.w3j_filepath)
            command += ' --input_type %s' % (args.input_type)
            command += ' --model_dir %s' % (args.model_dir)
            command += ' --split %s' % (split)
            command += ' --model_type %s' % (args.model_type)
            command += ' --hash %s' % (args.hash)
            command += ' --n_samples %d' % (args.n_samples)
            command += ' --seed %d' % (args.seed)
            os.system(command)


