from .base_solver import BaseSolver
from ..modules import Encoder, RewardModel, TransitionModel, Value
from ..simulators import SIMULATOR
from ..torch_utils import merge


class ModelFreeQNetwork(BaseSolver):
    def __init__(
        self,
        d_hidden,
    ):
        super().__init__()

        self.n_actions = SIMULATOR.n_actions
        self.d_hidden = d_hidden

        # initialise state encoding network
        self.f_encoder = Encoder(d_hidden)

        # initialise transition network
        self.f_transition = TransitionModel(d_hidden)

        # initialise reward network
        self.f_reward = RewardModel(d_hidden)

        # initialise Value network
        self.f_value = Value(d_hidden)

    def f_qvalue(
        self,
        states,
    ):
        next_states = self.f_transition(states)
        rewards = self.f_reward(states)

        value_of_next_states = self.f_value(
            merge(
                next_states,
                dims=(0, 1),
            ),
        ).view(-1, self.n_actions)

        return rewards + value_of_next_states

    def forward(
        self,
        states,
        encoded=False,
        training=False,
    ):
        # Encode states to the latent space if not already encoded
        if not encoded:
            states = self.f_encoder(states)

        return self.f_qvalue(states)
