__all__ = ["DTDSM", "DTDSDM", "DTDSDMErr", "CTDSDM", "CTDSDMErr", "CTDSM"]


import torch
from .base import NaiveCriterion
import core.func as func
import logging


def dt_dsm(x0, wrapper, schedule):
    n, eps, xn = schedule.sample(x0)
    eps_pred = wrapper(xn, n)
    return func.sos(eps - eps_pred)


def dt_dsdm(x0, wrapper, schedule):
    n, eps, xn = schedule.sample(x0)
    # n, eps, xn = schedule.sub_sample(x0)    # debug: sample from a subset of timesteps for ablation study
    if wrapper.typ == 'eps_eps2':
        eps_pred, eps2_pred = wrapper(xn, n)
    elif wrapper.typ == 'eps_epsc':
        eps_pred, epsc_pred = wrapper(xn, n)
        eps2_pred = epsc_pred + eps_pred.pow(2)
    elif wrapper.typ == 'eps_hes':
        xn.requires_grad_(True)
        eps_pred, diaghes_pred = wrapper(xn, n)

        b = xn.shape[0]
        score = func.stp(-1. / schedule.cum_betas[n] ** 0.5, eps_pred)
        v = torch.bernoulli(torch.ones_like(score)*0.5) * 2 - 1
        Sv = torch.bmm(score.view(b, 1, -1), v.view(b, -1, 1)).sum()
        Hv = torch.autograd.grad(Sv, xn, retain_graph=True, create_graph=False)[0]  # TODO: check retain_graph
        Hv = func.stp(schedule.cum_betas[n], Hv)   # reweight Hv for numerical stability
        est_diag_H = (v * Hv).detach()

        return func.sos(eps - eps_pred), func.sos(est_diag_H - diaghes_pred)

    elif wrapper.typ == 'eps_epsc_3k':
        eps_pred, (L,J,D1,D2) = wrapper(xn,n)

        # Sigma = L + J \kron D1 \kron D2
        J = torch.bmm(J, J.transpose(1,2))
        D1 = torch.bmm(D1, D1.transpose(1,2))
        D2 = torch.bmm(D2, D2.transpose(1,2))

        feat_dims = (-1,-2,-3)
        model_norm_terms = + torch.sum(J*J, dim=(-1,-2)) * torch.sum(D1*D1, dim=(-1,-2)) * torch.sum(D2*D2, dim=(-1,-2)) \
                           + torch.sum(L*L, dim=feat_dims) \
                           + 2 * torch.vmap(
                                lambda l,j,d1,d2: torch.sum(
                                    l.reshape(-1) *
                                    torch.kron(
                                        torch.diag(j),
                                        torch.kron(
                                            torch.diag(d1), torch.diag(d2)
                                ))))(L,J,D1,D2)
        def _trace_term(v):
            return torch.sum(v * (torch.einsum(
                'mij,mab,mcd,mjbd->miac', J,D1,D2, v
            ) + L * v), dim=feat_dims)
        
        total_loss = + model_norm_terms \
                     - 2 * (_trace_term(eps) - _trace_term(eps_pred))
        return func.sos(eps - eps_pred), func.stp(schedule.cum_betas[n], total_loss)
    
    elif wrapper.typ == 'eps_epsc_blockcirc':
        eps_pred, (L,_,J,D) = wrapper(xn,n)

        feat_dims = (-1,-2,-3)
        J_diag = torch.diagonal(J, dim1=-2, dim2=-1)
        J_T = torch.transpose(J, 1, 2)

        model_norm_terms = + (torch.linalg.matrix_norm(D) ** 2) * (torch.linalg.matrix_norm(J) ** 2) \
                           + torch.linalg.vector_norm(L, dim=feat_dims) ** 2 \
                           + 2 * torch.sum(
                            J_diag * torch.sum(
                                wrapper.model_.DCT2(D)[:,None,:,:] * L,
                                axis=(-1,-2)
                            ), axis=-1
                           )
        def _trace_term(v):
            vJ_T = torch.einsum('bimn,bij->bjmn', v, J_T)
            return torch.sum(v * wrapper.model_.DCT(
                D[:,None,:,:] * wrapper.model_.IDCT(vJ_T)
            ), dim=feat_dims) + \
                torch.sum(v * v * L, dim=feat_dims)
        total_loss = + model_norm_terms \
                     - 2 * (_trace_term(eps) - _trace_term(eps_pred))
        return func.sos(eps - eps_pred), func.stp(schedule.cum_betas[n], total_loss)
    
    elif wrapper.typ == 'eps_hes_blockcirc':
        xn.requires_grad_(True)
        eps_pred, (L,_,J,D) = wrapper(xn,n)

        feat_dims = (-1,-2,-3)
        noise = torch.bernoulli(torch.ones_like(xn)*0.5) * 2 - 1

        # jvp TODO: take more memory
        #_, Jn = torch.func.jvp(
        #    lambda s: func.stp(schedule.cum_betas[n] ** 0.5, wrapper(s,n)[0]), 
        #    (xn,), (noise,)
        #)
        #Jn = Jn.detach()

        # vjp
        b = xn.shape[0]
        score = func.stp(schedule.cum_betas[n] ** 0.5, eps_pred)
        Sv = torch.bmm(score.view(b, 1, -1), noise.view(b, -1, 1)).sum()
        JTn = torch.autograd.grad(Sv, xn, retain_graph=True, create_graph=False)[0]  # TODO: check retain_graph

        # (S-I)n = L (I + J \kron F^T D F) Ln - n
        Sn = wrapper.model_.IDCT(D[:,None,:,:] * wrapper.model_.DCT(
            torch.einsum('bij,bjmn->bimn', J, noise) #? bimn,bij->bjmn
        )) + L * noise

        total_loss = + torch.sum(Sn * Sn, dim=feat_dims) \
                     - 2 * torch.sum(Sn * noise, dim=feat_dims) \
                     + 2 * torch.sum(Sn * JTn.detach(), dim=feat_dims) \

        return func.sos(eps - eps_pred), func.stp(schedule.cum_betas[n], total_loss)
    else:
        raise NotImplementedError
    return func.sos(eps - eps_pred), func.sos(eps.pow(2) - eps2_pred)


