import json
import os
from pathlib import Path


# SSWIM

def export(name,
           dataset,
           seeds,
           horizon,
           n_hidden,
           sample_kernel,
           objective,
           normaliser,
           size_init_batch,
           batch_size,
           metrics,
           reg,
           ):
    experiments_dir = Path(__file__).resolve().parent / f"configs_ablation"
    experiments_dir.mkdir(parents=True, exist_ok=True)
    data_dict = {
        "dataset": dataset,
        "horizon": horizon,
        "seeds": seeds,
    }
    model_dict = {
        "n_hidden": n_hidden,
        "sample_kernel": sample_kernel,
        "objective": objective,
        "normaliser": normaliser,
    }
    train_dict = {
        "size_init_batch": size_init_batch,
        "batch_size": batch_size,
        "reg": reg,
        "metrics": metrics,
    }
    config = {
        "name": name,
        "results_dir": "results_ablation",
        "data": data_dict,
        "model": model_dict,
        "train": train_dict,
    }
    with open(os.path.join(experiments_dir, name + ".json"), "w") as f:
        json.dump(config, f)





for dataset in ["electricity", "pems-bay"]:
    for n_hidden in [50, 100, 150, 250, 350, 450, 550, 650, 750]:
        for objective in ["random", "dot", "dist"]:
            for normaliser in [("MS", 0.5, 0.5), ("MS", 0, 1), ("F", 0.5), ("F", 1)]:
                if not os.path.isfile(f"results_ablation/{dataset}_{n_hidden}_{objective}_{normaliser}.json"):
                    export(
                        name=f"{dataset}_{n_hidden}_{objective}_{normaliser}",
                        dataset=dataset,
                        seeds=[0, 42, 100],
                        horizon=48,
                        sample_kernel="hat",
                        normaliser=normaliser,
                        objective=objective,
                        n_hidden=n_hidden,
                        reg=1e-4,
                        size_init_batch=1000,
                        batch_size=1000 if dataset != "electricity" else 700,
                        metrics="best"
                    )
                else:
                    print(f"Skipping {dataset}_{n_hidden}_{objective}_{normaliser}")
