"""
Taken from the rlkit repo.
"""
import argparse
import gym

import rlkit.torch.pytorch_util as ptu
from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
from rlkit.envs.custom import CUSTOM_ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.launchers.launcher_util import setup_logger
from rlkit.samplers.data_collector import MdpPathCollector
from rlkit.torch.sac.policies import TanhGaussianPolicy, MakeDeterministic
from rlkit.torch.sac.sac import SACTrainer
from rlkit.torch.networks import FlattenMlp
from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm

import ucb.envs


def experiment(variant):
    eval_env = NormalizedBoxEnv(gym.make(variant["env"]))
    expl_env = NormalizedBoxEnv(gym.make(variant["env"]))
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    M = variant["layer_size"]
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
    )
    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant["replay_buffer_size"],
        expl_env,
    )
    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant["trainer_kwargs"]
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"]
    )
    algorithm.to(ptu.device)
    algorithm.train()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, default="weirdgain-v0")
    parser.add_argument("--epochs", type=int, default=3000)
    parser.add_argument("--save_every", type=int)
    parser.add_argument("--min_num_steps_before_training", type=int, default=1000)
    parser.add_argument("--layer_size", type=int, default=256)
    parser.add_argument("--pudb", action="store_true")
    parser.add_argument("--cuda_device", type=int)
    return parser.parse_args()


if __name__ == "__main__":
    # noinspection PyTypeChecker
    args = parse_args()
    if args.pudb:
        import pudb

        pudb.set_trac()
    variant = dict(
        algorithm="SAC",
        version="normal",
        layer_size=256,
        replay_buffer_size=int(1e6),
        env=args.env,
        cuda_device=args.cuda_device,
        algorithm_kwargs=dict(
            num_epochs=args.epochs,
            num_eval_steps_per_epoch=5000,
            num_trains_per_train_loop=1000,
            num_expl_steps_per_train_loop=1000,
            min_num_steps_before_training=args.min_num_steps_before_training,
            max_path_length=10,
            batch_size=256,
        ),
        trainer_kwargs=dict(
            discount=0.99,
            soft_target_tau=5e-3,
            target_update_period=1,
            policy_lr=3e-4,
            qf_lr=3e-4,
            reward_scale=1,
            use_automatic_entropy_tuning=True,
        ),
    )
    save_path = "sac-%s-baseline" % variant["env"]
    if args.save_every is not None and args.save_every > 0:
        setup_logger(
            save_path,
            variant=variant,
            snapshot_mode="gap_and_last",
            snapshot_gap=args.save_every,
        )
    else:
        setup_logger(
            save_path,
            variant=variant,
        )
    if args.cuda_device is not None:
        ptu.set_gpu_mode(True, args.cuda_device)
    # ptu.set_gpu_mode(True)  # optionally set the GPU (default=False)
    experiment(variant)
