from concurrent.futures import Future
from typing import Any, Tuple

import jax
import numpy as np
import torch
import torch.distributions as D
import torch.nn as nn
import torch.nn.functional as F
import tree
from acme.types import Transition

from rosmo.agent.world_model import preprocessing

from .a2c import *
from .common import *
from .decoders import *
from .encoders import *
from .functions import *
from .probes import *
from .rnn import *
from .rssm import *
from .tools import *

DREAMER_CONFIG = {
  # Features
  "image_key": "image",
  "image_size": 64,
  "image_channels": 1,
  "image_categorical": False,
  # Training
  "reset_interval": 200,
  "iwae_samples": 1,
  "kl_balance": 0.8,
  "kl_weight": 0.1,  # For atari.
  "image_weight": 1.0,
  "vecobs_weight": 1.0,
  "reward_weight": 1.0,
  "terminal_weight": 1.0,
  "adam_lr": 3.0e-4,
  "adam_lr_actor": 1.0e-4,
  "adam_lr_critic": 1.0e-4,
  "adam_eps": 1.0e-5,
  "keep_state": True,
  "batch_length": 50,
  "batch_size": 50,
  "device": "cuda:0",
  "grad_clip": 200,
  "grad_clip_ac": 200,
  "image_decoder_min_prob": 0,
  "amp": True,
  # Model
  "model": "dreamer",
  "deter_dim": 1024,  # For atari.
  "stoch_dim": 32,
  "stoch_discrete": 32,
  "hidden_dim": 1000,
  "gru_layers": 1,
  "gru_type": "gru",
  "layer_norm": True,
  "vecobs_size": 0,
  "image_encoder": "cnn",
  "cnn_depth": 48,
  "image_encoder_layers": 0,
  "image_decoder": "cnn",
  "image_decoder_layers": 0,
  "reward_input": False,
  "reward_decoder_layers": 4,
  "reward_decoder_categorical": None,
  "terminal_decoder_layers": 4,
  "map_stoch_dim": 64,
  "probe_model": "none",
  "map_decoder": "dense",
  "map_hidden_layers": 4,
  "map_hidden_dim": 1024,
  # Actor Critic
  "gamma": 0.995,
  "lambda_gae": 0.95,
  "entropy": 0.003,
  "target_interval": 100,
  "imag_horizon": 15,
  "actor_grad": "reinforce",
  "actor_dist": "onehot",
}


