import torch as t

from .model_free_q_network import ModelFreeQNetwork
from ..torch_utils import merge


class TreeQN(ModelFreeQNetwork):
    def __init__(
        self,
        d_hidden,
        depth,
    ):
        super().__init__(d_hidden)

        self.depth = depth

    def _expand_and_backup(
        self,
        states,
        depth,
    ):
        if depth == 0:
            return self.f_value(states)

        next_states = self.f_transition(states)
        rewards = self.f_reward(states)

        value_of_next_states = self._expand_and_backup(
            merge(
                next_states,
                dims=(0, 1),
            ),
            depth - 1,
        ).view(-1, self.n_actions)

        return t.max(rewards + value_of_next_states, dim=-1).values

    def forward(
        self,
        states,
        encoded=False,
        training=False,
    ):
        # Encode states to the latent space if not already encoded
        if not encoded:
            states = self.f_encoder(states)

        # Expand the root to get its children nodes
        next_states = self.f_transition(states)
        rewards = self.f_reward(states)

        # Compute the value of children nodes by expanding and backing up
        value_of_next_states = self._expand_and_backup(
            merge(
                next_states,
                dims=(0, 1),
            ),
            self.depth - 1,
        ).view(-1, self.n_actions)

        # Return the Q-value at root
        return rewards + value_of_next_states
