from typing import Protocol, Tuple
from dataclasses import dataclass

import jax, jax.numpy as jnp
from diffrax import (
    diffeqsolve,
    ODETerm,
    WeaklyDiagonalControlTerm,
    MultiTerm,
    UnsafeBrownianPath,
    DirectAdjoint,
    Euler,
    Heun,
    ReversibleHeun,
)
from numpyro.distributions import Normal

class ScoreModel(Protocol):
    def __call__(self, t: jax.Array, x: jax.Array, *args) -> jax.Array:
        ...

@dataclass(frozen=True)
class DDPMBetaSchedule:
    beta_min: float = 0.1
    beta_max: float = 20.0

    def beta(self, t: jax.Array):
        return self.beta_min + (self.beta_max - self.beta_min) * t

    def integrate_beta(self, t: jax.Array):
        # \int_0^t beta(t) dt
        return self.beta_min * t + (self.beta_max - self.beta_min) * t**2 / 2

@dataclass(frozen=True)
class VPSDE:
    eps = 1e-3
    beta_schedule: DDPMBetaSchedule

    @staticmethod
    def create(beta_min: float = 0.1, beta_max: float = 20.0):
        beta_schedule = DDPMBetaSchedule(beta_min, beta_max)
        return VPSDE(beta_schedule)

    # ----------------- SDE fundamental -----------------

    def sde(self, t: jax.Array, x: jax.Array):
        beta = self.beta_schedule.beta(t)
        drift = -1/2 * beta * x
        diffusion = jnp.sqrt(beta)
        return drift, diffusion

    def marginal_prob(self, t: jax.Array, x_0: jax.Array):
        beta_integral = self.beta_schedule.integrate_beta(t)
        if t.ndim == 0:
            mean = x_0 * jnp.exp(-1/2 * beta_integral)
            std = jnp.sqrt(1 - jnp.exp(-beta_integral))
        else:
            assert t.ndim == 1 and x_0.ndim == 2 and x_0.shape[0] == t.shape[0]
            mean = x_0 * jnp.exp(-1/2 * beta_integral)[:, None]
            std = jnp.sqrt(1 - jnp.exp(-beta_integral))[:, None]
        return mean, std

    def reverse_sde(self, t: jax.Array, x: jax.Array, score: jax.Array):
        beta = self.beta_schedule.beta(t)
        drift = -1/2 * beta * x - beta * score
        diffusion = jnp.sqrt(beta)
        return drift, diffusion

    def ode(self, t: float, x: jax.Array, score: jax.Array):
        beta = self.beta_schedule.beta(t)
        drift = -1/2 * beta * x - 1/2 * beta * score
        return drift

    # ----------------- Sampling -----------------

    def sample(self, key: jax.Array, model: ScoreModel, shape: Tuple[int, ...], *args, stochastic: bool = True) -> jax.Array:
        return self.sample_multiple(key, model, shape, *args, n_samples=1, stochastic=stochastic).squeeze(0)

    def sample_multiple(self, key: jax.Array, model: ScoreModel, shape: Tuple[int, ...], *args, n_samples: int, stochastic: bool = True) -> jax.Array:
        if stochastic:
            return self.sample_multi_sde(key, model, shape, *args, n_samples=n_samples)
        else:
            return self.sample_multi_ode(key, model, shape, *args, n_samples=n_samples)

    def sample_multi_sde(self, key: jax.Array, model: ScoreModel, shape: Tuple[int, ...], *args, n_samples: int) -> jax.Array:
        x_key, noise_key = jax.random.split(key)
        x = jax.random.normal(x_key, (n_samples, *shape))
        # brownian_motion = VirtualBrownianTree(self.eps, 1.0, tol=1e-3, shape=(n_samples, *shape), key=noise_key)
        brownian_motion = UnsafeBrownianPath(shape=(n_samples, *shape), key=noise_key)

        def drift_fn(t, x, args):
            score = jax.vmap(lambda x: model(t, x, *args))(x)
            drift, _ = self.reverse_sde(t, x, score)
            return drift

        def diffusion_fn(t, x, args):
            score = jax.vmap(lambda x: model(t, x, *args))(x)
            _, diffusion = self.reverse_sde(t, x, score)
            return jnp.broadcast_to(diffusion, (n_samples, *shape))

        sol = diffeqsolve(
            MultiTerm(ODETerm(drift_fn), WeaklyDiagonalControlTerm(diffusion_fn, brownian_motion)),
            Heun(),
            1.0, self.eps, -0.01, x, args,
            adjoint=DirectAdjoint()
        )

        return sol.ys.squeeze(0)

    def sample_multi_ode(self, key: jax.Array, model: ScoreModel, shape: Tuple[int, ...], *args, n_samples: int) -> jax.Array:
        x = jax.random.normal(key, (n_samples, *shape))

        def ode_fn(t, x, args):
            score = jax.vmap(lambda x: model(t, x, *args))(x)
            return self.ode(t, x, score)

        sol = diffeqsolve(
            ODETerm(ode_fn),
            Heun(),
            1.0, self.eps, -0.01, x, args,
        )

        return sol.ys.squeeze(0)

    def sample_with_log_prob(self, key: jax.Array, model: ScoreModel, shape: Tuple[int, ...], *args, stochastic: bool = False) -> Tuple[jax.Array, jax.Array]:
        x, log_prob = self.sample_multi_with_log_prob(key, model, shape, *args, n_samples=1, stochastic=stochastic)
        return x.squeeze(0), log_prob.squeeze(0)

    def sample_multi_with_log_prob(self, key: jax.Array, model: ScoreModel, shape: Tuple[int, ...], *args, n_samples: int, stochastic: bool = True) -> Tuple[jax.Array, jax.Array]:
        if stochastic:
            return self.sample_multi_sde_with_log_prob(key, model, shape, *args, n_samples=n_samples)
        else:
            return self.sample_multi_ode_with_log_prob(key, model, shape, *args, n_samples=n_samples)

    def sample_multi_sde_with_log_prob(self, key: jax.Array, model: ScoreModel, shape: Tuple[int, ...], *args, n_samples: int) -> Tuple[jax.Array, jax.Array]:
        x = self.sample_multi_sde(key, model, shape, *args, n_samples=n_samples)

        def ode_fn(t, y, args):
            def drift_fn(x, args):
                score = model(t, x[None], *args).squeeze(0)
                drift = self.ode(t, x, score)
                return drift, drift
            def value_and_divergence(x, args):
                assert x.ndim == 1
                jacobian, drift = jax.jacrev(drift_fn, has_aux=True)(x, args)
                divergence = jnp.trace(jacobian)
                return drift, divergence

            x, _ = y
            return jax.vmap(lambda x: jax.vmap(value_and_divergence)(x, args))(x)

        sol = diffeqsolve(
            ODETerm(ode_fn),
            Heun(),
            self.eps, 1.0, 0.01, (x, jnp.zeros(x.shape[:-1])), args,
        )

        x_prior, log_prob_transform = sol.ys
        x_prior = x_prior.squeeze(0)
        log_prob_transform = log_prob_transform.squeeze(0)

        log_prob_prior = Normal().log_prob(x_prior).sum(-1)

        return x, log_prob_prior - log_prob_transform

    def sample_multi_ode_with_log_prob(self, key: jax.Array, model: ScoreModel, shape: Tuple[int, ...], *args, n_samples: int) -> Tuple[jax.Array, jax.Array]:
        x = jax.random.normal(key, (n_samples, *shape))

        def ode_fn(t, y, args):
            def drift_fn(x, args):
                score = model(t, x[None], *args).squeeze(0)
                drift = self.ode(t, x, score)
                return drift, drift
            def value_and_divergence(x, args):
                assert x.ndim == 1
                jacobian, drift = jax.jacrev(drift_fn, has_aux=True)(x, args)
                divergence = jnp.trace(jacobian)
                return drift, divergence

            x, _ = y
            return jax.vmap(lambda x: jax.vmap(value_and_divergence)(x, args))(x)

        log_prob_prior = Normal().log_prob(x).sum(-1)

        sol = diffeqsolve(
            ODETerm(ode_fn),
            ReversibleHeun(),
            1.0, self.eps, -0.01, (x, jnp.zeros_like(log_prob_prior)), args,
        )

        x, log_prob_transform = sol.ys
        x = x.squeeze(0)
        log_prob_transform = log_prob_transform.squeeze(0)

        return x, log_prob_prior + log_prob_transform

    # ----------------- Training -----------------

    def loss(self, key: jax.Array, x_0: jax.Array, model: ScoreModel, *args):
        batch_size = x_0.shape[0]
        t_key, noise_key = jax.random.split(key)
        t = jax.random.uniform(t_key, (batch_size,), minval=self.eps, maxval=1)
        z = jax.random.normal(noise_key, x_0.shape)
        mean, std = self.marginal_prob(t, x_0)
        x_noisy = mean + std * z
        score = model(t, x_noisy, *args)
        loss = jnp.square(score * std + z)
        return loss.mean()
