# Copyright (c) 2019 Kai Arulkumaran (Original PlaNet parts) Copyright (c) 2020 Yusuke Urakami (Dreamer parts)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import *

import torch
from torch import jit, nn
import torch.nn.functional as F

from .transition import LatentState as TransitionLatentState
from .utils import BottledModule, get_activation_module


class RewardMeanStdNetwork(jit.ScriptModule):
    def __init__(self, belief_size, state_size, hidden_size, activation_function="relu"):
        super().__init__()
        self.net = BottledModule(nn.Sequential(
            nn.Linear(belief_size + state_size, hidden_size),
            get_activation_module(activation_function),
            nn.Linear(hidden_size, hidden_size),
            get_activation_module(activation_function),
            nn.Linear(hidden_size, hidden_size),
            get_activation_module(activation_function),
            nn.Linear(hidden_size, 2),
        ))

    @jit.script_method
    def forward_feature(self, feature: torch.Tensor) -> torch.Tensor:
        return self.net(feature)

    def forward(self, feature: torch.Tensor) -> torch.Tensor:
        out: torch.Tensor = self.forward_feature(feature)
        return out

class RewardModel(nn.Module):
    def __init__(self,
                 x_belief_size, x_state_size, hidden_size,
                 activation_function="relu"):
        super().__init__()
        self.forward_feature = RewardMeanStdNetwork(
            x_belief_size, x_state_size, hidden_size, activation_function)

    def get_distn_and_x_mean(self, latent_state: TransitionLatentState) -> Tuple[torch.distributions.Normal, torch.Tensor]:
        mean, log_std = torch.chunk(self.forward_feature(latent_state.x_feature), 2, -1)
        std = F.softplus(log_std)
        
        return torch.distributions.Normal(mean, std), mean

    def forward(self, latent_state: TransitionLatentState) -> torch.distributions.Normal:
        return self.get_distn_and_x_mean(latent_state)[0]

    def __call__(self, latent_state: TransitionLatentState) -> torch.distributions.Normal:
        return super().__call__(latent_state)
