#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
sampler_utils.py

Utilities for diffusion sampling.

Provides both:
  - xarray‐based noise/noise‐churn (for top‐level functions), and
  - raw‐jax‐array versions (for use inside lax.fori_loop).
"""

import jax
import jax.numpy as jnp
import xarray as xr
from typing import Any, Tuple

# -----------------------------------------------------------------------------
# Xarray‐based version (unmodified from before): for direct Dataset calls
# -----------------------------------------------------------------------------

def gaussian_white_noise_like(
    template: xr.Dataset,
    key: jnp.ndarray
) -> Tuple[xr.Dataset, jnp.ndarray]:
    """
    Sample i.i.d. standard normal noise matching `template`
    (keeps dims & coords), return (noise_ds, next_key).
    """
    subkey, next_key = jax.random.split(key)
    arr = template.to_array()  # dims e.g. ('variable','batch','node',...)
    noise = jax.random.normal(subkey, arr.shape, dtype=arr.dtype)
    da = xr.DataArray(data=noise, dims=arr.dims, coords=arr.coords)
    ds = da.to_dataset(dim="variable")
    return ds, next_key


def apply_stochastic_churn(
    x: xr.Dataset,
    noise_level: jnp.ndarray,
    stochastic_churn_rate: jnp.ndarray,
    noise_level_inflation_factor: jnp.ndarray,
    key: jnp.ndarray,
) -> Tuple[xr.Dataset, jnp.ndarray, jnp.ndarray]:
    """
    Increase sigma -> sigma*(1+rate), add extra noise to x (xarray version).
    Returns (x_noisier, new_sigma, next_key).
    """
    new_sigma = noise_level * (1.0 + stochastic_churn_rate)
    diff = new_sigma**2 - noise_level**2
    diff = jnp.maximum(diff, 0.0)
    extra_std = jnp.sqrt(diff) * noise_level_inflation_factor

    noise_ds, key = gaussian_white_noise_like(x, key)
    x_noisier = x + noise_ds * extra_std
    
    return x_noisier, new_sigma, key


# -----------------------------------------------------------------------------
# Raw‐jax‐array versions: for use entirely inside jax.lax.fori_loop
# -----------------------------------------------------------------------------

def gaussian_white_noise_like_arr(
    x_arr: jnp.ndarray,
    key: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Sample i.i.d. standard normal noise matching x_arr shape.
    Returns (noise_arr, next_key).
    """
    subkey, next_key = jax.random.split(key)
    noise = jax.random.normal(subkey, x_arr.shape, dtype=x_arr.dtype)
    return noise, next_key


def apply_stochastic_churn_arr(
    x_arr: jnp.ndarray,
    sigma: jnp.ndarray,
    churn_rate: jnp.ndarray,
    inflation: jnp.ndarray,
    key: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    Increase sigma -> sigma*(1+rate), add extra Gaussian noise to x_arr.
    Returns (x_noisier_arr, new_sigma, next_key).
    """
    new_sigma = sigma * (1.0 + churn_rate)
    diff = new_sigma**2 - sigma**2
    diff = jnp.maximum(diff, 0.0)
    extra_std = jnp.sqrt(diff) * inflation

    noise, key = gaussian_white_noise_like_arr(x_arr, key)
    
    # Simply broadcast extra_std - JAX handles broadcasting automatically
    x_noisier = x_arr + noise * extra_std
    
    return x_noisier, new_sigma, key


def rho_inverse_cdf(
    min_value: float,
    max_value: float,
    rho: float,
    cdf: Any
) -> Any:
    """Quantiles of the [rho,1] Beta–scaled to [min_value, max_value]."""
    return (
        min_value**(1 / rho)
        + cdf * (max_value**(1 / rho) - min_value**(1 / rho))
    ) ** rho


def noise_schedule(
    max_noise_level: float = 80.0,
    min_noise_level: float = 0.002,
    num_noise_levels: int = 30,
    rho: float = 7.0,
) -> jnp.ndarray:
    """Descending schedule of noise levels ending in zero."""
    levels = rho_inverse_cdf(
        min_value=min_noise_level,
        max_value=max_noise_level,
        rho=rho,
        cdf=jnp.linspace(1.0, 0.0, num_noise_levels),
    )
    return jnp.concatenate([levels, jnp.array([0.0], dtype=levels.dtype)])



# --------------------------------------------------------------------- #
# EDM / Karras noise schedule                                           #
# --------------------------------------------------------------------- #
def edm_noise_schedule(
    max_noise_level : float = 80.0,   # sigma_max in the paper
    min_noise_level : float = 0.002,  # sigma_min in the paper
    num_noise_levels: int   = 30,     # N      solver steps
    rho            : float = 7.0,     # rho    shape parameter
    dtype          = jnp.float32,     # jax dtype
) -> jnp.ndarray:
    """
    Returns a length-(num_noise_levels+1) array:
        [sigma_0, sigma_1, ..., sigma_{N-1}, 0.0]
    where sigma_0 = sigma_max, sigma_{N-1} = sigma_min, and 0.0 is appended for
    the final ‘t = 0’ evaluation used by the EDM sampler.
    """
    # 1) linear “ramp” from 0 → 1 (inclusive)
    ramp = jnp.linspace(0.0, 1.0, num_noise_levels, dtype=dtype)

    # 2) interpolate in the 1/rho-power domain
    inv_min = min_noise_level ** (1.0 / rho)
    inv_max = max_noise_level ** (1.0 / rho)
    sigmas  = (inv_max + ramp * (inv_min - inv_max)) ** rho   # descending

    # 3) append the terminal zero
    sigmas  = jnp.concatenate([sigmas, jnp.array([0.0], dtype=dtype)])
    return sigmas



def stochastic_churn_rate_schedule(
    noise_levels: jnp.ndarray,
    stochastic_churn_rate: float = 0.0,
    churn_min_noise_level: float = 0.05,
    churn_max_noise_level: float = 50.0,
) -> jnp.ndarray:
    """Compute per-step churn rates ∈ [0,√2−1]."""
    N = noise_levels.shape[0] - 1
    per_step = jnp.minimum(stochastic_churn_rate / N, jnp.sqrt(2) - 1)
    mask = (noise_levels[:-1] >= churn_min_noise_level) & (
        noise_levels[:-1] <= churn_max_noise_level)
    return per_step * mask.astype(jnp.float32)


def tree_where(
    cond: jnp.ndarray,
    xs: Any,
    ys: Any
    ) -> Any:
  """Like jnp.where but works with trees for xs and ys (but not for cond)."""
  return jax.tree_util.tree_map(lambda x, y: jnp.where(cond, x, y), xs, ys)
