import argparse

from ddm4signal.experiment import Experiment

def search_space(trial):
    return {
        "timestep": trial.suggest_int("timestep", 1, 100),
        "max_step": trial.suggest_categorical("max_step", [1000, 3000, 5000]),
        "ratio": trial.suggest_float("ratio", 0.1, 1.0),
        'min_noise': trial.suggest_float('min_noise', 1e-9, 1e-6),
        'max_noise': trial.suggest_float('max_noise', 1e-5, 1e-1)
    }

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", default="signal", type=str, help="name of models")
    parser.add_argument("--task", "-t", default="train", type=str, help="name of task")
    parser.add_argument("--dataset", "-d", default="RML2016.10a", type=str, help="list of datasets splited with #")
    parser.add_argument("--gpu", "-g", default="0", type=str, help="-1 means cpu")
    parser.add_argument("--load_from_pretrained", action="store_true", help="load model from the checkpoint")
    parser.add_argument("--compile", action="store_true", help="compile model")

    parser.add_argument("--search", action="store_true")
    parser.add_argument("--hpo_trials", default=10, type=int)

    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--length", default=1024, type=int)
    parser.add_argument("--num_layers", default=4, type=int)
    parser.add_argument("--num_classes", default=11, type=int)
    parser.add_argument("--max_step", default=3000, type=int)
    parser.add_argument("--timestep", default=4, type=int)
    parser.add_argument("--ratio", default=0.414, type=float)
    parser.add_argument("--min_noise", default=5.45e-6, type=float)
    parser.add_argument("--max_noise", default=0.0072, type=float)

    args = parser.parse_args()

    dataset = args.dataset.split("#")
    gpu = list(map(int, args.gpu.split("#")))
    if len(gpu) == 1:
        gpu = gpu[0]
    elif any(x < 0 for x in gpu):
        raise ValueError("Negative numbers should not appear in the GPU list!")
    
    if args.search:
        space = search_space
    else:
        space = None
    
    experiment = Experiment(model=args.model, dataset=dataset, task=args.task, gpu=gpu, batch_size=args.batch_size,
                            load_from_pretrained=args.load_from_pretrained, compile_flag=args.compile,
                            evaluate_interval=1, max_step=args.max_step, ratio=args.ratio, signal_length=args.length,
                            num_classes = args.num_classes, min_noise=args.min_noise, max_noise=args.max_noise, 
                            timestep=args.timestep, hpo_search_space=space, hpo_trials=args.hpo_trials,
                            num_layers=args.num_layers)

    experiment.run()