import os
import sys
import time
import wandb
import gymnasium as gym
import optax
import jax.numpy as jnp
from wandb.integration.sb3 import WandbCallback
from sbx.common.callbacks import EvalCallback
from sbx import SAC

from experiments.parser import parse


os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["WANDB_DIR"] = "/tmp"


def run_cli(argvs=sys.argv[1:]):
    args = parse(argvs)
    experiment_time = time.time()

    with wandb.init(
        entity="YOUR ENTITY",
        project=args["wandb_project"],
        name=f"seed={args['seed']}",
        group=args["group"],
        tags=[],
        sync_tensorboard=True,
        config=args,
        mode="online" if args.get("wandb_project") is not None else "disabled",
    ) as wandb_run:
        training_env = gym.make(args["env"])

        model = SAC(
            ("MlpPolicy"),
            training_env,
            learning_rate=0.001,
            qf_learning_rate=args["learning_rates"],
            learning_starts=args["n_initial_samples"],
            tensorboard_log=f"logs/{args['group'] + '/seed_' + str(args['seed']) + '/time=' + str(experiment_time)}/",
            gradient_steps=args["utd"],
            policy_delay=args["utd"],
            random_target_qf=args["random_target_qf"],
            all_policy_qf=args["all_policy_qf"],
            policy_kwargs=dict(
                {
                    "net_arch": {"pi": [256, 256], "qf": args["net_archs_qf"]},
                    "activation_fn": args["activation_fns"],
                    "optimizer_class": args["optimizer_classes"],
                    "m_critics": args["m_critics"],
                    "random_target_qf": args["random_target_qf"],
                    "aggregate_target_qf": (
                        jnp.min if args["aggregate_target_qf"] == "min" else jnp.mean
                    ),
                    "all_policy_qf": args["all_policy_qf"],
                    "end_epsilon": args["end_epsilon"],
                    "epsilon_duration": args["n_samples"],
                    "aggregate_policy_qf": (
                        jnp.min if args["aggregate_policy_qf"] == "min" else jnp.mean
                    ),
                }
            ),
            seed=args["seed"],
        )

        # Create log dir where evaluation results will be saved
        eval_log_dir = f"./eval_logs/{args['group'] + '/seed_' + str(args['seed']) + '/time=' + str(experiment_time)}/eval/"
        os.makedirs(eval_log_dir, exist_ok=True)

        # Create callback that evaluates agent
        eval_callback = EvalCallback(
            gym.make(args["env"]),
            jax_random_key_for_seeds=args["seed"],
            best_model_save_path=None,
            log_path=eval_log_dir,
            n_eval_episodes=1,
            deterministic=True,
            render=False,
        )
        callback_list = [eval_callback, WandbCallback(verbose=0)]

        model.learn(
            total_timesteps=args["n_samples"], progress_bar=True, callback=callback_list
        )
