import json
import argparse
from trainer import train
import random

ep_arr = [170]
# milestone_arr = [2, 3, 4]

lr_arr = [0.1]
lr_decay_arr = [0.1]
batch_arr = [128]
w_decay_arr = [5e-4]
scheduler_arr = ['cosine']

T_arr = [2]
lambda_kd_arr = [1]
lambda_fe_arr = [1]
beta1_arr = [0.96]
beta2_arr = [0.97]
comp_ep_arr = [130]

seed_arr = [0,1,2,3,4]
total_rand_num = 1

def main():
    args = setup_parser().parse_args()
    param = load_json(args.config)
    args = vars(args)  # Converting argparse Namespace to a dict.
    args.update(param)  # Add parameters from json
    
    for rand_num in range(total_rand_num):
        
        seed = random.choice(seed_arr)
        ep = random.choice(ep_arr)
#         milestone_num = random.choice(milestone_arr)
        lr = random.choice(lr_arr)
        lr_decay = random.choice(lr_decay_arr)
        batch = random.choice(batch_arr)
        w_decay = random.choice(w_decay_arr)
        scheduler = random.choice(scheduler_arr)
        
        T = random.choice(T_arr)
        lambda_kd = random.choice(lambda_kd_arr)
        lambda_fe = random.choice(lambda_fe_arr)
        beta1 = random.choice(beta1_arr)
        beta2 = random.choice(beta2_arr)
        comp_ep = random.choice(comp_ep_arr)

        milestones = [int(ep*(2/5)), int(ep*(4/5))]
        comp_milestones = [int(comp_ep*(2/5)), int(comp_ep*(4/5))]
            
        prefix = "0_origin_ep_{}_lr_{}_lr_decay_{}_batch_{}_w_decay_{}_scheduler_{}_T_{}_lambda_kd_{}_fe_{}_beta_{}_{}_comp_ep_{}".format(
                        ep,
                        lr,
                        lr_decay,
                        batch,
                        w_decay,
                        scheduler,
                        T,
                        lambda_kd,
                        lambda_fe,
                        beta1,
                        beta2,
                        comp_ep,
                    )
            
        parameters = {
            "seed":seed_arr, 
            "prefix":prefix, 
            "epochs":ep,
            "boosting_epochs":ep,
            "lrate":lr,
            "milestones":milestones,
            "lrate_decay":lr_decay,
            "batch_size":batch,
            "weight_decay":w_decay,
            "scheduler":scheduler,
            "T":T,
            "lambda_okd":lambda_kd,
            "lambda_fe":lambda_fe,
            "beta1":beta1,
            "beta2":beta2,
            "compression_epochs":comp_ep,
            "comp_milestones": comp_milestones
        }
        args.update(parameters)  # Add parameters from json

        print (args)

        train(args)




def load_json(settings_path):
    with open(settings_path) as data_file:
        param = json.load(data_file)

    return param


def setup_parser():
    parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
    parser.add_argument('--config', type=str, default='./exps/finetune.json',
                        help='Json file of settings.')

    return parser


if __name__ == '__main__':
    main()
