from typing import Any, Dict, Hashable, List, Tuple, TypeVar

import gin
import numpy as np
import torch
from torch import Tensor

from rgfn.api.training_hooks_mixin import TrainingHooksMixin
from rgfn.api.trajectories import TrajectoriesContainer
from rgfn.shared.policies.uniform_policy import TIndexedActionSpace

THashableAction = TypeVar("THashableAction", bound=Hashable)
THashableState = TypeVar("THashableState", bound=Hashable)


@gin.configurable()
class ExploitationPenaltyHelper(TrainingHooksMixin):
    """
    A helper class that computes the weights for the actions based on the number of times the state-action pair was
    visited. The more exploited the state-action pair is, the lower the weight will be with respect to the other actions
    in the action space. The weights sum up to 1 and are computed with the following formula:
    w(a_i | s) = c(s, a_i)^{-t} / \sum_j c(s, a_j)^{-t}, where c(s, a) is the number of times the state-action pair
    (s, a) was visited (plus the epsilon), and t is the temperature parameter that may vary with the trajectory length.

    Params:
        epsilon: A small constant added to the count of the state-action pair.
        initial_temperature: The initial temperature parameter.
        current_temperature: The current temperature parameter.
        temperature_delta: The delta by which the temperature parameter is increased in every call of the
            `set_next_temperature` method.
    """

    def __init__(
        self,
        epsilon: float = 1.0,
        initial_temperature: float = 0.5,
        temperature_delta: float = 0.25,
        zeroth_temperature_at_iteration: int | None = None,
    ):
        self.state_action_count: Dict[Tuple[THashableState, THashableAction], int] = {}
        self.epsilon = epsilon
        self.initial_temperature = initial_temperature
        self.current_temperature = initial_temperature
        self.temperature_delta = temperature_delta
        self.device = "cpu"
        self.last_update_idx = -1
        self.zeroth_temperature_at_iteration = zeroth_temperature_at_iteration
        self.temperature_linear_decay_multiplier = 1.0

    def set_next_temperature(self):
        self.current_temperature += (
            self.temperature_delta * self.temperature_linear_decay_multiplier
        )

    def reset_temperature(self):
        self.current_temperature = (
            self.initial_temperature * self.temperature_linear_decay_multiplier
        )

    def compute_weights(
        self,
        states: List[THashableState],
        action_spaces: List[TIndexedActionSpace],
        action_space_size: int,
    ) -> Tensor:
        """
        Compute the weights for the given states and action spaces. The more exploited the state-action pair is, the
        lower the weight will be with respect to the other actions in the action space. The weights do not sum to 1.
        """
        action_weights_list = []
        for state, action_space in zip(states, action_spaces):
            possible_action_indices = action_space.get_possible_actions_indices()
            possible_action_counts = []
            for idx in possible_action_indices:
                action = action_space.get_action_at_idx(idx)
                count = self.state_action_count.get((state, action), 0) + self.epsilon
                possible_action_counts.append(count)

            possible_action_weights = (
                np.array(possible_action_counts, dtype=np.float32) ** -self.current_temperature
            )
            possible_action_weights /= np.sum(possible_action_weights)

            actions_weights = [0] * action_space_size
            for idx, weight in zip(possible_action_indices, possible_action_weights):
                actions_weights[idx] = weight

            action_weights_list.append(actions_weights)

        return torch.tensor(action_weights_list, device=self.device)

    def on_end_computing_objective(
        self,
        iteration_idx: int,
        trajectories_container: TrajectoriesContainer,
        recursive: bool = True,
    ) -> Dict[str, float]:
        trajectories = trajectories_container.get_all_non_backward_trajectories()
        states = trajectories.get_non_last_states_flat()
        actions = trajectories.get_actions_flat()
        for state, action in zip(states, actions):
            self.state_action_count[(state, action)] = (
                self.state_action_count.get((state, action), 0) + 1
            )

        if self.zeroth_temperature_at_iteration is not None:
            self.temperature_linear_decay_multiplier = max(
                0.0, 1.0 - iteration_idx / self.zeroth_temperature_at_iteration
            )

        return {}

    def on_start_sampling(self, iteration_idx: int, recursive: bool = True) -> Dict[str, Any]:
        self.reset_temperature()
        return {"temperature": self.current_temperature}
