
import os
import yaml
import json
from itertools import product

from train import train
from evaluate import evaluate


os.makedirs("results", exist_ok=True)
experiments = yaml.load(open("experiments.yaml", "r"), Loader=yaml.FullLoader)


for experiment in experiments:

    if not experiment["do"]:
        continue

    res_dir = os.path.join("results", experiment["name"])
    os.makedirs(res_dir, exist_ok=True)

    varying_parameters = experiment["parameters"]
    parameters_names, parameters_values =  tuple(zip(*varying_parameters.items()))

    res = {}

    for parameters in product(*parameters_values):

        name = " _ ".join([f"{n.split('/')[-1]} = {p}" for n, p in zip(parameters_names, parameters)])

        print(name)

        config = yaml.load(open("base_config.yaml", "r"), Loader=yaml.FullLoader)
        config["save_folder"] = res_dir

        for n, p in zip(parameters_names, parameters):

            current = config
            for k in n.split("/")[:-1]:
                current = current[k]

            if n == "positional_encoding/num_landmarks":
                if p > 0:
                    config["positional_encoding"]["use"] = True
                    current[n.split("/")[-1]] = p

            elif experiment["name"] == "embedding_dim" and n == "model/params/out_channels":
                current[n.split("/")[-1]] = p
                config["model"]["params"]["hidden_channels"] = int(p * 2 // config["model"]["params"]["heads"])

            elif n == "model/name" and p in ["GCN", "GraphSAGE"]:
                del config["model"]["params"]["heads"]
                del config["model"]["params"]["residual"]
                del config["model"]["params"]["v2"]
                current[n.split("/")[-1]] = p

            else:
                current[n.split("/")[-1]] = p

        print(config)

        train(config, verbose=False)
        res_ = evaluate(config)
        res[name] = res_

        with open(os.path.join(res_dir, "results.json"), "w") as f:
            json.dump(res, f, indent=4)

    with open(os.path.join(res_dir, "results.json"), "w") as f:
        json.dump(res, f, indent=4)

