import torch
import torch.nn as nn
from torchtyping import TensorType

from gfn.modules import DiscretePolicyEstimator
from gfn.preprocessors import Preprocessor
from gfn.states import DiscreteStates
from torch.distributions import Categorical


class TemperaturePolicyEstimator(DiscretePolicyEstimator):
    def __init__(
            self,
            module: nn.Module,
            n_actions: int,
            preprocessor: Preprocessor | None,
            is_backward: bool = False,
    ):
        super().__init__(module, n_actions, preprocessor, is_backward=is_backward)
        self.temperature = 1.0
        self.sf_bias = 0.0
        self.epsilon = 0.0

    def to_probability_distribution(
            self,
            states: DiscreteStates,
            module_output: TensorType["batch_shape", "output_dim", float],
            temperature: float = 1.0,
            sf_bias: float = 0.0,
            epsilon: float = 0.0,
    ) -> Categorical:
        return super().to_probability_distribution(states, module_output, temperature=self.temperature,
                                                   sf_bias=self.sf_bias, epsilon=self.epsilon)