class Dreamer(nn.Module):

  def __init__(self, conf):
    super().__init__()
    assert conf.action_dim > 0, "Need to set action_dim to match environment"
    # state_dim = conf.deter_dim + conf.stoch_dim * (conf.stoch_discrete or 1)

    # World model

    self.config = conf
    self.wm = WorldModel(conf)
    self.action_dim = conf.action_dim

  def init_optimizers(self, lr, lr_actor=None, lr_critic=None, eps=1e-5):
    optimizer_wm = torch.optim.AdamW(self.wm.parameters(), lr=lr, eps=eps)
    return optimizer_wm

  def grad_clip(self, grad_clip, grad_clip_ac=None):
    grad_metrics = {
      'grad_norm': nn.utils.clip_grad_norm_(self.wm.parameters(), grad_clip),
    }
    return grad_metrics

  def init_state(self, batch_size: int):
    return self.wm.init_state(batch_size)

  def forward(
    self,
    obs: Dict[str, Tensor],
    in_state: Any,
  ) -> Tuple[D.Distribution, Any, Dict]:
    assert 'action' in obs, 'Observation should contain previous action'
    # act_shape = obs['action'].shape
    # assert len(act_shape) == 3 and act_shape[
    #   0] == 1, f'Expected shape (1,B,A), got {act_shape}'

    # Forward (world model)
    with torch.no_grad():
      features, out_state = self.wm.forward(obs, in_state)

    # feature = features[:, :, 0]  # (T=1,B,I=1,F) => (1,B,F)
    # action_distr = self.ac.forward_actor(feature)  # (1,B,A)
    # value = self.ac.forward_value(feature)  # (1,B)

    # metrics = dict(policy_value=value.detach().mean())
    # return action_distr, out_state, metrics
    return features, out_state

  def training_step(
    self,
    obs: Dict[str, Tensor],
    in_state: Any,
    iwae_samples: int = 1,
    imag_horizon: int = 1,
    do_open_loop=False,
    do_image_pred=False,
    do_dream_tensors=False,
  ):
    assert 'action' in obs, '`action` required in observation'
    assert 'reward' in obs, '`reward` required in observation'
    assert 'reset' in obs, '`reset` required in observation'
    assert 'terminal' in obs, '`terminal` required in observation'
    T, B = obs['action'].shape[:2]
    I, H = iwae_samples, imag_horizon

    # World model.

    loss_model, features, states, out_state, metrics, tensors = \
        self.wm.training_step(obs,
                              in_state,
                              iwae_samples=iwae_samples,
                              do_open_loop=do_open_loop,
                              do_image_pred=do_image_pred)

    # Dream for a log sample.

    dream_tensors = {}
    if do_dream_tensors and self.wm.decoder.image is not None:
      with torch.no_grad(
      ):  # careful not to invoke modules first time under no_grad (https://github.com/pytorch/pytorch/issues/60164)
        # The reason we don't just take real features_dream is because it's really big (H*T*B*I),
        # and here for inspection purposes we only dream from first step, so it's (H*B).
        # Oh, and we set here H=T-1, so we get (T,B), and the dreamed experience aligns with actual.
        in_state_dream: StateB = map_structure(
          states, lambda x: x.detach()[0, :, 0]
        )  # type: ignore  # (T,B,I) => (B)
        features_dream, rewards_dream, terminals_dream = self.dream(
          in_state_dream, obs["action"][1:], T - 1
        )  # H = T-1
        image_dream = self.wm.decoder.image.forward(features_dream)

        # The tensors are intentionally named same as in tensors, so the logged npz looks the same for dreamed or not
        dream_tensors = dict(
          reward_pred=rewards_dream.mean,
          terminal_pred=terminals_dream.mean,
          image_pred=image_dream,
        )
        assert dream_tensors['image_pred'].shape == obs['image'].shape

    return (loss_model,), out_state, metrics, tensors, dream_tensors

  def morel_policy_rollout(self, agent, in_state: StateB,
                           imag_horizon) -> Dict[str, Tensor]:
    # Should return results of shape [imag_horizon, batch_size, ...].
    features = []

    observations = []
    actions = []
    actions_one_hot = []
    logprobs = []
    values = []
    state = in_state
    self.wm.requires_grad_(
      False
    )  # Prevent dynamics gradiens from affecting world model

    with torch.no_grad():
      for i in range(imag_horizon):
        feature = self.wm.core.to_feature(*state)
        features.append(feature)

        image_dream = self.wm.decoder.image.forward(
          feature[None]
        )[0, :, ...].detach()  # [B, 1, H, W]
        obs = (image_dream + 0.5).clip(0, 1.)
        # print("obs", obs.shape)
        action, logprob, _, value = agent.get_action_and_value(obs)
        observations.append(obs)
        actions.append(action)
        logprobs.append(logprob)
        values.append(value)
        action = torch.eye(
          self.action_dim, dtype=torch.float32, device=action.device
        )[action]
        actions_one_hot.append(action)
        _, state = self.wm.core.cell.forward_prior(action, None, state)

      observations = torch.stack(observations)
      actions = torch.stack(actions)
      logprobs = torch.stack(logprobs)
      values = torch.stack(values)
      actions_one_hot = torch.stack(actions_one_hot)

      feature = self.wm.core.to_feature(*state)
      features.append(feature)
      features.pop(0)
      features = torch.stack(features)

      rewards = self.wm.decoder.reward.forward(features).mean
      terminals = self.wm.decoder.terminal.forward(features).mean

    return observations, actions, logprobs, values, rewards, terminals, actions_one_hot

  def combo_policy_rollout(
    self,
    batch,
    actor,
    imag_horizon,
    device,
    img_size=64
  ) -> Dict[str, Tensor]:

    def _resize(image):
      if img_size != 64:
        image = jax.image.resize(
          image, (image.shape[0], img_size, img_size, image.shape[-1]),
          "bilinear"
        )
      return image

    transitions = []
    # batch image shape: [T, B, C, H, W], range [-0.5 - 0.5]
    length, batch_size = batch["image"].shape[:2]
    # get first action from jax actor
    jax_observations = batch["image"][-1]  # [B, C, H, W]
    jax_observations = np.transpose(
      jax_observations, (0, 2, 3, 1)
    )  # [B, H, W, C]
    jax_observations = np.clip(
      (jax_observations + 0.5) * 255, 0, 255
    )  # [0, 1]
    jax_observations = _resize(jax_observations)
    jax_actions = actor.batch_select_action(jax_observations)  # [B,]
    batch["action"] = np.concatenate(
      [batch["action"][:-1], jax_actions[None]], axis=0
    )
    batch["action"] = preprocessing.to_onehot(
      batch["action"], self.config.action_dim
    )
    # print(f"batch action shape: {batch['action'].shape}")
    # dream rollout
    state = self.init_state(batch_size * self.config.iwae_samples)
    wm_batch = tree.map_structure(lambda x: torch.tensor(x).to(device), batch)
    features, state = self.wm.forward(wm_batch, state)
    features = features[:, :, 0]
    rewards = self.wm.decoder.reward.forward(features).mean
    terminals = self.wm.decoder.terminal.forward(features).mean
    images = self.wm.decoder.image.forward(features)
    # print(f"images: {images.shape}")
    # print(f"terminals: {terminals.shape}")
    # print(f"rewards: {rewards.shape}")
    # take the last timestep
    wb_image = images.cpu().detach().numpy()  # [T, B, C, H, W]
    jax_next_observations = wb_image[-1].transpose(
      (0, 2, 3, 1)
    )  # [B, H, W, C]
    jax_next_observations = np.clip(
      (jax_next_observations + 0.5) * 255, 0, 255
    )
    jax_next_observations = _resize(jax_next_observations)
    jax_rewards = rewards.cpu().detach().numpy()[-1]  # [B,]
    jax_terminals = terminals.cpu().detach().numpy()[-1] > 0.5  # [B,]
    jax_discounts = 1 - jax_terminals  # [B,]

    transition = Transition(
      observation=jax.device_get(jax_observations),
      action=jax_actions,
      reward=jax_rewards,
      discount=jax_discounts,
      next_observation=jax.device_get(jax_next_observations),
    )
    transitions.append(transition)
    for _ in range(imag_horizon - 1):
      mask = 1 - jax_terminals
      if mask.sum() == 0:
        break
      jax_observations = jax_next_observations[mask]  # [B, H, W, C]
      jax_actions = actor.batch_select_action(jax_observations)  # [B,]
      wm_batch = {}
      wm_batch["image"] = wb_image[-1][mask][None]
      wm_batch["action"] = preprocessing.to_onehot(
        jax_actions[None], self.config.action_dim
      )
      wm_batch["reset"] = jax_terminals[mask][None]
      state = (x[mask] for x in state)
      wm_batch = tree.map_structure(
        lambda x: torch.tensor(x).to(device), wm_batch
      )
      features, state = self.wm.forward(wm_batch, state)
      features = features[:, :, 0]
      rewards = self.wm.decoder.reward.forward(features).mean
      terminals = self.wm.decoder.terminal.forward(features).mean
      images = self.wm.decoder.image.forward(features)
      wb_image = images.cpu().detach().numpy()  # [T, B, C, H, W]
      jax_next_observations = wb_image[-1].transpose(
        (0, 2, 3, 1)
      )  # [B, H, W, C]
      jax_next_observations = np.clip(
        (jax_next_observations + 0.5) * 255, 0, 255
      )
      jax_next_observations = _resize(jax_next_observations)
      jax_rewards = rewards.cpu().detach().numpy()[-1]  # [B,]
      jax_terminals = terminals.cpu().detach().numpy()[-1] > 0.5  # [B,]
      jax_discounts = 1 - jax_terminals  # [B,]
      transition = Transition(
        observation=jax.device_get(jax_observations),
        action=jax_actions,
        reward=jax_rewards,
        discount=jax_discounts,
        next_observation=jax.device_get(jax_next_observations),
      )
      transitions.append(transition)

    return transitions

  def dream(
    self,
    in_state: StateB,
    given_actions,
    imag_horizon: int,
    policy=None,
  ):
    features = []
    # actions = []
    state = in_state
    self.wm.requires_grad_(
      False
    )  # Prevent dynamics gradiens from affecting world model

    for i in range(imag_horizon):
      feature = self.wm.core.to_feature(*state)
      features.append(feature)
      # When using dynamics gradients, this causes gradients in RSSM, which we don't want.
      # This is handled in backprop - the optimizer_model will ignore gradients from loss_actor.
      _, state = self.wm.core.cell.forward_prior(given_actions[i], None, state)

    feature = self.wm.core.to_feature(*state)
    features.append(feature)
    features = torch.stack(features)  # (H+1,TBI,D)
    # actions = torch.stack(actions)  # (H,TBI,A)

    rewards = self.wm.decoder.reward.forward(features)  # (H+1,TBI)
    terminals = self.wm.decoder.terminal.forward(features)  # (H+1,TBI)

    self.wm.requires_grad_(True)
    return features, rewards, terminals

  def __str__(self):
    # Short representation
    s = []
    s.append(f'Model: {param_count(self)} parameters')
    for submodel in (
      self.wm.encoder, self.wm.decoder, self.wm.core, self.ac, self.probe_model
    ):
      if submodel is not None:
        s.append(
          f'  {type(submodel).__name__:<15}: {param_count(submodel)} parameters'
        )
    return '\n'.join(s)

  def __repr__(self):
    # Long representation
    return super().__repr__()


