import os

import jax
import json

from algorithms.adqn.adqn import adqn
from algorithms.adqn.adqn import adqn
from algorithms.adqn.utils.adqn_parser import parse_args
from environments.lunar_lander import LunarLander


def comparison_run(args, t_key, env):

    run_path = os.path.join(
        "runs",
        args.path,
    )

    if len(args.hidden_layers) > 1:
        hist_path = os.path.join(run_path, "adqn", str(args.criterion))
    else:
        hist_path = os.path.join(
            run_path, "dqn", f"{str(args.hidden_layers[0])}_{args.m_seeds[0]}"
        )
    os.makedirs(hist_path, exist_ok=True)

    if len(args.hidden_layers) > 1:
        with open(os.path.join(run_path, "params.json"), "w") as fp:
            json.dump(vars(args), fp, indent=4)

    hist = adqn(
        env,
        t_key,
        **vars(args),
        save_path=hist_path,
    )


if __name__ == "__main__":
    with open("experiments/lunar_lander_params.json", "r") as fp:
        lunar_lander_params = json.load(fp)
    args = parse_args(default_args=lunar_lander_params)

    t_key = jax.random.PRNGKey(args.t_seed)
    t_key, subkey = jax.random.split(t_key)
    env = LunarLander(subkey)

    print(json.dumps(vars(args), indent=4))

    comparison_run(args, t_key, env)
