'''
mainly from run_mopo.py
'''
import argparse
from distutils.util import strtobool
import os
import random
import sys
import numpy as np
import torch
import sys
if __name__ == '__main__':

    sys.path.append(os.getcwd())
    import gym
    import d4rl
    import d4rl.gym_mujoco

    import csv
    import time
    import matplotlib.pyplot as plt
    from offlinerlkit.nets import MLP
    from offlinerlkit.modules import EnsembleDynamicsModel
    from offlinerlkit.dynamics import ReverseEnsembleDynamics, EnsembleDynamics
    from offlinerlkit.utils.scaler import StandardScaler
    from offlinerlkit.utils.termination_fns import get_termination_fn, obs_unnormalization
    from offlinerlkit.utils.load_dataset import qlearning_dataset
    from offlinerlkit.buffer import ReplayBuffer
    from offlinerlkit.utils.logger import Logger
    from offlinerlkit.policy_trainer import ReversePolicyTrainer

    from offlinerlkit.policy import DivergentPolicy

"""
suggested hypers

halfcheetah-medium-v2: rollout-length=5, penalty-coef=0.5
hopper-medium-v2: rollout-length=5, penalty-coef=5.0
walker2d-medium-v2: rollout-length=5, penalty-coef=0.5
halfcheetah-medium-replay-v2: rollout-length=5, penalty-coef=0.5
hopper-medium-replay-v2: rollout-length=5, penalty-coef=2.5
walker2d-medium-replay-v2: rollout-length=1, penalty-coef=2.5
halfcheetah-medium-expert-v2: rollout-length=5, penalty-coef=2.5
hopper-medium-expert-v2: rollout-length=5, penalty-coef=5.0
walker2d-medium-expert-v2: rollout-length=1, penalty-coef=2.5
"""

"""
suggested hypers from the mopo paper

halfcheetah-medium-v2: rollout-length=1, penalty-coef=1.0
hopper-medium-v2: rollout-length=5, penalty-coef=5.0
walker2d-medium-v2: rollout-length=5, penalty-coef=5.0
halfcheetah-medium-replay-v2: rollout-length=5, penalty-coef=1.0
hopper-medium-replay-v2: rollout-length=5, penalty-coef=1.0
walker2d-medium-replay-v2: rollout-length=1, penalty-coef=1.0
halfcheetah-medium-expert-v2: rollout-length=5, penalty-coef=1.0
hopper-medium-expert-v2: rollout-length=5, penalty-coef=1.0
walker2d-medium-expert-v2: rollout-length=1, penalty-coef=2.0
"""

