#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
DPM-Solver++ 2S sampler for GenSynth CFD diffusion.
"""

import dataclasses
from typing import Any, Dict, Optional

import chex
import flax.nnx as nnx
import jax
import jax.numpy as jnp
from gen_da import casting
from gen_da import samplers_utils as utils


@chex.dataclass
class SamplerConfig:
    rho: float = 7.0
    num_steps: int = 20
    sigma_min: float = 0.002
    sigma_max: float = 80.0
    s_churn: float = 0.0  # disabled by default
    s_min: float = 0.0
    s_max: float = 0.0
    s_noise: float = 1.0
    # Classifier-free guidance scale. 1.0 = conditional only, 0.0 = unconditional, >1 stronger guidance
    guidance_scale: float = 1.0


class Sampler:
    def __init__(
        self,
        denoiser: nnx.Module,
        sampler_config: SamplerConfig,
        rngs: nnx.Rngs,
    ):
        self._denoiser = denoiser
        self.cfg = sampler_config
        self.rngs = rngs
        # self.sigma_data = 1.0
        # self.sigma_data = 0.20
        self.sigma_data = 0.275695

    # EDM preconditioning helpers
    def _c_out(self, sigma):
        return self.sigma_data * sigma / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)

    def _c_skip(self, sigma):
        return self.sigma_data ** 2 / (self.sigma_data ** 2 + sigma ** 2)

    def _c_in(self, sigma):
        return 1.0 / jnp.sqrt(self.sigma_data ** 2 + sigma ** 2)

    def _sigma_schedule(self, num_steps: int, rho: float, sigma_max: float, sigma_min: float):
        # Use the noise_schedule function from samplers_utils
        sigmas = utils.noise_schedule(
            max_noise_level=sigma_max,
            min_noise_level=sigma_min,
            num_noise_levels=num_steps,
            rho=rho
        )
        # Return without the appended zero (sampler handles final step separately)
        return sigmas[:-1]

    def _preconditioned_denoiser(self, x: jnp.ndarray, sigma: jnp.ndarray, forcings: Dict[str, Any]):
        c_in = self._c_in(sigma)
        c_out = self._c_out(sigma)
        c_skip = self._c_skip(sigma)
        inp = c_in[:, None, None] * x
        eps_hat = self._denoiser(inp, sigma, forcings)
        D = c_skip[:, None, None] * x + c_out[:, None, None] * eps_hat
        return D
    
    # sampler.py
    def _denoise_step(self, x, sigma, forcings):
        # project to boundary values BEFORE denoising (keeps inputs in-distribution)
        if hasattr(self._denoiser, "parent"):
            model = self._denoiser.parent
        else:
            model = None
        if model is not None and hasattr(model, "project_boundaries"):
            x = model.project_boundaries(x, forcings)

        c_in  = self._c_in(sigma); c_out = self._c_out(sigma); c_skip = self._c_skip(sigma)
        inp   = c_in[:, None, None] * x

        # Classifier-free guidance: combine conditional and unconditional predictions
        s = getattr(self.cfg, 'guidance_scale', 1.0)
        if s == 1.0:
            # Conditional path uses clamped observations
            x_cond = self._apply_observation_clamp(x, forcings)
            eps = self._denoiser(c_in[:, None, None] * x_cond, sigma, forcings)
        else:
            # Build unconditional forcings (no observations)
            forcings_uncond = dict(forcings)
            if 'obs_mask' in forcings_uncond:
                # zero mask
                obs_mask = jnp.asarray(forcings_uncond['obs_mask'], jnp.float32)
                forcings_uncond['obs_mask'] = jnp.zeros_like(obs_mask)
            if 'U_field_guiding' in forcings_uncond:
                u = jnp.asarray(forcings_uncond['U_field_guiding'], jnp.float32)
                forcings_uncond['U_field_guiding'] = jnp.zeros_like(u)

            # Unconditional path sees NO observation clamp
            eps_uncond = self._denoiser(inp, sigma, forcings_uncond)
            # Conditional path: clamp inputs to observed values
            x_cond = self._apply_observation_clamp(x, forcings)
            eps_cond   = self._denoiser(c_in[:, None, None] * x_cond, sigma, forcings)
            eps = eps_uncond + s * (eps_cond - eps_uncond)

        D     = c_skip[:, None, None] * x + c_out[:, None, None] * eps
        return D
    
    
    def _apply_observation_clamp(self, x: jnp.ndarray, forcings: Dict[str, Any]) -> jnp.ndarray:
        """Clamp x at observed nodes using obs_mask and U_field_guiding if present."""
        obs_mask = forcings.get('obs_mask', None)
        obs_vals = forcings.get('U_field_guiding', None)
        if obs_mask is None or obs_vals is None:
            return x
        m = jnp.asarray(obs_mask, jnp.float32)[..., None]       # (B,N,1)
        v = jnp.asarray(obs_vals, jnp.float32)                  # (B,N,C)
        return x * (1.0 - m) + v * m


    def __call__(
        self,
        noisy_inputs: jnp.ndarray,
        forcings: Dict[str, Any],
        rngs: Optional[nnx.Rngs] = None,
    ) -> jnp.ndarray:
        if rngs is not None:
            self.rngs = rngs

        cfg = self.cfg
        B = noisy_inputs.shape[0]
        sigmas = self._sigma_schedule(cfg.num_steps, cfg.rho, cfg.sigma_max, cfg.sigma_min)

        x = noisy_inputs

        for i in range(cfg.num_steps - 1):
            sigma_cur = sigmas[i]
            sigma_next = sigmas[i + 1]

            # --- Stochastic Churn ---
            gamma = 0.0
            if cfg.s_churn > 0.0:
                 # Check range (using jnp for safety if JIT)
                 in_range = (sigma_cur >= cfg.s_min) & (sigma_cur <= cfg.s_max)
                 # gamma = min(s_churn/N, sqrt(2)-1) if in_range else 0
                 steps = cfg.num_steps
                 g_val = min(cfg.s_churn / steps, 2**0.5 - 1)
                 gamma = jnp.where(in_range, g_val, 0.0)
            
            sigma_hat = sigma_cur * (1 + gamma)
            
            # Inject noise (always draw key to keep shape consistent, mask with gamma)
            # If gamma=0, sigma_hat=sigma_cur, term is 0.
            # Using nnx.Rngs, we need a key. 
            key = self.rngs.noise()
            eps = jax.random.normal(key, shape=x.shape) * cfg.s_noise
            # Add noise: x <- x + eps * sqrt(tensor_sigma_hat^2 - sigma_cur^2)
            noise_scale = jnp.sqrt(sigma_hat**2 - sigma_cur**2)
            x = x + eps * noise_scale

            # Prepare sigma vectors for denoiser (bcast)
            sigma_i = sigma_hat * jnp.ones((B,), jnp.float32)
            sigma_j = sigma_next * jnp.ones((B,), jnp.float32)

            # Heun (2S) step
            # Denoise at sigma_hat (which is our new 'starting point' for this step)
            D_i = self._denoise_step(x, sigma_i, forcings)

            # Euler step to midpoint
            h = jnp.log(sigma_j) - jnp.log(sigma_i)
            # original denom: jnp.log(self.sigma_data + sigma_i_OLD) - jnp.log(self.sigma_data)
            # We must use sigma_hat in the update formula to be consistent.
            denom = jnp.log(self.sigma_data + sigma_i) - jnp.log(self.sigma_data)
            
            x_mid = x + h[:, None, None] * (D_i - x) / denom

            D_mid = self._denoise_step(x_mid, sigma_j, forcings)

            # Update x using Heun's method
            x = x + (D_i + D_mid - 2 * x) * 0.5

            # --- (c) clamp after update -----------------------------------  ### NEW
            x = self._apply_observation_clamp(x, forcings)                  ### NEW

        # Final step at sigma_min
        sigma_f = sigmas[-1] * jnp.ones((B,), jnp.float32)
        D_f = self._denoise_step(x, sigma_f, forcings)
        x = D_f
        # Final safety clamp
        x = self._apply_observation_clamp(x, forcings)                      # NEW (post-final)
        return x

    # def sample_with_frames(self, noisy_inputs: jnp.ndarray, forcings: Dict[str, Any]) -> jnp.ndarray:
    #     """
    #     Returns all frames: (T, B, N, C) where T = num_steps (+1 if you want include the initial noise).
    #     Uses nnx.fori_loop under the hood for speed.
    #     """
    #     cfg = self.cfg
    #     B, N, C = noisy_inputs.shape
    #     sigmas = self._sigma_schedule(cfg.num_steps, cfg.rho, cfg.sigma_max, cfg.sigma_min)  # (K,), K=num_steps-1+1?
    #     K = sigmas.shape[0]

    #     def init_state():
    #         x0 = self._apply_observation_clamp(noisy_inputs, forcings)  # (B,N,C)
    #         # Pre-allocate (K,B,N,C); we’ll fill step-by-step
    #         frames = jnp.zeros((K, B, N, C), dtype=x0.dtype).at[0].set(x0)
    #         return (0, x0, frames)

    #     def body_fun(i, carry):
    #         # i: 0..K-2 perform 2S updates; last index already filled
    #         idx, x, frames = carry
    #         sigma_i = sigmas[i]   * jnp.ones((B,), jnp.float32)
    #         sigma_j = sigmas[i+1] * jnp.ones((B,), jnp.float32)

    #         # clamp before denoise
    #         x = self._apply_observation_clamp(x, forcings)
    #         D_i = self._denoise_step(x, sigma_i, forcings)

    #         h = jnp.log(sigma_j) - jnp.log(sigma_i)
    #         x_mid = x + h[:, None, None] * (D_i - x) / (jnp.log(self.sigma_data + sigma_i) - jnp.log(self.sigma_data))
    #         x_mid = self._apply_observation_clamp(x_mid, forcings)
    #         D_mid = self._denoise_step(x_mid, sigma_j, forcings)

    #         x_next = x + 0.5 * (D_i + D_mid - 2 * x)
    #         x_next = self._apply_observation_clamp(x_next, forcings)

    #         frames = frames.at[i+1].set(x_next)
    #         return (i+1, x_next, frames)

    #     # Run loop
    #     _, _, out_frames = nnx.fori_loop(0, K-1, body_fun, init_state())
    #     # Optionally do a final refine; here sigmas already includes the last level, so K frames recorded.
    #     return out_frames  # (K,B,N,C)
