"""
Defines modules related to dynamics modelling.
"""

import copy
import itertools
import logging
import math
import os
from typing import Callable, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal, Normal
import gymnasium as gym

from . import op
from .utils import make_mlp, get_activation


LOG = logging.getLogger(__name__)
LOG.addHandler(logging.NullHandler())


def tril_covariance(values : torch.Tensor, dim : int,
                    diagonal_covariance : bool = False,
                    minimum_std : float = 0.0):
    batch_dims = values.shape[:-1]
    prefix_idxs = tuple(slice(None) for _ in range(len(batch_dims)))

    diag_values = values[prefix_idxs + (slice(None, dim),)]

    diag_idxs = torch.arange(dim).repeat(2,1)
    diagdata = torch.zeros(batch_dims + (dim, dim), device=values.device)

    #diagdata[prefix_idxs + (diag_idxs[0], diag_idxs[1])] = F.softplus(diag_values) + minimum_std
    # Clip std if too small
    diagdata[prefix_idxs + (diag_idxs[0], diag_idxs[1])] = torch.maximum(
        F.softplus(diag_values),
        torch.as_tensor(minimum_std, device=values.device),
    )

    chk = F.softplus(diag_values)
    if torch.isnan(chk).any():
        LOG.warning(f"NaN covariance: {chk}\nGenerated from these values: {diag_values}")

    if diagonal_covariance:
        return diagdata
    else:
        tril_idxs = torch.tril_indices(dim, dim, offset=-1)
        cov_values = values[prefix_idxs + (slice(dim, None),)]
        covdata = torch.zeros(batch_dims + (dim, dim), device=values.device)
        covdata[prefix_idxs + (tril_idxs[0], tril_idxs[1])] = torch.tanh(cov_values)
        return diagdata + covdata


