import argparse

import rl_utils as dc
import ar_aware_sac
from action_repeat_wrapper import ActionRepeatWrapper

from parallel_pbt_ac import *


def train_dmc_sac(args):

    train_env = ActionRepeatWrapper(
        dc.envs.load_dmc(args.domain, args.task),
        return_history=False,
        discount=args.gamma,
    )
    test_env = ActionRepeatWrapper(
        dc.envs.load_dmc(args.domain, args.task),
        return_history=False,
        discount=args.gamma,
    )

    obs_shape = train_env.observation_space.shape
    action_shape = train_env.action_space.shape
    max_action = train_env.action_space.high[0]

    agent = ar_aware_sac.SACAgent(
        obs_shape[0],
        action_shape[0],
        args.log_std_low,
        args.log_std_high,
        hidden_size=256,
    )

    buffer_t = dc.replay.ReplayBuffer
    buffer = buffer_t(
        args.buffer_size,
        state_dtype=float,
        state_shape=obs_shape,
        action_shape=action_shape,
    )

    agent = ar_aware_sac.sac(
        agent=agent, train_env=train_env, test_env=test_env, buffer=buffer, **vars(args)
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seeds", type=int, default=1)
    parser.add_argument("--ar_max", type=int, default=15)
    parser.add_argument("--ar_min", type=int, default=1)
    parser.add_argument("--ar_thompson_sampling_estimated", action="store_true")
    parser.add_argument("--ar_thompson_sampling_true", action="store_true")
    parser.add_argument("--domain", type=str, required=True)
    parser.add_argument("--task", type=str, required=True)
    # add sac-related cl args
    dc.sac.add_args(parser)
    args = parser.parse_args()
    args.from_pixels = False
    args.max_episode_steps = 1000
    args.actor_lr = 3e-4
    args.critic_lr = 3e-4
    for _ in range(args.seeds):
        train_dmc_sac(args)
