import torch as t
import torch.distributions as td

from .model_free_q_network import ModelFreeQNetwork
from ..torch_utils import merge, select

NEG_INF = -1e5


class AStarSearch(ModelFreeQNetwork):
    def __init__(
        self,
        d_hidden,
        n_trials,
        exploration_constant=0.0,
    ):
        super().__init__(d_hidden)

        assert n_trials >= 1, "--n-trials must be in >= 1"
        assert 0 <= exploration_constant <= 1, "--exploration-constant must be in [0, 1]"

        self.n_trials = n_trials
        self.exploration_constant = exploration_constant

    def _initialise(
        self,
        states,
        batch_size=1,
        device="cpu",
    ):
        self.batch_size = batch_size
        self.device = device

        # Create node properties to facilitate tree search
        self.states = states.unsqueeze(1)
        self.values = self.f_value(states)
        self.rewards = t.zeros_like(self.values)
        self.value_to_nodes = t.zeros_like(self.values)
        self.parents = -t.ones_like(self.values, dtype=t.long)
        self.masks = t.zeros_like(self.values)

    def _add_new_nodes(
        self,
        state_of_new_nodes,
        value_of_new_nodes,
        reward_to_new_nodes,
        value_to_new_nodes,
        parent_of_new_nodes,
    ):
        # Concatenate properties of new nodes to the existing nodes
        self.states = t.cat(
            (self.states, state_of_new_nodes),
            dim=1,
        )
        self.values = t.cat(
            (self.values, value_of_new_nodes),
            dim=1,
        )
        self.rewards = t.cat(
            (self.rewards, reward_to_new_nodes),
            dim=1,
        )
        self.value_to_nodes = t.cat(
            (self.value_to_nodes, value_to_new_nodes),
            dim=1,
        )
        self.parents = t.cat(
            (self.parents, parent_of_new_nodes),
            dim=1,
        )
        self.masks = t.cat(
            (self.masks, t.zeros_like(value_of_new_nodes)),
            dim=1,
        )

    def _expand(self):
        # Expand the search tree until the maximum depth is reached
        for trial in range(self.n_trials):
            # Compute the value of unexpanded leaf nodes
            leaves = self.masks * NEG_INF + self.value_to_nodes + self.values

            # If exploration is enabled, set the leaves to a very low value so that it results in a uniform distribution
            random_sample = t.rand(self.batch_size, device=self.device)
            random_selection_mask = (random_sample < self.exploration_constant).float()
            idx_of_selected_nodes = random_selection_mask * td.Categorical(logits=leaves / 1000).sample() + (
                1 - random_selection_mask
            ) * t.argmax(leaves, dim=-1)
            idx_of_selected_nodes = idx_of_selected_nodes.long()

            # Set the mask for the selected nodes
            self.masks[range(self.batch_size), idx_of_selected_nodes] = 1

            # Select the states to expand
            state_of_selected_nodes = select(self.states, idx_of_selected_nodes)

            # Perform the transition and compute the next_states
            state_of_new_nodes = self.f_transition(state_of_selected_nodes)

            # Merge the action dim into the batch dim and compute q_values for each of the next_states
            value_of_new_nodes = self.f_value(
                merge(
                    state_of_new_nodes,
                    dims=(0, 1),
                )
            ).view(self.batch_size, -1)

            # Compute rewards for taking each action at the current states
            reward_to_new_nodes = self.f_reward(state_of_selected_nodes)

            # Compute sum of rewards to reach the new nodes
            value_to_new_nodes = (
                select(
                    self.value_to_nodes,
                    idx_of_selected_nodes,
                    keepdim=True,
                )
                + reward_to_new_nodes
            )

            # Store the parent index for the new nodes
            parent_of_new_nodes = t.zeros_like(
                value_of_new_nodes,
                dtype=t.long,
            ) + idx_of_selected_nodes.view(-1, 1)

            # Add the new nodes to the tree
            self._add_new_nodes(
                state_of_new_nodes,
                value_of_new_nodes,
                reward_to_new_nodes,
                value_to_new_nodes,
                parent_of_new_nodes,
            )

    def _backup(self):
        # Backup values, starting from the last nodes added in the tree
        for trial in reversed(range(self.n_trials)):
            idx = trial * self.n_actions + 1
            # Get the parent index for each sample
            parent_node_idx = select(self.parents, idx)

            # Compute the values and rewards for the child nodes
            value_of_children_node = self.values[
                range(self.batch_size),
                idx : idx + self.n_actions,
            ]
            reward_to_children_node = self.rewards[
                range(self.batch_size),
                idx : idx + self.n_actions,
            ]

            # Prepare the masks tensor to update the values of the parent nodes
            update_mask = t.zeros_like(self.values)
            update_mask[range(self.batch_size), parent_node_idx] = 1
            forget_mask = (update_mask == 0).float()

            # Update the values of the parent nodes with the max of the values of the child nodes + rewards
            self.values = (
                forget_mask * self.values
                + update_mask
                * t.max(
                    reward_to_children_node + value_of_children_node,
                    dim=-1,
                    keepdim=True,
                ).values
            )

        q_values = (
            self.rewards[
                range(self.batch_size),
                1 : 1 + self.n_actions,
            ]
            + self.values[
                range(self.batch_size),
                1 : 1 + self.n_actions,
            ]
        )
        return q_values

    def forward(
        self,
        states,
        encoded=False,
        training=False,
    ):
        # Encode states to the latent space if not already encoded
        if not encoded:
            states = self.f_encoder(states)

        # Initialise new search tree with the input states
        self._initialise(
            states,
            batch_size=len(states),
            device=states.device,
        )

        # Expand the search tree until max expansion allowed is reached
        self._expand()

        # Update the Q-values from the expanded tree and return the Q-values at root node
        q_values = self._backup()

        return q_values
