import argparse
from pathlib import Path
import time
from functools import partial

import jax, jax.numpy as jnp


from relax.algorithm.dsact import DSACT
from relax.algorithm.dsace import DSACE
from relax.buffer import TreeBuffer
from relax.network.dsact import create_dsact_net
from relax.network.dsace import create_dsace_net
from relax.trainer.off_policy import OffPolicyTrainer
from relax.env import create_env, create_vector_env
from relax.utils.experience import Experience, ObsActionPair
from relax.utils.fs import PROJECT_ROOT
from relax.utils.random_utils import seeding


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--alg", type=str, default="dsace")
    parser.add_argument("--env", type=str, default="Humanoid-v3")
    parser.add_argument("--num_vec_envs", type=int, default=20)
    parser.add_argument("--hidden_num", type=int, default=3)
    parser.add_argument("--hidden_dim", type=int, default=256)
    parser.add_argument("--diffusion_hidden_dim", type=int, default=256)
    parser.add_argument("--start_step", type=int, default=int(1e4))
    parser.add_argument("--total_step", type=int, default=int(3e7))
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--entropy_ratio", type=float, default=20.0)
    args = parser.parse_args()

    master_seed = args.seed
    master_rng, _ = seeding(master_seed)
    env_seed, env_action_seed, eval_env_seed, buffer_seed, init_network_seed, train_seed = map(
        int, master_rng.integers(0, 2**32 - 1, 6)
    )
    init_network_key = jax.random.key(init_network_seed)
    train_key = jax.random.key(train_seed)
    del init_network_seed, train_seed

    if args.num_vec_envs > 0:
        env, obs_dim, act_dim = create_vector_env(args.env, args.num_vec_envs, env_seed, env_action_seed, mode="futex")
    else:
        env, obs_dim, act_dim = create_env(args.env, env_seed, env_action_seed)
    eval_env = None

    hidden_sizes = [args.hidden_dim] * args.hidden_num
    diffusion_hidden_sizes = [args.diffusion_hidden_dim] * args.hidden_num

    buffer = TreeBuffer.from_experience(obs_dim, act_dim, size=int(1e6), seed=buffer_seed)

    gelu = partial(jax.nn.gelu, approximate=False)
    if args.alg == "dsact":
        agent, params = create_dsact_net(init_network_key, obs_dim, act_dim, hidden_sizes, gelu, args.entropy_ratio)
        algorithm = DSACT(agent, params, lr=args.lr)
    elif args.alg == "dsace":
        agent, params = create_dsace_net(init_network_key, obs_dim, act_dim, hidden_sizes, gelu, args.entropy_ratio)
        algorithm = DSACE(agent, params, lr=args.lr)
    else:
        raise ValueError(f"Invalid algorithm {args.alg}!")

    trainer = OffPolicyTrainer(
        env=env,
        algorithm=algorithm,
        buffer=buffer,
        start_step=args.start_step,
        total_step=args.total_step,
        sample_per_iteration=1,
        evaluate_env=eval_env,
        save_policy_every=300000,
        warmup_with="random",
        log_path=PROJECT_ROOT / "logs" / args.env /
                 (args.alg + '_' + time.strftime("%Y-%m-%d_%H-%M-%S") + f'_s{args.seed}_ent{args.entropy_ratio}'),
    )

    trainer.setup(Experience.create_example(obs_dim, act_dim, trainer.batch_size))
    trainer.run(train_key)
