import copy
from typing import Dict, Iterable, Tuple

import jax.numpy as jnp
from jax.scipy.special import digamma, polygamma
import jax


def mse(r):
    return jnp.mean(r * r)

@jax.jit
def t_objective(r, nu: float, lam: float, eps=1e-16):
    nu = jnp.asarray(nu, dtype=r.dtype)
    lam = jnp.asarray(lam, dtype=r.dtype)
    #return jnp.mean(0.5 * (nu + 1.0) * jnp.log1p((lam * r * r) / (nu + eps)))
    loss = 0.5 * (nu + 1.0) * jnp.log(  1.0 + (lam * jnp.square(r)) / (nu ))
    return jnp.mean(loss)


def estep_w(r, nu: float, lam: float, eps=1e-16):
    return (nu + 1.0) / (nu + lam * jnp.square(r))


def estep_log_eta(r, nu: float, lam: float, eps=1e-16):
    return digamma(0.5 * (nu + 1.0)) - jnp.log(0.5 * (nu + lam * jnp.square(r)))


class ObjectiveHandler:
    def term_loss(self, term: str, resid, state):
        raise NotImplementedError

    def maybe_update_hyperparams(self, *, state, residuals, step: int):
        return state


class FixedSNLL(ObjectiveHandler):
    def __init__(self, nu_map: Dict[str, float], lam_map: Dict[str, float]):
        self.nu = nu_map
        self.lam = lam_map

    def term_loss(self, term, resid, state):
        nu = self.nu[term]
        lam = self.lam[term]
        #print("Fixed SNLL term_loss called with nu =", nu, "lam =", lam)
        return t_objective(resid, nu=nu, lam=lam)


class WeightedMSE(ObjectiveHandler):
    def __init__(self, nu_map: Dict[str, float], lam_map: Dict[str, float]):
        self.nu = nu_map
        self.lam = lam_map

    def term_loss(self, term, resid, state):
        nu = self.nu[term]
        lam = self.lam[term]
        w = jax.lax.stop_gradient(estep_w(resid, nu=nu, lam=lam))
        return 0.5  * lam * jnp.mean(w * jnp.square(resid))


class EMStudentT(ObjectiveHandler):
    def __init__(
        self,
        terms: Iterable[str],
        priors: Dict,
        update_freq: int,
        nu_clip: Tuple[float, float] = (1.1, 200.0),
        newton_steps: int = 1,
        eps: float = 1e-8,
    ):
        self.terms = list(terms)
        self.update_freq = int(update_freq)
        self.a_lam = float(priors.get("a_lam", 0.0))
        self.b_lam = float(priors.get("b_lam", 0.0))
        self.a_nu = float(priors.get("a_nu", 0.0))
        self.b_nu = float(priors.get("b_nu", 0.0))
        self.nu_min, self.nu_max = float(nu_clip[0]), float(nu_clip[1])
        self.newton_steps = int(newton_steps)
        self.eps = float(eps)

    def term_loss(self, term, resid, state):
        nu = state.st_params["nu"][term]
        lam = state.st_params["lam"][term]
        w = estep_w(resid, nu=nu, lam=lam)
        return 0.5 * lam * jnp.mean(w * resid * resid)

    def maybe_update_hyperparams(self, *, state, residuals, step: int):
        if (step % self.update_freq) != 0:
            return state

        st = copy.deepcopy(state.st_params)

        for term in self.terms:
            r = residuals[term]
            nu = st["nu"][term]
            lam = st["lam"][term]

            w = estep_w(r, nu=nu, lam=lam)
            logeta = estep_log_eta(r, nu=nu, lam=lam)

            N = r.size
            sr2 = jnp.sum(w * r * r)

            lam_new = (0.5 * N + self.a_lam - 1.0) / (self.b_lam + 0.5 * sr2 + self.eps)

            nu_new = nu
            for _ in range(self.newton_steps):
                S1 = jnp.sum(logeta - w)

                dlogp = (self.a_nu - 1.0) / (nu_new + self.eps) - self.b_nu
                ddlogp = -(self.a_nu - 1.0) / ((nu_new + self.eps) ** 2)

                g = (
                    0.5 * N * (jnp.log(nu_new / 2.0 + self.eps) + 1.0 - digamma(nu_new / 2.0))
                    + 0.5 * S1
                    + dlogp
                )
                H = (
                    0.5 * N * (1.0 / (nu_new + self.eps) - 0.5 * polygamma(1, nu_new / 2.0))
                    + ddlogp
                )

                z = jnp.log(nu_new + self.eps)
                z = z - (nu_new * g) / (nu_new * nu_new * H + self.eps)
                nu_new = jnp.exp(z)
                nu_new = jnp.clip(nu_new, self.nu_min, self.nu_max)

            st["lam"][term] = lam_new
            st["nu"][term] = nu_new

        return state.replace(st_params=st)




