import os
import sys
import time

import gymnasium as gym
import jax.numpy as jnp
import optax
import wandb
from wandb.integration.sb3 import WandbCallback

from experiments.parser import parse
from sbx import SAC
from sbx.common.callbacks import EvalCallback

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["WANDB_DIR"] = "/tmp"


def run_cli(argvs=sys.argv[1:]):
    args = parse(argvs)
    experiment_time = time.time()

    with wandb.init(
        entity="YOUR_ENTITY",
        project=args["wandb_project"],
        name=f"seed={args['seed']}",
        group=args["group"],
        tags=[],
        sync_tensorboard=True,
        config=args,
        mode="online" if args.get("wandb_project") is not None else "disabled",
    ) as wandb_run:
        training_env = gym.make(args["env"])

        model = SAC(
            ("MlpPolicy"),
            training_env,
            learning_rate=0.001,
            qf_learning_rate=args["learning_rates"],
            learning_starts=args["n_initial_samples"],
            tensorboard_log=f"logs/{args['group'] + '/seed_' + str(args['seed']) + '/time=' + str(experiment_time)}/",
            gradient_steps=args["utd"],
            policy_delay=args["utd"],
            random_target_qf=args["random_target_qf"],
            all_policy_qf=args["all_policy_qf"],
            policy_kwargs=dict(
                {
                    "net_arch": {"pi": [256, 256], "qf": args["net_archs_qf"]},
                    "activation_fn": args["activation_fns"],
                    "optimizer_class": args["optimizer_classes"],
                    "m_critics": args["m_critics"],
                    "random_target_qf": args["random_target_qf"],
                    "aggregate_target_qf": jnp.min if args["aggregate_target_qf"] == "min" else jnp.mean,
                    "all_policy_qf": args["all_policy_qf"],
                    "end_epsilon": args["end_epsilon"],
                    "epsilon_duration": args["n_samples"],
                    "aggregate_policy_qf": jnp.min if args["aggregate_policy_qf"] == "min" else jnp.mean,
                }
            ),
            seed=args["seed"],
        )

        # Create log dir where evaluation results will be saved
        eval_log_dir = (
            f"./eval_logs/{args['group'] + '/seed_' + str(args['seed']) + '/time=' + str(experiment_time)}/eval/"
        )
        os.makedirs(eval_log_dir, exist_ok=True)

        # Create callback that evaluates agent
        eval_callback = EvalCallback(
            gym.make(args["env"]),
            jax_random_key_for_seeds=args["seed"],
            best_model_save_path=None,
            log_path=eval_log_dir,
            n_eval_episodes=1,
            deterministic=True,
            render=False,
        )
        callback_list = [eval_callback, WandbCallback(verbose=0)]

        model.learn(total_timesteps=args["n_samples"], progress_bar=True, callback=callback_list)
