from typing import Protocol, Tuple
from dataclasses import dataclass

import numpy as np
import jax, jax.numpy as jnp
import optax
from numpyro.distributions import Normal

class DiffusionModel(Protocol):
    def __call__(self, t: jax.Array, x: jax.Array, *args) -> jax.Array:
        ...

@dataclass(frozen=True)
class BetaScheduleCoefficients:
    betas: jax.Array
    alphas: jax.Array
    alphas_cumprod: jax.Array
    alphas_cumprod_prev: jax.Array
    sqrt_alphas_cumprod: jax.Array
    sqrt_one_minus_alphas_cumprod: jax.Array
    log_one_minus_alphas_cumprod: jax.Array
    sqrt_recip_alphas_cumprod: jax.Array
    sqrt_recipm1_alphas_cumprod: jax.Array
    posterior_variance: jax.Array
    posterior_log_variance_clipped: jax.Array
    posterior_mean_coef1: jax.Array
    posterior_mean_coef2: jax.Array

    @staticmethod
    def from_beta(betas: np.ndarray):
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        # calculations for diffusion q(x_t | x_{t-1}) and others
        sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
        sqrt_one_minus_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
        log_one_minus_alphas_cumprod = np.log(1. - alphas_cumprod)
        sqrt_recip_alphas_cumprod = np.sqrt(1. / alphas_cumprod)
        sqrt_recipm1_alphas_cumprod = np.sqrt(1. / alphas_cumprod - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        posterior_log_variance_clipped = np.log(np.maximum(posterior_variance, 1e-20))
        posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
        posterior_mean_coef2 = (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)

        return BetaScheduleCoefficients(
            *jax.device_put((
                betas, alphas, alphas_cumprod, alphas_cumprod_prev,
                sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, log_one_minus_alphas_cumprod,
                sqrt_recip_alphas_cumprod, sqrt_recipm1_alphas_cumprod,
                posterior_variance, posterior_log_variance_clipped, posterior_mean_coef1, posterior_mean_coef2
            ))
        )
    
    @staticmethod
    def from_beta(betas: np.ndarray):
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        # calculations for diffusion q(x_t | x_{t-1}) and others
        sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
        sqrt_one_minus_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
        log_one_minus_alphas_cumprod = np.log(1. - alphas_cumprod)
        sqrt_recip_alphas_cumprod = np.sqrt(1. / alphas_cumprod)
        sqrt_recipm1_alphas_cumprod = np.sqrt(1. / alphas_cumprod - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        posterior_log_variance_clipped = np.log(np.maximum(posterior_variance, 1e-20))
        posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
        posterior_mean_coef2 = (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)

        return BetaScheduleCoefficients(
            *jax.device_put((
                betas, alphas, alphas_cumprod, alphas_cumprod_prev,
                sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, log_one_minus_alphas_cumprod,
                sqrt_recip_alphas_cumprod, sqrt_recipm1_alphas_cumprod,
                posterior_variance, posterior_log_variance_clipped, posterior_mean_coef1, posterior_mean_coef2
            ))
        )

    @staticmethod
    def vp_beta_schedule(timesteps: int):
        t = np.arange(1, timesteps + 1)
        T = timesteps
        b_max = 10.
        b_min = 0.1
        alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2)
        betas = 1 - alpha
        return betas

    @staticmethod
    def cosine_beta_schedule(timesteps: int):
        s = 0.008
        t = np.arange(0, timesteps + 1) / timesteps
        alphas_cumprod = np.cos((t + s) / (1 + s) * np.pi / 2) ** 2
        alphas_cumprod /= alphas_cumprod[0]
        betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
        betas = np.clip(betas, 0, 0.999)
        return betas


@dataclass(frozen=True)
class BetaSchedule:
    beta: jax.Array

    @property
    def alpha(self):
        return 1. - self.beta

    @property
    def alpha_bar(self):
        return jnp.cumprod(self.alpha, axis=0)

    @property
    def alpha_bar_prev(self):
        return jnp.append(1., self.alpha_bar[:-1])

    @staticmethod
    def vp_beta_schedule(T: int, beta_min: float = 0.1, beta_max: float = 20.0):
        t = np.arange(1, T + 1)
        alpha = np.exp(-beta_min / T - 0.5 * (beta_max - beta_min) * (2 * t - 1) / T ** 2)
        betas = 1 - alpha
        return BetaSchedule(jax.device_put(betas))

@dataclass(frozen=True)
class GaussianDiffusion:
    num_timesteps: int
    beta_schedule: BetaSchedule

    @staticmethod
    def vp_diffusion(T: int, beta_min: float = 0.1, beta_max: float = 20.0):
        beta_schedule = BetaSchedule.vp_beta_schedule(T, beta_min=beta_min, beta_max=beta_max)
        return GaussianDiffusion(T, beta_schedule)
    
    def ddim_beta_schedule(self):
        with jax.ensure_compile_time_eval():
            betas = BetaScheduleCoefficients.vp_beta_schedule(self.num_timesteps)
            return BetaScheduleCoefficients.from_beta(betas)
    # ----------------- Sampling -----------------

    def p_mean_std(self, t: int, x: jax.Array, noise_pred: jax.Array):
        B = self.beta_schedule
        model_mean = 1 / jnp.sqrt(B.alpha[t]) * (x - B.beta[t] / jnp.sqrt(1 - B.alpha_bar[t]) * noise_pred)
        # variance bounded between [(1-alpha_bar_prev) / (1-alpha) * beta, beta]
        model_std = jnp.sqrt(B.beta[t])
        return model_mean, model_std

    def p_transition(self, t: int, x: jax.Array, noise_pred: jax.Array):
        # Deterministic transition via probability flow ODE, see 2011.13456
        return jnp.tanh(noise_pred)
        B = self.beta_schedule
        return 1 / jnp.sqrt(B.alpha[t]) * (x - 0.5 * B.beta[t] / jnp.sqrt(1 - B.alpha_bar[t]) * noise_pred)

    def p_sample(self, key: jax.Array, model: DiffusionModel, shape: Tuple[int, ...], *args, stochastic: bool = True) -> jax.Array:
        return self.p_sample_multi_ddim(key, model, shape, *args, n_samples=1)
        # return self.p_sample_multi(key, model, shape, *args, n_samples=1, stochastic=stochastic).squeeze(0)

    def p_sample_multi(self, key: jax.Array, model: DiffusionModel, shape: Tuple[int, ...], *args, n_samples: int, stochastic: bool = True) -> jax.Array:
        if stochastic:
            return self.p_sample_multi_sde(key, model, shape, *args, n_samples=n_samples)
        else:
            return self.p_sample_multi_ode(key, model, shape, *args, n_samples=n_samples)

    def p_sample_multi_sde(self, key: jax.Array, model: DiffusionModel, shape: Tuple[int, ...], *args, n_samples: int) -> jax.Array:
        x_key, noise_key = jax.random.split(key)
        x = jax.random.normal(x_key, (n_samples, *shape))
        noise = jax.random.normal(noise_key, (self.num_timesteps, n_samples, *shape))
        t = jnp.arange(self.num_timesteps)[::-1]

        def body_fn(carry, input):
            x = carry
            t, noise = input
            noise_pred = jax.vmap(lambda x: model(t, x, *args))(x)
            model_mean, model_std = self.p_mean_std(t, x, noise_pred)
            carry = model_mean + (t > 0) * model_std * noise
            return carry, None

        x, _ = jax.lax.scan(body_fn, x, (t, noise))
        return x

    def p_sample_multi_ode(self, key: jax.Array, model: DiffusionModel, shape: Tuple[int, ...], *args, n_samples: int) -> jax.Array:
        x = jax.random.normal(key, (n_samples, *shape))
        t = jnp.arange(self.num_timesteps)[::-1]

        def body_fn(carry, input):
            x, t = carry, input
            noise_pred = jax.vmap(lambda x: model(t, x, *args))(x)
            carry = self.p_transition(t, x, noise_pred)
            return carry, None

        x, _ = jax.lax.scan(body_fn, x, t)
        return x
    
    def p_sample_multi_ddim(self, key: jax.Array, model: DiffusionModel, shape: Tuple[int, ...], *args, n_samples: int) -> jax.Array:
        x_key, noise_key = jax.random.split(key)
        x = jax.random.normal(x_key, shape)
        noise = jax.random.normal(noise_key, (self.num_timesteps, *shape))

        def body_fn(x, input):
            t, noise = input
            noise_pred = model(t, x)
            model_mean, model_log_variance = self.p_mean_variance(t, x, noise_pred)
            x = jnp.tanh(model_mean)
            # model_mean 
            # + (t > 0) * jnp.exp(0.5 * model_log_variance) * noise
            return x, None

        t = jnp.arange(self.num_timesteps)[::-1]
        x, _ = jax.lax.scan(body_fn, x, (t, noise))
        return x
    
    def p_mean_variance(self, t: int, x: jax.Array, noise_pred: jax.Array):
        ddim_B = self.ddim_beta_schedule()
        # x_recon = x * B.sqrt_recip_alphas_cumprod[t] - noise_pred * B.sqrt_recipm1_alphas_cumprod[t]
        # x_recon = jnp.clip(x_recon, -1, 1)
        x_recon = jnp.tanh(noise_pred)
        model_mean = x_recon
        # * ddim_B.posterior_mean_coef1[t] + x * ddim_B.posterior_mean_coef2[t]
        model_log_variance = ddim_B.posterior_log_variance_clipped[t]
        return model_mean, model_log_variance

    def p_sample_with_log_prob(self, key: jax.Array, model: DiffusionModel, shape: Tuple[int, ...], *args, stochastic: bool = False) -> Tuple[jax.Array, jax.Array]:
        x, log_prob = self.p_sample_multi_with_log_prob(key, model, shape, *args, n_samples=1, stochastic=stochastic)
        return x.squeeze(0), log_prob.squeeze(0)

    def p_sample_multi_with_log_prob(self, key: jax.Array, model: DiffusionModel, shape: Tuple[int, ...], *args, n_samples: int, stochastic: bool = True) -> Tuple[jax.Array, jax.Array]:
        if stochastic:
            raise NotImplementedError  # A bit tricky to implement this
        else:
            return self.p_sample_multi_ode_with_log_prob(key, model, shape, *args, n_samples=n_samples)

    def p_sample_multi_ode_with_log_prob(self, key: jax.Array, model: DiffusionModel, shape: Tuple[int, ...], *args, n_samples: int) -> Tuple[jax.Array, jax.Array]:
        x = jax.random.normal(key, (n_samples, *shape))
        t = jnp.arange(self.num_timesteps)[::-1]

        def sample_fn(x, *args):
            def body_fn(carry, input):
                x, t = carry, input
                noise_pred = model(t, x, *args)
                carry = self.p_transition(t, x, noise_pred)
                return carry, None
            x, _ = jax.lax.scan(body_fn, x, t)
            return x, x

        log_prob_prior = Normal().log_prob(x).sum(-1)

        jac_fn = jax.jacrev(sample_fn, has_aux=True)
        jac, x = jax.vmap(lambda x: jax.vmap(jac_fn)(x, *args))(x)

        log_prob = log_prob_prior - jnp.linalg.slogdet(jac).logabsdet

        return x, log_prob

    # ----------------- Training -----------------

    def q_sample(self, t: int, x_0: jax.Array, noise: jax.Array):
        B = self.beta_schedule
        return jnp.sqrt(B.alpha_bar[t]) * x_0 + jnp.sqrt(1 - B.alpha_bar[t]) * noise

    def p_loss(self, key: jax.Array, model: DiffusionModel, t: jax.Array, x_0: jax.Array):
        assert t.ndim == 1 and t.shape[0] == x_0.shape[0]

        noise = jax.random.normal(key, x_0.shape)
        x_noisy = jax.vmap(self.q_sample)(t, x_0, noise)
        noise_pred = model(t, x_noisy)
        loss = optax.l2_loss(noise_pred, noise)
        return loss.mean()