class EMStudentTNewtonLambda(ObjectiveHandler):
    def __init__(
        self,
        terms: Iterable[str],
        priors: Dict,
        update_freq: int,
        nu_clip: Tuple[float, float] = (1.1, 200.0),
        nu_newton_steps: int = 1,
        lam_newton_steps: int = 1,
        lam_step_init: float = 1.0,
        lam_backtrack: bool = True,
        lam_bt_steps: int = 10,
        lam_bt_shrink: float = 0.5,
        eps: float = 1e-8,
    ):
        self.terms = list(terms)
        self.update_freq = int(update_freq)
        self.a_lam = float(priors.get("a_lam", 0.0))
        self.b_lam = float(priors.get("b_lam", 0.0))
        self.a_nu = float(priors.get("a_nu", 0.0))
        self.b_nu = float(priors.get("b_nu", 0.0))
        self.nu_min, self.nu_max = float(nu_clip[0]), float(nu_clip[1])
        self.nu_newton_steps = int(nu_newton_steps)
        self.lam_newton_steps = int(lam_newton_steps)
        self.lam_step_init = float(lam_step_init)
        self.lam_backtrack = bool(lam_backtrack)
        self.lam_bt_steps = int(lam_bt_steps)
        self.lam_bt_shrink = float(lam_bt_shrink)
        self.eps = float(eps)

    def term_loss(self, term, resid, state):
        nu = state.st_params["nu"][term]
        lam = state.st_params["lam"][term]
        w = estep_w(resid, nu=nu, lam=lam, eps=self.eps)
        return 0.5 * lam * jnp.mean(w * resid * resid)

    def _F_lam(self, r, nu, lam):
        r2 = r * r
        t = 1.0 + (lam * r2) / (nu + self.eps)
        L = 0.5 * (nu + 1.0) * jnp.sum(jnp.log(t + self.eps))
        prior = -(self.a_lam - 1.0) * jnp.log(lam + self.eps) + self.b_lam * lam
        return L + prior

    def _dF_dlam(self, r, nu, lam):
        r2 = r * r
        denom = (nu + lam * r2) + self.eps
        dL = 0.5 * (nu + 1.0) * jnp.sum(r2 / denom)
        dprior = -(self.a_lam - 1.0) / (lam + self.eps) + self.b_lam
        return dL + dprior

    def _d2F_dlam2(self, r, nu, lam):
        r2 = r * r
        denom = (nu + lam * r2) + self.eps
        d2L = -0.5 * (nu + 1.0) * jnp.sum((r2 * r2) / (denom * denom))
        d2prior = (self.a_lam - 1.0) / ((lam + self.eps) ** 2)
        return d2L + d2prior

    def _newton_update_log_lam(self, r, nu, lam):
        lam = jnp.maximum(lam, self.eps)
        Fp = self._dF_dlam(r, nu, lam)
        Fpp = self._d2F_dlam2(r, nu, lam)
        gu = lam * Fp
        hu = lam * lam * Fpp + lam * Fp
        u = jnp.log(lam + self.eps)
        eta = self.lam_step_init

        if not self.lam_backtrack:
            u_new = u - eta * gu / (hu + self.eps)
            return jnp.exp(u_new)

        F0 = self._F_lam(r, nu, lam)
        u_new = u
        lam_new = lam

        for _ in range(self.lam_bt_steps):
            u_try = u - eta * gu / (hu + self.eps)
            lam_try = jnp.exp(u_try)
            F_try = self._F_lam(r, nu, lam_try)
            if bool(F_try <= F0):
                u_new = u_try
                lam_new = lam_try
                break
            eta *= self.lam_bt_shrink

        return lam_new

    def maybe_update_hyperparams(self, *, state, residuals, step: int):
        if (step % self.update_freq) != 0:
            return state

        st = copy.deepcopy(state.st_params)

        for term in self.terms:
            r = residuals[term]
            nu = st["nu"][term]
            lam = st["lam"][term]

            lam_new = lam
            for _ in range(self.lam_newton_steps):
                lam_new = self._newton_update_log_lam(r, nu, lam_new)

            w = estep_w(r, nu=nu, lam=lam_new, eps=self.eps)
            logeta = estep_log_eta(r, nu=nu, lam=lam_new, eps=self.eps)

            N = r.size
            nu_new = nu
            for _ in range(self.nu_newton_steps):
                S1 = jnp.sum(logeta - w)
                dlogp = (self.a_nu - 1.0) / (nu_new + self.eps) - self.b_nu
                ddlogp = -(self.a_nu - 1.0) / ((nu_new + self.eps) ** 2)

                g = (
                    0.5 * N * (jnp.log(nu_new / 2.0 + self.eps) + 1.0 - digamma(nu_new / 2.0))
                    + 0.5 * S1
                    + dlogp
                )
                H = (
                    0.5 * N * (1.0 / (nu_new + self.eps) - 0.5 * polygamma(1, nu_new / 2.0))
                    + ddlogp
                )

                z = jnp.log(nu_new + self.eps)
                z = z - (nu_new * g) / (nu_new * nu_new * H + self.eps)
                nu_new = jnp.exp(z)
                nu_new = jnp.clip(nu_new, self.nu_min, self.nu_max)

            st["lam"][term] = lam_new
            st["nu"][term] = nu_new

        return state.replace(st_params=st)
    

