from __future__ import annotations

import torch
from torch.distributions import Normal


class TanhNormal:
    def __init__(self, mean: torch.Tensor, log_std: torch.Tensor):
        self.mean = mean
        self.log_std = log_std.clamp(-5.0, 2.0)
        self.std = self.log_std.exp()
        self.dist = Normal(self.mean, self.std)

    def rsample(self):
        x = self.dist.rsample()
        y = torch.tanh(x)
        logp = self.dist.log_prob(x) - torch.log(1 - y.pow(2) + 1e-6)
        return y, logp.sum(-1, keepdim=True)

    def mode(self):
        return torch.tanh(self.mean)

