import json
import argparse
from trainer import train
import random

ep_arr = [30, 70, 120, 160, 200]
milestone_arr = [2, 3, 4]

lr_arr = [0.05, 0.1, 0.15, 0.2, 0.3]
lr_decay_arr = [0.1, 0.3, 0.5]
batch_arr = [32, 64, 128, 256, 512]
w_decay_arr = [0.0001, 0.0005, 0.001, 0.005]
scheduler_arr = ['steplr', 'cosine']

lambda_c_arr = [1,3,5,7,9]
lambda_f_arr = [0.5, 1, 1.5, 2, 3,]
nb_proxy_arr = [10, 20, 30, 50, 100]
ft_epochs_arr = [5, 10, 20, 30, 50]
ft_lrate_arr = [0.001, 0.003, 0.005, 0.007, 0.01]

seed_arr = [0,1,2]
total_rand_num = 20

def main():
    args = setup_parser().parse_args()
    param = load_json(args.config)
    args = vars(args)  # Converting argparse Namespace to a dict.
    args.update(param)  # Add parameters from json
    
    for rand_num in range(total_rand_num):
        
        seed = random.choice(seed_arr)
        ep = random.choice(ep_arr)
        milestone_num = random.choice(milestone_arr)
        lr = random.choice(lr_arr)
        lr_decay = random.choice(lr_decay_arr)
        batch = random.choice(batch_arr)
        w_decay = random.choice(w_decay_arr)
        scheduler = random.choice(scheduler_arr)
        
        lambda_c = random.choice(lambda_c_arr)
        lambda_f = random.choice(lambda_f_arr)
        nb_proxy = random.choice(nb_proxy_arr)
        ft_epochs = random.choice(ft_epochs_arr)
        ft_lrate = random.choice(ft_lrate_arr)
        
        if random.random() > 0.5:
            adaptive_factor = True
        else:
            adaptive_factor = False
            
        
        
        if milestone_num == 2:
            
            milestones = [int(ep*(2/5)), int(ep*(4/5))]
            ft_milestones = [int(ft_epochs*(2/5)), int(ft_epochs*(4/5))]
            
        elif milestone_num == 3:
            
            milestones = [int(ep*(2/7)), int(ep*(4/7)), int(ep*(6/7))]
            ft_milestones = [int(ft_epochs*(2/7)), int(ft_epochs*(4/7)), int(ft_epochs*(6/7))]
            
        elif milestone_num == 4:
            
            milestones = [int(ep*(2/9)), int(ep*(4/9)), int(ep*(6/9)), int(ep*(8/9))]
            ft_milestones = [int(ft_epochs*(2/9)), int(ft_epochs*(4/9)), int(ft_epochs*(6/9)), int(ft_epochs*(8/9))]
            
        prefix = "rand_num_{}_ep_{}_milestone_{}_lr_{}_lr_decay_{}_batch_{}_w_decay_{}_scheduler_{}_lambda_c_{}_lambda_f_{}_nb_proxy_{}_ft_epochs_{}_ft_lrate_{}_adaptive_factor_{}".format(
                        rand_num,
                        ep,
                        milestone_num,
                        lr,
                        lr_decay,
                        batch,
                        w_decay,
                        scheduler,
                        lambda_c,
                        lambda_f,
                        nb_proxy,
                        ft_epochs,
                        ft_lrate,
                        adaptive_factor
                    )
            
        parameters = {
            "seed":seed_arr, 
            "prefix":prefix, 
            "epochs":ep,
            "lrate":lr,
            "milestones":milestones,
            "lrate_decay":lr_decay,
            "batch_size":batch,
            "weight_decay":w_decay,
            "scheduler":scheduler,
            "lambda_c":lambda_c,
            "lambda_f":lambda_f,
            "nb_proxy":nb_proxy,
            "ft_epochs":ft_epochs,
            "ft_lrate":ft_lrate,
            "ft_milestones":ft_milestones,
            "adaptive_factor":adaptive_factor
        }

        args.update(parameters)  # Add parameters from json

        print (args)

        train(args)
                        
    


def load_json(settings_path):
    with open(settings_path) as data_file:
        param = json.load(data_file)

    return param


def setup_parser():
    parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
    parser.add_argument('--config', type=str, default='./exps/finetune.json',
                        help='Json file of settings.')

    return parser


if __name__ == '__main__':
    main()