def build_objectives(config, loss_keys):
    obj = getattr(config, "objectives", None)
    obj = obj if obj is not None else {}

    terms = dict(getattr(obj, "terms", {}) or obj.get("terms", {}) or {})
    st_cfg = getattr(obj, "student_t", None)
    st_cfg = st_cfg if st_cfg is not None else obj.get("student_t", {}) or {}

    init = getattr(st_cfg, "init", None)
    init = init if init is not None else st_cfg.get("init", {}) or {}

    nu0 = dict(getattr(init, "nu", {}) or init.get("nu", {}) or {})
    lam0 = dict(getattr(init, "lam", {}) or init.get("lam", {}) or {})

    for k in loss_keys:
        terms.setdefault(k, "mse")
        nu0.setdefault(k, 5.0)
        lam0.setdefault(k, 1.0)

    em_terms = [k for k in loss_keys if str(terms.get(k, "mse")).lower() == "em"]

    mode_set = {str(terms[k]).lower() for k in loss_keys}
    ems = ["em", "emnewtonlam"]
    uses_em = any(m in ems for m in mode_set)
    handler = None
    if uses_em:
        priors = dict(getattr(st_cfg, "priors", {}) or st_cfg.get("priors", {}) or {})
        update_freq = int(getattr(st_cfg, "update_freq", None) or st_cfg.get("update_freq", 1000))
        newton_steps = int(getattr(st_cfg, "newton_steps", None) or st_cfg.get("newton_steps", 1))
        nu_clip = getattr(st_cfg, "nu_clip", None) or st_cfg.get("nu_clip", {min: 2.0, max: 50.0})


        handler = EMStudentT(
            terms=em_terms if em_terms else [],
            priors=priors,
            update_freq=update_freq,
            nu_clip=nu_clip,
            newton_steps=newton_steps,
        )

    init_st_params = {"nu": nu0, "lam": lam0}
    #print(terms)
    return terms, handler, init_st_params
