__all__ = ["DTDPM", "DDPM", "DDIM"]

import numpy as np
import logging
import core.func as func
from core.utils import global_device
import torch


class DTDPM(object):  # diffusion with discrete timesteps
    r"""
        E[xs|xt] = E[ E[xs|xt,x0] |xt] = E[xs|xt,x0=E[x0|xt]] in DDPM or DDIM forward process
    """
    def __init__(self, wrapper, schedule, clip_x0: bool, clip_cov_x0=None, avg_cov=False):
        assert wrapper.typ in ['eps', 'x0', 'epshesest', 'eps_hes', 'eps_eps2', 'eps_epsc', 'eps_iddpm', 'eps_hes_blockcirc', 'eps_epsc_blockcirc']
        self.wrapper = wrapper

        self.schedule = schedule
        self.N = schedule.N
        self.betas = schedule.betas
        self.alphas = schedule.alphas
        self.cum_betas = schedule.cum_betas
        self.cum_alphas = schedule.cum_alphas
        self.skip_alphas = schedule.skip_alphas
        self.skip_betas = schedule.skip_betas
        self.tilde_beta = schedule.tilde_beta

        self.clip_x0 = clip_x0
        self.clip_cov_x0 = clip_x0 if clip_cov_x0 is None else clip_cov_x0
        self.avg_cov = avg_cov

        logging.info("DTDPM with clip_x0={} clip_cov_x0={} avg_cov={}".format(self.clip_x0, self.clip_cov_x0, self.avg_cov))
        self.statistics = {}

    def predict_x0_eps(self, xt, t):  # estimate E[x0|xt], E[eps|xt] w.r.t. q
        if self.wrapper.typ in ['eps', 'x0']:
            pred = self.wrapper(xt, t)
        elif self.wrapper.typ.startswith('eps_'):
            pred, _ = self.wrapper(xt, t)
        else:
            raise NotImplementedError
        x0_pred, eps_pred = self._predict_x0_eps(pred, xt, t)
        return x0_pred, eps_pred

    def _predict_x0_eps(self, pred, xt, t):  # pred: the direct output of the first model
        if self.wrapper.typ == 'eps' or self.wrapper.typ == 'epshesest' or self.wrapper.typ.startswith('eps_'):
            eps_pred = pred
            x0_pred = self.cum_alphas[t] ** -0.5 * xt - (self.cum_betas[t] / self.cum_alphas[t]) ** 0.5 * eps_pred
        elif self.wrapper.typ == 'x0':
            x0_pred = pred
            eps_pred = - (self.cum_alphas[t] / self.cum_betas[t]) ** 0.5 * x0_pred + (1. / self.cum_betas[t] ** 0.5) * xt
        else:
            raise NotImplementedError
        if self.clip_x0:
            x0_pred = x0_pred.clamp(-1., 1.)
        return x0_pred, eps_pred

    def predict_cov_x0(self, xt, t, ms_eps=None):  # estimate Cov[x0|xt], for typ = "optimal"
        if ms_eps is not None:
            return self._e_cov_x0(t, ms_eps)

        if self.wrapper.typ in ['eps_eps2', 'eps_epsc']:
            pred1, pred2 = self.wrapper(xt, t)
        elif self.wrapper.typ in ["eps"]:
            pred1 = self.wrapper(xt, t)
            pred2 = None
        else:
            raise NotImplementedError
        cov_x0_pred = self._predict_cov_x0(pred1, pred2, xt, t)
        return cov_x0_pred

    def _predict_cov_x0(self, pred1, pred2, xt, t):  # for typ = "optimal"
        if self.wrapper.typ == 'eps_hes':
            eps_pred, diaghes_pred = pred1, pred2
            diaghes_pred = diaghes_pred / self.cum_betas[t]
            cov_x0_pred = (self.cum_betas[t] / self.cum_alphas[t]) * (1 + self.cum_betas[t] * diaghes_pred)
        elif self.wrapper.typ == 'epshesest':
            torch.set_grad_enabled(True)
            eps_pred = pred1
            mc_num = 100

            b = xt.size(0)
            xt = xt.requires_grad_(True)
            diag_hessian = torch.zeros_like(xt)
            for i in range(0, mc_num):
                eps = self.wrapper(xt, t) 
                score = (-1. / self.cum_betas[t] ** 0.5) * eps
                v = torch.bernoulli(torch.ones_like(diag_hessian)*0.5) * 2 - 1
                Sv = torch.bmm(score.view(b, 1, -1), v.view(b, -1, 1)).sum()
                Hv = torch.autograd.grad(Sv, xt, retain_graph=False)[0]
                diag_hessian += (v * Hv).detach()
            diaghes_pred = diag_hessian / mc_num
            cov_x0_pred = (self.cum_betas[t] / self.cum_alphas[t]) * (1 + self.cum_betas[t] * diaghes_pred)
            
            torch.set_grad_enabled(False)   # important, otherwise, the gradient will be accumulated
        elif self.wrapper.typ == 'eps_eps2':
            eps_pred, eps2_pred = pred1, pred2
            cov_x0_pred = (self.cum_betas[t] / self.cum_alphas[t]) * (eps2_pred - eps_pred.pow(2))
        elif self.wrapper.typ == 'eps_epsc':
            eps_pred, epsc_pred = pred1, pred2
            cov_x0_pred = (self.cum_betas[t] / self.cum_alphas[t]) * epsc_pred
        elif self.wrapper.typ == 'eps' and pred2 is None:
            # a huristic estimate: cov[v|w] ≈ beta/alpha (1 - beta s(w)^2) = beta/alpha (1 - eps(w)^2)
            eps_pred = pred1
            cov_x0_pred = (self.cum_betas[t] / self.cum_alphas[t]) * (1 - eps_pred.pow(2))
        else:
            raise NotImplementedError
        if self.avg_cov:
            cov_x0_pred = func.mean_flat(cov_x0_pred, keepdim=True)
        self.statistics['cov_x0'] = cov_x0_pred.mean().item()
        if self.clip_cov_x0:
            cov_x0_pred = cov_x0_pred.clamp(0., 1.)
            self.statistics['cov_x0_clip'] = cov_x0_pred.mean().item()
        return cov_x0_pred

    def _e_cov_x0(self, t, ms_eps):  # estimate E Cov[x0|xt]
        cov_x0_pred = self.cum_betas[t] / self.cum_alphas[t] * (1. - ms_eps[t])
        if not isinstance(cov_x0_pred, torch.Tensor):
            cov_x0_pred = torch.tensor(cov_x0_pred, device=global_device())
        self.statistics['cov_x0'] = cov_x0_pred.item()
        if self.clip_cov_x0:
            cov_x0_pred = cov_x0_pred.clamp(0., 1.)
            self.statistics['cov_x0_clip'] = cov_x0_pred.item()
        return cov_x0_pred

    def predict_x0_eps_cov_x0(self, xt, t, ms_eps=None):  # for typ = "optimal"
        if ms_eps is not None:
            x0_pred, eps_pred = self.predict_x0_eps(xt, t)
            cov_x0_pred = self.predict_cov_x0(None, t, ms_eps)
            return x0_pred, eps_pred, cov_x0_pred

        if self.wrapper.typ in ['eps_hes', 'eps_eps2', 'eps_epsc']:
            pred1, pred2 = self.wrapper(xt, t)
        elif self.wrapper.typ in ["eps", "epshesest"]:
            pred1 = self.wrapper(xt, t)
            pred2 = None
        else:
            raise NotImplementedError
        x0_pred, eps_pred = self._predict_x0_eps(pred1, xt, t)
        cov_x0_pred = self._predict_cov_x0(pred1, pred2, xt, t)
        return x0_pred, eps_pred, cov_x0_pred

    def q_posterior_mean(self, x0, s, t, xt=None, eps=None):  # E[xs|xt,x0] w.r.t. q
        raise NotImplementedError

    def predict_xprev(self, xt, s, t):  # estimate E[xs|xt] w.r.t. q
        x0_pred, eps_pred = self.predict_x0_eps(xt, t)
        return self.q_posterior_mean(x0_pred, s, t, xt=xt, eps=eps_pred)

    def predict_cov_prev(self, xt, s, t, typ, ms_eps=None):  # estimate Cov[xs|xt] w.r.t. q
        cov_x0_pred = self.predict_cov_x0(xt, t, ms_eps) if typ == 'optimal' else None
        return self._predict_cov_prev(s, t, typ, cov_x0_pred)

    def _predict_cov_prev(self, s, t, typ, cov_x0_pred=None):
        raise NotImplementedError

    def predict_xprev_cov_xprev(self, xt, s, t, typ, ms_eps=None):
        if typ == 'optimal':
            x0_pred, eps_pred, cov_x0_pred = self.predict_x0_eps_cov_x0(xt, t, ms_eps)
        else:
            x0_pred, eps_pred = self.predict_x0_eps(xt, t)
            cov_x0_pred = None
        xprev_pred = self.q_posterior_mean(x0_pred, s, t, xt=xt, eps=eps_pred)
        sigma2 = self._predict_cov_prev(s, t, typ, cov_x0_pred)
        return xprev_pred, sigma2
    
    
    def sample_xprev_cov_xprev(self, xt, s, t, n1, n2, N):
        pred1, aux = self.wrapper(xt, t)
        x0_pred, eps_pred = self._predict_x0_eps(pred1, xt, t)
        xprev_pred = self.q_posterior_mean(x0_pred, s, t, xt=xt, eps=eps_pred)

        sigma2_small = self.tilde_beta(s, t)
        cov_coeff = self.skip_betas[s, t] ** 2 / self.skip_alphas[s, t] / self.cum_betas[t]

        if t > int(N * 0.9):
            cov_coeff = self.cum_alphas[s] * self.skip_betas[s, t] ** 2 / self.cum_betas[t] ** 2
        if 'blockcirc' in self.wrapper.typ:
            L, J, _, D = aux
            L = L.clamp(min=0.)

            n1 = torch.einsum('bij,bjmn->bimn', J, n1)
            n1 = self.wrapper.model_.IDCT(
                torch.sqrt(D)[:,None,:,:] * n1
            )

        n1 = n1 * cov_coeff ** 0.5
        n2 = n2 * (sigma2_small + cov_coeff * L) ** 0.5

        return xprev_pred, n1 + n2
    
    def blockcirc_kl(self, xt, s, t, mu_0, var_0, N):
        pred1, aux = self.wrapper(xt, t)
        if 'blockcirc' in self.wrapper.typ:
            L, J, _, D = aux

        var_0 = var_0 if isinstance(var_0, torch.Tensor) else torch.tensor(var_0).to(mu_0)

        x0_pred, eps_pred = self._predict_x0_eps(pred1, xt, t)
        xprev_pred = self.q_posterior_mean(x0_pred, s, t, xt=xt, eps=eps_pred)

        J_diag = torch.diagonal(J, dim1=-2, dim2=-1)
        diag = L + torch.einsum('bi,bmn->bimn', J_diag, self.DCT2(D))

        sigma2_small = self.tilde_beta(s, t)
        sigma2_threshold = self.cum_alphas[t] / self.cum_betas[t]
        scale = torch.where(
            diag > sigma2_threshold,
            sigma2_threshold / diag,
            1.
        )
        scale = scale ** 0.5
        scale = scale.transpose(-1,-2).reshape(diag.size(0),-1)
        cov_coeff = self.skip_betas[s, t] ** 2 / self.skip_alphas[s, t] / self.cum_betas[t]

        var_p = torch.vmap(lambda L,J,D: torch.kron(
            J, torch.kron(
                self.DCT.kron @ (D.t().reshape(-1)[:,None] * self.IDCT.kron)
            )) + torch.diag(L.transpose(-1,-2).reshape(-1))
        )(L,J,D) # B, 3*h*h, 3*h*h

        identity = torch.eye(var_p.size(1), device=var_p.device)
        if t < int(N * 0.95):
            var_p = sigma2_small * identity + cov_coeff * (scale[:,:,None] * var_p * scale[:,None,:])
        else:
            cov_coeff = self.cum_alphas[s] * self.skip_betas[s, t] ** 2 / self.cum_betas[t] ** 2
            var_p = sigma2_small * identity + cov_coeff * var_p
        var_p = torch.linalg.cholesky(var_p)

        # tr(S_2^-1 S_1)
        trace_s2 = torch.linalg.solve_triangular(var_p, identity, upper=False)
        trace_s2 = (trace_s2 ** 2).sum(1)
        trace_s2 = trace_s2.reshape(*mu_0.shape).transpose(-1,-2)
        trace = (var_0 * trace_s2).flatten(1).sum(1)

        # log(det S_2) - log(det S_1)
        eval_re = torch.diagonal(var_p, dim1=-2, dim2=-1)
        eval_re = eval_re.reshape(*mu_0.shape).transpose(-1,-2)
        logdet = (2*eval_re.log() - var_0.log()).flatten(1).sum(1)

        # (m1-m2)^T S_2^-1 (m1-m2)
        v = (mu_0 - xprev_pred).transpose(-1,-2).flatten(1)
        cond_v = torch.linalg.solve_triangular(
            var_p.transpose(-1,-2),
            torch.linalg.solve_triangular(
                var_p, v[:,:,None],
                upper=False
            ), upper=True
        )[:,:,0]
        distance = (v * cond_v).sum(1)

        return 0.5 * (
            logdet + trace + distance - torch.numel(mu_0[0])
        )