def dt_dsdm_err(x0, wrapper, schedule):
    n, eps, xn = schedule.sample(x0)
    eps_pred, epsc_pred = wrapper(xn, n)
    eps_err = eps - eps_pred
    return func.sos(eps_err), func.sos(eps_err.detach().pow(2) - epsc_pred)


def ct_dsdm(x0, wrapper, sde, t_init=1e-5):
    t, eps, xt = sde.sample(x0, t_init)
    if wrapper.typ == 'eps_eps2':
        eps_pred, eps2_pred = wrapper(xt, t)
    elif wrapper.typ == 'eps_epsc':
        eps_pred, epsc_pred = wrapper(xt, t)
        eps2_pred = epsc_pred + eps_pred.pow(2)
    else:
        raise NotImplementedError
    return func.sos(eps - eps_pred), func.sos(eps.pow(2) - eps2_pred)


def ct_dsm(x0, wrapper, sde, t_init=1e-5):
    t, eps, xt = sde.sample(x0, t_init)
    eps_pred = wrapper(xt, t)
    return func.mos(eps - eps_pred)


def ct_dsdm_err(x0, wrapper, sde, t_init=1e-5):
    t, eps, xt = sde.sample(x0, t_init)
    eps_pred, epsc_pred = wrapper(xt, t)
    eps_err = eps - eps_pred
    return func.sos(eps_err), func.sos(eps_err.detach().pow(2) - epsc_pred)


class DTDSM(NaiveCriterion):
    def __init__(self, schedule, wrapper, **kwargs):
        assert wrapper.typ == "eps"
        super().__init__(wrapper, **kwargs)
        self.schedule = schedule

    def objective(self, v, **kwargs):
        return dt_dsm(v, self.wrapper, self.schedule)


