# Conditional Classes and functions
import inox
import inox.nn as nn
from priors.diffusion import *
from utils import *

class ConditionalDDPM(nn.Module):
    r"""DDPM sampler for the reverse SDE.

    .. math:: x_s = x_t - \tau (x_t - f(x_t)) + \sigma_s \sqrt{\tau} \epsilon

    where :math:`\tau = 1 - \frac{\sigma_s^2}{\sigma_t^2}`.

    Arguments:
        model: A denoiser model :math:`f(x_t, y) \approx E[x | x_t, y]`.
        sde: The forward SDE.
    """

    def __init__(self, model: nn.Module, sde: VESDE = None, **kwargs):
        super().__init__()

        self.model = model

        if sde is None:
            self.sde = VESDE()
        else:
            self.sde = sde

    @inox.jit
    def __call__(self, xt: Array, t: Array, y: Array, steps: int = 64, key: Array = None) -> Array:
        if t is None:
            t = 1.0
        dt = jnp.asarray(t / steps)
        time = jnp.linspace(t, dt, steps)
        keys = jax.random.split(key, steps)

        def f(xt, t_key):
            t, key = t_key
            return self.step(xt, t, y, t - dt, key), None

        x0, _ = jax.lax.scan(f, xt, (time, keys))

        return self.model(x0, self.sde.sigma(0.0), y)

    @inox.jit
    def step(self, xt: Array, t: Array, y: Array, s: Array, key: Array) -> Array:
        sigma_s, sigma_t = self.sde.sigma(s), self.sde.sigma(t)
        tau = 1 - (sigma_s / sigma_t) ** 2
        eps = jax.random.normal(key, xt.shape)

        return xt - tau * (xt - self.model(xt, sigma_t, y)) + sigma_s * jnp.sqrt(tau) * eps

class ConditionalPosteriorDenoiser(PosteriorDenoiser):
    r"""Posterior denoiser model for a Gaussian observation.

    .. math:: p(y | x) = N(y | Ax, \Sigma_y)

    Arguments:
        model: A denoiser model :math:`f(x_t) \approx E[x | x_t]`.
        A: The forward model :math:`A`.
        y: An observation.
        cov_y: The observation covariance :math:`\Sigma_y`.
        cov_x: The hidden covariance :math:`\Sigma_x`.
    """

    def __init__(
        self,
        *args,
        **kwargs
    ):
        return super().__init__(*args, **kwargs)


    @inox.jit
    def __call__(self, xt: Array, sigma_t: Array, y: Array, key: Array = None) -> Array:
        return super().__call__(xt, sigma_t, key)
        
class ConditionalDenoiserLoss(nn.Module):
    r"""Loss for a denoiser model.

    .. math:: \lambda_t || A f(x_t) - y ||^2

    Arguments:
        sde: The forward SDE.
    """

    def __init__(self, sde: VESDE = None):
        if sde is None:
            self.sde = VESDE()
        else:
            self.sde = sde

    @inox.jit
    def __call__(
        self,
        model: ConditionalDenoiser,
        x0,
        z,
        t,
        y_cond,
        key: Array = None,
    ) -> Array:
        r"""
        Arguments:
            x: x
            z: the random vectors from normal distribution
            t: the times
            y: corrupted x0s
            corruption_matrix: corruption matrix ~ P(A)
        """
        sigma_t = self.sde.sigma(t)
        lmbda_t = 1 / sigma_t**2 + 1

        xt = self.sde(x0, z, t)
        # we can give the corruption matrix as well? should we?
        ft = model(xt, sigma_t, y_cond, key)

        error = ft - x0

        # what about norm 1, they said some stuff in the pallette class
        return jnp.mean(lmbda_t * jnp.mean(error**2, axis=-1))

class ConditionalGaussianDenoiser(GaussianDenoiser):
    r"""Denoiser model for a Gaussian random variable.

    .. math:: p(x) = N(x | \mu_x, \Sigma_x, y_cond)

    Arguments:
        mu_x: The mean :math:`\mu_x`.
        cov_x: The covariance :math:`\Sigma_x`.
    """

    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)

    @inox.jit
    def __call__(self, xt: Array, sigma_t: Array, y_cond: Array, key: Array = None) -> Array:
        # TODO: How to take advantage of knowing y_cond?
        return super().__call__(xt = xt, sigma_t =  sigma_t, key = key)

def sample_any_conditional(
    model: nn.Module,
    shape: Sequence[int],
    shard: bool = False,
    y_cond: Array = None,
    cov_y: Union[Array, DPLR] = None,
    A: Callable[[Array], Array] = None,
    y: Array = None,
    key: Array = None,
    sampler: str = 'ddpm',
    steps: int = 64,
    rtol: float = 1e-3,
    maxiter: int = 1,
    method: str = 'cg',
    verbose: bool = False,
    **kwargs,
) -> Array:
    r"""Samples from :math:`q(x)` or :math:`q(x | A, y)`."""

    mu_x = getattr(model, 'mu_x', None)
    cov_x = getattr(model, 'cov_x', None)

    if A is None or y is None:
        pass
    else:
        # TODO: y_cond is being completely ignored
        model = ConditionalPosteriorDenoiser(
            model=model,
            A=A,
            y=y,
            cov_y=cov_y,
            cov_x=cov_x,
            rtol=rtol,
            maxiter=maxiter,
            method=method,
            verbose=verbose,
        )

    if sampler == 'ddpm':
        sampler = ConditionalDDPM(model, **kwargs)
    elif sampler == 'ddim':
        # TODO: Implement this
        raise Exception("Conditional version not implemented.")
        # sampler = DDIM(model, **kwargs)
    elif sampler == 'pc':
        sampler = ConditionalPredictorCorrector(model, **kwargs)

    z = jax.random.normal(key, shape)

    if shard:
        z = distribute(z) # TODO

    # if mu_x is None:
    x1 = sampler.sde(0.0, z, 1.0)
    # else:
    #     x1 = sampler.sde(mu_x, z, 1.0)
        
    x0 = sampler(x1, t = 1.0, y = y_cond , steps=steps, key=key)

    return x0
