# Copyright (c) 2019 Kai Arulkumaran (Original PlaNet parts) Copyright (c) 2020 Yusuke Urakami (Dreamer parts)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import *

import dataclasses
import enum

import torch
from torch import jit, nn
from torch.nn import functional as F
import torch.distributions

from .utils import BottledModule, get_activation_module, no_jit_fuser
from ... import utils


class LatentStateTuple(NamedTuple):
    # This hack is needed because JIT doesn't support dataclasses w/
    # cached computed property or even custom methods.
    # https://github.com/pytorch/pytorch/issues/48984
    x_belief: torch.Tensor
    x_state: torch.Tensor


@dataclasses.dataclass
class LatentState:
    x_belief: torch.Tensor
    x_state: torch.Tensor

    @classmethod
    def stack(cls, list: 'List[LatentState]', dim=0):
        return LatentState(
            x_belief=torch.stack([ls.x_belief for ls in list], dim=dim),
            x_state=torch.stack([ls.x_state for ls in list], dim=dim),
        )

    def replace(self,
                x: Optional['LatentState'] = None) -> 'LatentState':
        replace_kwargs = {}
        if x is not None:
            replace_kwargs.update(
                x_belief=x.x_belief,
                x_state=x.x_state,
            )
        return dataclasses.replace(self, **replace_kwargs)

    def new_zeros(self, x=False) -> 'LatentState':
        return LatentState(
            x_belief=(self.x_belief if not x else torch.zeros_like(self.x_belief)),
            x_state=(self.x_state if not x else torch.zeros_like(self.x_state)),
        )

    def new_emptydim(self, x=False) -> 'LatentState':
        empty_tensor = self.x_belief.narrow(-1, 0, 0)
        return LatentState(
            x_belief=(self.x_belief if not x else empty_tensor),
            x_state=(self.x_state if not x else empty_tensor),
        )

    @property
    def batch_shape(self) -> torch.Size:
        return self.x_belief.shape[:-1]

    @utils.lazy_property
    def full_feature(self) -> torch.Tensor:
        return torch.cat(
            [
                self.x_belief, self.x_state,
            ], dim=-1)

    @utils.lazy_property
    def x_feature(self) -> torch.Tensor:
        return self.full_feature.narrow(-1, 0, self.x_belief.shape[-1] + self.x_state.shape[-1])

    def as_namedtuple(self) -> LatentStateTuple:
        return LatentStateTuple(
            x_belief=self.x_belief,
            x_state=self.x_state,
        )

    def detach(self) -> 'LatentState':
        return LatentState(
            x_belief=self.x_belief.detach(),
            x_state=self.x_state.detach(),
        )

    def flatten(self, start_dim=0, end_dim=-2) -> 'LatentState':
        return LatentState(
            x_belief=self.x_belief.flatten(start_dim, end_dim),
            x_state=self.x_state.flatten(start_dim, end_dim),
        )

    def unflatten(self, dim, sizes) -> 'LatentState':
        assert 0 <= dim < len(self.batch_shape)
        return LatentState(
            x_belief=self.x_belief.unflatten(dim, sizes),
            x_state=self.x_state.unflatten(dim, sizes),
        )

    def narrow(self, dim: int, start: int, length: int) -> 'LatentState':
        assert 0 <= dim < len(self.batch_shape)
        return LatentState(
            x_belief=self.x_belief.narrow(dim, start, length),
            x_state=self.x_state.narrow(dim, start, length),
        )

    def __getitem__(self, slice) -> 'LatentState':
        x_belief = self.x_belief[slice]
        assert x_belief.shape[-1] == self.x_belief.shape[-1], "can only slice batch dims"
        return LatentState(
            x_belief=x_belief,
            x_state=self.x_state[slice],
        )


class PartialOutputWithoutPosterior(NamedTuple):
    belief: torch.Tensor
    prior_state: torch.Tensor
    prior_noise: torch.Tensor
    prior_mean: torch.Tensor
    prior_stddev: torch.Tensor


class PartialOutputWithPosterior(NamedTuple):
    belief: torch.Tensor
    prior_state: torch.Tensor
    prior_noise: torch.Tensor
    prior_mean: torch.Tensor
    prior_stddev: torch.Tensor
    posterior_state: torch.Tensor
    posterior_noise: torch.Tensor
    posterior_mean: torch.Tensor
    posterior_stddev: torch.Tensor


class LatentPart(enum.Enum):
    x = enum.auto()


XYZContrainedT = TypeVar('XYZContrainedT')


