
import os, sys
import argparse

sys.path.append('..')
from utils.argparse_utils import *

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_class', type=str, default='fibers')
    parser.add_argument('--w3j_filepath', type=str, default='../cg_coefficients/w3j_matrices-lmax=14-version=0.5.0.pkl')
    parser.add_argument('--data_dir', type=str, default='../data/neighborhoods/data')
    parser.add_argument('--model_dir', type=str, default='../runs/neighborhoods/local_equiv_fibers')
    parser.add_argument('--splits', type=str, default='test')
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model')
    parser.add_argument('--hash', type=str, required=True)
    parser.add_argument('--seed', type=int, default=1000005) # 1000005, 1000006
    parser.add_argument('--do_inference', type=str_to_bool, default=True)

    args = parser.parse_args()

    if args.do_inference:
        ## call inference on all splits
        if args.model_class == 'fibers':
            command = 'python inference_fibers_zernicke.py'
            command += ' --w3j_filepath %s' % (args.w3j_filepath)
            command += ' --data_dir %s' % (args.data_dir)
            command += ' --model_dir %s' % (args.model_dir)
            command += ' --splits %s' % (args.splits)
            command += ' --model_type %s' % (args.model_type)
            command += ' --hash %s' % (args.hash)
            command += ' --seed %d' % (args.seed)
            os.system(command)

    ## call evaluation on each split separately
    for split in args.splits.split(','):
        command = 'python evaluation_zernicke.py'
        command += ' --model_dir %s' % (args.model_dir)
        command += ' --split %s' % (split)
        command += ' --model_type %s' % (args.model_type)
        command += ' --hash %s' % (args.hash)
        command += ' --seed %d' % (args.seed)
        os.system(command)


