from dataclasses import dataclass

import jax.numpy as jnp
from jax.scipy.special import digamma, polygamma, gammaln


@dataclass
class EMConfig:
    newton_steps_nu: int = 2
    nu_min: float = 1
    nu_max: float = 1e3
    eps_denom: float = 1e-8
    a_lam: float = 1e5
    b_lam: float = 1e5


def e_step(residuals, lam, nu):
    r2 = residuals ** 2
    w = (nu + 1.0) / (nu + lam * r2)
    elog_eta = digamma((nu + 1.0) / 2.0) - jnp.log((nu + lam * r2) / 2.0)
    return w, elog_eta


def m_step_lambda(residuals, w, a_lam, b_lam):
    r2 = residuals ** 2
    N = residuals.shape[0]
    num = (N / 2.0) + a_lam - 1.0
    den = b_lam + 0.5 * jnp.sum(w * r2)
    return num / den


def nu_newton_step_flat(nu, sum_elog_eta, sum_w, N, cfg):
    psi = digamma(nu / 2.0)
    tri = polygamma(1, nu / 2.0)

    dL_dnu = (N / 2.0) * (jnp.log(nu / 2.0) + 1.0 - psi) \
            + 0.5 * sum_elog_eta - 0.5 * sum_w

    d2L_dnu2 = (N / 2.0) * (1.0 / nu - 0.5 * tri)

    dL_dz = nu * dL_dnu
    d2L_dz2 = (nu ** 2) * d2L_dnu2 + nu * dL_dnu

    zeta = jnp.log(nu)
    zeta_new = zeta - dL_dz / (d2L_dz2 + cfg.eps_denom)
    nu_new = jnp.exp(zeta_new)

    return jnp.clip(nu_new, cfg.nu_min, cfg.nu_max)


def m_step_nu_flat(nu, elog_eta, w, cfg):
    A = jnp.sum(elog_eta)
    B = jnp.sum(w)
    N = elog_eta.shape[0]
    nu_new = nu
    for _ in range(cfg.newton_steps_nu):
        nu_new = nu_newton_step_flat(nu_new, A, B, N, cfg)
    return nu_new
def em(
   residuals,
    lam0,
    nu0,
    a_lam,
    b_lam,
    cfg,
):
    w, elog_eta = e_step(residuals, lam0, nu0)
    lam = lam0 #m_step_lambda(residuals, w, a_lam, b_lam)
    nu = m_step_nu_flat(nu0, elog_eta, w, cfg)
       

    return lam, nu

def student_t_nll(residuals, lam, nu):
    r2 = residuals ** 2
    return jnp.sum(0.5 * (nu + 1.0) * jnp.log1p((lam * r2) / nu))


def elbo(residuals, lam, nu, w, elog_eta, a_lam, b_lam, logp_theta=0.0):
    r2 = residuals ** 2
    N = residuals.shape[0]
    alpha = (nu + 1.0) / 2.0
    beta = (nu + lam * r2) / 2.0

    eq_logp_r = 0.5 * N * jnp.log(lam) + 0.5 * jnp.sum(elog_eta) - 0.5 * lam * jnp.sum(w * r2)

    eq_logp_eta = (
        N * ((nu / 2.0) * jnp.log(nu / 2.0) - gammaln(nu / 2.0))
        + (nu / 2.0 - 1.0) * jnp.sum(elog_eta)
        - (nu / 2.0) * jnp.sum(w)
    )

    logp_lam = (a_lam - 1.0) * jnp.log(lam) - b_lam * lam

    h_q = jnp.sum(alpha - jnp.log(beta) + gammaln(alpha) + (1.0 - alpha) * digamma(alpha))

    return eq_logp_r + eq_logp_eta + logp_lam + logp_theta + h_q


def em_train(
    theta_init,
    lam0,
    nu0,
    residual_fn,
    theta_update_fn,
    a_lam,
    b_lam,
    cfg,
    max_iters=50,
    tol=1e-4,
):
    theta = theta_init
    lam = lam0
    nu = nu0

    prev_nll = jnp.inf

    nll_hist = []
    elbo_hist = []

    for _ in range(max_iters):
        residuals = residual_fn(theta)
        w, elog_eta = e_step(residuals, lam, nu)

        theta = theta_update_fn(theta, lam, nu, w)

        residuals = residual_fn(theta)
        w, elog_eta = e_step(residuals, lam, nu)

        lam = m_step_lambda(residuals, w, a_lam, b_lam)
        nu = m_step_nu_flat(nu, elog_eta, w, cfg)

        residuals = residual_fn(theta)
        w, elog_eta = e_step(residuals, lam, nu)

        curr_nll = student_t_nll(residuals, lam, nu)
        curr_elbo = elbo(residuals, lam, nu, w, elog_eta, a_lam, b_lam, logp_theta=0.0)

        nll_hist.append(curr_nll)
        elbo_hist.append(curr_elbo)

        rel = jnp.abs(curr_nll - prev_nll) / (jnp.abs(prev_nll) + 1e-12)
        if rel < tol:
            break

        prev_nll = curr_nll

    return theta, lam, nu, w, jnp.array(nll_hist), jnp.array(elbo_hist)



