import os

import jax
import json

from algorithms.adqn.adqn import adqn
from algorithms.adqn.adqn import adqn
from algorithms.adqn.utils.adqn_parser import parse_args
from environments.lunar_lander import LunarLander


def debug(args, t_key, env):
    run_path = os.path.join(
        "runs",
        "debug",
        args.path,
    )

    if len(args.hidden_layers) > 1:
        hist_path = os.path.join(run_path, "adqn", str(args.criterion))
    else:
        hist_path = os.path.join(
            run_path, "dqn", f"{str(args.hidden_layers[0])}_{args.m_seeds[0]}"
        )
    os.makedirs(hist_path, exist_ok=True)

    with open(os.path.join(run_path, "params.json"), "w") as fp:
        json.dump(vars(args), fp, indent=4)

    hist = adqn(
        env,
        t_key,
        **vars(args),
        save_path=hist_path,
    )
    print("done")


if __name__ == "__main__":
    with open("experiments/lunar_lander_debug_params.json", "r") as fp:
        lunar_lander_params = json.load(fp)
    args = parse_args(
        "-hl 100-100 100-100 -s 1 -ms 0 1 --criterion random -p test".split(),
        lunar_lander_params,
    )

    t_key = jax.random.PRNGKey(args.t_seed)
    t_key, subkey = jax.random.split(t_key)
    env = LunarLander(subkey)

    print(json.dumps(vars(args), indent=4))

    debug(args, t_key, env)
