from typing import Protocol, Tuple

import jax
import jax.numpy as jnp


from .types import Batch, Rng, ndarray


class EpsModel(Protocol):
    
    def __call__(self, t: ndarray, yt: ndarray, x: ndarray, mask: ndarray, *, key: Rng) -> ndarray:
        ...


def expand_to(a, b):
    new_shape = a.shape + (1,) * (b.ndim - a.ndim)
    return a.reshape(new_shape)


def cosine_schedule(beta_start, beta_end, timesteps, s=0.008, **kwargs):
    x = jnp.linspace(0, timesteps, timesteps + 1)
    ft = jnp.cos(((x / timesteps) + s) / (1 + s) * jnp.pi * 0.5) ** 2
    alphas_cumprod = ft / ft[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    betas = jnp.clip(betas, 0.0001, 0.9999)
    betas = (betas - betas.min()) / (betas.max() - betas.min())
    return betas * (beta_end - beta_start) + beta_start


class GaussianDiffusion:
    betas: ndarray
    alphas: ndarray
    alpha_bars: ndarray
 

    def __init__(self, betas):
        self.betas = betas
        self.alphas = 1.0 - betas
        self.alpha_bars = jnp.cumprod(self.alphas)
      



    def bridge_sample(self, key: Rng, z0: ndarray, zT: ndarray, t: ndarray) -> Tuple[ndarray, ndarray]:
        '''
        Sample from a Gaussian bridge between z0 and zT at timestep t.
        
        Implementation Notes:
        1. This differs from the theoretical formulation in the paper by:
           - Using weighted interpolation (w = snr_T/snr_t) for stability
           - Applying truncation to prevent extreme values in the bridge
           - Using sqrt(alpha_t) scaling for z0 instead of direct weighting
        
        2. The training objective (L_θ) minimizes both forward and reverse KL divergences,
           which is implemented as MSE between predicted and actual noise. The weighting
           and truncation here help stabilize training by:
           - Preventing gradient explosion from extreme bridge samples
           - Balancing the influence of z0 and zT across different timesteps
           - Maintaining reasonable signal-to-noise ratios throughout diffusion
        
        3. The parameterization (w = snr_T/snr_t) ensures smooth interpolation between
           endpoints while respecting the diffusion process constraints.
        '''
        alpha_t = expand_to(self.alpha_bars[t], z0) 
        alpha_T = expand_to(self.alpha_bars[-1], z0)
    
        # SNR-based weighting for stable interpolation
        snr_t = alpha_t / (1 - alpha_t)
        snr_T = alpha_T / (1 - alpha_T)
        w = snr_T / snr_t  # Ensures proper time-dependent weighting
        
        # Modified mean computation with sqrt(alpha_t) scaling
        mu_t = w * zT + jnp.sqrt(alpha_t) * z0 * (1 - w)
        
        # Truncated variance computation prevents instability
        std_t = jnp.sqrt((1 - alpha_t) * (1 - w)) 
    
        noise = jax.random.normal(key, z0.shape)
        zt = mu_t + std_t * noise
        
        return zt, noise



    def dbpm_backward_step(self, key: Rng, noise: ndarray, yt: ndarray, t: ndarray, y_T: ndarray) -> ndarray:
        """
        Perform one backward diffusion step with SNR-weighted bridge correction.
        
        The mean combines:
        1. Standard DDPM denoising term (μ_θ)
        2. Bridge correction term that pulls toward terminal state y_T
        
        Args:
            key: JAX random key
            noise: Predicted noise (ε_θ) from the model
            yt: Noisy sample at time t
            t: Current timestep
            y_T: Terminal state to bridge toward
            
        Returns:
            Denoised sample y_{t-1} with bridge correction
        """
        # Hyperparameter controlling bridge strength (empirically tuned)
        bridge_scale = 0.1  
        
        # Get diffusion schedule parameters
        beta_t = expand_to(self.betas[t], yt)
        alpha_t = expand_to(self.alphas[t], yt)
        alpha_bar_t = expand_to(self.alpha_bars[t], yt)
        
        # 1. Standard DDPM denoising term (μ_θ)
        mean_ddpm = (1.0 / jnp.sqrt(alpha_t)) * (
            yt - (beta_t / jnp.sqrt(1.0 - alpha_bar_t)) * noise
        )

        # 2. Bridge correction term (C_t)
        # Compute SNR ratios for time-dependent weighting
        snr_t = alpha_t / (1 - alpha_t)          # SNR at current step
        snr_T = alpha_T / (1 - alpha_T)          # SNR at terminal step
        w = snr_T / snr_t                        # Relative weighting
        
        # Bridge strength combines:
        # - SNR ratio (w) for temporal consistency
        # - 1/√α_t scaling for proper magnitude
        bridge_strength = w / jnp.sqrt(alpha_t)  
        
        # Directional correction toward y_T (negative sign for attraction)
        correction = -bridge_scale * bridge_strength * y_T

        # Combined mean
        mu_tilde = mean_ddpm + correction

        # Add noise for non-final steps (t > 0)
        z = (t > 0) * jax.random.normal(key, shape=yt.shape, dtype=yt.dtype)
        yt_minus_one = mu_tilde + jnp.sqrt(beta_t) * z
        
        return yt_minus_one




    
    
        
    
        
    


    def sample(self, key, x, mask, *, model_fn: EpsModel, output_dim: int = 1):
        key, ykey = jax.random.split(key)
        yT = jax.random.normal(ykey, (len(x), output_dim))

        if mask is None:
            mask = jnp.zeros_like(x[:, 0])

        @jax.jit
        def scan_fn(y, inputs):
            t, key = inputs
            mkey, rkey = jax.random.split(key)
            noise_hat = model_fn(t, y, x, mask, key=mkey)
            y = self.ddpm_backward_step(key=rkey, noise=noise_hat, yt=y, t=t)
            return y, None

        ts = jnp.arange(len(self.betas))[::-1]
        keys = jax.random.split(key, len(ts))
        yf, yt = jax.lax.scan(scan_fn, yT, (ts, keys))
        return yt if yt is not None else yf

    def conditional_sample(
        self,
        key,
        x,
        mask,
        *,
        x_context,
        y_context,
        mask_context,
        model_fn: EpsModel,
        num_inner_steps: int = 5,
        method: str = "repaint",
    ):
        if mask is None:
            mask = jnp.zeros_like(x[:, 0])

        if mask_context is None:
            mask_context = jnp.zeros_like(x_context[:, 0])

        key, ykey = jax.random.split(key)
        x_augmented = jnp.concatenate([x_context, x], axis=0)
        mask_augmented = jnp.concatenate([mask_context, mask], axis=0)
        num_context = len(x_context)

        @jax.jit
        def repaint_inner(yt_target, inputs):
            t, key = inputs
            key, fkey, mkey, bkey = jax.random.split(key, 4)
            yt_context = self.bridge_sample(fkey, y_context,x_context,t)[0]

            y_augmented = jnp.concatenate([yt_context, yt_target], axis=0)
            noise_hat = model_fn(t, y_augmented, x_augmented, mask_augmented, key=mkey)
           
            y = self.dbpm_backward_step(key=bkey, noise=noise_hat, yt=y_augmented, t=t,y_T=x_augmented)
            y = y[num_context:]
            # one step forward: t-1 -> t
            z = jax.random.normal(key, shape=y.shape)
            beta__t_minus_1 = expand_to(self.betas[t - 1], y)
            y = jnp.sqrt(1.0 - beta__t_minus_1) * y + jnp.sqrt(beta__t_minus_1) * z
            return y, None

        @jax.jit
        def repaint_outer(y, inputs):
            t, key = inputs
            # loop
            key, ikey = jax.random.split(key)
            ts = jnp.ones((num_inner_steps,), dtype=jnp.int32) * t
            keys = jax.random.split(ikey, num_inner_steps)
            y, _ = jax.lax.scan(repaint_inner, y, (ts, keys))

            # step backward: t -> t-1
            key, fkey, mkey, bkey = jax.random.split(key, 4)
          
          
            yt_context = self.bridge_sample(fkey, y_context,x_context,t)[0]
       
            y_augmented = jnp.concatenate([yt_context, y], axis=0)
            noise_hat = model_fn(t, y_augmented, x_augmented, mask_augmented, key=mkey)
         
            y = self.dbpm_backward_step(key=bkey, noise=noise_hat, yt=y_augmented, t=t,y_T=x_augmented)
            y = y[num_context:]
            return y, None

        ts = jnp.arange(len(self.betas))[::-1]
        keys = jax.random.split(key, len(ts))
        
        yT_target = x+jax.random.normal(ykey, (len(x), y_context.shape[-1]))

        y, _ = jax.lax.scan(repaint_outer, yT_target, (ts[:-1], keys[:-1]))
        return y





def loss(
    process: GaussianDiffusion,
    network: EpsModel,
    batch: Batch,
    key: Rng,
    *,
    num_timesteps: int,
    loss_type: str = "l1",
):
    if loss_type == "l1":

        def loss_metric(a, b):
            return jnp.abs(a - b)

    elif loss_type == "l2":

        def loss_metric(a, b):
            return (a - b) ** 2

    else:
        raise ValueError(f"Unknown loss type {loss_type}")


    def loss_fn(key, t, y, x, mask):
      
        yt, noise = process.bridge_sample(key, y,x, t)
    
        noise_hat = network(t, yt, x, mask, key=key)

        loss_value = jnp.sum(loss_metric(noise, noise_hat), axis=1)  # [N,]
        loss_value = loss_value * (1.0 - mask)
        num_points = len(mask) - jnp.count_nonzero(mask)
        return jnp.sum(loss_value) / num_points

    batch_size = len(batch.x_target)

    key, tkey = jax.random.split(key)
    # Low-discrepancy sampling over t to reduce variance
    t = jax.random.uniform(tkey, (batch_size,), minval=0, maxval=num_timesteps / batch_size)
    t = t + (num_timesteps / batch_size) * jnp.arange(batch_size)
    t = t.astype(jnp.int32)

    keys = jax.random.split(key, batch_size)

    if batch.mask_target is None:
        # consider all points
        mask_target = jnp.zeros_like(batch.x_target[..., 0])
    else:
        mask_target = batch.mask_target

    losses = jax.vmap(loss_fn)(keys, t, batch.y_target, batch.x_target, mask_target)

    
    return jnp.mean(losses)
