from typing import Callable, Dict, Iterator, List, Type

import gin
import torch
from torch import Tensor
from torch.distributions import Categorical
from torch.nn import Parameter

from rgfn.api.type_variables import TAction, TState
from rgfn.gfns.reaction_gfn.api.reaction_api import (
    ReactionAction,
    ReactionActionSpace,
    ReactionActionSpace0,
    ReactionActionSpace0Invalid,
    ReactionActionSpaceA,
    ReactionActionSpaceB,
    ReactionActionSpaceC,
    ReactionActionSpaceEarlyTerminate,
    ReactionState,
    ReactionState0,
    ReactionStateA,
    ReactionStateB,
    ReactionStateC,
)
from rgfn.shared.policies.exploitation_penalty_helper import ExploitationPenaltyHelper
from rgfn.shared.policies.few_phase_policy import FewPhasePolicyBase, TSharedEmbeddings
from rgfn.shared.policies.uniform_policy import TIndexedActionSpace

from .reaction_forward_policy import ReactionForwardPolicy, SharedEmbeddings


@gin.configurable()
class ReactionForwardPolicyExploitationPenalty(
    FewPhasePolicyBase[ReactionState, ReactionActionSpace, ReactionAction, SharedEmbeddings],
):
    """
    A policy that combines the reaction forward policy with the exploitation penalty.
    """

    def __init__(
        self,
        reaction_forward_policy: ReactionForwardPolicy,
        exploitation_penalty_helper: ExploitationPenaltyHelper,
    ):
        super().__init__()
        self.reaction_forward_policy = reaction_forward_policy
        self.exploitation_penalty_helper = exploitation_penalty_helper

        self._action_space_type_to_forward_fn = {
            ReactionActionSpace0: self._forward_0,
            ReactionActionSpaceA: self._forward_a,
            ReactionActionSpaceB: self._forward_b,
            ReactionActionSpaceC: self._forward_c,
            ReactionActionSpaceEarlyTerminate: self._forward_early_terminate,
            ReactionActionSpace0Invalid: self._forward_early_terminate,
        }

    @property
    def hook_objects(self) -> List["TrainingHooksMixin"]:
        return [self.reaction_forward_policy, self.exploitation_penalty_helper]

    @property
    def action_space_to_forward_fn(
        self,
    ) -> Dict[
        Type[TIndexedActionSpace],
        Callable[[List[TState], List[TIndexedActionSpace], TSharedEmbeddings], Tensor],
    ]:
        return self._action_space_type_to_forward_fn

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        return self.reaction_forward_policy.parameters()

    def _calculate_augmented_probs(
        self,
        logits: Tensor,
        states: List[ReactionState],
        action_spaces: List[ReactionActionSpace],
    ) -> Tensor:
        """
        A helper function to calculate the augmented probabilities. The more exploited the action in a given state is,
        the (relatively) lower its weight will be.
        """
        probs = torch.softmax(logits, dim=1)
        weights = self.exploitation_penalty_helper.compute_weights(
            states, action_spaces, action_space_size=probs.shape[1]
        )
        probs = probs * weights
        probs = probs / probs.sum(dim=-1, keepdim=True)
        self.exploitation_penalty_helper.set_next_temperature()
        return probs

    def _forward_0(
        self,
        states: List[ReactionState0],
        action_spaces: List[ReactionActionSpace0],
        shared_embeddings: SharedEmbeddings,
    ) -> Tensor:
        logits = self.reaction_forward_policy._forward_0(states, action_spaces, shared_embeddings)
        return self._calculate_augmented_probs(logits, states, action_spaces)

    def _forward_a(
        self,
        states: List[ReactionStateA],
        action_spaces: List[ReactionActionSpaceA],
        shared_embeddings: SharedEmbeddings,
    ) -> Tensor:
        logits = self.reaction_forward_policy._forward_a(states, action_spaces, shared_embeddings)
        return self._calculate_augmented_probs(logits, states, action_spaces)

    def _forward_b(
        self,
        states: List[ReactionStateB],
        action_spaces: List[ReactionActionSpaceB],
        shared_embeddings: SharedEmbeddings,
    ) -> Tensor:
        logits = self.reaction_forward_policy._forward_b(states, action_spaces, shared_embeddings)
        return self._calculate_augmented_probs(logits, states, action_spaces)

    def _forward_c(
        self,
        states: List[ReactionStateC],
        action_spaces: List[ReactionActionSpaceC],
        shared_embeddings: SharedEmbeddings,
    ) -> Tensor:
        logits = self.reaction_forward_policy._forward_c(states, action_spaces, shared_embeddings)
        return self._calculate_augmented_probs(logits, states, action_spaces)

    def _forward_early_terminate(
        self,
        states: List[ReactionState],
        action_spaces: List[ReactionActionSpaceEarlyTerminate],
        shared_embeddings: SharedEmbeddings,
    ) -> Tensor:
        return torch.ones((len(states), 1), device=self.device, dtype=torch.float32)

    def get_shared_embeddings(
        self, states: List[ReactionState], action_spaces: List[ReactionActionSpace]
    ) -> SharedEmbeddings:
        return self.reaction_forward_policy.get_shared_embeddings(states, action_spaces)

    def _sample_actions_from_logits(
        self, logits: Tensor, action_spaces: List[TIndexedActionSpace]
    ) -> List[TAction]:
        """
        This method overrides the parent method to use probabilities instead of logits.

        Args:
            logits: probabilities of the shape (N, max_num_actions).
                Called logits for consistency with the parent method.
            action_spaces: the list of action spaces of the length N.

        Returns:
            the list of sampled actions.
        """
        action_indices = Categorical(probs=logits).sample()  # logits are actually probabilities
        return [
            action_space.get_action_at_idx(idx.item())
            for action_space, idx in zip(action_spaces, action_indices)
        ]
