import torch.nn as nn

from .commons import ResidualConv
from ..simulators import SIMULATOR
from ..torch_utils import select


class TransitionModel(nn.Module):
    def __init__(
        self,
        d_hidden,
    ):
        super().__init__()

        self.layers = nn.Sequential(
            ResidualConv(d_hidden),
            nn.Conv2d(d_hidden, SIMULATOR.n_actions * d_hidden, 1),
            nn.Tanh(),
        )

    def forward(
        self,
        state,
        action=None,
    ):
        bs, d_hidden, h, w = state.shape
        next_states = self.layers(state).view(bs, SIMULATOR.n_actions, d_hidden, h, w)
        if action is not None:
            next_states = select(next_states, action)
        return next_states


class RewardModel(nn.Module):
    def __init__(
        self,
        d_hidden,
    ):
        super().__init__()

        self.layers = nn.Sequential(
            ResidualConv(d_hidden),
            nn.AdaptiveMaxPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(d_hidden, SIMULATOR.n_actions),
        )

    def forward(
        self,
        state,
        action=None,
    ):
        rewards = self.layers(state)
        if action is not None:
            rewards = select(rewards, action)
        return rewards
