from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterator, List, Type

import gin
import torch
from torch import nn
from torch.distributions import Categorical
from torch.nn import Parameter
from torch import Tensor

from rgfn.api.trajectories import Trajectories
from rgfn.api.type_variables import TAction, TActionSpace, TState
from rgfn.gfns.reaction_gfn.api.reaction_api import (
    ReactionAction,
    ReactionActionSpace,
    ReactionActionSpace0,
    ReactionActionSpace0Invalid,
    ReactionActionSpaceA,
    ReactionActionSpaceB,
    ReactionActionSpaceC,
    ReactionActionSpaceEarlyTerminate,
    ReactionState,
    ReactionState0,
    ReactionStateA,
    ReactionStateB,
    ReactionStateC,
)
from rgfn.shared.policies.exploitation_penalty_helper import ExploitationPenaltyHelper
from rgfn.shared.policies.few_phase_policy import FewPhasePolicyBase, TSharedEmbeddings
from rgfn.shared.policies.uniform_policy import TIndexedActionSpace

from .reaction_forward_policy import ReactionForwardPolicy
from .reaction_forward_policy import SharedEmbeddings
from .reaction_forward_policy import SharedEmbeddings as ForwardSharedEmbeddings
from .rnd_novelty_forward_policy import RNDNoveltyForwardPolicy
from .rnd_novelty_forward_policy import SharedEmbeddings as RNDSharedEmbeddings


@gin.configurable()
class ReactionForwardPolicyExploitationPenalty(
    FewPhasePolicyBase[ReactionState, ReactionActionSpace, ReactionAction, SharedEmbeddings],
):
    """
    A policy that combines the reaction forward policy with the exploitation penalty.
    """

    def __init__(
        self,
        reaction_forward_policy: ReactionForwardPolicy,
        exploitation_penalty_helper: ExploitationPenaltyHelper,
    ):
        super().__init__()
        self.reaction_forward_policy = reaction_forward_policy
        self.exploitation_penalty_helper = exploitation_penalty_helper

        self._action_space_type_to_forward_fn = {
            ReactionActionSpace0: self._forward_0,
            ReactionActionSpaceA: self._forward_a,
            ReactionActionSpaceB: self._forward_b,
            ReactionActionSpaceC: self._forward_c,
            ReactionActionSpaceEarlyTerminate: self._forward_early_terminate,
            ReactionActionSpace0Invalid: self._forward_early_terminate,
        }

    @property
    def hook_objects(self) -> List["TrainingHooksMixin"]:
        return [self.reaction_forward_policy, self.exploitation_penalty_helper]

    @property
    def action_space_to_forward_fn(
        self,
    ) -> Dict[
        Type[TIndexedActionSpace],
        Callable[[List[TState], List[TIndexedActionSpace], TSharedEmbeddings], Tensor],
    ]:
        return self._action_space_type_to_forward_fn

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        return self.reaction_forward_policy.parameters()

    def _calculate_augmented_probs(
        self,
        logits: Tensor,
        states: List[ReactionState],
        action_spaces: List[ReactionActionSpace],
    ) -> Tensor:
        """
        A helper function to calculate the augmented probabilities. The more exploited the action in a given state is,
        the (relatively) lower its weight will be.
        """
        probs = torch.softmax(logits, dim=1)
        weights = self.exploitation_penalty_helper.compute_weights(
            states, action_spaces, action_space_size=probs.shape[1]
        )
        probs = probs * weights
        probs = probs / probs.sum(dim=-1, keepdim=True)
        self.exploitation_penalty_helper.set_next_temperature()
        return probs

    def _forward_0(
        self,
        states: List[ReactionState0],
        action_spaces: List[ReactionActionSpace0],
        shared_embeddings: SharedEmbeddings,
    ) -> Tensor:
        logits = self.reaction_forward_policy._forward_0(states, action_spaces, shared_embeddings)
        return self._calculate_augmented_probs(logits, states, action_spaces)

    def _forward_a(
        self,
        states: List[ReactionStateA],
        action_spaces: List[ReactionActionSpaceA],
        shared_embeddings: SharedEmbeddings,
    ) -> Tensor:
        logits = self.reaction_forward_policy._forward_a(states, action_spaces, shared_embeddings)
        return self._calculate_augmented_probs(logits, states, action_spaces)

    def _forward_b(
        self,
        states: List[ReactionStateB],
        action_spaces: List[ReactionActionSpaceB],
        shared_embeddings: SharedEmbeddings,
    ) -> Tensor:
        logits = self.reaction_forward_policy._forward_b(states, action_spaces, shared_embeddings)
        return self._calculate_augmented_probs(logits, states, action_spaces)

    def _forward_c(
        self,
        states: List[ReactionStateC],
        action_spaces: List[ReactionActionSpaceC],
        shared_embeddings: SharedEmbeddings,
    ) -> Tensor:
        logits = self.reaction_forward_policy._forward_c(states, action_spaces, shared_embeddings)
        return self._calculate_augmented_probs(logits, states, action_spaces)

    def _forward_early_terminate(
        self,
        states: List[ReactionState],
        action_spaces: List[ReactionActionSpaceEarlyTerminate],
        shared_embeddings: SharedEmbeddings,
    ) -> Tensor:
        return torch.ones((len(states), 1), device=self.device, dtype=torch.float32)

    def get_shared_embeddings(
        self, states: List[ReactionState], action_spaces: List[ReactionActionSpace]
    ) -> SharedEmbeddings:
        return self.reaction_forward_policy.get_shared_embeddings(states, action_spaces)

    def _sample_actions_from_logits(
        self, logits: Tensor, action_spaces: List[TIndexedActionSpace]
    ) -> List[TAction]:
        """
        This method overrides the parent method to use probabilities instead of logits.

        Args:
            logits: probabilities of the shape (N, max_num_actions).
                Called logits for consistency with the parent method.
            action_spaces: the list of action spaces of the length N.

        Returns:
            the list of sampled actions.
        """
        action_indices = Categorical(probs=logits).sample()  # logits are actually probabilities
        return [
            action_space.get_action_at_idx(idx.item())
            for action_space, idx in zip(action_spaces, action_indices)
        ]

    def on_start_sampling(self, iteration_idx: int, recursive: bool = True) -> Dict[str, Any]:
        self.exploitation_penalty_helper.reset_temperature()
        return super().on_start_sampling(iteration_idx, recursive)
