import json
import os
from pathlib import Path

# SGD
# This file generates the configs for the experiments

def export(name,
           results_dir,
           dataset,
           seeds,
           data_norm,
           horizon,
           n_hidden,
           sample_kernel,
           mu_u,
           xi,
           size_init_batch,
           batch_size,
           initial_lr,
           num_epochs,
           patience,
           min_delta,
           final_lr,
           lambda_reg
           ):
    experiments_dir = Path(__file__).resolve().parent / f"configs_{dataset}"
    experiments_dir.mkdir(parents=True, exist_ok=True)
    data_dict = {
        "dataset": dataset,
        "horizon": horizon,
        "seeds": seeds,
        "data_norm": data_norm,
    }
    model_dict = {
        "n_hidden": n_hidden,
        "sample_kernel": sample_kernel,
        "mu_u": mu_u,
        "xi": xi,
    }
    train_dict = {
        "size_init_batch": size_init_batch,
        "batch_size": batch_size,
        "initial_lr": initial_lr,
        "num_epochs": num_epochs,
        "patience": patience,
        "min_delta": min_delta,
        "final_lr": final_lr,
        "lambda_reg": lambda_reg,
    }
    config = {
        "name": name,
        "results_dir": results_dir,
        "data": data_dict,
        "model": model_dict,
        "train": train_dict,
    }
    with open(os.path.join(experiments_dir, name + ".json"), "w") as f:
        json.dump(config, f)

RESULTS_DIR = "results"

for sample_kernel in ["hat", "morlet"]:
    for dataset in ["solar", "electricity", "metr-la", "pems-bay"]:
        for horizon in [6, 24, 48, 96]:
                export(
                    name=f"{dataset}_{sample_kernel}_{horizon}",
                    results_dir = RESULTS_DIR,
                    dataset=dataset,
                    seeds=[0, 42, 100],
                    data_norm="zero_one",
                    horizon=horizon,
                    sample_kernel=sample_kernel,
                    n_hidden=750,
                    mu_u=0.5,
                    xi=1,
                    size_init_batch=1000,
                    batch_size=64,
                    initial_lr=5e-4,
                    num_epochs=1,
                    patience=30,
                    min_delta=1e-6,
                    final_lr=1e-05,
                    lambda_reg=1e-04
                )
