import argparse
import datetime
import os
import random
import importlib

import gym

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from models.transition_model import TransitionModel
from models.policy_models import MLP, ActorProb, Critic, DiagGaussian
from algo.sac import SACPolicy
from algo.primorl import PriMORL
from common.buffer import ReplayBuffer
from common.logger import Logger
from trainer import Trainer
from common.util import set_device_and_logger, qlearning_dataset_custom

from gym.spaces.utils import flatten_space

import warnings
warnings.filterwarnings("ignore")


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo-name", type=str, default="primorl")
    parser.add_argument("--task", type=str, default="dm2gym:CartpoleSwingup-v0")
    parser.add_argument("--seed", type=int, default=999999)
    parser.add_argument("--actor-lr", type=float, default=3e-4)
    parser.add_argument("--critic-lr", type=float, default=3e-4)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--alpha", type=float, default=0.2)
    parser.add_argument('--auto-alpha', default=True)
    parser.add_argument('--target-entropy', type=int, default=-3)
    parser.add_argument('--alpha-lr', type=float, default=3e-4)

    # dynamics model's arguments
    parser.add_argument("--dynamics-lr", type=float, default=0.001)
    parser.add_argument("--n-ensembles", type=int, default=7)
    parser.add_argument("--n-elites", type=int, default=5)
    parser.add_argument("--reward-penalty-coef", type=float, default=1.0)
    parser.add_argument("--rollout-length", type=int, default=1)
    parser.add_argument("--rollout-batch-size", type=int, default=50000)
    parser.add_argument("--rollout-freq", type=int, default=1000)
    parser.add_argument("--model-retain-epochs", type=int, default=5)
    parser.add_argument("--real-ratio", type=float, default=0.0)
    parser.add_argument("--dynamics-model-dir", type=str, default=None)

    parser.add_argument("--epoch", type=int, default=1000)
    parser.add_argument("--step-per-epoch", type=int, default=1000)
    parser.add_argument("--eval_episodes", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--logdir", type=str, default="log")
    parser.add_argument("--log-freq", type=int, default=1000)
    parser.add_argument("--device", type=str, default="cpu")

    parser.add_argument("--noise-multiplier", type=float, default=0.0)
    parser.add_argument("--max-grad-norm", type=float, default=1.0)
    parser.add_argument("--model-batch-size", type=int, default=256)
    parser.add_argument("--poisson-q", type=float, default=0.001)
    parser.add_argument("--max-model-update-epochs-to-improve", type=int, default=10)
    parser.add_argument("--model-rounds", type=int, default=100)
    parser.add_argument("--save-dataset", action="store_true")
    parser.add_argument("--load-dataset", action="store_true")
    parser.add_argument("--hold-out-ratio", type=float, default=0.01)
    parser.add_argument("--max-model-train-iterations", type=int, default=100000)
    parser.add_argument("--fedavg-local-epochs", type=int, default=1)
    parser.add_argument("--fedavg-local-batch-size", type=int, default=16)
    parser.add_argument("--load-model-name", default=None)
    parser.add_argument("--continue-training", action="store_true")
    parser.add_argument("--no-clip", action="store_true")
    parser.add_argument("--uncertainty-estimator", type=str, default='max_diff')
    parser.add_argument("--clipping-mode", type=str, default='per_model')

    return parser.parse_args()


def train(args=get_args()):
    # create env and dataset
    env = gym.make(args.task)
    if not args.load_dataset:
        # Pre-process dataset
        dataset = qlearning_dataset_custom(env)
        if args.save_dataset:
            dataset = np.array(dataset)
            np.save(os.path.join('datasets', args.task.replace(':', '') + 'preproc_dataset.npy'), dataset)
    else:
        # Load pre-processed dataset
        dataset = np.load(os.path.join('datasets', args.task.replace(':', '') + 'preproc_dataset.npy'), allow_pickle=True).item()
    env.observation_space = flatten_space(env.observation_space)
    env.unwrapped.observation_space = flatten_space(env.unwrapped.observation_space)
    args.obs_shape = env.observation_space.shape
    args.action_dim = np.prod(env.action_space.shape)

    # seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.device != "cpu":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    env.seed(args.seed)

    # log
    t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
    log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_{args.algo_name}'
    log_path = os.path.join(args.logdir, args.task, args.algo_name, log_file)
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = Logger(writer=writer,log_path=log_path)

    Devid = 0 if args.device == 'cuda' else -1
    set_device_and_logger(Devid,logger)

    # import configs
    task = args.task.split('-')[0]
    if task[:6] == 'dm2gym':
        task = task.split(':')[-1]
    import_path = f"static_fns.{task}"
    static_fns = importlib.import_module(import_path).StaticFns
    config_path = f"config.{task}"
    config = importlib.import_module(config_path).default_config

    # create policy model
    actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=[256, 256])
    critic1_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=[256, 256])
    critic2_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=[256, 256])
    dist = DiagGaussian(
        latent_dim=getattr(actor_backbone, "output_dim"),
        output_dim=args.action_dim,
        unbounded=True,
        conditioned_sigma=True
    )

    actor = ActorProb(actor_backbone, dist, args.device)
    critic1 = Critic(critic1_backbone, args.device)
    critic2 = Critic(critic2_backbone, args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

    if args.auto_alpha:
        target_entropy = args.target_entropy if args.target_entropy \
            else -np.prod(env.action_space.shape)

        args.target_entropy = target_entropy

        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
        args.alpha = (target_entropy, log_alpha, alpha_optim)

    # create policy
    sac_device = args.device
    sac_policy = SACPolicy(
        actor,
        critic1,
        critic2,
        actor_optim,
        critic1_optim,
        critic2_optim,
        action_space=env.action_space,
        dist=dist,
        tau=args.tau,
        gamma=args.gamma,
        alpha=args.alpha,
        device=sac_device
    )

    # create dynamics model
    dynamics_model = TransitionModel(obs_space=env.observation_space,
                                     action_space=env.action_space,
                                     static_fns=static_fns,
                                     lr=args.dynamics_lr,
                                     noise_multiplier=args.noise_multiplier,
                                     max_grad_norm=args.max_grad_norm,
                                     local_epochs=args.fedavg_local_epochs,
                                     local_batch_size=args.fedavg_local_batch_size,
                                     load_model_name=args.load_model_name,
                                     no_clip=args.no_clip,
                                     reward_penalty_coef=args.reward_penalty_coef,
                                     uncertainty_estimator=args.uncertainty_estimator,
                                     clipping_mode=args.clipping_mode,
                                     **config["transition_params"]
    )

    # create buffer
    offline_buffer = ReplayBuffer(
        buffer_size=len(dataset["observations"]),
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.float32
    )
    offline_buffer.load_dataset(dataset)
    model_buffer = ReplayBuffer(
        buffer_size=args.rollout_batch_size * args.rollout_length * args.model_retain_epochs,
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.float32
    )

    # create MOPO algo
    algo = PriMORL(
        sac_policy,
        dynamics_model,
        offline_buffer=offline_buffer,
        model_buffer=model_buffer,
        reward_penalty_coef=args.reward_penalty_coef,
        rollout_length=args.rollout_length,
        batch_size=args.batch_size,
        real_ratio=args.real_ratio,
        logger=logger,
        model_batch_size=args.model_batch_size,
        max_model_update_epochs_to_improve=args.max_model_update_epochs_to_improve,
        model_rounds=args.model_rounds,
        poisson_q=args.poisson_q,
        hold_out_ratio=args.hold_out_ratio,
        max_model_train_iterations=args.max_model_train_iterations,
        load_model_name=args.load_model_name,
        continue_training=args.continue_training,
        **config["primorl_params"]
    )

    # create trainer
    trainer = Trainer(
        algo,
        eval_env=env,
        epoch=args.epoch,
        step_per_epoch=args.step_per_epoch,
        rollout_freq=args.rollout_freq,
        logger=logger,
        log_freq=args.log_freq,
        eval_episodes=args.eval_episodes
    )

    # Train dynamics model on the whole dataset
    trainer.train_dynamics()

    # Optimize policy under the private model
    trainer.train_policy()


if __name__ == "__main__":
    train()
