"""
Based on CleanRL PPO.
"""
import argparse
import os
import random
import time
from distutils.util import strtobool
from functools import partial
from multiprocessing import dummy
from typing import Dict

import dm_env
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tree
from ml_collections import ConfigDict
from torch import Tensor
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from rosmo.agent.mbrl.morel import utils
from rosmo.agent.mbrl.morel.utils import get_data_loader
from rosmo.agent.world_model import Dreamer, preprocessing
from rosmo.data.rl_unplugged import atari

from stable_baselines3.common.atari_wrappers import (  # isort:skip
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)


def parse_args():
  # fmt: off
  parser = argparse.ArgumentParser()
  parser.add_argument(
    "--exp-name",
    type=str,
    default=os.path.basename(__file__).rstrip(".py"),
    help="the name of this experiment"
  )
  parser.add_argument(
    "--seed",
    type=int,
    default=int(time.time()),
    help="seed of the experiment"
  )
  parser.add_argument(
    "--torch-deterministic",
    type=lambda x: bool(strtobool(x)),
    default=True,
    nargs="?",
    const=True,
    help="if toggled, `torch.backends.cudnn.deterministic=False`"
  )
  parser.add_argument(
    "--cuda",
    type=lambda x: bool(strtobool(x)),
    default=True,
    nargs="?",
    const=True,
    help="if toggled, cuda will be enabled by default"
  )
  parser.add_argument(
    "--track",
    type=lambda x: bool(strtobool(x)),
    default=False,
    nargs="?",
    const=True,
    help="if toggled, this experiment will be tracked with Weights and Biases"
  )
  parser.add_argument(
    "--wandb-project-name",
    type=str,
    default="cleanRL",
    help="the wandb's project name"
  )
  parser.add_argument(
    "--wandb-entity",
    type=str,
    default=None,
    help="the entity (team) of wandb's project"
  )
  parser.add_argument(
    "--capture-video",
    type=lambda x: bool(strtobool(x)),
    default=False,
    nargs="?",
    const=True,
    help="whether to capture videos of the agent performances (check out `videos` folder)"
  )

  # Algorithm specific arguments
  parser.add_argument(
    "--env-id",
    type=str,
    default="MsPacmanNoFrameskip-v0",
    help="the id of the environment"
  )
  parser.add_argument(
    "--total-timesteps",
    type=int,
    default=1_000_000,
    help="total timesteps of the experiments"
  )
  parser.add_argument(
    "--learning-rate",
    type=float,
    default=2.5e-4,
    help="the learning rate of the optimizer"
  )
  parser.add_argument(
    "--num-envs",
    type=int,
    default=50,
    help="the number of parallel game environments"
  )
  parser.add_argument(
    "--num-steps",
    type=int,
    default=50,
    help="the number of steps to run in each environment per policy rollout"
  )
  parser.add_argument("--batch-size", type=int, default=50, help="batch size")
  parser.add_argument(
    "--anneal-lr",
    type=lambda x: bool(strtobool(x)),
    default=True,
    nargs="?",
    const=True,
    help="Toggle learning rate annealing for policy and value networks"
  )
  parser.add_argument(
    "--gamma", type=float, default=0.99, help="the discount factor gamma"
  )
  parser.add_argument(
    "--gae-lambda",
    type=float,
    default=0.95,
    help="the lambda for the general advantage estimation"
  )
  parser.add_argument(
    "--num-minibatches",
    type=int,
    default=4,
    help="the number of mini-batches"
  )
  parser.add_argument(
    "--update-epochs",
    type=int,
    default=4,
    help="the K epochs to update the policy"
  )
  parser.add_argument(
    "--norm-adv",
    type=lambda x: bool(strtobool(x)),
    default=True,
    nargs="?",
    const=True,
    help="Toggles advantages normalization"
  )
  parser.add_argument(
    "--clip-coef",
    type=float,
    default=0.1,
    help="the surrogate clipping coefficient"
  )
  parser.add_argument(
    "--clip-vloss",
    type=lambda x: bool(strtobool(x)),
    default=True,
    nargs="?",
    const=True,
    help="Toggles whether or not to use a clipped loss for the value function, as per the paper."
  )
  parser.add_argument(
    "--ensemble",
    type=lambda x: bool(strtobool(x)),
    default=False,
    nargs="?",
    const=True,
  )
  parser.add_argument(
    "--ent-coef", type=float, default=0.01, help="coefficient of the entropy"
  )
  parser.add_argument(
    "--vf-coef",
    type=float,
    default=0.5,
    help="coefficient of the value function"
  )
  parser.add_argument(
    "--negative-reward", type=float, default=0.0, help="reward penalty"
  )
  parser.add_argument(
    "--truncate-lim", type=float, default=24.7545, help="truncate threshold"
  )

  parser.add_argument(
    "--max-grad-norm",
    type=float,
    default=0.5,
    help="the maximum norm for the gradient clipping"
  )
  parser.add_argument(
    "--target-kl",
    type=float,
    default=None,
    help="the target KL divergence threshold"
  )
  args = parser.parse_args()
  # args.batch_size = int(args.num_envs * args.num_steps)
  args.minibatch_size = int(args.batch_size // args.num_minibatches)
  # fmt: on
  return args


def make_env(env_id, seed, idx, capture_video, run_name):

  def thunk():
    env = gym.make(env_id)
    env = gym.wrappers.RecordEpisodeStatistics(env)
    if capture_video:
      if idx == 0:
        env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
    env = NoopResetEnv(env, noop_max=30)
    env = MaxAndSkipEnv(env, skip=4)
    env = EpisodicLifeEnv(env)
    if "FIRE" in env.unwrapped.get_action_meanings():
      env = FireResetEnv(env)
    env = ClipRewardEnv(env)
    env = gym.wrappers.ResizeObservation(env, (64, 64))
    env = gym.wrappers.GrayScaleObservation(env)
    env = gym.wrappers.FrameStack(env, 1)
    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    return env

  return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
  torch.nn.init.orthogonal_(layer.weight, std)
  torch.nn.init.constant_(layer.bias, bias_const)
  return layer


class Agent(nn.Module):

  def __init__(self, envs):
    super().__init__()
    self.network = nn.Sequential(
      layer_init(nn.Conv2d(1, 32, 8, stride=4)),
      nn.ReLU(),
      layer_init(nn.Conv2d(32, 64, 4, stride=2)),
      nn.ReLU(),
      layer_init(nn.Conv2d(64, 64, 3, stride=1)),
      nn.ReLU(),
      nn.Flatten(),
      layer_init(nn.Linear(64 * 4 * 4, 512)),  # 64x64 downsampled
      nn.ReLU(),
    )
    self.actor = layer_init(
      nn.Linear(512, envs.single_action_space.n), std=0.01
    )
    self.critic = layer_init(nn.Linear(512, 1), std=1)

  def get_value(self, x):
    # return self.critic(self.network(x / 255.0))
    return self.critic(self.network(x))

  def get_action_and_value(self, x, action=None):
    # hidden = self.network(x / 255.0)
    hidden = self.network(x)
    logits = self.actor(hidden)
    probs = Categorical(logits=logits)
    if action is None:
      action = probs.sample()
    return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)