class DDPM(DTDPM):
    def q_posterior_mean(self, x0, s, t, xt=None, eps=None):  # E[xs|xt,x0] w.r.t. q
        assert xt is not None
        coeff1 = self.skip_betas[s, t] * self.cum_alphas[s] ** 0.5 / self.cum_betas[t]
        coeff2 = self.skip_alphas[s, t] ** 0.5 * self.cum_betas[s] / self.cum_betas[t]
        return coeff1 * x0 + coeff2 * xt

    def _predict_cov_prev(self, s, t, typ, cov_x0_pred=None):
        sigma2_small = self.tilde_beta(s, t)
        self.statistics['sigma2_small'] = sigma2_small
        self.statistics['sigma2_big'] = self.skip_betas[s, t]
        if typ == 'small':
            sigma2 = sigma2_small
        elif typ == 'big':
            sigma2 = self.skip_betas[s, t]
        elif typ == 'optimal':
            coeff_cov_x0 = self.cum_alphas[s] * self.skip_betas[s, t] ** 2 / self.cum_betas[t] ** 2
            offset = coeff_cov_x0 * cov_x0_pred
            sigma2 = sigma2_small + offset
            self.statistics['coeff_cov_x0'] = coeff_cov_x0.item()
            self.statistics['offset'] = offset.mean().item()
        else:
            raise NotImplementedError
        if not isinstance(sigma2, torch.Tensor):
            sigma2 = torch.tensor(sigma2, device=global_device())
        return sigma2

    def predict_xprev_cov_xprev(self, xt, s, t, typ, ms_eps=None):
        if self.wrapper.typ == "eps_iddpm" and typ == "optimal" and ms_eps is None:
            eps_pred, model_var_values = self.wrapper(xt, t)
            x0_pred, eps_pred = self._predict_x0_eps(eps_pred, xt, t)
            xprev_pred = self.q_posterior_mean(x0_pred, s, t, xt=xt, eps=eps_pred)
            min_log = np.log(self.tilde_beta(s, t) if s > 0 else self.tilde_beta(1, 2))
            max_log = np.log(self.skip_betas[s, t])
            frac = (model_var_values + 1) / 2
            model_log_variance = frac * max_log + (1 - frac) * min_log
            sigma2 = model_log_variance.exp()
            return xprev_pred, sigma2
        else:
            return super().predict_xprev_cov_xprev(xt, s, t, typ, ms_eps)


