import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--method",
                    help="the method to use",
                    type=str,
                    default='base')
parser.add_argument('--device',
                    help='what device to perform training on',
                    type=str,
                    default='cuda:0')
parser.add_argument('--dataset',
                    help='what dataset to perform training on',
                    type=str,
                    default='all')
parser.add_argument("--kurtosis_weight",
                    help="weight for the kurtosis term",
                    type=float,
                    default=0.1)
parser.add_argument("--pareto_weight",
                    help="weight for the pareto margin loss",
                    type=float,
                    default=1.0)
parser.add_argument("--pareto_xi",
                    help="GPD fit parameter (shape)",
                    type=float,
                    default=0.1)
parser.add_argument("--pareto_sigma",
                    help="GPD fit parameter (scale)",
                    type=float,
                    default=1.0)
parser.add_argument("--model_name",
                    help="name to use for saved models",
                    type=str,
                    default='model')

datasets = ["hotel", "univ", "zara1", "zara2", "eth"]

def train(method, device, dataset, args):
    flags = ""
    if method != "base":
        flags = "--"+method
    name = args.model_name
    if method == "base":
        if dataset!= "nuscenes":
            command = "python train.py --save_every 50 --train_data_dict " + dataset + "_train.pkl --eval_data_dict " + dataset + "_val.pkl --offline_scene_graph yes --preprocess_workers 5 --log_dir ./models --log_tag _" + dataset + "_EWTA --train_epochs 500 --augment --batch_size 512 --conf ./models/" + dataset + "_EWTA/config.json --data_dir ./Trajectron_plus_plus/experiments/processed --model_name " + method + "_" + dataset + "_EWTA --device " + str(device)
        else:
            command = "python train.py --device "+device+" --model_name "+name+" --conf ./models/nuScenes_EWTA/config.json --data_dir ./Trajectron_plus_plus/experiments/processed --train_data_dict nuScenes_train_full.pkl --offline_scene_graph yes --preprocess_workers 5 --batch_size 1024 --log_dir ./models --train_epochs 25 --node_freq_mult_train --log_tag _int_ee_me --map_encoding --augment"
        os.system(command)
        return
    elif method == "kurtosis":
        flags += " --kurtosis_weight " + str(args.kurtosis_weight)
    elif method == "pareto_weighted":
        flags += " --pareto_weight " + str(args.pareto_weight)
        flags += " --pareto_xi " + str(args.pareto_xi)
        flags += " --pareto_sigma " + str(args.pareto_sigma)
    elif method == "pareto_margin":
        flags += " --pareto_weight " + str(args.pareto_weight)
        flags += " --pareto_xi " + str(args.pareto_xi)
        flags += " --pareto_sigma " + str(args.pareto_sigma)
    if dataset != "nuscenes":
        command = "python train.py --save_every 50 --train_data_dict "+dataset+"_train.pkl --eval_data_dict "+dataset+"_val.pkl --offline_scene_graph yes --preprocess_workers 5 --log_dir ./models --log_tag _"+dataset+"_EWTA --train_epochs 500 --augment --batch_size 512 --conf ./models/"+dataset+"_EWTA/config.json --data_dir ./Trajectron_plus_plus/experiments/processed "+flags+" --model_name " + name + "_" + dataset + "_EWTA --device " + str(device)
    else:
        command = "python train.py --device "+device+" --model_name "+name+" "+flags+" --conf ./models/nuScenes_EWTA/config.json --data_dir ./Trajectron_plus_plus/experiments/processed --train_data_dict nuScenes_train_full.pkl --offline_scene_graph yes --preprocess_workers 10 --batch_size 1024 --log_dir ./models --train_epochs 25 --node_freq_mult_train --log_tag _int_ee_me --map_encoding --augment"
    print(command)
    os.system(command)

def main():
    args = parser.parse_args()
    method = args.method
    device = args.device
    train(method, device, args.dataset, args)
if __name__ == '__main__':
    main()