def evaluate(eval_env: dm_env.Environment, num_episodes=2):
  returns = []
  length = []
  for i in range(num_episodes):
    ret = 0.
    cnt = 0
    timestep = eval_env.reset()
    while not timestep.last():
      obs = torch.tensor(timestep.observation).permute(2, 0, 1)[None,
                                                                ...] / 255.
      obs = obs.to("cuda")
      action, logprob, _, value = agent.get_action_and_value(obs)
      timestep = eval_env.step(action)
      ret += timestep.reward
      cnt += 1
    returns.append(ret)
    length.append(cnt)
  return np.mean(returns), np.mean(length)


_truncate_lims = {
  "MsPacman": 24,
}
wm_paths = {
  "MsPacman":
    [
      "./wm/14kcebow-train_state_290001.pt",
      "./wm/2n4ccnd8-train_state_290001.pt"
    ],
}

if __name__ == "__main__":
  args = parse_args()
  run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
  PESSIMISTIC = args.ensemble
  config = ConfigDict()
  config.update(utils.CONFIG)
  config.image_channels = config.stack_size
  config.game_name = "MsPacman"
  config.algo = "morel"
  config.data_percentage = 100
  config.batch_length = 10
  config.batch_size = args.batch_size
  # config.negative_reward = -3
  config.negative_reward = args.negative_reward

  path = wm_paths.get(
    config.game_name,
    [
      "./wm/14kcebow-train_state_290001.pt",
      "./wm/2n4ccnd8-train_state_290001.pt",
      # "./wm/2sgszd1k-train_state_290001.pt",
    ]
  )
  config.truncate_lim = _truncate_lims.get(config.game_name, args.truncate_lim)
  print("*" * 10, f"Truncate Limit: {config.truncate_lim}", "*" * 10)

  env = utils.get_environment(config)
  env_spec = utils.make_environment_spec(env)
  config.action_dim = env_spec.actions.num_values
  dataloader = get_data_loader(config, env, half=False)

  def _new_preprocess(data):
    data["action"] = preprocessing.to_onehot(data["action"], config.action_dim)
    return data

  cast_tensor = lambda x: torch.tensor(x)

  if args.track:
    import wandb

    wandb.init(
      project=args.wandb_project_name,
      entity=args.wandb_entity,
      sync_tensorboard=True,
      config={
        **vars(args),
        **config.to_dict()
      },
      name=run_name,
      monitor_gym=False,
      save_code=True,
    )
  writer = SummaryWriter(f"runs/{run_name}")
  writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" %
    ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
  )

  # TRY NOT TO MODIFY: seeding
  random.seed(args.seed)
  np.random.seed(args.seed)
  torch.manual_seed(args.seed)
  torch.backends.cudnn.deterministic = args.torch_deterministic

  device = torch.device(
    "cuda" if torch.cuda.is_available() and args.cuda else "cpu"
  )

  models = [Dreamer(config) for _ in range(len(path))]
  for i, m in enumerate(models):
    m.to(device).load_state_dict(torch.load(path[i])["model_state_dict"])

  print("*" * 10, "Model loaded.", "*" * 10)

  # env setup
  dummy_envs = gym.vector.SyncVectorEnv(
    [
      make_env(args.env_id, args.seed + i, i, args.capture_video, run_name)
      for i in range(args.num_envs)
    ]
  )
  assert isinstance(
    dummy_envs.single_action_space, gym.spaces.Discrete
  ), "only discrete action space is supported"

  eval_env = atari.environment(
    game=config["game_name"],
    stack_size=config["stack_size"],
    screen_size=64,
  )

  agent = Agent(dummy_envs).to(device)
  optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

  # ALGO Logic: Storage setup
  # obs = torch.zeros(
  #   (args.num_steps, args.num_envs) + dummy_envs.single_observation_space.shape
  # ).to(device)
  # actions = torch.zeros(
  #   (args.num_steps, args.num_envs) + dummy_envs.single_action_space.shape
  # ).to(device)
  # logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
  # rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
  # dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
  # values = torch.zeros((args.num_steps, args.num_envs)).to(device)

  # TRY NOT TO MODIFY: start the game
  global_step = 0
  start_time = time.time()
  num_updates = args.total_timesteps

  for update in range(0, num_updates):
    # Annealing the rate if instructed to do so.
    if args.anneal_lr:
      frac = 1.0 - (update - 1.0) / num_updates
      lrnow = frac * args.learning_rate
      optimizer.param_groups[0]["lr"] = lrnow

    # Init state using offline data.
    data_batch = next(dataloader)
    data_batch = _new_preprocess(data_batch)
    data_batch: Dict[str, Tensor] = tree.map_structure(
      lambda x: cast_tensor(x).to(device), data_batch
    )

    all_observations = []
    all_logprobs = []
    all_actions = []
    all_actions_one_hot = []
    all_values = []

    all_rewards = []
    all_dones = []
    all_next_dones = []
    all_next_values = []

    # Generate trajectories.
    global_step += args.num_steps * args.num_envs
    for idx, model in enumerate(models):
      zero_state = model.init_state(config.batch_size * config.iwae_samples)
      _, rnn_state = model.forward(data_batch, zero_state)

      observations, actions, logprobs, values, rewards, terminals, actions_one_hot = model.morel_policy_rollout(
        agent, rnn_state, imag_horizon=args.num_steps + 1
      )
      values = values.squeeze()
      observations = observations[:args.num_steps]
      actions = actions[:args.num_steps]
      actions_one_hot = actions_one_hot[:args.num_steps]
      logprobs = logprobs[:args.num_steps]
      values = values[:args.num_steps]
      rewards = rewards[:args.num_steps]
      dones = terminals[:args.num_steps]
      next_value = values[-1]
      next_done = terminals[-1]

      all_observations.append(observations)
      all_logprobs.append(logprobs)
      all_actions.append(actions)
      all_actions_one_hot.append(actions_one_hot)
      all_values.append(values)

      all_rewards.append(rewards)
      all_dones.append(dones)
      all_next_dones.append(next_done)
      all_next_values.append(next_value)

      if True:  #not PESSIMISTIC:
        # Just use single model even enable uncertainty estimation.
        break

    all_observations = torch.cat(all_observations, dim=1)
    all_logprobs = torch.cat(all_logprobs, dim=1)
    all_actions = torch.cat(all_actions, dim=1)
    all_actions_one_hot = torch.cat(all_actions_one_hot, dim=1)
    all_values = torch.cat(all_values, dim=1)

    all_rewards = torch.cat(all_rewards, dim=1)
    all_dones = (torch.cat(all_dones, dim=1) > 0.5).float()

    all_next_dones = torch.cat(all_next_dones, dim=0)
    all_next_values = torch.cat(all_next_values, dim=0)

    # Discard cross trajectory transitions predicted from the model.
    loss_mask = torch.ones_like(all_rewards)
    for b in range(all_rewards.shape[1]):  # sum_B
      if 1 in all_dones[:, b]:
        st = torch.where(all_dones[:, b])[0][0] + 1
        if st == all_dones.shape[0]:
          st -= 1
        loss_mask[st:, b] = 0.

    # Model ensemble cross validate.
    morel_stats = {
      "disc_mean": 0.,
      "disc_std": 0.,
      "violate_count": 0.,
      "loss_mask": loss_mask.sum().detach().cpu().numpy().item(),
    }
    if PESSIMISTIC:
      buffer_data = {
        "image": all_observations - 0.5,
        "action": all_actions_one_hot,
        "reward": all_rewards,
        "terminal": (all_dones > 0.5).float(),
        "reset": (all_dones > 0.5),
      }
      pred_err = torch.zeros_like(all_rewards)
      for idx_1, model_1 in enumerate(models):
        for idx_2, model_2 in enumerate(models):
          if idx_2 > idx_1:
            with torch.no_grad():
              rnn_state = model_1.init_state(all_rewards.shape[1])
              feat_1, _ = model_1.forward(buffer_data, rnn_state)
              rnn_state = model_2.init_state(all_rewards.shape[1])
              feat_2, _ = model_2.forward(buffer_data, rnn_state)
              model_err = torch.norm(feat_1 - feat_2, dim=-1).squeeze()
              pred_err = torch.maximum(pred_err, model_err)  # (T, sum_B)

      violations_idx = torch.where(pred_err > config.truncate_lim)
      morel_stats["disc_mean"] = pred_err.mean().detach().cpu().numpy().item()
      morel_stats["disc_std"] = pred_err.std().detach().cpu().numpy().item()
      morel_stats["violate_count"] = len(violations_idx[0])
      # print(pred_err.shape)
      # print(violations_idx)
      max_t, max_b = pred_err.shape
      for b in range(max_b):  # sum_B
        if b in violations_idx[1]:
          st_t = violations_idx[0][torch.where(b == violations_idx[1]
                                              )[0]].min()
          if st_t == max_t:
            st_t -= 1
          loss_mask[st_t + 1:, b] = 0.
          all_rewards[st_t, b] = config.negative_reward
          all_dones[st_t, b] = 1.
      morel_stats["loss_mask"] = loss_mask.sum().detach().cpu().numpy().item()

    # bootstrap value if not done
    all_advantages = torch.zeros_like(all_rewards).to(device)
    lastgaelam = 0
    for t in reversed(range(args.num_steps)):
      if t == args.num_steps - 1:
        nextnonterminal = 1.0 - all_dones[t]
        nextvalues = all_next_values
      else:
        nextnonterminal = 1.0 - all_dones[t]
        nextvalues = all_values[t + 1]

      delta = all_rewards[
        t] + args.gamma * nextvalues * nextnonterminal - all_values[t]
      all_advantages[
        t
      ] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
    all_returns = all_advantages + all_values

    # break
    # flatten the batch
    b_obs = all_observations.reshape(
      (-1,) + dummy_envs.single_observation_space.shape
    )
    b_logprobs = all_logprobs.reshape(-1)
    b_actions = all_actions.reshape(
      (-1,) + dummy_envs.single_action_space.shape
    )
    b_advantages = all_advantages.reshape(-1)
    b_returns = all_returns.reshape(-1)
    b_values = all_values.reshape(-1)
    b_loss_mask = loss_mask.reshape(-1)

    # Optimizing the policy and value network
    b_inds = np.arange(len(b_obs))
    clipfracs = []
    for epoch in range(args.update_epochs):
      np.random.shuffle(b_inds)
      for start in range(0, args.batch_size, args.minibatch_size):
        end = start + args.minibatch_size
        mb_inds = b_inds[start:end]

        _, newlogprob, entropy, newvalue = agent.get_action_and_value(
          b_obs[mb_inds],
          b_actions.long()[mb_inds]
        )
        logratio = newlogprob - b_logprobs[mb_inds]
        ratio = logratio.exp()

        with torch.no_grad():
          # calculate approx_kl http://joschu.net/blog/kl-approx.html
          old_approx_kl = (-logratio).mean()
          approx_kl = ((ratio - 1) - logratio).mean()
          clipfracs += [
            ((ratio - 1.0).abs() > args.clip_coef).float().mean().item()
          ]

        mb_advantages = b_advantages[mb_inds]
        if args.norm_adv:
          mb_advantages = (mb_advantages - mb_advantages.mean()) / (
            mb_advantages.std() + 1e-8
          )

        # Policy loss
        pg_loss1 = -mb_advantages * ratio
        pg_loss2 = -mb_advantages * torch.clamp(
          ratio, 1 - args.clip_coef, 1 + args.clip_coef
        )
        pg_loss = b_loss_mask[mb_inds] * torch.max(pg_loss1, pg_loss2)
        pg_loss = pg_loss.mean()

        # Value loss
        newvalue = newvalue.view(-1)
        if args.clip_vloss:
          v_loss_unclipped = (newvalue - b_returns[mb_inds])**2
          v_clipped = b_values[mb_inds] + torch.clamp(
            newvalue - b_values[mb_inds],
            -args.clip_coef,
            args.clip_coef,
          )
          v_loss_clipped = (v_clipped - b_returns[mb_inds])**2
          v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
          v_loss = 0.5 * (b_loss_mask[mb_inds] * v_loss_max).mean()
        else:
          v_loss = 0.5 * (
            b_loss_mask[mb_inds] * (newvalue - b_returns[mb_inds])**2
          ).mean()

        entropy_loss = (b_loss_mask[mb_inds] * entropy).mean()

        # print("ent", entropy.shape)
        # print("pg", pg_loss1.shape)
        # print("v", v_loss.shape)
        loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
        optimizer.step()

      if args.target_kl is not None:
        if approx_kl > args.target_kl:
          break
    y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
    var_y = np.var(y_true)
    explained_var = np.nan if var_y == 0 else 1 - np.var(
      y_true - y_pred
    ) / var_y

    # TRY NOT TO MODIFY: record rewards for plotting purposes

    if update % config.eval_interval == 0:
      avg_returns, avg_length = evaluate(eval_env, 2)
      print("eval:", avg_returns, avg_length)
      writer.add_scalar("charts/eval_return", avg_returns, global_step)
      writer.add_scalar("charts/eval_length", avg_length, global_step)

    if update % config.log_interval == 0:
      for k, v in morel_stats.items():
        writer.add_scalar(f"morel/{k}", v, global_step)
      writer.add_scalar(
        "charts/learning_rate", optimizer.param_groups[0]["lr"], global_step
      )
      writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
      writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
      writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
      writer.add_scalar(
        "losses/old_approx_kl", old_approx_kl.item(), global_step
      )
      writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
      writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
      writer.add_scalar(
        "losses/explained_variance", explained_var, global_step
      )
      print("SPS:", int(global_step / (time.time() - start_time)))
      writer.add_scalar(
        "charts/SPS", int(global_step / (time.time() - start_time)),
        global_step
      )
      mdp_return = rewards.sum(-1).mean(0).to("cpu").item()
      print("Learned model return:", mdp_return)
      writer.add_scalar("charts/model_return", mdp_return, global_step)

  del dataloader
