import argparse

import rl_utils as dc
from rl_utils.envs import load_dmc, ActionRepeatOutputWrapper
from gym.wrappers import TimeLimit


def train_dmc_sac(args):
    train_env = ActionRepeatOutputWrapper(
        TimeLimit(load_dmc(**vars(args)), args.max_episode_steps), repeat_multiplier=15
    )
    test_env = ActionRepeatOutputWrapper(
        TimeLimit(load_dmc(**vars(args)), args.max_episode_steps), repeat_multiplier=15
    )

    obs_shape = train_env.observation_space.shape
    action_shape = train_env.action_space.shape

    agent = dc.sac.SACAgent(
        obs_shape[0],
        action_shape[0],
        args.log_std_low,
        args.log_std_high,
        hidden_size=256,
    )

    buffer_t = dc.replay.ReplayBuffer
    buffer = buffer_t(
        args.buffer_size,
        state_dtype=float,
        state_shape=obs_shape,
        action_shape=action_shape,
    )

    agent = dc.sac.sac(
        agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args)
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seeds", type=int, default=1)
    # add dmc-related cl args
    dc.envs.add_dmc_args(parser)
    # add sac-related cl args
    dc.sac.add_args(parser)
    args = parser.parse_args()
    args.from_pixels = False
    args.max_episode_steps = 1000
    args.actor_lr = 3e-4
    args.critic_lr = 3e-4
    for _ in range(args.seeds):
        train_dmc_sac(args)