class WorldModel(nn.Module):

  def __init__(self, conf):
    super().__init__()

    self.deter_dim = conf.deter_dim
    self.stoch_dim = conf.stoch_dim
    self.stoch_discrete = conf.stoch_discrete
    self.kl_weight = conf.kl_weight
    self.kl_balance = None if conf.kl_balance == 0.5 else conf.kl_balance

    # Encoder

    self.encoder = MultiEncoder(conf)

    # Decoders

    features_dim = conf.deter_dim + conf.stoch_dim * (conf.stoch_discrete or 1)
    self.decoder = MultiDecoder(features_dim, conf)

    # RSSM

    self.core = RSSMCore(
      embed_dim=self.encoder.out_dim,
      action_dim=conf.action_dim,
      deter_dim=conf.deter_dim,
      stoch_dim=conf.stoch_dim,
      stoch_discrete=conf.stoch_discrete,
      hidden_dim=conf.hidden_dim,
      gru_layers=conf.gru_layers,
      gru_type=conf.gru_type,
      layer_norm=conf.layer_norm
    )

    # Init

    for m in self.modules():
      init_weights_tf2(m)

  def init_state(self, batch_size: int) -> Tuple[Any, Any]:
    return self.core.init_state(batch_size)

  def forward(self, obs: Dict[str, Tensor], in_state: Any):
    loss, features, states, out_state, metrics, tensors = \
        self.training_step(obs, in_state, forward_only=True)
    return features, out_state

  def training_step(
    self,
    obs: Dict[str, Tensor],
    in_state: Any,
    iwae_samples: int = 1,
    do_open_loop=False,
    do_image_pred=False,
    forward_only=False
  ):

    # Encoder
    # _check_shape("obs", obs)
    embed = self.encoder(obs)

    # RSSM

    prior, post, post_samples, features, states, out_state = \
        self.core.forward(embed,
                          obs['action'],
                          obs['reset'],
                          in_state,
                          iwae_samples=iwae_samples,
                          do_open_loop=do_open_loop)

    if forward_only:
      return torch.tensor(0.0), features, states, out_state, {}, {}

    # Decoder

    loss_reconstr, metrics, tensors = self.decoder.training_step(features, obs)

    # KL loss

    d = self.core.zdistr
    dprior = d(prior)
    dpost = d(post)
    loss_kl_exact = D.kl.kl_divergence(dpost, dprior)  # (T,B,I)
    if iwae_samples == 1:
      # Analytic KL loss, standard for VAE
      if not self.kl_balance:
        loss_kl = loss_kl_exact
      else:
        loss_kl_postgrad = D.kl.kl_divergence(dpost, d(prior.detach()))
        loss_kl_priograd = D.kl.kl_divergence(d(post.detach()), dprior)
        loss_kl = (
          1 - self.kl_balance
        ) * loss_kl_postgrad + self.kl_balance * loss_kl_priograd
    else:
      # Sampled KL loss, for IWAE
      z = post_samples.reshape(dpost.batch_shape + dpost.event_shape)
      loss_kl = dpost.log_prob(z) - dprior.log_prob(z)

    # Total loss

    assert loss_kl.shape == loss_reconstr.shape
    loss_model_tbi = self.kl_weight * loss_kl + loss_reconstr
    loss_model = -logavgexp(-loss_model_tbi, dim=2)

    # Metrics

    with torch.no_grad():
      loss_kl = -logavgexp(
        -loss_kl_exact, dim=2
      )  # Log exact KL loss even when using IWAE, it avoids random negative values
      entropy_prior = dprior.entropy().mean(dim=2)
      entropy_post = dpost.entropy().mean(dim=2)
      tensors.update(
        loss_kl=loss_kl.detach(),
        entropy_prior=entropy_prior,
        entropy_post=entropy_post
      )
      metrics.update(
        loss_model=loss_model.mean(),
        loss_kl=loss_kl.mean(),
        entropy_prior=entropy_prior.mean(),
        entropy_post=entropy_post.mean()
      )

    # Predictions

    if do_image_pred:
      with torch.no_grad():
        prior_samples = self.core.zdistr(prior).sample().reshape(
          post_samples.shape
        )
        features_prior = self.core.feature_replace_z(features, prior_samples)
        # Decode from prior
        _, mets, tens = self.decoder.training_step(
          features_prior, obs, extra_metrics=True
        )
        metrics_logprob = {
          k.replace('loss_', 'logprob_'): v
          for k, v in mets.items()
          if k.startswith('loss_')
        }
        tensors_logprob = {
          k.replace('loss_', 'logprob_'): v
          for k, v in tens.items()
          if k.startswith('loss_')
        }
        tensors_pred = {
          k.replace('_rec', '_pred'): v
          for k, v in tens.items()
          if k.endswith('_rec')
        }
        metrics.update(**metrics_logprob)  # logprob_image, ...
        tensors.update(**tensors_logprob)  # logprob_image, ...
        tensors.update(**tensors_pred)  # image_pred, ...

    return loss_model.mean(), features, states, out_state, metrics, tensors
