from ray.rllib.utils.exploration.exploration import Exploration
import torch
import torch.distributions as D

class BetaTempExploration(Exploration):
    def __init__(self, action_space, *, framework,
                 policy_config=None, model=None,
                 num_workers=0, worker_index=0,
                 temperature=1.0, **kwargs):
        # RLlib in your trace does NOT pass `policy`, so pass None to super.
        super().__init__(action_space,
                         framework=framework,
                         policy_config=policy_config,
                         model=model,
                         num_workers=num_workers,
                         worker_index=worker_index)
        self.temperature = float(temperature)

    @staticmethod
    def _get_alpha_beta(action_distribution):
        """
        Expect UnclampedBeta/GlucoseBeta: action_distribution.dist is torch.distributions.Beta
        with .concentration1 (alpha) and .concentration0 (beta).
        """
        dist = getattr(action_distribution, "dist", None)
        if dist is None or not hasattr(dist, "concentration1") or not hasattr(dist, "concentration0"):
            raise AttributeError("Expected a Beta-backed distribution with `.dist.concentration1/0`.")
        return dist.concentration1, dist.concentration0

    @staticmethod
    def _make_tempered_beta(alpha, beta, T: float):
        """
        Power tempering: p_T(x) ∝ p(x)^(1/T).
        For Beta(alpha, beta): alpha'=(alpha-1)/T + 1; beta'=(beta-1)/T + 1.
        If action is vector-valued (last dim > 1), wrap with Independent so log_prob reduces to [B].
        """
        eps = 1e-6
        alpha_t = torch.clamp((alpha - 1.0) / T + 1.0, min=eps)
        beta_t  = torch.clamp((beta  - 1.0) / T + 1.0, min=eps)

        base = D.Beta(alpha_t, beta_t)
        # If last dim > 1, interpret as multi-dimensional independent Betas
        if alpha_t.dim() > 0 and alpha_t.shape[-1] > 1:
            return D.Independent(base, 1)
        return base

    def get_exploration_action(self, *, action_distribution, timestep, explore=True, **kwargs):
        # Deterministic path: use RLlib's own deterministic_sample/logp
        if not explore or self.temperature == 1.0:
            action = action_distribution.deterministic_sample()
            logp  = action_distribution.logp(action)
            if logp.dim() == 0:
                logp = logp.unsqueeze(0)
            return action, logp

        # Stochastic with temperature: temper the Beta concentration parameters and sample
        alpha, beta = self._get_alpha_beta(action_distribution)
        T = max(self.temperature, 1e-8)
        tempered = self._make_tempered_beta(alpha, beta, T)

        action = tempered.sample()
        logp   = tempered.log_prob(action)

        # Ensure shapes are [B] (RLlib expects vector logp)
        if logp.dim() == 0:
            logp = logp.unsqueeze(0)

        return action, logp


class SoftmaxTempExploration(Exploration):
    def __init__(self, action_space, *, framework,
                 policy_config=None, model=None,
                 num_workers=0, worker_index=0,
                 temperature=1.0, **kwargs):
        # RLlib in your trace does NOT pass `policy`, so pass None to super.
        super().__init__(action_space, framework=framework, policy_config=policy_config, model=model, num_workers=num_workers, worker_index=worker_index)
        self.temperature = float(temperature)

    @staticmethod
    def _get_logits(action_distribution):
        # RLlib 2.7 TorchCategorical usually has `.inputs`; fallback to `.dist.logits`.
        if hasattr(action_distribution, "inputs"):
            return action_distribution.inputs
        if hasattr(action_distribution, "logits"):
            return action_distribution.logits
        if hasattr(action_distribution, "dist") and hasattr(action_distribution.dist, "logits"):
            return action_distribution.dist.logits
        raise AttributeError("Could not find logits on action_distribution")

    def get_exploration_action(self, *, action_distribution, timestep, explore=True, **kwargs):
        # Deterministic path: use RLlib's own deterministic_sample/logp
        if not explore:
            action = action_distribution.deterministic_sample()
            logp  = action_distribution.logp(action)
            return action, logp

        # Stochastic with temperature: rescale logits and sample
        logits = self._get_logits(action_distribution)
        T = max(self.temperature, 1e-8)
        scaled_logits = logits / T

        # Build a temp-scaled categorical and sample
        dist = D.Categorical(logits=scaled_logits)
        action = dist.sample()

        # Log-prob under the temp-scaled dist
        logp = dist.log_prob(action)

        # Ensure shapes are [B] (RLlib expects vector logp)
        if logp.dim() == 0:
            logp = logp.unsqueeze(0)

        return action, logp