import math
import torch
import torch.nn.functional as F

from torch.distributions import Distribution
from torch.distributions import Beta, register_kl, kl_divergence


class HypersphericalUniform(Distribution):
    arg_constraints = {}
    support = torch.distributions.constraints.real
    has_rsample = True

    def __init__(self, dim, device='cpu', dtype=torch.float32, validate_args=None):
        super().__init__(
            batch_shape=torch.Size([]), 
            event_shape=torch.Size([dim]), 
            validate_args=validate_args
        )
        
        self.dim = dim
        self.device = device
        self.dtype = dtype

    @property
    def mean(self):
        return torch.zeros(self.dim, device=self.device, dtype=self.dtype)

    def log_prob(self, value):
        return -self._log_surface_area() * torch.ones(
            value.shape[:-1], device=self.device, dtype=self.dtype
        )

    def _log_surface_area(self):
        d = torch.tensor(self.dim, device=self.device, dtype=self.dtype)
        return (math.log(2) + 
                (d / 2) * math.log(math.pi) - 
                torch.lgamma(d / 2))

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        z = torch.randn(shape, device=self.device, dtype=self.dtype)
        return F.normalize(z, p=2, dim=-1)


class PowerSpherical(Distribution):
    arg_constraints = {'loc': torch.distributions.constraints.real,
                       'scale': torch.distributions.constraints.positive}
    support = torch.distributions.constraints.real
    has_rsample = True

    def __init__(self, loc, scale, validate_args=None):
        self.loc = F.normalize(loc, p=2, dim=-1)
        self.scale = scale
        self.dim = loc.shape[-1]
        super().__init__(
            batch_shape=scale.shape, 
            event_shape=torch.Size([loc.shape[-1]]), 
            validate_args=validate_args)

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        alpha = self.scale + (self.dim - 1) / 2
        beta = torch.tensor((self.dim - 1) / 2, device=self.loc.device, dtype=self.loc.dtype)
        
        beta_dist = Beta(alpha, beta) 
        z = beta_dist.rsample(sample_shape) 
        t = 2 * z - 1
        v_shape = shape[:-1] + (self.dim - 1,)
        v = torch.randn(v_shape, device=self.loc.device, dtype=self.loc.dtype)
        v = F.normalize(v, p=2, dim=-1)
        
        t_unsqueezed = t.unsqueeze(-1)
        v_scaled = v * torch.sqrt((1 - t_unsqueezed**2).clamp(min=1e-6))
        y = torch.cat([t_unsqueezed, v_scaled], dim=-1)

        return self._householder_rotation(y)

    def _householder_rotation(self, y):
        e1 = torch.zeros_like(self.loc)
        e1[..., 0] = 1.0
        u = e1 - self.loc
        u_norm = torch.norm(u, p=2, dim=-1, keepdim=True)
        mask = (u_norm > 1e-6).to(y.dtype)
        u_normalized = u / (u_norm + 1e-6)
        dot_prod = torch.sum(u_normalized * y, dim=-1, keepdim=True)
        x = y - 2 * dot_prod * u_normalized
        
        return mask * x + (1 - mask) * y

    def log_prob(self, value):
        dot_prod = torch.sum(self.loc * value, dim=-1, keepdim=True)
        return self.scale * torch.log((1 + dot_prod).clamp(min=1e-6)) + self._log_normalization()

    def _log_normalization(self):
        alpha = self.scale + (self.dim - 1) / 2
        return (torch.lgamma(alpha) - 
                torch.lgamma(self.scale) - 
                alpha * math.log(2) - 
                ((self.dim - 1) / 2) * math.log(math.pi))

@register_kl(PowerSpherical, HypersphericalUniform)
def kl_power_spherical_uniform(q, p):
    assert q.dim == p.dim, f"Dimension mismatch: q.dim={q.dim} but p.dim={p.dim}"
    beta_param = torch.tensor((q.dim - 1) / 2, device=q.scale.device, dtype=q.scale.dtype)
    alpha_param = q.scale + beta_param
    q_beta = Beta(alpha_param, beta_param)
    p_beta = Beta(beta_param, beta_param)
    return kl_divergence(q_beta, p_beta)