from flax import nnx
from jax import Array, lax, numpy as jnp, random as jrandom, vmap

from offline.diffusion.modules.unconditional import UnconditionalDDPM
from offline.diffusion.repaint.modules import InpaintPolicy


class MCG(UnconditionalDDPM):
    def __init__(
        self,
        beta_schedule: str,
        clip_sample: bool,
        diffusion_steps: int,
        ground_truth_dim: int,
        rngs: nnx.Rngs,
        sample_dim: int,
        temperature: float,
        time_dim: int,
        **kwargs
    ):
        super().__init__(
            beta_schedule=beta_schedule,
            clip_sample=clip_sample,
            diffusion_steps=diffusion_steps,
            rngs=rngs,
            sample_dim=ground_truth_dim + sample_dim,
            temperature=temperature,
            time_dim=time_dim,
            **kwargs
        )
        self.sample_dim = sample_dim

    def mcg(self, ground_truth: Array, state: tuple[Array, int]):
        key, offset = state

        carry = self.mcg_backward_step(
            alpha_prod=self.alphas_cumprod[1:],
            alpha_prod_prev=self.alphas_cumprod[:-1],
            ground_truth=ground_truth,
            key=key,
            offset=offset,
            sample=jrandom.normal(
                jrandom.fold_in(key, offset),
                ground_truth.shape[:-1] + (self.noise_predictor.sample_dim,),
            ),
            time=jnp.arange(self.diffusion_steps),
        )

        carry = carry[..., : self.sample_dim]
        return carry, (key, offset + self.diffusion_steps)

    @nnx.scan(
        in_axes=(None, 0, 0, None, None, None, nnx.Carry, 0),
        out_axes=nnx.Carry,
        reverse=True,
    )
    def mcg_backward_step(
        self,
        alpha_prod: Array,
        alpha_prod_prev: Array,
        ground_truth: Array,
        key: Array,
        offset: int,
        sample: Array,
        time: Array,
    ):
        original_shape = sample.shape
        sample = jnp.reshape(sample, (-1, original_shape[-1]))
        ground_truth = jnp.reshape(ground_truth, (-1, ground_truth.shape[-1]))

        input_time = jnp.full(sample.shape[:-1] + (1,), time)

        # 1. compute alphas, betas
        beta_prod = 1 - alpha_prod
        beta_prod_prev = 1 - alpha_prod_prev
        alpha = alpha_prod / alpha_prod_prev
        beta = 1 - alpha
        sqrt_alpha_prod_prev = jnp.sqrt(alpha_prod_prev)

        grad_fn = vmap(
            nnx.grad(self.mcg_predicted_error, argnums=3, has_aux=True),
            in_axes=(None, None, 0, 0, 0),
        )
        grads, pred_original_sample = grad_fn(
            alpha_prod_prev, beta_prod, ground_truth, sample, input_time
        )

        # 3. Clip "predicted x_0"
        if self.clip_sample:
            pred_original_sample = jnp.clip(pred_original_sample, -1, 1)

        # 4. Add noise
        noise = lax.cond(
            time > 0,
            lambda x: jrandom.normal(
                jrandom.fold_in(key, offset + time), x.shape
            ),
            jnp.zeros_like,
            sample,
        )

        # 5. Compute coefficients for pred_original_sample x_0 and
        # current sample x_t. See formula (7) from Ho et al. (2020)
        # https://arxiv.org/abs/2006.11239
        pred_original_sample_coeff = (sqrt_alpha_prod_prev * beta) / beta_prod
        current_sample_coeff = jnp.sqrt(alpha) * beta_prod_prev / beta_prod

        # 6. Compute predicted previous sample µ_t
        # See formula (7) from Ho et al. (2020)
        # https://arxiv.org/abs/2006.11239
        prev_unknown_part = (
            pred_original_sample_coeff * pred_original_sample
            + current_sample_coeff * sample[:, : self.sample_dim]
        )
        std_dev = self.temperature * jnp.sqrt(beta)
        prev_unknown_part = (
            prev_unknown_part + std_dev * noise[:, : self.sample_dim]
        )
        prev_unknown_part = prev_unknown_part - grads[:, : self.sample_dim]

        # 8. Algorithm 1 Line 5
        # Lugmayr et al. (2022) https://arxiv.org/abs/2201.09865
        # The computation reported in Algorithm 1 Line 5 is incorrect. Line 5
        # refers to formula (8a) of the same paper, which tells to sample from
        # a Gaussian distribution with mean
        # "(alpha_prod_prev ** 0.5) * original_image" and variance
        # "(1 - alpha_prod_prev)". This means that the standard Gaussian
        # distribution "noise" should be scaled by the square root of the
        # variance (as it is done here), however Algorithm 1 Line 5 tells to
        # scale by the variance.
        prev_known_part = (
            sqrt_alpha_prod_prev * ground_truth
            + jnp.sqrt(beta_prod_prev) * noise[:, self.sample_dim :]
        )

        # 9. Algorithm 1 Line 8
        # Lugmayr et al. (2022) https://arxiv.org/abs/2201.09865
        pred_prev_sample = jnp.concatenate(
            (prev_unknown_part, prev_known_part), axis=-1
        )

        pred_prev_sample = jnp.reshape(pred_prev_sample, original_shape)

        return pred_prev_sample

    def mcg_predicted_error(
        self,
        alpha_prod: Array,
        beta_prod: Array,
        ground_truth: Array,
        sample: Array,
        time: Array,
    ):
        pred_noise = self.noise_predictor(sample, time)

        # 2. compute predicted original sample from predicted noise also
        # called "predicted x_0" of formula (15) from Ho et al. (2020)
        # https://arxiv.org/abs/2006.11239
        pred_original_sample = (
            sample - jnp.sqrt(beta_prod) * pred_noise
        ) / jnp.sqrt(alpha_prod)

        predicted_error = jnp.sum(
            jnp.square(ground_truth - pred_original_sample[self.sample_dim :])
        )
        return predicted_error, pred_original_sample[: self.sample_dim]


class MCGPolicy(InpaintPolicy[MCG]):
    def __init__(
        self,
        action_dim: int,
        beta_schedule: str,
        clip_sample: bool,
        diffusion_steps: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        temperature: float,
        time_dim: int,
        **kwargs
    ):
        self.diffusion = MCG(
            beta_schedule=beta_schedule,
            clip_sample=clip_sample,
            diffusion_steps=diffusion_steps,
            ground_truth_dim=observation_dim,
            rngs=rngs,
            sample_dim=action_dim,
            temperature=temperature,
            time_dim=time_dim,
            **kwargs
        )

    def __call__(self, observations: Array, state: tuple[Array, int]):
        actions, state = self.diffusion.mcg(observations, state)
        return actions, state, {}
