
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

def cosine_beta_schedule(T, s=0.008):
    # follow Nichol & Dhariwal 2021 cosine schedule
    steps = T + 1
    x = torch.linspace(0, T, steps)
    alphas_cumprod = torch.cos(((x / T) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return betas.clamp(1e-5, 0.999)

class Diffusion1D(nn.Module):
    """
    Simple 1D scheduler that provides q-sampling utility for noise-pred training.
    """
    def __init__(self, dim, T=64, device='cpu'):
        super().__init__()
        self.dim = dim
        self.T = T
        betas = cosine_beta_schedule(T)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.register_buffer('betas', betas.to(device))
        self.register_buffer('alphas', alphas.to(device))
        self.register_buffer('alphas_cumprod', alphas_cumprod.to(device))
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).to(device))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod).to(device))

    def q_sample(self, x0, t, noise=None):
        """
        x_t = sqrt(alpha_bar_t) x0 + sqrt(1 - alpha_bar_t) noise
        t: (B,) integer timesteps in [0,T-1]
        """
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_ab = self.sqrt_alphas_cumprod[t].view(-1, 1).expand_as(x0)
        sqrt_omab = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1).expand_as(x0)
        return sqrt_ab * x0 + sqrt_omab * noise, noise

    def predict_x0(self, x_t, t, eps_hat):
        """Estimate x0 from noisy x_t and predicted noise eps_hat at timestep t."""
        sqrt_ab = self.sqrt_alphas_cumprod[t].view(-1, 1).expand_as(x_t)
        sqrt_omab = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1).expand_as(x_t)
        x0_hat = (x_t - sqrt_omab * eps_hat) / (sqrt_ab + 1e-8)
        return x0_hat

def dirichlet_kl(alpha_q: torch.Tensor, alpha_p: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    KL( Dir(alpha_q) || Dir(alpha_p) ) elementwise for a batch. alpha_* shape: (B,K)
    Closed form using digamma and lgamma.
    """
    aq = alpha_q + eps
    ap = alpha_p + eps
    sum_aq = aq.sum(dim=-1, keepdim=True)
    sum_ap = ap.sum(dim=-1, keepdim=True)

    term1 = torch.lgamma(sum_aq) - torch.lgamma(aq).sum(dim=-1, keepdim=True)
    term2 = -torch.lgamma(sum_ap) + torch.lgamma(ap).sum(dim=-1, keepdim=True)
    digam = torch.digamma(aq) - torch.digamma(sum_aq)
    term3 = ((aq - ap) * digam).sum(dim=-1, keepdim=True)
    kl = (term1 + term2 + term3).squeeze(-1)
    return kl

def mse(x, y):
    return ((x - y) ** 2).mean()