class DDIM(DTDPM):
    def __init__(self, wrapper, schedule, clip_x0: bool, eta: float, clip_cov_x0=None, avg_cov=False):
        super().__init__(wrapper, schedule, clip_x0, clip_cov_x0=clip_cov_x0, avg_cov=avg_cov)
        self.eta = eta
        logging.info("DDIM with eta={}".format(eta))

    def q_posterior_mean(self, x0, s, t, xt=None, eps=None):  # E[xs|xt,x0] w.r.t. q
        # eps = (xt - self.cum_alphas[t] ** 0.5 * x0) / self.cum_betas[t] ** 0.5
        assert eps is not None
        sigma2_small = self.tilde_beta(s, t)
        lamb2 = self.eta ** 2 * sigma2_small

        coeff1 = self.cum_alphas[s] ** 0.5
        coeff2 = (self.cum_betas[s] - lamb2) ** 0.5
        return coeff1 * x0 + coeff2 * eps

    def _predict_cov_prev(self, s, t, typ, cov_x0_pred=None):
        sigma2_small = self.tilde_beta(s, t)
        lamb2 = self.eta ** 2 * sigma2_small
        if typ == 'small':
            sigma2 = lamb2
        elif typ == 'optimal':
            coeff_cov_x0 = (self.cum_alphas[s] ** 0.5 - ((self.cum_betas[s] - lamb2) * self.cum_alphas[t] / self.cum_betas[t]) ** 0.5) ** 2
            offset = coeff_cov_x0 * cov_x0_pred
            sigma2 = lamb2 + offset
        else:
            raise NotImplementedError
        if not isinstance(sigma2, torch.Tensor):
            sigma2 = torch.tensor(sigma2, device=global_device())
        return sigma2