def mv_log_prob_withdist(mv : MultivariateNormal, value, distance=torch.sub):
    from torch.distributions.multivariate_normal import _batch_mahalanobis
    if mv._validate_args:
        mv._validate_sample(value)
    diff = distance(value, mv.loc)
    M = _batch_mahalanobis(mv._unbroadcasted_scale_tril, diff)
    half_log_det = (
        mv._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
    )
    return -0.5 * (mv._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det


def normal_log_prob_withdist(dist : Normal, value, distance=torch.sub):
    # N(x|μ,σ) = 1/(σ * sqrt(2π)) * exp(-0.5 * (d(x, μ)/σ)²)
    # log N(x|μ,σ) = -0.5 * (d(x, μ)/σ)² - log(σ * sqrt(2π))
    #              = -0.5 * (d(x, µ)/σ)² - log(σ) - log(sqrt(2π))
    d = distance(value, dist.loc)
    return (
        -0.5 * torch.square(d / dist.scale)
        -torch.log(dist.scale)
        -math.log(2 * math.pi)
    )



class NNStepModel(torch.nn.Module):
    def sample(self, s, a, deterministic=True):
        raise NotImplementedError

    def loss(self, s_next, s, a, deterministic=True):
        """
        Returns a tuple (loss, info)
        """
        raise NotImplementedError


# ----------------------------------- #
#    Explicitly predictive models     #
# (Special case of sequential models) #
# ----------------------------------- #


class NNPredictiveModel(nn.Module):
    """
    Model dimensions:
        N = batch size
        L = maximum length of any batch
        S = state size
        A = action size

    Note that we assume a flat state and action size.
    """
    def __init__(self, optimizable=False):
        super().__init__()
        self.optimizable = optimizable

    def loss(self,
             state : torch.Tensor,
             actions : torch.Tensor,
             next_states : torch.Tensor,
             lengths : torch.Tensor,
             distance : Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.sub,
             deterministic : bool = True,
            ) -> Tuple[torch.Tensor, dict]:
        """
        Dimensions:
         - state: [N, S]
         - actions: [N, L, A]
         - next_states: [N, L, S]
         - lengths: [N] (long tensor)
         - distance: [*,S],[*,S] -> [*,S]

        Return: (
            singular 0-dim tensor (i.e. a wrapped scalar value),
            loss information, {"stats": {k: v}}
        )
        """
        raise NotImplementedError

    def sample(self,
               state : torch.Tensor,
               actions : torch.Tensor,
               lengths : torch.Tensor,
               override_mu : Optional[torch.Tensor] = None,
               n_samples : Optional[int] = None,
               deterministic : bool = False,
               with_info : bool = False
              ) -> Tuple[torch.Tensor, dict]:
        """
        Dimensions:
         - state: [N, S]
         - actions: [N, L, A]
         - lengths: [N] (long tensor)

        if n_samples is None:
         - return dim: [N, S]
        else:
         - return dim: [N, n_samples, S]

        The info follows the {"stats": {k: v}} structure.
        """
        raise NotImplementedError

    def generate_actions(self,
                         policy : nn.Module,
                         state : torch.Tensor,
                         actions : torch.Tensor,
                         lengths : torch.Tensor,
                         blank_horizon : int = 1,
                         generate_from_latent : bool = True,
                         deterministic : bool = False):
        """
        Generates actions based predictions, and then keeps generating for
        a blank horizon.

        The blank horizon states that, for any given state prediction, we
        generate this many actions based on that state. A blank_horizon of 1
        states that we only generate an action based on the predicted state,
        and do not perform any blank rollout to generate more actions.

        Dimensions:
         - state: [N, S]
         - actions: [N, L, A]
         - lengths: [N] (long tensor)

        Return dim: [N, L, blank_horizon, A]

        if generate_from_latent:
            policy: [N, HidRec] -> [N, A]
        else:
            policy: [N, S] -> [N, A]
        """
        raise NotImplementedError


class NNPredictivePerfectEnvironment(NNPredictiveModel):
    """
    Model dimensions:
        N = batch size
        L = maximum length of any batch
        S = state size
        A = action size

    Note that we assume a flat state and action size.
    """
    def __init__(self, env, use_vec_env=False):
        import latency_env
        super().__init__(optimizable=False)
        self.mujoco_update_0state = False
        self.internal_state_update = False

        if isinstance(env.unwrapped, (gym.envs.mujoco.MujocoEnv,)):
            self.mujoco_update_0state = True
        elif isinstance(env.unwrapped, (latency_env.furuta_ode.FurutaODEPendulum,)):
            self.internal_state_update = True
        else:
            raise ValueError(f"Unsupported environment class: {type(env)}")

        def env_fn():
            return copy.deepcopy(env)

        self.base_env = env_fn()

        self.use_vec_env = use_vec_env
        self.vec_env = None
        if use_vec_env:
            self.vec_env = gym.vector.AsyncVectorEnv([env_fn for _ in range(os.cpu_count())])

        self.dummy_parameter = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))

    def extra_repr(self):
        return f"env={self.base_env.unwrapped}"

    def loss(self,
             state : torch.Tensor,
             actions : torch.Tensor,
             next_states : torch.Tensor,
             lengths : torch.Tensor,
             distance : Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.sub,
            ) -> Tuple[torch.Tensor, dict]:
        """
        Expected dimensions:
         - state: [N, S]
         - actions: [N, L, A]
         - next_states: [N, L, S]
         - lengths: [N] (long tensor)
         - distance: [*,S],[*,S] -> [*,S]

        Returns: singular 0-dim tensor (i.e. a wrapped scalar value)
        """
        return (torch.tensor(0.0), {})

    def sample(self,
               state : torch.Tensor,
               actions : torch.Tensor,
               lengths : torch.Tensor,
               override_mu : Optional[torch.Tensor] = None,
               n_samples : Optional[int] = None,
               deterministic : bool = False,
               with_info : bool = False):
        """
        Expected dimensions:
         - state: [N, S]
         - actions: [N, L, A]
         - lengths: [N] (long tensor)

        if n_samples is None:
         - return dim: [N, S]
        else:
         - return dim: [N, n_samples, S]

        Since this is a baseline model based on a perfect environment, we are
        essentially ignoring most parameters and just stepping through the
        environment in the best way we can.
        """
        N, S = state.shape
        N, Lp1, A = actions.shape

        n_samples_squeeze = bool(n_samples is None)
        if n_samples is None:
            n_samples = 1

        if override_mu is not None:
            assert override_mu.shape == (N, S)
            output = override_mu.unsqueeze(1).repeat(1, n_samples, 1)
        else:
            output = torch.zeros((N, n_samples, S), dtype=torch.float32)

            if self.mujoco_update_0state:
                self.base_env.reset()
                ref_qpos = self.base_env.unwrapped.data.qpos.copy()
                ref_qvel = self.base_env.unwrapped.data.qvel.copy()

            for n in range(N):
                s = state[n].cpu().numpy()
                if self.mujoco_update_0state:
                    qpos = ref_qpos.copy()
                    qvel = ref_qvel.copy()
                    qpos[1:] = s[:len(qpos)-1]
                    qvel[:] = s[len(qpos)-1:]

                if self.use_vec_env:
                    i = 0
                    while i < n_samples:
                        self.vec_env.reset()
                        if self.mujoco_update_0state:
                            self.vec_env.call("set_state", qpos, qvel)
                        elif self.internal_state_update:
                            self.vec_env.call("set_internal_state", s)

                        bs = np.stack([s]*self.vec_env.num_envs)
                        for l in range(lengths[n]):
                            a = actions[n, l].cpu().numpy()
                            bs, _, _, _, _ = self.vec_env.step([a]*self.vec_env.num_envs)

                        j = self.vec_env.num_envs
                        if (i + j) > n_samples:
                            j = n_samples - i
                            bs = bs[:j]

                        output[n, i:(i+j)] = torch.as_tensor(bs)

                        i += self.vec_env.num_envs
                else:
                    for i in range(n_samples):
                        self.base_env.reset()
                        if self.mujoco_update_0state:
                            self.base_env.unwrapped.set_state(qpos, qvel)
                        elif self.internal_state_update:
                            self.base_env.unwrapped.set_internal_state(s)

                        si = copy.copy(s)
                        for l in range(lengths[n]):
                            a = actions[n, l].cpu().numpy()
                            si, _, _, _, _ = self.base_env.step(a)

                        output[n, i] = torch.as_tensor(si)

        if n_samples_squeeze:
            output = output.squeeze(1)

        output = output.to(state.device)

        return (output, {"stats": {}})


