import argparse
import random

import rl_utils as dc

from parallel_pbt_ac import *


def train_sac(args):
    train_env = eval(args.make_env_func)()
    test_env = eval(args.make_env_func)()

    obs_shape = train_env.observation_space.shape
    action_shape = train_env.action_space.shape

    agent = dc.sac.SACAgent(
        obs_shape[0],
        action_shape[0],
        args.log_std_low,
        args.log_std_high,
        hidden_size=256,
    )

    buffer_t = dc.replay.ReplayBuffer
    buffer = buffer_t(
        args.buffer_size,
        state_dtype=float,
        state_shape=train_env.observation_space.shape,
        action_shape=train_env.action_space.shape,
    )

    if args.random_hparams:

        k_max = 6 if "inventory" in args.make_env_func else 16
        k = random.choice([x for x in range(1, k_max)])
        train_env.set_k(k)
        test_env.set_k(k)
        train_env.return_history = False
        test_env.return_history = False

        gamma = random.uniform(0.65, 0.999)
        transitions_per_step = 20
        num_steps = round(args.num_steps / transitions_per_step)
        tau = random.uniform(0.001, 0.05)
        target_delay = random.choice([1, 2])
        actor_delay = random.choice([1, 2, 3, 4, 5])
        gradient_updates_per_step = random.choice([x for x in range(1, 40)])
        target_entropy = random.uniform(0.25, 1.75) * -float(action_shape[0])
        if "industrial_benchmark" in args.make_env_func:
            batch_size = 128
        elif "inventory" in args.make_env_func:
            batch_size = 128
        else:
            batch_size = 512
        agent = dc.sac.sac(
            num_steps=num_steps,
            agent=agent,
            train_env=train_env,
            test_env=test_env,
            buffer=buffer,
            batch_size=batch_size,
            transitions_per_step=transitions_per_step,
            tau=tau,
            gamma=gamma,
            target_delay=target_delay,
            actor_delay=actor_delay,
            gradient_updates_per_step=gradient_updates_per_step,
            target_entropy=target_entropy,
            actor_lr=3e-4,
            critic_lr=3e-4,
            max_episode_steps=args.max_episode_steps,
            verbosity=1,
            name=args.name,
        )
    else:
        agent = dc.sac.sac(
            agent=agent,
            train_env=train_env,
            test_env=test_env,
            buffer=buffer,
            **vars(args)
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seeds", type=int, default=1)
    parser.add_argument("--make_env_func", type=str, required=True)
    parser.add_argument("--random_hparams", action="store_true")
    # add sac-related cl args
    dc.sac.add_args(parser)
    args = parser.parse_args()
    args.from_pixels = False
    args.max_episode_steps = 1000
    args.actor_lr = 3e-4
    args.critic_lr = 3e-4
    for _ in range(args.seeds):
        train_sac(args)