class DTDSDM(NaiveCriterion):
    def __init__(self, schedule, ratio, wrapper, **kwargs):
        assert wrapper.typ in ["eps_hes", "eps_eps2", "eps_epsc", "eps_epsc_blockcirc", "eps_hes_blockcirc", "eps_epsc_3k"]
        super().__init__(wrapper, **kwargs)
        self.schedule = schedule
        self.ratio = ratio
        logging.info(f'DTDSDM with ratio={self.ratio} wrapper.typ={wrapper.typ}')

    def objective(self, v, **kwargs):
        dsm_obj, ddm_obj = dt_dsdm(v, self.wrapper, self.schedule)
        with torch.no_grad():
            self.statistics["dsm_obj"] = dsm_obj_mean = dsm_obj.mean().item()
            self.statistics["ddm_obj"] = ddm_obj_mean = ddm_obj.mean().item()
            self.statistics["ratio"] = ratio = dsm_obj_mean / ddm_obj_mean if self.ratio == "adaptive" else self.ratio
        return dsm_obj + ratio * ddm_obj


class DTDSDMErr(NaiveCriterion):
    def __init__(self, schedule, ratio, wrapper, **kwargs):
        assert wrapper.typ in ["eps_epsc"]
        super().__init__(wrapper, **kwargs)
        self.schedule = schedule
        self.ratio = ratio
        logging.info(f'DTDSDMErr with ratio={self.ratio} wrapper.typ={wrapper.typ}')

    def objective(self, v, **kwargs):
        dsm_obj, ddm_obj = dt_dsdm_err(v, self.wrapper, self.schedule)
        with torch.no_grad():
            self.statistics["dsm_obj"] = dsm_obj_mean = dsm_obj.mean().item()
            self.statistics["ddm_obj"] = ddm_obj_mean = ddm_obj.mean().item()
            self.statistics["ratio"] = ratio = dsm_obj_mean / ddm_obj_mean if self.ratio == "adaptive" else self.ratio
        return dsm_obj + ratio * ddm_obj


class CTDSM(NaiveCriterion):
    def __init__(self, sde, wrapper, **kwargs):
        assert wrapper.typ == 'eps'
        super().__init__(wrapper, **kwargs)
        self.sde = sde

    def objective(self, v, **kwargs):
        dsm_obj = ct_dsm(v, self.wrapper, self.sde)
        return dsm_obj


class CTDSDM(NaiveCriterion):
    def __init__(self, sde, ratio, wrapper, **kwargs):
        assert wrapper.typ in ["eps_eps2", "eps_epsc"]
        super().__init__(wrapper, **kwargs)
        self.sde = sde
        self.ratio = ratio
        logging.info(f'CTDSDM with ratio={self.ratio} wrapper.typ={wrapper.typ} sde={type(self.sde)}')

    def objective(self, v, **kwargs):
        dsm_obj, ddm_obj = ct_dsdm(v, self.wrapper, self.sde)
        with torch.no_grad():
            self.statistics["dsm_obj"] = dsm_obj_mean = dsm_obj.mean().item()
            self.statistics["ddm_obj"] = ddm_obj_mean = ddm_obj.mean().item()
            self.statistics["ratio"] = ratio = dsm_obj_mean / ddm_obj_mean if self.ratio == "adaptive" else self.ratio
        return dsm_obj + ratio * ddm_obj


class CTDSDMErr(NaiveCriterion):
    def __init__(self, sde, ratio, wrapper, **kwargs):
        assert wrapper.typ in ["eps_epsc"]
        super().__init__(wrapper, **kwargs)
        self.sde = sde
        self.ratio = ratio
        logging.info(f"CTDSDMErr with ratio={self.ratio} wrapper.typ={wrapper.typ} sde={type(self.sde)}")

    def objective(self, v, **kwargs):
        dsm_obj, ddm_obj = ct_dsdm_err(v, self.wrapper, self.sde)
        with torch.no_grad():
            self.statistics["dsm_obj"] = dsm_obj_mean = dsm_obj.mean().item()
            self.statistics["ddm_obj"] = ddm_obj_mean = ddm_obj.mean().item()
            self.statistics["ratio"] = ratio = dsm_obj_mean / ddm_obj_mean if self.ratio == "adaptive" else self.ratio
        return dsm_obj + ratio * ddm_obj
