import argparse
import os.path as osp

from ddm4signal.experiment import Experiment

def search_space(trial):
    return {
        "timestep": trial.suggest_int("timestep", 1, 100),
        "max_step": trial.suggest_categorical("max_step", [1000, 3000, 5000]),
        # "ratio1": trial.suggest_float("ratio1", 0.1, 2.0),
        'min_noise': trial.suggest_float('min_noise', 1e-10, 1e-6),
        'max_noise': trial.suggest_float('max_noise', 1e-6, 1e-1)
    }

from ddm4signal.experiment import Experiment
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", default="signal", type=str, help="name of models")
    parser.add_argument("--task", "-t", default="signal", type=str, help="name of task")
    parser.add_argument("--dataset", "-d", default="RML2016.10a", type=str, help="list of datasets splited with #")
    parser.add_argument("--gpu", "-g", default="0", type=str, help="-1 means cpu")
    parser.add_argument("--load_from_pretrained", action="store_true", help="load model from the checkpoint")
    parser.add_argument("--compile", action="store_true", help="compile model")
    parser.add_argument("--search", action="store_true")
    parser.add_argument("--hpo_trials", default=15, type=int)
    parser.add_argument("--batch_size", default=256, type=int)
    parser.add_argument("--length", default=128, type=int)
    parser.add_argument("--num_classes", default=11, type=int)
    parser.add_argument("--max_step", default=1000, type=int)
    parser.add_argument("--timestep", default=30, type=int)
    parser.add_argument("--ratio1", default=1.04, type=float)
    parser.add_argument("--min_noise", default=4.28e-7, type=float)
    parser.add_argument("--max_noise", default=0.01189, type=float)

    args = parser.parse_args()

    dataset = args.dataset.split("#")
    gpu = list(map(int, args.gpu.split("#")))
    if len(gpu) == 1:
        gpu = gpu[0]
    elif any(x < 0 for x in gpu):
        raise ValueError("Negative numbers should not appear in the GPU list!")
    
    if args.search:
        space = search_space
    else:
        space = None

    experiment = Experiment(model=args.model, dataset=dataset, task=args.task, gpu=gpu, batch_size=args.batch_size,
                            load_from_pretrained=args.load_from_pretrained, compile_flag=args.compile,
                            evaluate_interval=1, max_step=args.max_step, ratio1=args.ratio1, signal_length=args.length,
                            num_classes = args.num_classes, min_noise=args.min_noise, max_noise=args.max_noise, 
                            timestep=args.timestep, hpo_search_space=space, hpo_trials=args.hpo_trials)

    experiment.run()