@dataclasses.dataclass(frozen=True)
class XYZContainer(Generic[XYZContrainedT]):
    x: XYZContrainedT

    def __getitem__(self, ii: Any) -> XYZContrainedT:
        if ii == 'x' or ii is LatentPart.x:
            return self.x
        raise ValueError(f"Unexpected index {repr(ii)}")


PartialOutputT = TypeVar('PartialOutputT', PartialOutputWithoutPosterior, PartialOutputWithPosterior)


@dataclasses.dataclass
class _Output(Generic[PartialOutputT]):
    x: PartialOutputT
    belief: XYZContainer[torch.Tensor] = dataclasses.field(init=False)
    prior_state: XYZContainer[torch.Tensor] = dataclasses.field(init=False)
    prior_noise: XYZContainer[torch.Tensor] = dataclasses.field(init=False)
    prior: XYZContainer[torch.distributions.Normal] = dataclasses.field(init=False)

    def __post_init__(self):
        self.belief = XYZContainer(
            x=self.x.belief
        )
        self.prior_state = XYZContainer(
            x=self.x.prior_state
        )
        self.prior_noise = XYZContainer(
            x=self.x.prior_noise
        )
        self.prior = XYZContainer(
            x=torch.distributions.Normal(self.x.prior_mean, self.x.prior_stddev),
        )

    @utils.lazy_property
    def prior_latent_state(self) -> LatentState:
        return LatentState(
            x_belief=self.belief.x,
            x_state=self.prior_state.x,
        )


@dataclasses.dataclass
class OutputWithoutPosterior(_Output[PartialOutputWithoutPosterior]):
    pass


@dataclasses.dataclass
class OutputWithPosterior(_Output[PartialOutputWithPosterior]):
    posterior_state: XYZContainer[torch.Tensor] = dataclasses.field(init=False)
    posterior_noise: XYZContainer[torch.Tensor] = dataclasses.field(init=False)
    posterior: XYZContainer[torch.distributions.Normal] = dataclasses.field(init=False)

    def __post_init__(self):
        super().__post_init__()
        self.posterior_state = XYZContainer(
            x=self.x.posterior_state
        )
        self.posterior_noise = XYZContainer(
            x=self.x.posterior_noise
        )
        self.posterior = XYZContainer(
            x=torch.distributions.Normal(self.x.posterior_mean, self.x.posterior_stddev),
        )

    @utils.lazy_property
    def posterior_latent_state(self) -> LatentState:
        return LatentState(
            x_belief=self.belief.x,
            x_state=self.posterior_state.x,
        )


class DummyEmptyGRUCell(nn.Module):
    def forward(self, input: torch.Tensor, hidden: torch.Tensor):
        assert input.shape[-1] == hidden.shape[-1] == 0
        return hidden


class DummyEmptyNet(nn.Module):
    def forward(self, input: torch.Tensor):
        assert input.shape[-1] == 0
        return input


