#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
gen_da.py

"""

from typing import Any, Dict, Optional, Tuple

import chex
import flax.nnx as nnx
import jax
import jax.numpy as jnp
import numpy as np

from gen_da.denoiser import (
    Denoiser,
    DenoiserArchitectureConfig,
    NoiseEncoderConfig,
    AngleEncoderConfig,
)
from gen_da.dpm_solver_plus_plus_2s import Sampler
import gen_da.samplers_utils as utils



# @chex.dataclass(frozen=True, eq=True)
# class SamplerConfig:
#     """
#     Configures the sampler used to draw samples.
#     max_noise_level: The highest noise level used at the start of the
#     sequence of reverse diffusion steps.
#     min_noise_level: The lowest noise level used at the end of the sequence of
#     reverse diffusion steps.
#     num_noise_levels: Determines the number of noise levels used and hence the
#     number of reverse diffusion steps performed.
#     rho: Parameter affecting the spacing of noise steps. Higher values will
#     concentrate noise steps more around zero.
#     stochastic_churn_rate: S_churn from the paper. This controls the rate
#     at which noise is re-injected/'churned' during the sampling algorithm.
#     If this is set to zero then we are performing deterministic sampling
#     as described in Algorithm 1.
#     churn_max_noise_level: Maximum noise level at which stochastic churn
#     occurs. S_min from the paper. Only used if stochastic_churn_rate > 0.
#     churn_min_noise_level: Minimum noise level at which stochastic churn
#     occurs. S_min from the paper. Only used if stochastic_churn_rate > 0.
#     noise_level_inflation_factor: This can be used to set the actual amount of
#     noise injected higher than what the denoiser is told has been added.
#     The motivation is to compensate for a tendency of L2-trained denoisers
#     to remove slightly too much noise / blur too much. S_noise from the
#     paper. Only used if stochastic_churn_rate > 0.
#     guidance_scale: Classifier-free guidance scale used at sampling time.
#     """
#     max_noise_level: float = 80.0
#     min_noise_level: float = 0.03
#     num_noise_levels: int = 20
#     rho: float = 7.
#     # Stochastic sampler settings.
#     # stochastic_churn_rate: float = 40.0
#     stochastic_churn_rate: float = 2.5
#     churn_min_noise_level: float = 0.75
#     churn_max_noise_level: float = float('inf')
#     noise_level_inflation_factor: float = 1.05
@chex.dataclass
class SamplerConfig:
    rho: float = 7.0
    num_steps: int = 20
    sigma_min: float = 0.03 * (0.275695 / 0.5)
    sigma_max: float = 80.0 * (0.275695 / 0.5)
    s_churn: float = 2.5  # disabled by default
    s_min: float = 0.75
    s_max: float = float('inf')
    s_noise: float = 1.05
    # Classifier-free guidance scale at sampling. 0=uncond, 1=conditional only (default), >1 to push towards cond
    guidance_scale: float = 1.0

# Add to NoiseConfig (or create a PhysicsConfig if you prefer)
@chex.dataclass(frozen=True, eq=True)
class NoiseConfig:
    training_noise_level_rho: float = 7.0
    training_max_noise_level: float = 88.0 * (0.275695 / 0.5)
    training_min_noise_level: float = 0.02 * (0.275695 / 0.5)
    # Probability to drop observations (set obs_mask and U_field_guiding to zero) for CFG training
    cfg_dropout_prob: float = 0.10


@chex.dataclass(frozen=True, eq=True)
class LossConfig:
    use_graph_gmse: bool = True       # turn on/off gradient-weighted MSE
    base_channel_reduction: str = "mean"  # "mean" (your current)
    # Graph-GMSE params (analogs of sigma, gamma, Co in the paper)
    blur_iters: int = 4               # ≈ stronger blur -> larger number
    blur_alpha: float = 0.7           # 0<alpha<1; 1.0 ≈ full neighbor mean, 0.0 ≈ no blur
    gamma: float = 1.0                # contrast on blurred weights
    offset_co: float = 0.2            # non-zero lower bound, Co in [0,1]
    # Gradient estimation
    use_speed_magnitude: bool = True  # weight from ||U|| instead of per-component
    normalize_by_edge_length: bool = True  # divide diffs by edge length
    # Boundary handling
    floor_on_boundaries: bool = True  # clamp boundary nodes to baseline Co