class NNEmitterBase(nn.Module):
    """
    Base class for all emitters.
    """
    def emit(self,
             embed : torch.Tensor,
             deterministic=True,
            ) -> Tuple[torch.Tensor, dict]:
        """
        Emits a state.

        Dimensions:
          embed : [N, H]
          return : [N, S]
        """
        raise NotImplementedError

    def loss(self,
             embed : torch.Tensor,
             targets : torch.Tensor,
             distance : Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
             deterministic=True,
            ) -> Tuple[torch.Tensor, dict]:
        """
        Emits the batch loss based on the embedded states.

        Dimensions:
          embed : [N, H]
          targets : [N, S]
          return : [N]
        """
        raise NotImplementedError


class NNRegressiveEmitter(NNEmitterBase):
    """
    Deterministic regressive emitter.
    """
    def __init__(self, state_size, embed_size, hidden_size,
                 n_layers,
                 dropout=0.0,
                 activation=nn.SiLU,
                 **kwargs,
                ):
        super().__init__(**kwargs)

        activation = get_activation(activation)

        self.emitter = make_mlp(
            indim=embed_size,
            hidden_dims=[hidden_size]*n_layers,
            outdim=state_size,
            dropout=dropout,
            activation=activation,
        )

    def emit(self,
             embed : torch.Tensor,
             deterministic=True,
            ) -> Tuple[torch.Tensor, dict]:
        """
        Emits a state.

        Dimensions:
          embed : [N, H]
          return : [N, S]
        """
        return (self.emitter(embed), {})

    def loss(self,
             embed : torch.Tensor,
             targets : torch.Tensor,
             distance : Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
             deterministic=True,
            ) -> Tuple[torch.Tensor, dict]:
        """
        Emits the batch loss based on the embedded states.

        Dimensions:
          embed : [N, H]
          targets : [N, S]
          return : [N]
        """
        return (distance(self.emitter(embed), targets).square().mean(dim=1), {})