class TransitionModel(jit.ScriptModule):
    __constants__ = [
        "min_stddev", "action_size",
        "x_belief_size", "x_state_size",
        "embedding_size",
    ]

    x_belief_size: int
    x_state_size: int
    embedding_size: int
    action_size: int

    def __init__(
        self,
        x_belief_size: int,  # h_t
        x_state_size: int,   # s_t
        action_size: int,
        hidden_size: int,
        embedding_size: int,  # enc(o_t)
        activation_function: str = "relu",
        min_stddev: float = 0.1,
    ):
        super().__init__()
        self.min_stddev = min_stddev
        self.x_belief_size = x_belief_size
        self.x_state_size = x_state_size
        self.embedding_size = embedding_size
        self.action_size = action_size

        assert (x_belief_size > 0) and (x_state_size > 0)

        # x
        self.x_state_action_pre_rnn = BottledModule(nn.Sequential(
            nn.Linear(x_state_size + action_size, hidden_size),
            get_activation_module(activation_function),
        ))                                                  # p(c_{t - 1} | s_{t - 1}, a_{t - 1})
        self.x_rnn = nn.GRUCell(hidden_size, x_belief_size) # p(h_t | c_{t - 1}, h_{t - 1})

        self.x_belief_to_state_prior = BottledModule(nn.Sequential(
            nn.Linear(x_belief_size, hidden_size),
            get_activation_module(activation_function),
            # nn.Linear(x_belief_size, hidden_size),
            # get_activation_module(activation_function),            
            nn.Linear(hidden_size, 2 * x_state_size),
        ))                      # state prior: p(s_t | h_t)

        # posterior

        self.xy_belief_obs_to_state_posterior = BottledModule(nn.Sequential(
            nn.Linear(
                x_belief_size + embedding_size,
                hidden_size,
            ),
            get_activation_module(activation_function),
            # nn.Linear(hidden_size, hidden_size),
            # get_activation_module(activation_function),            
            nn.Linear(
                hidden_size,
                2 * x_state_size,
            ),
        ))                      # state posterior: p(s_t | h_t, o_t)

    def init_latent_state(self, *, batch_shape: Tuple[int, ...] = ()) -> LatentState:
        zero = self.x_rnn.weight_ih.new_zeros(())
        return LatentState(
            x_belief=zero.expand(*batch_shape, self.x_belief_size),
            x_state=zero.expand(*batch_shape, self.x_state_size),
        )

    # Operates over (previous) state, (previous) actions, (previous) belief, (previous) nonterminals (mask), and (current) observations
    # Diagram of expected inputs and outputs for T = 5 (-x- signifying beginning of output belief/state that gets sliced off):
    # t :    0    1    2    3    4    5
    # o :        -X--X--X--X--X-
    # a : -X--X--X--X--X-
    # n : -X--X--X--X--X-
    # pb: -X-
    # ps: -X-
    # b : -x--X--X--X--X--X-
    # s : -x--X--X--X--X--X-
    @jit.script_method
    def forward_generic(
        self,
        previous_latent_state: LatentStateTuple,
        actions: torch.Tensor,
        next_observations: Optional[torch.Tensor] = None,
        rewards: Optional[torch.Tensor] = None,
        next_observation_nonfirststeps: Optional[torch.Tensor] = None,  # whether `next_observation` is the beginning of an episode
    ) -> PartialOutputWithPosterior:
        """
        Input: init_belief, init_state:    torch.Size([50, 200]) torch.Size([50, 30])
        Output: beliefs, prior_states, prior_means, prior_stddevs, posterior_states, posterior_means, posterior_stddevs
                        torch.Size([49, 50, 200]) torch.Size([49, 50, 30]) torch.Size([49, 50, 30]) torch.Size([49, 50, 30]) torch.Size([49, 50, 30]) torch.Size([49, 50, 30]) torch.Size([49, 50, 30])
        """
        # Create lists for hidden states (cannot use single tensor as buffer because autograd won't work with inplace writes)
        assert actions.size(-1) == self.action_size
        T: int = actions.size(0)
        assert T > 0
        empty_tensor = actions[0].narrow(-1, 0, 0)
        (
            x_beliefs,
            x_prior_states,
            x_prior_means,
            x_prior_stddevs,
            x_posterior_states,
            x_posterior_means,
            x_posterior_stddevs,
        ) = (
            [empty_tensor] * T,
            [empty_tensor] * T,
            [empty_tensor] * T,
            [empty_tensor] * T,
            [empty_tensor] * T,
            [empty_tensor] * T,
            [empty_tensor] * T,
        )
        prev_x_belief = previous_latent_state.x_belief
        prev_x_prior_state = previous_latent_state.x_state
        prev_x_posterior_state = previous_latent_state.x_state

        assert int(next_observations is None) == int(rewards is None)

        x_prior_noises = torch.randn(
            list(actions.shape[:-1]) + [self.x_state_size],
            dtype=actions.dtype, device=actions.device)

        if next_observations is not None:
            x_posterior_noises = torch.randn_like(x_prior_noises)
        else:
            x_posterior_noises = torch.randn([T, 0], dtype=actions.dtype, device=actions.device)

        # Loop over time sequence
        for t in range(T):
            action = actions[t]
            # Select appropriate previous state
            if next_observations is None:
                prev_x_state = prev_x_prior_state
            else:
                prev_x_state = prev_x_posterior_state

            # Mask if previous transition was terminal
            if next_observation_nonfirststeps is not None:
                next_observation_nonfirststep = next_observation_nonfirststeps[t].unsqueeze(-1)
                prev_x_belief = prev_x_belief * next_observation_nonfirststep
                prev_x_state = prev_x_state * next_observation_nonfirststep
                action = action * next_observation_nonfirststep

            # [X prior] Compute belief (deterministic hidden state)
            prev_x_state_action = torch.cat([prev_x_state, action], dim=-1)
            x_belief = self.x_rnn(
                self.x_state_action_pre_rnn(prev_x_state_action),
                prev_x_belief,
            )
            # [X prior] Compute state prior by applying transition dynamics
            x_prior_mean, _x_prior_stddev = torch.chunk(self.x_belief_to_state_prior(x_belief), 2, dim=-1)
            x_prior_stddev = F.softplus(_x_prior_stddev) + self.min_stddev
            x_prior_state = x_prior_mean + x_prior_stddev * x_prior_noises[t]
            # [X prior] save results
            x_beliefs[t] = x_belief
            x_prior_means[t] = x_prior_mean
            x_prior_stddevs[t] = x_prior_stddev
            x_prior_states[t] = x_prior_state

            
            # [XY posterior]
            if next_observations is not None:
                # Compute state posterior by applying transition dynamics (i.e., using *_belief) and using current observation
                xy_posterior_input = torch.cat(
                    [x_belief, next_observations[t]],
                    dim=-1,
                )
                xy_posterior_mean, _xy_posterior_stddev = torch.chunk(
                    self.xy_belief_obs_to_state_posterior(xy_posterior_input), 2,
                    dim=-1,
                )

                x_posterior_mean = xy_posterior_mean
                x_posterior_stddev: torch.Tensor = F.softplus(_xy_posterior_stddev) + self.min_stddev

                x_posterior_state = x_posterior_mean + x_posterior_stddev * x_posterior_noises[t]

                x_posterior_means[t] = x_posterior_mean
                x_posterior_stddevs[t] = x_posterior_stddev
                x_posterior_states[t] = x_posterior_state

            prev_x_belief = x_beliefs[t]
            prev_x_prior_state = x_prior_states[t]
            prev_x_posterior_state = x_posterior_states[t]

        # Return new hidden states
        return PartialOutputWithPosterior(
            belief=torch.stack(x_beliefs, dim=0),
            prior_state=torch.stack(x_prior_states, dim=0),
            prior_noise=x_prior_noises,
            prior_mean=torch.stack(x_prior_means, dim=0),
            prior_stddev=torch.stack(x_prior_stddevs, dim=0),
            posterior_state=torch.stack(x_posterior_states, dim=0),
            posterior_noise=x_posterior_noises,
            posterior_mean=torch.stack(x_posterior_means, dim=0),
            posterior_stddev=torch.stack(x_posterior_stddevs, dim=0),
        )

    def posterior_rsample(
        self,
        actions: torch.Tensor,
        next_observations: torch.Tensor,
        rewards: torch.Tensor, *,
        next_observation_nonfirststeps: Optional[torch.Tensor] = None,
        previous_latent_state: Optional[LatentState] = None,
    ) -> OutputWithPosterior:
        assert actions.ndim == 3 and actions.shape[-1] == self.action_size, "must have [T, B, *] shape"
        if previous_latent_state is None:
            previous_latent_state = self.init_latent_state(batch_shape=(actions.shape[1],))
        else:
            assert len(previous_latent_state.batch_shape) == 2, "must have [T, B, *] shape"
        with no_jit_fuser():  # https://github.com/pytorch/pytorch/issues/68800
            x = self.forward_generic(
                previous_latent_state.as_namedtuple(), actions, next_observations,
                rewards, next_observation_nonfirststeps)
        return OutputWithPosterior(x)

    @jit.script_method
    def forward_generic_one_step(
        self,
        previous_latent_state: LatentStateTuple,
        action: torch.Tensor,
        next_observation: Optional[torch.Tensor] = None,
        reward: Optional[torch.Tensor] = None,
        next_observation_nonfirststep: Optional[torch.Tensor] = None,
    ) -> PartialOutputWithPosterior:
        assert action.size(-1) == self.action_size
        assert int(next_observation is None) == int(reward is None)

        prev_x_belief = previous_latent_state.x_belief
        prev_x_state = previous_latent_state.x_state

        x_prior_noise = torch.randn_like(prev_x_state)

        empty_tensor = action.narrow(-1, 0, 0)
        if next_observation is not None:
            x_posterior_noise = torch.randn_like(x_prior_noise)
        else:
            x_posterior_noise = empty_tensor

        # Mask if previous transition was terminal
        if next_observation_nonfirststep is not None:
            next_observation_nonfirststep = next_observation_nonfirststep.unsqueeze(-1)
            prev_x_belief = prev_x_belief * next_observation_nonfirststep
            prev_x_state = prev_x_state * next_observation_nonfirststep
            action = action * next_observation_nonfirststep

        # [X prior] Compute belief (deterministic hidden state)
        prev_x_state_action = torch.cat([prev_x_state, action], dim=-1)
        x_belief = self.x_rnn(
            self.x_state_action_pre_rnn(prev_x_state_action),
            prev_x_belief,
        )
        # [X prior] Compute state prior by applying transition dynamics
        x_prior_mean, _x_prior_stddev = torch.chunk(self.x_belief_to_state_prior(x_belief), 2, dim=-1)
        x_prior_stddev = F.softplus(_x_prior_stddev) + self.min_stddev
        x_prior_state = x_prior_mean + x_prior_stddev * x_prior_noise

        # [XY posterior]
        if next_observation is not None:
            # Compute state posterior by applying transition dynamics and using current observation
            xy_posterior_input = torch.cat(
                [x_belief, next_observation],
                dim=-1,
            )

            xy_posterior_mean, _xy_posterior_stddev = torch.chunk(
                self.xy_belief_obs_to_state_posterior(xy_posterior_input), 2,
                dim=-1,
            )

            x_posterior_mean = xy_posterior_mean
            x_posterior_stddev: torch.Tensor = F.softplus(_xy_posterior_stddev) + self.min_stddev

            x_posterior_state = x_posterior_mean + x_posterior_stddev * x_posterior_noise
        else:
            x_posterior_mean = empty_tensor
            x_posterior_stddev = empty_tensor
            x_posterior_state = empty_tensor

        # Return new hidden states
        return PartialOutputWithPosterior(
            belief=x_belief,
            prior_state=x_prior_state,
            prior_noise=x_prior_noise,
            prior_mean=x_prior_mean,
            prior_stddev=x_prior_stddev,
            posterior_state=x_posterior_state,
            posterior_noise=x_posterior_noise,
            posterior_mean=x_posterior_mean,
            posterior_stddev=x_posterior_stddev,
        )

    @jit.script_method
    def forward_x_prior_one_step(
        self,
        previous_latent_state: LatentStateTuple,
        action: torch.Tensor,
    ) -> PartialOutputWithoutPosterior:
        assert action.size(-1) == self.action_size

        prev_x_belief = previous_latent_state.x_belief
        prev_x_state = previous_latent_state.x_state
        x_prior_noise = torch.randn_like(prev_x_state)

        # [X prior] Compute belief (deterministic hidden state)
        prev_x_state_action = torch.cat([prev_x_state, action], dim=-1)
        x_belief = self.x_rnn(
            self.x_state_action_pre_rnn(prev_x_state_action),
            prev_x_belief,
        )
        # [X prior] Compute state prior by applying transition dynamics
        x_prior_mean, _x_prior_stddev = torch.chunk(self.x_belief_to_state_prior(x_belief), 2, dim=-1)
        x_prior_stddev = F.softplus(_x_prior_stddev) + self.min_stddev
        x_prior_state = x_prior_mean + x_prior_stddev * x_prior_noise

        empty_tensor = action.narrow(-1, 0, 0)
        return PartialOutputWithoutPosterior(
            belief=x_belief,
            prior_state=x_prior_state,
            prior_noise=x_prior_noise,
            prior_mean=x_prior_mean,
            prior_stddev=x_prior_stddev,
        )

    def x_prior_rsample_one_step(
        self,
        action: torch.Tensor, *,
        previous_latent_state: Optional[LatentState] = None,
    ) -> OutputWithoutPosterior:
        assert action.ndim == 2 and action.shape[-1] == self.action_size, "must have [B, *] shape"
        if previous_latent_state is None:
            previous_latent_state = self.init_latent_state(batch_shape=(action.shape[0],))
        else:
            assert len(previous_latent_state.batch_shape) == 1, "must have [B, *] shape"
        with no_jit_fuser():  # https://github.com/pytorch/pytorch/issues/68800
            x = self.forward_x_prior_one_step(
                previous_latent_state.as_namedtuple(), action)
        return OutputWithoutPosterior(x)

    def posterior_rsample_one_step(
        self,
        action: torch.Tensor,
        next_observation: torch.Tensor,
        reward: torch.Tensor, *,
        next_observation_nonfirststep: Optional[torch.Tensor] = None,
        previous_latent_state: Optional[LatentState] = None,
    ) -> OutputWithPosterior:
        assert action.ndim == 2 and action.shape[-1] == self.action_size, "must have [B, *] shape"
        if previous_latent_state is None:
            previous_latent_state = self.init_latent_state(batch_shape=(action.shape[0],))
        else:
            assert len(previous_latent_state.batch_shape) == 1, "must have [B, *] shape"
        with no_jit_fuser():  # https://github.com/pytorch/pytorch/issues/68800
            x = self.forward_generic_one_step(
                previous_latent_state.as_namedtuple(), action,
                next_observation, reward, next_observation_nonfirststep)
        return OutputWithPosterior(x)
