import torch
import torch.nn.functional as F

def beta_pdf(x, alpha, beta, eps=1e-8):
    """
    Beta probability density function (PDF) in PyTorch.

    Args:
        x (Tensor): Input values in (0, 1), shape [N]
        alpha (float or Tensor): Alpha parameter of the Beta distribution
        beta (float or Tensor): Beta parameter of the Beta distribution
        eps (float): Small value to avoid numerical instability

    Returns:
        Tensor: PDF values for each x
    """
    alpha = alpha.clone().detach() if isinstance(alpha, torch.Tensor) else torch.tensor(alpha)
    beta = beta.clone().detach() if isinstance(beta, torch.Tensor) else torch.tensor(beta)
    x = x.clamp(eps, 1 - eps)  # Avoid log(0)
    log_pdf = (alpha - 1) * torch.log(x) + (beta - 1) * torch.log(1 - x)
    log_beta_fn = torch.lgamma(alpha) + torch.lgamma(beta) - torch.lgamma(alpha + beta)
    return torch.exp(log_pdf - log_beta_fn)


def estimate_beta(sims, weights, eps=1e-8):
    """
    Estimate alpha and beta for the Beta distribution from weighted similarity scores.

    Args:
        sims (Tensor): Similarity scores in (0, 1), shape [N]
        weights (Tensor): Weights for each score, shape [N]
        eps (float): Small constant for numerical stability

    Returns:
        alpha (Tensor): Estimated alpha parameter
        beta (Tensor): Estimated beta parameter
    """
    # weighted mean
    weighted_sum = (sims * weights).sum()
    total_weight = weights.sum() + eps
    mean = weighted_sum / total_weight

    # weighted variance  
    var = (((sims - mean) ** 2) * weights).sum() / (total_weight + eps)

    mean = mean.clamp(eps, 1 - eps)
    
    # 안전장치: variance가 너무 작으면 최소값으로 클램프
    min_var = 0.01  # 최소 분산
    var = var.clamp(min=min_var)

    common = mean * (1 - mean) / var - 1.0
    
    # 안전장치: common이 너무 크면 클램프 (alpha, beta 폭발 방지)
    common = common.clamp(max=100.0)  # alpha, beta 최대 약 50 정도로 제한
    
    alpha = mean * common
    beta = (1 - mean) * common
    
    # 최종 안전장치: alpha, beta를 reasonable range로 클램프
    alpha = alpha.clamp(min=0.1, max=50.0)
    beta = beta.clamp(min=0.1, max=50.0)

    return alpha.detach(), beta.detach()
