import torch
from torch import Tensor

from compression_autoencoder.policies.policy import Policy


class StochasticMCPPolicy(Policy):
    """A class for a stochastic policy for MountainCar Continuous."""

    def predict(
        self, x: Tensor, weights: Tensor | None = None, deterministic: bool = False
    ) -> Tensor:
        mu_logsigma = super().predict(x, weights)

        mu, logsigma = mu_logsigma.chunk(2, dim=-1)

        if deterministic:
            return mu

        sigma = logsigma.exp()
        z = mu + sigma * torch.randn_like(sigma)
        return z
