from typing import Callable

from flax import nnx
from jax import Array, lax, numpy as jnp, random as jrandom
from optax import squared_error

from offline.diffusion.modules.base import DDPM
from offline.diffusion.modules.utils import FourierFeatures
from offline.modules.mlp import MLP


class ConditionalNoisePredictor(nnx.Module):
    def __init__(
        self,
        condition_dim: int,
        num_layers: int,
        rngs: nnx.Rngs,
        sample_dim: int,
        dropout: float = 0,
        hidden_features: int = 256,
        layer_norm: bool = False,
        nonlinearity: Callable[[Array], Array] | str = "mish",
        time_dim: int = 16,
        **kwargs
    ):
        self.mlp = MLP(
            dropout=dropout,
            hidden_features=hidden_features,
            in_features=condition_dim + sample_dim + 32,
            layer_norm=layer_norm,
            nonlinearity=nonlinearity,
            num_layers=num_layers,
            out_features=sample_dim,
            rngs=rngs,
            **kwargs
        )
        self.sample_dim = sample_dim
        self.time_embedding = FourierFeatures(time_dim)
        self.time_mlp = MLP(
            hidden_features=32,
            in_features=time_dim,
            layer_norm=layer_norm,
            nonlinearity=nonlinearity,
            num_layers=3,
            out_features=32,
            rngs=rngs,
            **kwargs
        )

    def __call__(self, samples: Array, conditions: Array, time: Array):
        time_features = self.time_embedding(time)
        time_features = self.time_mlp(time_features)
        inputs = jnp.concatenate((samples, conditions, time_features), axis=-1)
        return self.mlp(inputs)


class ConditionalDDPM(DDPM):
    def __init__(
        self,
        beta_schedule: str,
        clip_sample: bool,
        condition_dim: int,
        diffusion_steps: int,
        rngs: nnx.Rngs,
        sample_dim: int,
        temperature: float,
        time_dim: int,
        **kwargs
    ):
        super().__init__(
            beta_schedule=beta_schedule,
            clip_sample=clip_sample,
            diffusion_steps=diffusion_steps,
            temperature=temperature,
        )
        self.noise_predictor = ConditionalNoisePredictor(
            condition_dim=condition_dim,
            rngs=rngs,
            sample_dim=sample_dim,
            time_dim=time_dim,
            **kwargs
        )

    def __call__(self, conditions: Array, state: tuple[Array, int]):
        key, offset = state
        carry = self.backward_process(
            alpha_prod=self.alphas_cumprod[1:],
            alpha_prod_prev=self.alphas_cumprod[:-1],
            conditions=conditions,
            key=key,
            offset=offset,
            sample=jrandom.normal(
                jrandom.fold_in(key, offset),
                conditions.shape[:-1] + (self.noise_predictor.sample_dim,),
            ),
            time=jnp.arange(self.diffusion_steps),
        )
        return carry, (key, offset + self.diffusion_steps)

    @nnx.scan(
        in_axes=(None, 0, 0, None, None, None, nnx.Carry, 0),
        out_axes=nnx.Carry,
        reverse=True,
    )
    def backward_process(
        self,
        alpha_prod: Array,
        alpha_prod_prev: Array,
        conditions: Array,
        key: Array,
        offset: int,
        sample: Array,
        time: Array,
    ):
        input_time = jnp.full(sample.shape[:-1] + (1,), time)
        pred_noise = self.noise_predictor(sample, conditions, input_time)

        # 1. compute alphas, betas
        beta_prod = 1 - alpha_prod
        beta_prod_prev = 1 - alpha_prod_prev
        alpha = alpha_prod / alpha_prod_prev
        beta = 1 - alpha

        # 2. compute predicted original sample from predicted noise also
        # called "predicted x_0" of formula (15) from Ho et al. (2020)
        # https://arxiv.org/abs/2006.11239
        pred_original_sample = (
            sample - jnp.sqrt(beta_prod) * pred_noise
        ) / jnp.sqrt(alpha_prod)

        # 3. Clip "predicted x_0"
        if self.clip_sample:
            pred_original_sample = jnp.clip(pred_original_sample, -1, 1)

        # 4. Compute coefficients for pred_original_sample x_0 and
        # current sample x_t. See formula (7) from Ho et al. (2020)
        # https://arxiv.org/abs/2006.11239
        pred_original_sample_coeff = (
            jnp.sqrt(alpha_prod_prev) * beta
        ) / beta_prod
        current_sample_coeff = jnp.sqrt(alpha) * beta_prod_prev / beta_prod

        # 5. Compute predicted previous sample µ_t
        # See formula (7) from Ho et al. (2020)
        # https://arxiv.org/abs/2006.11239
        pred_prev_sample = (
            pred_original_sample_coeff * pred_original_sample
            + current_sample_coeff * sample
        )

        # 6. Add noise
        pred_prev_sample = lax.cond(
            time > 0,
            lambda x: self.add_noise_pred_prev_sample(
                beta=beta,
                key=jrandom.fold_in(key, time + offset),
                pred_prev_sample=x,
            ),
            lambda x: x,
            pred_prev_sample,
        )

        return pred_prev_sample


def conditional_diffusion_loss_fn(
    noise_predictor: ConditionalNoisePredictor,
    conditions: Array,
    noise: Array,
    noisy_samples: Array,
    timesteps: Array,
):
    noise_pred = noise_predictor(noisy_samples, conditions, timesteps)
    loss = jnp.mean(squared_error(noise_pred, noise))
    return loss
