import json
import os
from pathlib import Path


# SSWIM

def export(name,
           dataset,
           seeds,
           horizon,
           n_hidden,
           sample_kernel,
           objective,
           normaliser,
           size_init_batch,
           batch_size,
           metrics,
           reg,
           ):
    experiments_dir = Path(__file__).resolve().parent / f"configs_metric"
    experiments_dir.mkdir(parents=True, exist_ok=True)
    data_dict = {
        "dataset": dataset,
        "horizon": horizon,
        "seeds": seeds,
    }
    model_dict = {
        "n_hidden": n_hidden,
        "sample_kernel": sample_kernel,
        "objective": objective,
        "normaliser": normaliser,
    }
    train_dict = {
        "size_init_batch": size_init_batch,
        "batch_size": batch_size,
        "reg": reg,
        "metrics": metrics,
    }
    config = {
        "name": name,
        "results_dir": "results_metric",
        "data": data_dict,
        "model": model_dict,
        "train": train_dict,
    }
    with open(os.path.join(experiments_dir, name + ".json"), "w") as f:
        json.dump(config, f)


for dataset in ["solar", "metr-la", "electricity", "pems-bay"]:
    for H in [6, 24, 48]:
        for metrics in ["best", ("FourierMag", "FourierMag"),
                        ("L2", "L2"), ("BandedFourierHigh", "BandedFourierHigh"),
                        ("CosineDistance", "CosineDistance"),
                        ("FourierAngle", "FourierAngle"),
                        ("BandedFourierLow", "BandedFourierLow")

                        ]:
            name = f"{dataset}_{H}_{metrics}"
            if not os.path.isfile(f"results_metric/{name}.json"):
                export(
                    name=name,
                    dataset=dataset,
                    seeds=[0, 42, 100],
                    horizon=H,
                    sample_kernel="hat",
                    normaliser=("MS", 0.5, 0.5),
                    objective="dot",
                    n_hidden=750,
                    metrics=metrics,
                    reg=-1,
                    size_init_batch=1000,
                    batch_size=1000 if dataset != "electricity" else 700,
                )
            else:
                print(f"Skipping {name}")