class NNGaussianEmitterBase(NNEmitterBase):
    """
    A gaussian emitter. Emits mean and standard deviation (diagonal covariance)
    from embeddings.
    """
    def __init__(self, log_std_min=-20.0, log_std_max=2.0):
        super().__init__()
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

    def emit(self,
             embed : torch.Tensor,
             deterministic=True,
            ) -> Tuple[torch.Tensor, dict]:
        """
        Emits a state.

        Dimensions:
          embed : [N, H]
          return : [N, S]
        """
        mu, std = self.mean_std(embed, deterministic=deterministic)
        dist = Normal(loc=mu, scale=std)
        if deterministic:
            ret = mu
        else:
            ret = dist.rsample()

        with torch.no_grad():
            info = {
                "mu": mu.detach(),
                "std": std.detach(),
            }

        return (ret, info)

    def mean_std(self,
                 embed : torch.Tensor,
                 deterministic=True,
                ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Emits the mean and std from embedded states of the network.

        Dimensions:
          embed : [N, H]
          return : ([N, S], [N, S])
        """
        raise NotImplementedError


class NNGaussianProbabilisticEmitter(NNGaussianEmitterBase):
    """
    Probabilistic emitters for a gaussian predictor.
    """
    def __init__(self, state_size, embed_size, hidden_size,
                 common_layers, head_layers,
                 dropout=0.0,
                 activation=nn.SiLU,
                 **kwargs,
                ):
        super().__init__(**kwargs)

        activation = get_activation(activation)

        self.emitter = op.NNLayerConcat(dim=-1, next=nn.Sequential(
            make_mlp(
                indim=embed_size,
                hidden_dims=[hidden_size]*common_layers,
                outdim=hidden_size,
                dropout=dropout,
                activation=activation,
            ),
            activation(),
            op.NNLayerHeadSplit(
                mu=make_mlp(
                    indim=hidden_size,
                    hidden_dims=[hidden_size]*head_layers,
                    outdim=state_size,
                    dropout=dropout,
                    activation=activation,
                ),
                log_std=make_mlp(
                    indim=hidden_size,
                    hidden_dims=[hidden_size]*head_layers,
                    outdim=state_size,
                    dropout=dropout,
                    activation=activation,
                ),
            )
        ))

    def mean_std(self,
                 embed : torch.Tensor,
                 deterministic=True,
                ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Emits the mean and std from embedded states.

        Dimensions:
          embed : [N, H]
          return : ([N, S], [N, S])
        """
        emit = self.emitter(embed)
        return (
            emit.mu,
            emit.log_std.clamp(min=self.log_std_min, max=self.log_std_max).exp(),
        )

    def loss(self,
             embed : torch.Tensor,
             targets : torch.Tensor,
             distance : Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
             deterministic=True,
            ) -> Tuple[torch.Tensor, dict]:
        """
        Emits the batch loss based on the embedded states.

        Dimensions:
          embed : [N, H]
          targets : [N, S]
          return : [N]
        """
        N, H = embed.shape
        N, S = targets.shape

        mu, std = self.mean_std(embed, deterministic=deterministic)
        dist = Normal(loc=mu, scale=std)

        neg_lp = -normal_log_prob_withdist(dist, targets, distance=distance).sum(dim=-1)
        assert neg_lp.shape == (N,)

        with torch.no_grad():
            info = {}

        return (neg_lp, info)


class NNGaussianVAEEmitter(NNGaussianEmitterBase):
    """
    A VAE emitter for embedded states.

    To predict the state s, we use the following approach:

    z ~ N(*|0,I)
    ŝ ~ dec(* | z, h)

    The for the loss we also train an encoder as:

    ẑ ~ enc(* | s, h)

    Note that this is the regular VAE setup, except that all distributions are
    conditioned on the embedding h.
    """
    def __init__(self, state_size, embed_size, hidden_size,
                 vae_latent_size,
                 common_layers, head_layers,
                 dropout=0.0,
                 activation=nn.SiLU,
                 **kwargs,
                ):
        super().__init__(**kwargs)

        activation = get_activation(activation)

        self.vae_latent_size = vae_latent_size

        self.register_buffer("prior_mu", torch.zeros(vae_latent_size))
        self.register_buffer("prior_std", torch.ones(vae_latent_size))

        self.encoder = op.NNLayerConcat(dim=-1, next=nn.Sequential(
            make_mlp(
                indim=(state_size + embed_size),
                hidden_dims=[hidden_size]*common_layers,
                outdim=hidden_size,
                dropout=dropout,
                activation=activation,
            ),
            activation(),
            op.NNLayerHeadSplit(
                mu=make_mlp(
                    indim=hidden_size,
                    hidden_dims=[hidden_size]*head_layers,
                    outdim=vae_latent_size,
                    dropout=dropout,
                    activation=activation,
                ),
                log_std=make_mlp(
                    indim=hidden_size,
                    hidden_dims=[hidden_size]*head_layers,
                    outdim=vae_latent_size,
                    dropout=dropout,
                    activation=activation,
                ),
            )
        ))

        self.decoder = op.NNLayerConcat(dim=-1, next=nn.Sequential(
            make_mlp(
                indim=(vae_latent_size + embed_size),
                hidden_dims=[hidden_size]*common_layers,
                outdim=hidden_size,
                dropout=dropout,
                activation=activation,
            ),
            activation(),
            op.NNLayerHeadSplit(
                mu=make_mlp(
                    indim=hidden_size,
                    hidden_dims=[hidden_size]*head_layers,
                    outdim=state_size,
                    dropout=dropout,
                    activation=activation,
                ),
                log_std=make_mlp(
                    indim=hidden_size,
                    hidden_dims=[hidden_size]*head_layers,
                    outdim=state_size,
                    dropout=dropout,
                    activation=activation,
                ),
            )
        ))

    def mean_std(self,
                 embed : torch.Tensor,
                 deterministic=False,
                ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Emits the mean and std from embedded states.

        Dimensions:
          embed : [N, H]
          return : ([N, S], [N, S])
        """
        N, H = embed.shape
        Z = self.vae_latent_size

        prior = Normal(self.prior_mu, self.prior_std)
        if deterministic:
            z = prior.mean.reshape(1, Z).repeat(N, 1)
        else:
            z = prior.rsample((N,))

        dec = self.decoder(z, embed)
        return (
            dec.mu,
            dec.log_std.clamp(min=self.log_std_min, max=self.log_std_max).exp()
        )

    def loss(self,
             embed : torch.Tensor,
             targets : torch.Tensor,
             distance : Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
             deterministic=True,
            ) -> Tuple[torch.Tensor, dict]:
        """
        Emits the batch loss based on the embedded states.

        Dimensions:
          embed : [N, H]
          targets : [N, S]
          return : [N]
        """
        N, H = embed.shape
        N, S = targets.shape

        enc = self.encoder(targets, embed)

        enc_dist = MultivariateNormal(
            loc=enc.mu,
            scale_tril=torch.diag_embed(
                enc.log_std.clamp(min=self.log_std_min, max=self.log_std_max).exp()
            ),
        )
        prior = MultivariateNormal(
            loc=self.prior_mu,
            scale_tril=torch.diag_embed(self.prior_std),
        )

        if deterministic:
            z = enc_dist.mean # [N, Z]
        else:
            z = enc_dist.rsample()

        dec = self.decoder(z, embed)

        dec_dist = Normal(
            dec.mu,
            dec.log_std.clamp(min=self.log_std_min, max=self.log_std_max).exp(),
        )

        # ELBO loss
        loss_obs = -normal_log_prob_withdist(dec_dist, targets, distance=distance).sum(dim=-1)
        loss_kl = torch.distributions.kl.kl_divergence(enc_dist, prior)
        assert loss_obs.shape == (N,), f"loss_obs.shape = {loss_obs.shape}"
        assert loss_kl.shape == (N,), f"loss_kl.shape = {loss_kl.shape}"

        loss_total = loss_obs + loss_kl

        with torch.no_grad():
            info = {
                "loss_obs": loss_obs.detach(),
                "loss_kl": loss_kl.detach(),
            }

        return (loss_total, info)




class NNPredictiveRecurrent(NNPredictiveModel):
    """
    A recurrent predictor with deterministic latent transitions and emitters
    from those latent embeddings.
    """
    def __init__(self, state_size, action_size,
                 emitter,
                 hidden_rec_size=256,
                 hidden_size=256,
                 dropout=0.0,
                 emit_dropout=0.0,
                 state_embed_layers=2,
                 action_embed_layers=2,
                 recurrent_layers=1,
                 ignore_h0=False,
                 length_scale=True,
                 activation=nn.SiLU):
        super().__init__(optimizable=True)
        assert isinstance(emitter, NNEmitterBase)
        self.emitter = emitter

        activation = get_activation(activation)

        self.state_size = state_size
        self.action_size = action_size
        self.hidden_size = hidden_size
        self.hidden_rec_size  = hidden_rec_size
        self.recurrent_layers = recurrent_layers
        self.dropout = dropout
        self.emit_dropout = emit_dropout
        self.ignore_h0 = ignore_h0
        self.length_scale = length_scale

        self.net_embed_state = make_mlp(
            indim=state_size,
            hidden_dims=[hidden_size]*state_embed_layers,
            outdim=(recurrent_layers * hidden_rec_size),
            dropout=dropout,
            activation=activation,
        )
        if action_embed_layers == 0:
            self.net_embed_action = nn.Identity()
            self.input_rec_size = action_size
        else:
            self.net_embed_action = make_mlp(
                indim=action_size,
                hidden_dims=[hidden_size]*(action_embed_layers - 1),
                outdim=hidden_size,
                dropout=dropout,
                activation=activation,
            )
            self.input_rec_size = hidden_size

        self.net_rec = nn.GRU(
            input_size=self.input_rec_size,
            hidden_size=self.hidden_rec_size,
            num_layers=self.recurrent_layers,
            dropout=dropout,
            batch_first=True,
        )

    def embed_inputs(self,
                     state : torch.Tensor,
                     actions : torch.Tensor,
                     omit_h0 = False,
                    ) -> torch.Tensor:
        """
        Dimensions:
         - state: [N, S]
         - actions: [N, L, A]

        Returned dimension: [N, L+1, H]
        """
        N, S = state.shape
        N, L, A = actions.shape
        H = self.hidden_rec_size
        Rec = self.recurrent_layers

        h0 = self.net_embed_state(state) # [N, S] -> [N, Rec * H]
        h0 = h0.reshape(N, Rec, H).permute(1, 0, 2).contiguous() # [N, Rec * H] -> [N, Rec, H] -> [Rec, N, H]
        emb_act = self.net_embed_action(actions) # [N, L, A] -> [N, L, H]

        emb_recur, _ = self.net_rec(emb_act, h0) # [N, L, H] -> [N, L, H]

        if omit_h0:
            emb = emb_recur
            assert emb.shape == (N, L, H)
        else:
            emb = torch.cat([h0[-1].unsqueeze(1), emb_recur], dim=1) # ([N, H], [N, L, H]) -> [N, L+1, H]
            assert emb.shape == (N, L + 1, H)

        return emb

    def embed_onestep(self,
                      h0 : torch.Tensor,
                      action : torch.Tensor,
                     ) -> torch.Tensor:
        """
        A single-step embedding of an action, given an initial hidden recurrent
        state.

        Dimensions:
         - h0: [N, H]
         - action: [N, A]

        Returned dimension: [N, H]
        """
        N, H = h0.shape
        N, A = action.shape
        InRec = self.input_rec_size

        emb_act = self.net_embed_action(action) # [N, A] -> [N, InRec]

        # Squeeze or "pretend length" dim, then remove it before returning
        emb_act = emb_act.reshape(N, 1, InRec) # [N, InRec] -> [N, 1, InRec]
        h0 = h0.reshape(1, N, H) # [N, H] -> [1, N, H]

        emb_recur, _ = self.net_rec(emb_act.contiguous(), h0.contiguous()) # [N, 1, H] -> [N, 1, H]
        emb = emb_recur.reshape(N, H) # [N, 1, H] -> [N, H]
        return emb

    def loss(self,
             state : torch.Tensor,
             actions : torch.Tensor,
             next_states : torch.Tensor,
             lengths : torch.Tensor,
             distance : Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.sub,
             deterministic : bool = True,
            ) -> Tuple[torch.Tensor, dict]:
        """
        Dimensions:
         - state: [N, S]
         - actions: [N, L, A]
         - next_states: [N, L, S]
         - lengths: [N] (long tensor)
         - distance: [*,S],[*,S] -> [*,S]

        Return: (
            singular 0-dim tensor (i.e. a wrapped scalar value),
            loss information,
        )
        """
        N, S = state.shape
        N, L, A = actions.shape
        H = self.hidden_rec_size

        embed = self.embed_inputs(state, actions)
        assert embed.shape == (N, L + 1, H)

        targets = torch.cat([state.unsqueeze(1), next_states], dim=1)
        assert targets.shape == (N, L + 1, S)

        # Set up mask and scaler
        with torch.no_grad():
            loss_mask = torch.arange(L + 1, device=lengths.device) <= lengths.unsqueeze(-1)
            loss_scaler = torch.ones((N, L + 1), device=lengths.device)
            assert loss_mask.shape == (N, L + 1)
            assert loss_scaler.shape == (N, L + 1)

            if self.ignore_h0:
                loss_mask[:, 0] = False

            if self.length_scale:
                loss_invlength = (1.0 / (lengths + 1)).unsqueeze(-1)
                assert loss_invlength.shape == (N, 1)
                loss_scaler = loss_scaler * loss_invlength

        emit_loss, emit_info = self.emitter.loss(
            embed.reshape(N * (L + 1), H),
            targets.reshape(N * (L + 1), S),
            distance=distance,
            deterministic=deterministic,
        )
        if self.emit_dropout > 0.0:
            emit_loss = F.dropout(emit_loss, p=self.emit_dropout)

        loss = (loss_scaler * (loss_mask * emit_loss.reshape(N, L + 1))).sum(dim=1).mean()

        with torch.no_grad():
            info = {"stats": {}}
            for k, infoloss in emit_info.items():
                if isinstance(infoloss, torch.Tensor):
                    info["stats"][k] = (loss_scaler * (loss_mask * infoloss.reshape(N, L + 1))).sum(dim=1).mean()

        return (loss, info)

    def sample(self,
               state : torch.Tensor,
               actions : torch.Tensor,
               lengths : torch.Tensor,
               override_mu : Optional[torch.Tensor] = None,
               n_samples : Optional[int] = None,
               deterministic : bool = False,
               with_info : bool = False
              ) -> Tuple[torch.Tensor, dict]:
        """
        Dimensions:
         - state: [N, S]
         - actions: [N, L, A]
         - lengths: [N] (long tensor)

        if n_samples is None:
         - return dim: [N, S]
        else:
         - return dim: [N, n_samples, S]
        """
        N, S = state.shape
        N, L, A = actions.shape
        H = self.hidden_rec_size

        info = {}

        n_samples_squeeze = bool(n_samples is None)
        if n_samples_squeeze:
            n_samples = 1

        embed = self.embed_inputs(state, actions)
        assert embed.shape == (N, L + 1, H)

        embstack = torch.stack([embed[i, lengths[i]] for i in range(N)])
        assert embstack.shape == (N, H)

        embed = embstack.reshape(N, 1, H).repeat(1, n_samples, 1).reshape(N * n_samples, H)
        info["embed"] = embed
        if n_samples_squeeze:
            info["embed"] = embed.squeeze(1)

        output, output_info = self.emitter.emit(embed, deterministic=deterministic)

        output = output.reshape(N, n_samples, S)
        if n_samples_squeeze:
            output = output.squeeze(1)

        with torch.no_grad():
            info["stats"] = {}
            for k, oi in output_info.items():
                if isinstance(oi, torch.Tensor):
                    info["stats"][k] = oi.reshape(*((N, n_samples) + oi.shape[1:]))
                    if n_samples_squeeze:
                        info["stats"][k] = info["stats"][k].squeeze(1)

        return (output, info)

    def generate_actions(self,
                         policy : nn.Module,
                         state : torch.Tensor,
                         actions : torch.Tensor,
                         blank_horizon : int = 1,
                         generate_from_latent : bool = True,
                         deterministic : bool = False):
        """
        Generates actions based predictions, and then keeps generating for
        a blank horizon.

        The blank horizon states that, for any given state prediction, we
        generate this many actions based on that state. A blank_horizon of 1
        states that we only generate an action based on the predicted state,
        and do not perform any blank rollout to generate more actions.

        Dimensions:
         - state: [N, S]
         - actions: [N, L, A]
         - lengths: [N] (long tensor)

        Return dim: [N, L, blank_horizon, A]

        if generate_from_latent:
            policy: [N, HidRec] -> [N, A] x [N]
        else:
            policy: [N, S] -> [N, A] x [N]

        The second output from the policy is the logprob that the action was
        chosen.
        """
        N, S = state.shape
        N, L, A = actions.shape
        H = self.hidden_rec_size

        assert L > 0
        assert self.recurrent_layers == 1, "Can only use this with a single recurrent layer"
        assert not torch.isnan(actions).any(), "Cannot have NaN actions"

        # Perform the memorized prediction, retaining latent representations
        embed = self.embed_inputs(state, actions)
        assert embed.shape == (N, L + 1, H)
        embed = embed[:, 1:, :]
        assert embed.shape == (N, L, H)

        # This masking is only really necessary for a single case, when we
        # have no actions and just the state. Since we will always have at
        # least sent a horizon of actions to the buffer, or the buffer is
        # simply empty. Then again, we could probably just have the buffer be
        # filled with something corresponding to the do-nothing action.
        perform_masking = False
        if perform_masking:
            # Replace anything greater than the length with the final prediction
            emblast = torch.stack([embed[i, lengths[i]] for i in range(N)])
            assert emblast.shape == (N, H)
            mask = torch.arange(L, device=lengths.device).repeat(N, 1) < lengths.reshape(N, 1)
            assert mask.shape == (N, L)
            mask = mask.reshape(N, L, 1)

            embmasked = (
                mask * embed
              + mask.logical_not() * emblast.reshape(N, 1, H).repeat(1, L, 1)
            )
            assert embmasked.shape == (N, L, H)
            embed = embmasked

        embed = embed.reshape(N * L, H) # [N, L, H] -> [N * L, H]

        # Do blank predictions
        output = torch.zeros((N, L, blank_horizon, A), device=embed.device)
        output_logprob = torch.zeros((N, L, blank_horizon), device=embed.device)
        for i in range(blank_horizon):
            if generate_from_latent:
                pi_input = embed
            else:
                emit, _ = self.emitter.emit(embed, deterministic=deterministic)
                pi_input = emit

            (a, a_logp) = policy(pi_input)
            output[:, :, i, :] = a.reshape(N, L, A)
            output_logprob[:, :, i] = a_logp.reshape(N, L)

            # Do a blank step
            next_embed = self.embed_onestep(embed, a)

        return (output, output_logprob)
