import torch as t
import torch.distributions as td
import torch.nn.functional as F

from .model_free_q_network import ModelFreeQNetwork
from ..torch_utils import merge, select

NEG_INF = -1e5


class DifferentiableTreeSearch(ModelFreeQNetwork):
    def __init__(
        self,
        d_hidden,
        n_trials,
    ):
        super().__init__(d_hidden)

        assert n_trials >= 1, "--n-trials must be in >= 1"

        self.n_trials = n_trials
        self.policy_factor = t.nn.Parameter(t.ones(1))
        self.tree_policy_nll = []
        self.q_values = None
        self.is_training = False

    def _initialise(
        self,
        states,
        batch_size=1,
        device="cpu",
    ):
        self.batch_size = batch_size
        self.device = device

        # Create node properties to facilitate tree search
        self.states = states.unsqueeze(1)
        self.values = t.zeros(self.batch_size, 1, dtype=t.float32, device=self.device)
        self.rewards = t.zeros_like(self.values)
        self.value_to_nodes = t.zeros_like(self.values)
        self.parents = -t.ones_like(self.values, dtype=t.long)
        self.masks = t.zeros_like(self.values)

        self.tree_policy_nll = []
        self.q_values = [None] * self.n_trials

    def _add_new_nodes(
        self,
        state_of_new_nodes,
        value_of_new_nodes,
        reward_to_new_nodes,
        value_to_new_nodes,
        parent_of_new_nodes,
    ):
        # Concatenate properties of new nodes to the existing nodes
        self.states = t.cat(
            (self.states, state_of_new_nodes),
            dim=1,
        )
        self.values = t.cat(
            (self.values, value_of_new_nodes),
            dim=1,
        )
        self.rewards = t.cat(
            (self.rewards, reward_to_new_nodes),
            dim=1,
        )
        self.value_to_nodes = t.cat(
            (self.value_to_nodes, value_to_new_nodes),
            dim=1,
        )
        self.parents = t.cat(
            (self.parents, parent_of_new_nodes),
            dim=1,
        )
        self.masks = t.cat(
            (self.masks, t.zeros_like(value_of_new_nodes)),
            dim=1,
        )

    def _expand(self):
        # Compute the tree search policy
        tree_search_policy = self.policy_factor * (self.value_to_nodes + self.values) + self.masks * NEG_INF

        # Sample the leaf node to expand based on the tree search policy
        idx_of_selected_nodes = td.Categorical(logits=tree_search_policy).sample()

        # Compute the log-likelihood of the selected node
        if self.is_training:
            self.tree_policy_nll.append(
                F.cross_entropy(
                    tree_search_policy,
                    idx_of_selected_nodes,
                    reduction="none",
                ).view(self.batch_size, -1)
            )

        # Set the mask for the selected nodes
        self.masks[range(self.batch_size), idx_of_selected_nodes] = 1

        # Select the states to expand
        state_of_selected_nodes = select(self.states, idx_of_selected_nodes)

        # Perform the transition and compute the next_states
        state_of_new_nodes = self.f_transition(state_of_selected_nodes)

        # Merge the action dim into the batch dim and compute q_values for each of the next_states
        value_of_new_nodes = self.f_value(
            merge(
                state_of_new_nodes,
                dims=(0, 1),
            )
        ).view(self.batch_size, -1)

        # Compute rewards for taking each action at the current states
        reward_to_new_nodes = self.f_reward(state_of_selected_nodes)

        # Compute sum of rewards to reach the new nodes
        value_to_new_nodes = (
            select(
                self.value_to_nodes,
                idx_of_selected_nodes,
                keepdim=True,
            )
            + reward_to_new_nodes
        )

        # Store the parent index for the new nodes
        parent_of_new_nodes = t.zeros_like(
            value_of_new_nodes,
            dtype=t.long,
        ) + idx_of_selected_nodes.view(-1, 1)

        # Add the new nodes to the tree
        self._add_new_nodes(
            state_of_new_nodes,
            value_of_new_nodes,
            reward_to_new_nodes,
            value_to_new_nodes,
            parent_of_new_nodes,
        )

    def _backup(self, trial):
        # Backup values, starting from the last nodes added in the tree
        for i in reversed(range(1, trial)):
            idx = i * self.n_actions + 1
            # Get the parent index for each sample
            parent_node_idx = select(self.parents, idx)

            # Compute the values and rewards for the child nodes
            value_of_children_node = self.values[
                range(self.batch_size),
                idx : idx + self.n_actions,
            ]
            reward_to_children_node = self.rewards[
                range(self.batch_size),
                idx : idx + self.n_actions,
            ]

            # Prepare the masks tensor to update the values of the parent nodes
            update_mask = t.zeros_like(self.values)
            update_mask[range(self.batch_size), parent_node_idx] = 1
            forget_mask = (update_mask == 0).float()

            # Update the values of the parent nodes with the max of the values of the child nodes + rewards
            self.values = (
                forget_mask * self.values
                + update_mask
                * t.max(
                    reward_to_children_node + value_of_children_node,
                    dim=-1,
                    keepdim=True,
                ).values
            )

        q_values = (
            self.rewards[
                range(self.batch_size),
                1 : 1 + self.n_actions,
            ]
            + self.values[
                range(self.batch_size),
                1 : 1 + self.n_actions,
            ]
        )
        return q_values

    def forward(
        self,
        states,
        encoded=False,
        training=False,
    ):
        # Encode states to the latent space if not already encoded
        if not encoded:
            states = self.f_encoder(states)

        # Initialise new search tree with the input states
        self._initialise(
            states,
            batch_size=len(states),
            device=states.device,
        )

        self.is_training = training
        if training:
            for trial in range(self.n_trials):
                # Expand the search tree until max expansion allowed is reached
                self._expand()

                # Update the Q-values from the expanded tree and return the Q-values at root node
                self.q_values[trial] = self._backup(trial + 1)

            q_values = self.q_values[-1]
        else:
            for trial in range(self.n_trials):
                # Expand the search tree until max expansion allowed is reached
                self._expand()

            q_values = self._backup(self.n_trials)

        return q_values
