from gym.envs.mujoco import HalfCheetahEnv

import rlkit.torch.pytorch_util as ptu
from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
from rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.launchers.launcher_util import setup_logger
from rlkit.samplers.data_collector import MdpPathCollector
from rlkit.torch.sac.policies import TanhGaussianPolicy, MakeDeterministic
from rlkit.torch.sac.sac_svdd_no_bear import SACTrainer
from rlkit.torch.networks import FlattenMlp
from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
import numpy as np
from torch.utils.data import DataLoader
from svdd.dataset import *
from svdd import *
from svdd.svdd import *
import os
import argparse
import gym
import d4rl
import datetime


def experiment(args, variant):
    eval_env = gym.make(variant['env_name'])
    expl_env = eval_env
    
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    dataset_svdd = D4RLDataset(args.env)
    data_loader_svdd = DataLoader(dataset_svdd, batch_size=args.batch_size_svdd, shuffle=True)
    deep_SVDD = TrainerDeepSVDD(args, obs_dim, action_dim, data_loader_svdd, ptu.device)
    
    if args.svdd_pretrain:
        deep_SVDD.pretrain()
    if args.svdd_train:
        deep_SVDD.train()


    M = variant['layer_size']
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
    )
    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        svdd = deep_SVDD.load_SVDD(),
        **variant['trainer_kwargs']
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()


def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

if __name__ == "__main__":
    # noinspection PyTypeChecker
    parser = argparse.ArgumentParser(description='RLAD-SAC')

    # From BEAR
    parser.add_argument("--env", type=str, default='hopper-medium-v2')
    parser.add_argument("--algo_name", type=str, default='SVDD_SAC')
    parser.add_argument("--gpu", default='0', type=str)
    parser.add_argument('--qf_lr', default=3e-4, type=float)
    parser.add_argument('--policy_lr', default=1e-4, type=float)
    parser.add_argument('--seed', default= int(np.random.randint(0, 100000)), type=int)
    parser.add_argument('--all_saves', default="saves", type=str)
    parser.add_argument('--trial_name', default="", type=str)
    parser.add_argument('--nepochs', default=1000, type=int)
    parser.add_argument('--log_dir', default='./default/', type=str, 
                        help="Location for logging")
    parser.add_argument('--epochs_svdd', default=200, type=int)
    parser.add_argument('--epochs_ae', default=200, type=int)
    parser.add_argument('--patience', default=50, type=int)
    parser.add_argument('--lr_svdd', default=1e-3, type=float)
    parser.add_argument('--lr_milestones', default=[50], type=list)
    parser.add_argument('--svdd_batch_size', default=256, type=int)
    parser.add_argument('--svdd_pretrain', default=True, type=str2bool)
    parser.add_argument('--svdd_train', default=True, type=str2bool)
    parser.add_argument('--svdd_latent_dim', default=128, type=int)
    parser.add_argument('--svdd_hidden_dim', default=256, type=int)
    parser.add_argument('--normal_class', default=1, type=int)
    parser.add_argument('--weight_decay_svdd', default=0.5e-6, type=float)
    parser.add_argument('--weight_decay_ae', default=0.5e-3, type=float)
    parser.add_argument('--batch_size_svdd', default=256, type=int)

    args = parser.parse_args()
    
    # noinspection PyTypeChecker
    variant = dict(
        algorithm="SAC",
        env_name=args.env,
        version="normal",
        layer_size=256,
        replay_buffer_size=int(1E6),
        algorithm_kwargs=dict(
            num_epochs=args.nepochs,
            num_eval_steps_per_epoch=5000,
            num_trains_per_train_loop=1000,
            num_expl_steps_per_train_loop=1000,
            min_num_steps_before_training=1000,
            max_path_length=1000,
            batch_size=256,
        ),
        trainer_kwargs=dict(
            discount=0.99,
            soft_target_tau=5e-3,
            target_update_period=1,
            policy_lr=3E-4,
            qf_lr=3E-4,
            reward_scale=1,
            use_automatic_entropy_tuning=True,
        ),
    )
    file_name = args.trial_name
    setup_logger(file_name, variant=variant, base_log_dir=args.all_saves, name = file_name, log_dir=os.path.join(args.all_saves,args.log_dir,file_name,args.env,'%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))))
    ptu.set_gpu_mode(True)  # optionally set the GPU (default=False)
    experiment(args, variant)