from typing import Generic, TypeVar

from flax import nnx
from jax import Array, lax, numpy as jnp, random as jrandom
import numpy as np

from offline.diffusion.modules.unconditional import UnconditionalDDPM
from offline.modules.policy import Policy


DiffusionT = TypeVar("DiffusionT", bound=UnconditionalDDPM)


def get_schedule(jump_length: int, jump_samples: int, steps: int):
    timesteps = []

    jumps = {}
    for j in range(0, steps - jump_length, jump_length):
        jumps[j] = jump_samples - 1

    t = steps
    while t >= 1:
        t = t - 1
        timesteps.append(t)

        if jumps.get(t, 0) > 0:
            jumps[t] -= 1
            for _ in range(jump_length):
                t = t + 1
                timesteps.append(t)

    return np.asarray(timesteps, dtype=np.int32)


class RepaintDDPM(UnconditionalDDPM):
    def __init__(
        self,
        beta_schedule: str,
        clip_sample: bool,
        diffusion_steps: int,
        ground_truth_dim: int,
        jump_length: int,
        jump_samples: int,
        rngs: nnx.Rngs,
        sample_dim: int,
        temperature: float,
        time_dim: int,
        **kwargs
    ):
        super().__init__(
            beta_schedule=beta_schedule,
            clip_sample=clip_sample,
            diffusion_steps=diffusion_steps,
            rngs=rngs,
            sample_dim=ground_truth_dim + sample_dim,
            temperature=temperature,
            time_dim=time_dim,
            **kwargs
        )
        self.sample_dim = sample_dim
        self.set_repaint_timesteps(
            jump_length=jump_length, jump_samples=jump_samples, **kwargs
        )

    def repaint(self, ground_truth: Array, state: tuple[Array, int]):
        key, offset = state

        carry = self.repaint_process(
            alpha_prod=self.repaint_alphas_cumprod.value,
            alpha_prod_prev=self.repaint_alphas_cumprod_prev.value,
            backwards=self.backwards.value,
            ground_truth=ground_truth,
            key=key,
            offset=offset + jnp.arange(self.repaint_steps),
            sample=jrandom.normal(
                jrandom.fold_in(key, offset),
                ground_truth.shape[:-1] + (self.noise_predictor.sample_dim,),
            ),
            time=self.repaint_timesteps.value,
        )

        carry = carry[..., : self.sample_dim]
        return carry, (key, offset + self.repaint_steps)

    @nnx.scan(
        in_axes=(None, 0, 0, 0, None, None, 0, nnx.Carry, 0),
        out_axes=nnx.Carry,
    )
    def repaint_process(
        self,
        alpha_prod: Array,
        alpha_prod_prev: Array,
        backwards: Array,
        ground_truth: Array,
        key: Array,
        offset: Array,
        sample: Array,
        time: Array,
    ):
        sample = lax.cond(
            backwards,
            lambda: self.repaint_backward_step(
                alpha_prod=alpha_prod,
                alpha_prod_prev=alpha_prod_prev,
                ground_truth=ground_truth,
                key=key,
                offset=offset,
                sample=sample,
                time=time,
            ),
            lambda: self.repaint_undo_step(
                key=jrandom.fold_in(key, offset), sample=sample, time=time
            ),
        )
        return sample

    def repaint_backward_step(
        self,
        alpha_prod: Array,
        alpha_prod_prev: Array,
        ground_truth: Array,
        key: Array,
        offset: Array,
        sample: Array,
        time: Array,
    ):
        input_time = jnp.full(sample.shape[:-1] + (1,), time)
        pred_noise = self.noise_predictor(sample, input_time)
        pred_noise = pred_noise[..., : self.sample_dim]

        # 1. compute alphas, betas
        beta_prod = 1 - alpha_prod
        beta_prod_prev = 1 - alpha_prod_prev
        alpha = alpha_prod / alpha_prod_prev
        beta = 1 - alpha
        sqrt_alpha_prod_prev = jnp.sqrt(alpha_prod_prev)

        # 2. compute predicted original sample from predicted noise also
        # called "predicted x_0" of formula (15) from Ho et al. (2020)
        # https://arxiv.org/abs/2006.11239
        unknown_part = sample[..., : self.sample_dim]
        pred_original_sample = (
            unknown_part - jnp.sqrt(beta_prod) * pred_noise
        ) / jnp.sqrt(alpha_prod)

        # 3. Clip the unknown part of "predicted x_0"
        if self.clip_sample:
            pred_original_sample = jnp.clip(pred_original_sample, -1, 1)

        # 4. Add noise
        noise = lax.cond(
            time > 0,
            lambda x: jrandom.normal(jrandom.fold_in(key, offset), x.shape),
            jnp.zeros_like,
            sample,
        )

        # 5. Compute coefficients for pred_original_sample x_0 and
        # current sample x_t. See formula (7) from Ho et al. (2020)
        # https://arxiv.org/abs/2006.11239
        pred_original_sample_coeff = (sqrt_alpha_prod_prev * beta) / beta_prod
        current_sample_coeff = jnp.sqrt(alpha) * beta_prod_prev / beta_prod

        # 6. Compute predicted previous sample µ_t
        # See formula (7) from Ho et al. (2020)
        # https://arxiv.org/abs/2006.11239
        prev_unknown_part = (
            pred_original_sample_coeff * pred_original_sample
            + current_sample_coeff * unknown_part
        )
        std_dev = self.temperature * jnp.sqrt(beta)
        prev_unknown_part = lax.cond(
            time > 0,
            lambda x: x + std_dev * noise[..., : self.sample_dim],
            lambda x: x,
            prev_unknown_part,
        )

        # 8. Algorithm 1 Line 5
        # Lugmayr et al. (2022) https://arxiv.org/abs/2201.09865
        # The computation reported in Algorithm 1 Line 5 is incorrect. Line 5
        # refers to formula (8a) of the same paper, which tells to sample from
        # a Gaussian distribution with mean
        # "(alpha_prod_prev ** 0.5) * original_image" and variance
        # "(1 - alpha_prod_prev)". This means that the standard Gaussian
        # distribution "noise" should be scaled by the square root of the
        # variance (as it is done here), however Algorithm 1 Line 5 tells to
        # scale by the variance.
        prev_known_part = (
            sqrt_alpha_prod_prev * ground_truth
            + jnp.sqrt(beta_prod_prev) * noise[..., self.sample_dim :]
        )

        # 9. Algorithm 1 Line 8
        # Lugmayr et al. (2022) https://arxiv.org/abs/2201.09865
        pred_prev_sample = jnp.concatenate(
            (prev_unknown_part, prev_known_part), axis=-1
        )

        return pred_prev_sample

    def repaint_undo_step(self, key: Array, sample: Array, time: Array):
        beta = self.betas[time]
        noise = jrandom.normal(key, sample.shape)

        # 10. Algorithm 1 Line 10 in Lugmayr et al. (2022)
        # https://arxiv.org/abs/2201.09865
        sample = jnp.sqrt(1 - beta) * sample + jnp.sqrt(beta) * noise

        return sample

    def set_repaint_timesteps(
        self, *, jump_length: int, jump_samples: int, **kwargs
    ):
        del kwargs

        timesteps = get_schedule(
            jump_length=jump_length,
            jump_samples=jump_samples,
            steps=self.diffusion_steps,
        )
        timesteps_last = np.insert(timesteps[:-1], 0, timesteps[0] + 1)
        self.backwards = nnx.Variable(jnp.asarray(timesteps < timesteps_last))
        self.repaint_steps = int(timesteps.size)

        # If t < t_last, we need t. Otherwise, we need t_last.
        timesteps = np.minimum(timesteps, timesteps_last)
        self.repaint_timesteps = nnx.Variable(jnp.asarray(timesteps))

        prev_timesteps = timesteps
        timesteps = timesteps + 1
        self.repaint_alphas_cumprod = nnx.Variable(
            self.alphas_cumprod[timesteps]
        )
        self.repaint_alphas_cumprod_prev = nnx.Variable(
            self.alphas_cumprod[prev_timesteps]
        )


# pylint: disable = abstract-method
class InpaintPolicy(Generic[DiffusionT], Policy[tuple[Array, int]]):
    diffusion: DiffusionT


class RepaintPolicy(InpaintPolicy[RepaintDDPM]):
    def __init__(
        self,
        action_dim: int,
        beta_schedule: str,
        clip_sample: bool,
        diffusion_steps: int,
        jump_length: int,
        jump_samples: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        temperature: float,
        time_dim: int,
        **kwargs
    ):
        self.diffusion = RepaintDDPM(
            beta_schedule=beta_schedule,
            clip_sample=clip_sample,
            diffusion_steps=diffusion_steps,
            ground_truth_dim=observation_dim,
            jump_length=jump_length,
            jump_samples=jump_samples,
            rngs=rngs,
            sample_dim=action_dim,
            temperature=temperature,
            time_dim=time_dim,
            **kwargs
        )

    def __call__(self, observations: Array, state: tuple[Array, int]):
        actions, state = self.diffusion.repaint(observations, state)
        return actions, state, {}