def get_args(argv=None):
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp_name", type=str, default=os.path.basename(__file__).rstrip(".py"),
        help="the name of this experiment")
    parser.add_argument("--algo_name", type=str, default="reverse")
    parser.add_argument("--task", type=str, default="walker2d-medium-expert-v2")
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--load_policy_path", "-lpp", type=str, default=None)
    parser.add_argument("--hidden_dims", "-hd", type=int, nargs='*', default=[256, 256])

    parser.add_argument("--max_holdout_size", "-mhs", type=int, default=1000)
    parser.add_argument("--reverse_policy_mode", "-rpm", type=str, default="cvae", help="reverse policy mode, 'divergent' or 'cvae' or 'uncertainty'")
    parser.add_argument("--action_mode", "-acm", type=str, default="sample", help='sample from realbuffer and perturb it')
    ## divergent policy
    parser.add_argument("--scale_coef", "-sc", type=float, default=None)
    parser.add_argument("--noise_coef", "-nc", type=float, default=None)

    parser.add_argument("--dynamics_lr", "-dlr", type=float, default=1e-3)
    parser.add_argument("--dynamics_hidden_dims", "-dhd", type=int, nargs='*', default=[200, 200, 200, 200])
    parser.add_argument("--dynamics_weight_decay", type=float, nargs='*', default=[2.5e-5, 5e-5, 7.5e-5, 7.5e-5, 1e-4])
    parser.add_argument("--n_ensemble", "-ne", type=int, default=7)
    parser.add_argument("--n_elites", type=int, default=5)
    parser.add_argument("--rollout_epoch", "-re", type=int, default=100)
    parser.add_argument("--rollout_batch_size", "-rbs", type=int, default=3000)
    parser.add_argument("--rollout_length", "-rl", type=int, default=5)
    parser.add_argument("--model_retain_epochs", type=int, default=5)
    parser.add_argument("--real_ratio", "-rr", type=float, default=0.05)
    parser.add_argument("--load_dynamics_path", "-ldp", type=str, default=None)
    parser.add_argument("--load_reverse_dynamics_path", "-lrdp", type=str, default=None)
    parser.add_argument("--rollout_augmentation", "-ra", type=lambda x: bool(strtobool(x)), nargs="?", const=True, default=False)
    parser.add_argument("--holdout_ratio", "-hr", type=float, default=0.2) #
    parser.add_argument("--logvar_loss_coef", "-llc", type=float, default=0.01)
    parser.add_argument("--max_epochs_since_update", "-mesu", type=int, default=5)
    parser.add_argument("--epoch", type=int, default=1000)
    parser.add_argument("--step_per_epoch", type=int, default=1000)
    parser.add_argument("--eval_episodes", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    if argv is not None:
        args, unknown_args = parser.parse_known_args(argv)
        args._original_parser = parser
        return args
    return parser.parse_args(argv)


def train(args):
    timestamp = time.strftime("%Y-%m-%d_%H-%M-%S_%Z")

    if args.load_reverse_dynamics_path and not args.load_reverse_dynamics_path=='dummy':
        import json
        hyperparams_path = os.path.join(args.load_reverse_dynamics_path, "../record/hyper_param.json")
        with open(hyperparams_path, "r") as f:
            dynamics_hyperparams = json.load(f)
        assert args.task == dynamics_hyperparams["task"]
        args.reverse_dynamics_args = dynamics_hyperparams
        args.dynamics_hidden_dims=dynamics_hyperparams["dynamics_hidden_dims"]
        args.n_ensemble=dynamics_hyperparams["n_ensemble"]
        args.n_elites=dynamics_hyperparams["n_elites"]

    # create env and dataset
    env = gym.make(args.task)
    eval_env = gym.make(args.task)
    dataset = qlearning_dataset(env, with_goal=False)

    args.obs_shape = env.observation_space.shape
    args.action_dim = np.prod(env.action_space.shape)
    highs = env.action_space.high
    neg_lows = -env.action_space.low
    assert np.all(highs == highs[0]) and np.all(neg_lows == highs[0])
    args.max_action = env.action_space.high[0]
    args.entropy_weight = 0.5

    # seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    env.seed(args.seed)

    # create buffer
    real_buffer = ReplayBuffer(
        args,
        buffer_size=len(dataset["observations"]),
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.float32,
        device=args.device
    )
    real_buffer.load_dataset(dataset)

    fake_buffer_size=args.rollout_batch_size*args.rollout_length*args.rollout_epoch

    fake_buffer = ReplayBuffer(
        args,
        buffer_size=fake_buffer_size,
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.float32,
        device=args.device
    )

    obs_mean_np, obs_std_np = np.zeros_like(dataset['observations'][0]), np.ones_like(dataset['observations'][0]) # dummy
    obs_mean, obs_std = torch.tensor(obs_mean_np, dtype=torch.float32, device=args.device), torch.tensor(obs_std_np, dtype=torch.float32, device=args.device)

    dynamics_model = EnsembleDynamicsModel(
        args,
        obs_dim=np.prod(args.obs_shape),
        action_dim=args.action_dim,
        hidden_dims=args.dynamics_hidden_dims,
        num_ensemble=args.n_ensemble,
        num_elites=args.n_elites,
        weight_decays=args.dynamics_weight_decay,
        device=args.device
    )
    dynamics_optim = torch.optim.Adam(
        dynamics_model.parameters(),
        lr=args.dynamics_lr
    )
    dynamics_scaler = StandardScaler()
    termination_fn = obs_unnormalization(get_termination_fn(task=args.task), obs_mean, obs_std)
    dynamics = EnsembleDynamics(
        args,
        dynamics_model,
        dynamics_optim,
        dynamics_scaler,
        termination_fn,
    )

    if args.load_dynamics_path and not args.load_dynamics_path=='dummy':
        dynamics.load(args.load_dynamics_path)
    dynamics.model.eval()

    # create reverse dynamics #
    load_reverse_dynamics_model = True if args.load_reverse_dynamics_path else False
    reverse_dynamics_model = EnsembleDynamicsModel(
        args,
        obs_dim=np.prod(args.obs_shape),
        action_dim=args.action_dim,
        hidden_dims=args.dynamics_hidden_dims,
        num_ensemble=args.n_ensemble,
        num_elites=args.n_elites,
        weight_decays=args.dynamics_weight_decay,
        with_reward=False,
        device=args.device
    )
    reverse_dynamics_optim = torch.optim.Adam(
        reverse_dynamics_model.parameters(),
        lr=args.dynamics_lr
    )
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(reverse_dynamics_optim, args.rollout_epoch)
    reverse_dynamics_scaler = StandardScaler()
    termination_fn = obs_unnormalization(get_termination_fn(task=args.task), obs_mean, obs_std)
    reverse_dynamics = ReverseEnsembleDynamics(
        args,
        dynamics,
        reverse_dynamics_model,
        reverse_dynamics_optim,
        reverse_dynamics_scaler,
        termination_fn,
    )

    if args.load_reverse_dynamics_path and not args.load_reverse_dynamics_path=='dummy':
        reverse_dynamics.load(args.load_reverse_dynamics_path)

    # create reverse policy #
    if args.reverse_policy_mode == 'divergent':
        policy = DivergentPolicy(
            args,
            eval_env,
            dynamics,
            reverse_dynamics,
            real_buffer,
            args.action_dim,
            args.max_action,
            scale_coef=args.scale_coef,
            noise_coef=args.noise_coef,
            device=args.device,
            seed=args.seed,
            )
    else:
        raise NotImplementedError

    log_dirs = os.getcwd()
    output_config = {
        "consoleout_backup": "stdout",
        "policy_training_progress": "csv",
        "reverse_dynamics_training_progress": "csv",
        "tb": "tensorboard"
    }
    logger = Logger(log_dirs, output_config)

    logger.log_hyperparameters(vars(args))

    # create policy trainer
    policy_trainer = ReversePolicyTrainer(
        args=args,
        reverse_dynamics=reverse_dynamics,
        dynamics=dynamics,
        policy=policy,
        eval_env=eval_env,
        real_buffer=real_buffer,
        fake_buffer=fake_buffer,
        logger=logger,
        rollout_setting=(args.rollout_epoch, args.rollout_batch_size, args.rollout_length),
        epoch=args.epoch,
        step_per_epoch=args.step_per_epoch,
        batch_size=args.batch_size,
        eval_episodes=args.eval_episodes,
        lr_scheduler=lr_scheduler,
    )

    # train
    if not load_reverse_dynamics_model:
        reverse_dynamics.train(
            real_buffer.sample_all(),
            logger,
            holdout_ratio=args.holdout_ratio,
            logvar_loss_coef=args.logvar_loss_coef,
            max_epochs_since_update=args.max_epochs_since_update,
        )
    # generate dataset
    policy_trainer.generate()

    return log_dirs

if __name__ == "__main__":
    log_dirs = train(get_args())