@chex.dataclass(frozen=True, eq=True)
class CheckPoint:
    description: str
    license: str
    params: dict[str, Any]
    denoiser_architecture_config: DenoiserArchitectureConfig
    sampler_config: SamplerConfig
    noise_config: NoiseConfig
    noise_encoder_config: NoiseEncoderConfig
    angle_encoder_config: AngleEncoderConfig


def mse_loss(
    pred: jnp.ndarray,
    tgt: jnp.ndarray
) -> jnp.ndarray:
    """
    Compute MSE loss between predicted and target arrays.
    
    Args:
        pred: Predicted values, shape (..., channels)
        tgt: Target values, shape (..., channels)
        
    Returns:
        MSE loss averaged over spatial dimensions and channels, shape (batch,)
    """
    se = (pred - tgt) ** 2
    
    # Average over all dimensions except batch
    spatial_dims = list(range(1, se.ndim))
    per_sample = se.mean(axis=tuple(spatial_dims))
    
    return per_sample




class GenSynth(nnx.Module):
    """GenSynth diffusion model with eager-initialized multigraph denoiser."""

    def __init__(
        self,
        denoiser_architecture_config: DenoiserArchitectureConfig,
        sampler_config: SamplerConfig,
        noise_config: NoiseConfig,
        mesh,
        rngs: nnx.Rngs,
        *,
        # New required args for eager init:
        example_graph_structures: Dict[str, jnp.ndarray],
        target_channels: int,
        guiding_channels: int,
        # Optional encoders / extras:
        noise_encoder_config: Optional[NoiseEncoderConfig] = None,
        angle_encoder_config: Optional[AngleEncoderConfig] = None,
        precomputed_adj_mat: Optional[Any] = None,
        loss_config: Optional['LossConfig'] = None,
    ):
        self.rngs = rngs
        self._noise_config = noise_config
        self._loss_cfg = loss_config or LossConfig()

        # Build denoiser (all nets init here)
        self.denoiser = Denoiser(
            noise_encoder_config=noise_encoder_config,
            angle_encoder_config=angle_encoder_config,
            denoiser_architecture_config=denoiser_architecture_config,
            rngs=rngs,
            mesh=mesh,
            precomputed_adj_mat=precomputed_adj_mat,
            example_graph_structures=example_graph_structures,
            target_channels=target_channels,
            guiding_channels=guiding_channels,
        )

        # #TODO: Sampler (DPM-Solver++ 2S) – pass config object directly
        # self._sampler = Sampler(
        #     denoiser=self.denoiser,
        #     max_noise_level=sampler_config.max_noise_level,
        #     min_noise_level=sampler_config.min_noise_level,
        #     num_noise_levels=sampler_config.num_noise_levels,
        #     rho=sampler_config.rho,
        #     stochastic_churn_rate=sampler_config.stochastic_churn_rate,
        #     churn_min_noise_level=sampler_config.churn_min_noise_level,
        #     churn_max_noise_level=sampler_config.churn_max_noise_level,
        #     noise_level_inflation_factor=sampler_config.noise_level_inflation_factor,
        # )
        self._sampler = Sampler(
            denoiser=self.denoiser,
            sampler_config=sampler_config,
            rngs=rngs,
        )

        # Constants for EDM preconditioning
        self.sigma_data = self._sampler.sigma_data
        

    # preconditioning helpers
    def _c_in(self, sigma: jnp.ndarray) -> jnp.ndarray:
        return (sigma**2 + self.sigma_data**2) ** -0.5

    def _c_out(self, sigma: jnp.ndarray) -> jnp.ndarray:
        return (sigma*self.sigma_data) / ((sigma**2 + self.sigma_data**2) ** 0.5)

    def _c_skip(self, sigma: jnp.ndarray) -> jnp.ndarray:
        return (self.sigma_data**2) / (sigma**2 + self.sigma_data**2)

    def _loss_weighting(self, sigma: jnp.ndarray) -> jnp.ndarray:
        return (sigma**2 + self.sigma_data**2) / ((sigma*self.sigma_data)**2)

    def _get_fluid_mask(self, forcings: Dict[str, jnp.ndarray]) -> jnp.ndarray:
        """
        Extract node types and create a mask for fluid nodes (type 1).
        Returns mask with shape (N_nodes,) where True indicates fluid nodes.
        """
        gstructs = forcings["graph_structures"]
        if isinstance(gstructs, dict):
            node_types = jnp.asarray(gstructs["node_types"], jnp.int32)
        elif isinstance(gstructs, (list, tuple)):
            # All graphs in batch should have same structure when z-bucketed
            node_types = jnp.asarray(gstructs[0]["node_types"], jnp.int32)
        else:
            raise ValueError(f"Unexpected graph_structures type: {type(gstructs)}")
        
        # Mask: True for fluid nodes (type 1), False for boundary nodes (types 2, 3)
        fluid_mask = (node_types == 1)
        return fluid_mask

    def _apply_boundary_constraints(
        self,
        noisy_inputs: jnp.ndarray,
        clean_inputs: jnp.ndarray,
        forcings: Dict[str, jnp.ndarray],
    ) -> jnp.ndarray:
        """
        Apply boundary constraints: keep exact values for boundary/inlet/outlet nodes,
        only allow noise on fluid nodes.
        """
        fluid_mask = self._get_fluid_mask(forcings)  # (N_nodes,)
        
        # Expand mask for broadcasting: (N_nodes,) -> (1, N_nodes, 1)
        mask_expanded = fluid_mask[None, :, None]
        
        # Apply constraints: use noisy values for fluid nodes, clean values for boundary nodes
        constrained = jnp.where(mask_expanded, noisy_inputs, clean_inputs)
        return constrained


    def _get_masks_all(self, forcings: Dict[str, jnp.ndarray]):
        g = forcings["graph_structures"]
        g = g[0] if isinstance(g, (list, tuple)) else g
        ntypes = jnp.asarray(g["node_types"], jnp.int32)  # (N,)
        fluid = (ntypes == 1)
        wall  = (ntypes == 2)
        bnd   = (ntypes == 3)
        return g, fluid, wall, bnd

    # gen_da.py
    def _get_boundary_mask(self, forcings):
        g = forcings["graph_structures"]
        g = g[0] if isinstance(g, (list, tuple)) else g
        ntypes = jnp.asarray(g["node_types"], jnp.int32)
        # you use: fluid==1, wall==2, boundary==3
        # here we treat non-fluid as boundary:
        return (ntypes != 1)  # shape (N,)

    def project_boundaries(self, x, forcings):
        # x: (B,N,C)
        bmask = self._get_boundary_mask(forcings)[None, :, None]  # (1,N,1)
        # priority for values:
        if "boundary_values" in forcings:
            bvals = jnp.asarray(forcings["boundary_values"], x.dtype)  # (B,N,C)
        else:
            # fall back to zeros to avoid NaNs
            bvals = jnp.zeros_like(x)
        return jnp.where(bmask, bvals, x)


    def _preconditioned_denoiser(
        self,
        noisy_inputs: jnp.ndarray,
        noise_levels: jnp.ndarray,
        forcings: Dict[str, jnp.ndarray],
    ) -> jnp.ndarray:
        """Apply preconditioning to the denoiser input and output."""
        c_in = self._c_in(noise_levels)
        # Expand dimensions for broadcasting (batch,) -> (batch, 1, ..., 1)
        c_in_expanded = c_in.reshape((-1,) + (1,) * (noisy_inputs.ndim - 1))
        
        y = noisy_inputs * c_in_expanded
        raw = self.denoiser(noisy_inputs=y, noise_levels=noise_levels, forcings=forcings)
        
        c_out = self._c_out(noise_levels)
        c_skip = self._c_skip(noise_levels)
        c_out_expanded = c_out.reshape((-1,) + (1,) * (noisy_inputs.ndim - 1))
        c_skip_expanded = c_skip.reshape((-1,) + (1,) * (noisy_inputs.ndim - 1))
        
        return raw * c_out_expanded + noisy_inputs * c_skip_expanded



    def loss(
        self,
        clean_inputs: jnp.ndarray,
        forcings: Dict[str, jnp.ndarray],
    ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
        """Compute the diffusion loss for training with per-channel balancing."""
        if self._noise_config is None:
            raise ValueError("Missing NoiseConfig for training.")

        key = self.rngs.noise()
        B, N, C = clean_inputs.shape

        # --- Sample noise levels (EDM rho schedule) ---
        key, key_noise, key_drop = jax.random.split(key, 3)
        uniform_samples = jax.random.uniform(key, (B,))
        sigma = utils.rho_inverse_cdf(
            min_value=self._noise_config.training_min_noise_level,
            max_value=self._noise_config.training_max_noise_level,
            rho=self._noise_config.training_noise_level_rho,
            cdf=uniform_samples,
        )

        # --- Noisify clean inputs ---
        noise = jax.random.normal(key_noise, clean_inputs.shape)
        sigma_expanded = sigma.reshape((-1,) + (1,) * (clean_inputs.ndim - 1))
        noisy_unconstrained = clean_inputs + (noise * sigma_expanded)

        # Boundaries: re-impose clean values on non-fluid nodes if your helper does that
        noisy = self._apply_boundary_constraints(
            noisy_inputs=noisy_unconstrained,
            clean_inputs=clean_inputs,
            forcings=forcings,
        )

        obs_mask = forcings.get('obs_mask', None)  # (B,N)
        obs_mask = jnp.zeros((B, N), jnp.float32) if obs_mask is None else jnp.asarray(obs_mask, jnp.float32)
        obs_mask3 = obs_mask[..., None]            # (B,N,1)

        # HARD CLAMP at observed nodes (inputs contain true observations)
        if 'U_field_guiding' in forcings:
            obs_vals = jnp.asarray(forcings['U_field_guiding'], jnp.float32)  # (B,N,C), zeros where unobs
            noisy = noisy * (1.0 - obs_mask3) + obs_vals * obs_mask3

        # -----------------------------
        # Classifier-free guidance dropout (training)
        # with probability p, drop observations for that sample
        # -----------------------------
        p = jnp.asarray(self._noise_config.cfg_dropout_prob, jnp.float32)
        # bernoulli per-sample: 1 means DROP (unconditional)
        drop_flags = jax.random.bernoulli(key_drop, p=p, shape=(B,))
        drop_flags = drop_flags.astype(jnp.float32)  # (B,)
        drop_flags_bn1 = drop_flags[:, None, None]   # (B,1,1)
        drop_flags_bn = drop_flags[:, None]          # (B,1)

        # Build modified forcings with obs removed for dropped samples
        if 'U_field_guiding' in forcings:
            u_guid = jnp.asarray(forcings['U_field_guiding'], jnp.float32)
            u_guid = u_guid * (1.0 - drop_flags_bn1)  # zero-out where dropped
        else:
            u_guid = None

        obs_mask_eff = obs_mask * (1.0 - drop_flags_bn)  # zero-out where dropped

        # Re-clamp noisy using possibly reduced obs (no effect when dropped)
        if u_guid is not None:
            noisy = noisy * (1.0 - obs_mask_eff[..., None]) + u_guid * obs_mask_eff[..., None]

        # Assemble a shallow copy of forcings with updated fields (pure JAX)
        # Keep other entries (like graph_structures) unchanged
        forcings_mod = dict(forcings)
        forcings_mod['obs_mask'] = obs_mask_eff
        if u_guid is not None:
            forcings_mod['U_field_guiding'] = u_guid

        # --- Predict denoised target (preconditioned head handles sigma) ---
        pred = self._preconditioned_denoiser(
            noisy_inputs=noisy,
            noise_levels=sigma,
            forcings=forcings_mod,
        )

        # =======================
        # Channel-balanced MSE
        # =======================
        # err: (B, N, C)
        err = pred - clean_inputs

        # ---------------------------
        # Graph-GMSE weighting
        # ---------------------------
        # if self._loss_cfg.use_graph_gmse:
        #     # weights per node (N,), same for all samples in batch (shared graph)
        #     W = self._gmse_weights_for_batch(clean_inputs, forcings)  # (N,)
        #     W_exp = W[None, :, None]  # (1,N,1) -> broadcast
        #     se = (err ** 2) * W_exp   # weighted squared error
        # else:
        se = err ** 2

        # Mean over spatial dims first (keep batch and channel)
        spatial_axes = tuple(range(1, se.ndim - 1))
        se_mean_spatial = jnp.mean(se, axis=spatial_axes)   # (B, C)

        # Equal-weight channels
        mse_per_sample = jnp.mean(se_mean_spatial, axis=-1) # (B,)


        # Optional EDM weighting by noise level
        weighted = mse_per_sample * self._loss_weighting(sigma)  # (B,)
        total_loss = jnp.mean(weighted)

        # For logging (no gradient significance)
        # Per-channel global means (scalar each)
        per_channel_mse = jnp.mean(se_mean_spatial, axis=0)  # (C,)

        return total_loss, {
            "mse_per_sample": mse_per_sample,  # (B,)
            "mse_per_channel": per_channel_mse,  # (C,)
            "sigma": sigma,                     # (B,)
            "total_loss": total_loss,

        }


    # inference
    def __call__(
        self,
        noisy_inputs: jnp.ndarray,
        noise_levels: jnp.ndarray,
        forcings: Dict[str, jnp.ndarray],
        **kwargs
    ) -> jnp.ndarray:
        return self.denoiser(
            noisy_inputs=noisy_inputs,
            noise_levels=noise_levels,
            forcings=forcings,
            **kwargs
        )
    
    def full_sampling(
        self,
        noisy_inputs: jnp.ndarray,
        forcings: Dict[str, jnp.ndarray],
    ) -> jnp.ndarray:
        """Full sampling procedure, returning JAX array."""
        pred = self._sampler(noisy_inputs=noisy_inputs, forcings=forcings, rngs=self.rngs)
        # if "case_center" in forcings:
        #         pred = pred + forcings["case_center"]
        # else:
        #     raise ValueError("forcings must contain 'case_center' for de-normalization.")
        return pred


    def _get_oo_edges(self, forcings):
        g = forcings["graph_structures"]
        g = g[0] if isinstance(g, (list, tuple)) else g
        send = jnp.asarray(g["o2o_senders"], jnp.int32)
        recv = jnp.asarray(g["o2o_receivers"], jnp.int32)
        feats = jnp.asarray(g.get("o2o_features", None), jnp.float32)
        # Expect features include [dx, dy, d] after your normalization pipeline
        # If missing, fall back to unit length to be safe.
        if feats is not None and feats.shape[-1] >= 3:
            d_ij = jnp.abs(feats[:, 2])  # edge length magnitude (already normalized per your ds)
        else:
            d_ij = jnp.ones((send.shape[0],), jnp.float32)
        return send, recv, d_ij

    def _node_types_mask(self, forcings):
        g = forcings["graph_structures"]
        g = g[0] if isinstance(g, (list, tuple)) else g
        ntypes = jnp.asarray(g["node_types"], jnp.int32)  # {1=fluid,2=wall,3=boundary}
        fluid = (ntypes == 1)
        return fluid, (ntypes != 1)

    def _graph_neighbor_mean(self, x, send, recv):
        # x: (N,) scalar per node. Return neighbor mean via edges.
        N = x.shape[0]
        deg = jnp.zeros((N,), x.dtype).at[recv].add(1.0)
        agg = jnp.zeros((N,), x.dtype).at[recv].add(x[send])
        neigh_mean = agg / jnp.maximum(deg, 1.0)
        return neigh_mean

    def _graph_blur(self, w, send, recv, iters, alpha):
        # simple diffusion: w <- (1-alpha)*w + alpha*mean(neigh)
        for _ in range(max(0, int(iters))):
            neigh = self._graph_neighbor_mean(w, send, recv)
            w = (1.0 - alpha) * w + alpha * neigh
        return w

    def _gmse_weights_for_batch(self, clean_inputs: jnp.ndarray, forcings: Dict[str, jnp.ndarray]) -> jnp.ndarray:
        """
        Returns per-node weights W with shape (N,), normalized to [Co, 1].
        Based on ground-truth flow of the FIRST sample in the batch (all share same graph).
        """
        cfg = self._loss_cfg
        # Ground truth velocities: (B, N, C) with C=2
        U = clean_inputs[0]  # (N, C)
        if cfg.use_speed_magnitude:
            Uscal = jnp.linalg.norm(U, axis=-1)  # (N,)
        else:
            # combine components with L2 over edge diffs later; start from each comp
            Uscal = None

        send, recv, d_ij = self._get_oo_edges(forcings)
        eps = 1e-6

        if Uscal is not None:
            # edge disparity as |u_j - u_i| / d
            ediff = jnp.abs(Uscal[recv] - Uscal[send])
            if cfg.normalize_by_edge_length:
                ediff = ediff / (d_ij + eps)
            # accumulate to nodes (mean over incident edges)
            N = U.shape[0]
            deg = jnp.zeros((N,), U.dtype).at[recv].add(1.0)
            node_disp = jnp.zeros((N,), U.dtype).at[recv].add(ediff) / jnp.maximum(deg, 1.0)
        else:
            # per-component edge diffs, then node RMS
            N = U.shape[0]
            comp_diffs = []
            for c in range(U.shape[-1]):
                ed = jnp.abs(U[recv, c] - U[send, c])
                if cfg.normalize_by_edge_length:
                    ed = ed / (d_ij + eps)
                deg = jnp.zeros((N,), U.dtype).at[recv].add(1.0)
                nd = jnp.zeros((N,), U.dtype).at[recv].add(ed) / jnp.maximum(deg, 1.0)
                comp_diffs.append(nd)
            node_disp = jnp.sqrt(jnp.sum(jnp.stack(comp_diffs, axis=-1)**2, axis=-1))

        # Graph "Gaussian" blur (diffusion)
        w = self._graph_blur(node_disp, send, recv, cfg.blur_iters, cfg.blur_alpha)

        # Gamma and min-max + offset Co   (per paper: blur -> gamma -> normalize -> offset)
        if cfg.gamma != 1.0:
            w = jnp.power(jnp.maximum(w, 0.0), cfg.gamma)
        w_min = jnp.min(w)
        w_max = jnp.max(w)
        w = (w - w_min) / jnp.maximum(w_max - w_min, 1e-6)  # to [0,1]
        w = (1.0 - cfg.offset_co) * w + cfg.offset_co        # to [Co,1]

        # Keep boundary nodes from dominating: clamp them to baseline Co if requested
        if cfg.floor_on_boundaries:
            _, boundary_mask = self._node_types_mask(forcings)
            w = jnp.where(boundary_mask, jnp.asarray(cfg.offset_co, w.dtype), w)

        # Normalize weights to mean ~1 so scale of loss stays comparable
        mean_w = jnp.maximum(jnp.mean(w), 1e-6)
        w = w / mean_w
        return w  # (N,)

