import os
import re
import json
import click
import tqdm
import time
import pickle
import numpy as np
import torch
import PIL.Image
import dnnlib
import h5py
import scipy.io
import random
import gc
from torch_utils import misc
from training import utils
from einops import rearrange
import matplotlib.pyplot as plt
import torch.nn.functional as F
from training.datasets import *
from torch_utils import distributed as dist
from torchvision.transforms import ToPILImage
from training import evaluation_utils
from training.evaluation_utils import calculate_metrics, save_samples_to_pdf, compute_coefficient_error_rate, plot_process, plot_dps_losses 
from scipy.interpolate import Rbf

#----------------------------------------------------------------------------
# Proposed EDM sampler (Algorithm 2).

def deterministic_edm_sampler(
    net, latents, class_labels=None, randn_like=torch.randn_like,
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    x_next = latents.to(torch.float64) * t_steps[0]
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next


        x_hat = x_cur ### SAME AS COND EDM BUT WHY?? CHECK
        t_hat = t_cur
        # # Increase noise temporarily.
        # gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        # t_hat = net.round_sigma(t_cur + gamma * t_cur)
        # x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
        B = x_hat.shape[0]
        noise_labels = t_hat.expand(B).to(device=x_hat.device, dtype=x_hat.dtype)
        
        # Euler step.
        denoised = net(x_hat, noise_labels, class_labels).to(torch.float64)
        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            B = x_next.shape[0]
            noise_labels_next = t_next.expand(B).to(device=x_next.device, dtype=x_next.dtype)

            denoised = net(x_next, noise_labels_next, class_labels).to(torch.float64)
            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

    return x_next

#----------------------------------------------------------------------------
# Generalized ablation sampler, representing the superset of all sampling
# methods discussed in the paper. Adjusted to be determinstis from Chris's cond diffusion repo #TODO

def deterministic_ablation_sampler(
    net, latents, class_labels=None, randn_like=torch.randn_like,
    num_steps=18, sigma_min=None, sigma_max=None, rho=7,
    solver='heun', discretization='edm', schedule='linear', scaling='none',
    epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
    assert solver in ['euler', 'heun']
    assert discretization in ['vp', 've', 'iddpm', 'edm']
    assert schedule in ['vp', 've', 'linear']
    assert scaling in ['vp', 'none']

    # Helper functions for VP & VE noise level schedules.
    vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
    vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
    vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
    ve_sigma = lambda t: t.sqrt()
    ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
    ve_sigma_inv = lambda sigma: sigma ** 2

    # Select default noise level range based on the specified time step discretization.
    if sigma_min is None:
        vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s)
        sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
    if sigma_max is None:
        vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1)
        sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]

    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Compute corresponding betas for VP.
    vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
    vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d

    # Define time steps in terms of noise level.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    if discretization == 'vp':
        orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
        sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
    elif discretization == 've':
        orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
        sigma_steps = ve_sigma(orig_t_steps)
    elif discretization == 'iddpm':
        u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
        alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
        for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
            u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
        u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
        sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
    else:
        assert discretization == 'edm'
        sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho

    # Define noise level schedule.
    if schedule == 'vp':
        sigma = vp_sigma(vp_beta_d, vp_beta_min)
        sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
        sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
    elif schedule == 've':
        sigma = ve_sigma
        sigma_deriv = ve_sigma_deriv
        sigma_inv = ve_sigma_inv
    else:
        assert schedule == 'linear'
        sigma = lambda t: t
        sigma_deriv = lambda t: 1
        sigma_inv = lambda sigma: sigma

    # Define scaling schedule.
    if scaling == 'vp':
        s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
        s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
    else:
        assert scaling == 'none'
        s = lambda t: 1
        s_deriv = lambda t: 0

    # Compute final time steps based on the corresponding noise levels.
    t_steps = sigma_inv(net.round_sigma(sigma_steps))
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    t_next = t_steps[0]
    x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        x_hat = x_cur
        t_hat = t_cur


        # # Increase noise temporarily.
        # gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
        # t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
        # x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)

        # Euler step.
        h = t_next - t_hat
        denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels)
        d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
        x_prime = x_hat + alpha * h * d_cur
        t_prime = t_hat + alpha * h

        # Apply 2nd order correction.
        if solver == 'euler' or i == num_steps - 1:
            x_next = x_hat + h * d_cur
        else:
            assert solver == 'heun'
            denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels)   #.to(torch.float64)
            d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
            x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)

    return x_next
# Proposed EDM sampler (Algorithm 2) from EDM Paper
def edm_sampler(
    net, latents, class_labels=None, masks=None, randn_like=torch.randn_like,
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, training_mode=None,
    pde_residual_mode=False, pde_loss_function=None, compute_pde_loss_fn=None,
    pde_direction=None, guided_pde_residual_mode=False, device=torch.device('cuda'), dataset_obj=None, 
    normalize_pde_residual=False, sigma_data=0.5, pass_masks_to_model=False,
    track_losses=False, loss_type="mse", re_paint=False, test_uncond=False, pde_res_tracking=False, **kwargs
):
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)
    # print(f"Using sigma_min: {sigma_min}, sigma_max: {sigma_max}")

    loss_history = [] if track_losses else None

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Initialize PDE residual tracking if requested
    if pde_res_tracking:
        pde_residual_stats = {
            'timestep_stats': [],  # Will store stats for each timestep
            'timestep_residuals': [],  # Will store actual PDE residual images for visualization
            'timestep_predictions': [],  # Will store predicted field images (denormalized) per timestep
            'num_timesteps': num_steps  # We track num_steps-1 since the loop goes from 0 to N-1
        }

    # Main sampling loop.
    x_next = latents.to(torch.float64) * t_steps[0]
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next
        # print("x_cur shape: ", x_cur.shape)


        # x_hat = x_cur ###  possibly to make the process more deterministic and reduce diversity
        # t_hat = t_cur
        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = net.round_sigma(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
        B = x_hat.shape[0]
        noise_labels = t_hat.expand(B).to(device=x_hat.device, dtype=x_hat.dtype)

        # Euler step.
        if pde_residual_mode:
            # print("Computing PDE residual")
            if training_mode == 'conditional':
                pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, x_hat.to(torch.float32), class_labels.to(torch.float32), device).unsqueeze(1)
            if training_mode == 'unified':
                if guided_pde_residual_mode:
                    # print("Forcing PDE residual computation in unified mode")
                    pde_input = (1-masks[:, 0:1]) * x_hat.to(torch.float32)[:, 0:1] + class_labels.to(torch.float32)[:, 0:1] # labels is mask * images
                    pde_output = (1-masks[:, 1:2]) * x_hat.to(torch.float32)[:, 1:2] + class_labels.to(torch.float32)[:, 1:2]
                    pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, pde_input, pde_output, device=device, training_mode=training_mode).unsqueeze(1)
                    # print("pde input shape: ", pde_input.shape)
                    # print("pde output shape: ", pde_output.shape)
                else:
                    pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, x_hat[:,0:1].to(torch.float32), x_hat[:,1:2].to(torch.float32), device, training_mode).unsqueeze(1)
            # First apply mask if using sparse conditioning
            ### not masking pde residuals
            # if pass_masks_to_model:
            #     # print("Passing mask to modeler at step ", i)
            #     pde_residual = pde_residual * masks  # Apply mask to PDE residual if using sparse conditioning
            
            # Then apply normalization if requested (same as training)
            if normalize_pde_residual:
                # print("PDE residual shape at step ", i, ": ", pde_residual.shape)
                # print("Normalizing PDE residual at step ", i)
                pde_residual = get_normalized_pde_residual(pde_residual, sigma_data=0.5)
            if pass_masks_to_model:
                # print("Passing mask to modeler at step ", i)
                # print("Mask stats - min:", masks.min().item(), "max:", masks.max().item(), "mean:", masks.mean().item())
                if test_uncond:
                    # print("Running unconditional sampling at step ", i)
                    masks_uncond = torch.zeros_like(masks)
                    class_labels_uncond = torch.zeros_like(class_labels)
                    denoised = net(x_hat, noise_labels, class_labels_uncond, mask=masks_uncond, pde_residual=pde_residual).to(torch.float64)
                else:
                    denoised = net(x_hat, noise_labels, class_labels, mask=masks, pde_residual=pde_residual).to(torch.float64)
            else:
                # print("Not passing mask to modeler at step ", i)
                denoised = net(x_hat, noise_labels, class_labels, pde_residual=pde_residual).to(torch.float64)
            if training_mode == 'conditional':
                x_hat_denorm = dataset_obj.denorm_output(x_hat)
                class_labels_denorm = dataset_obj.denorm_output(class_labels)
                pde_res_denorm = compute_pde_loss_fn(pde_loss_function, pde_direction, x_hat_denorm.to(torch.float32), class_labels_denorm.to(torch.float32), device).unsqueeze(1)
            if training_mode == 'unified':
                x_hat_denorm = dataset_obj.denorm_tensor(x_hat)
                pde_res_denorm = compute_pde_loss_fn(pde_loss_function, pde_direction, x_hat_denorm[:, 0:1].to(torch.float32), x_hat_denorm[:, 1:2].to(torch.float32), device, training_mode).unsqueeze(1)
        else:
            denoised = net(x_hat, noise_labels, class_labels).to(torch.float64)

        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            B = x_next.shape[0]
            noise_labels_next = t_next.expand(B).to(device=x_next.device, dtype=x_next.dtype)

            if pde_residual_mode:
                if training_mode == 'conditional':
                    pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, x_next.to(torch.float32), class_labels.to(torch.float32), device).unsqueeze(1)
                # if training_mode == 'unified':
                #     pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, x_next[:,0:1].to(torch.float32), x_next[:,1:2].to(torch.float32), device, training_mode).unsqueeze(1)
                if training_mode == 'unified':
                    if guided_pde_residual_mode:
                        # print("Forcing PDE residual computation in unified mode")
                        pde_input = (1-masks[:, 0:1]) * x_next.to(torch.float32)[:, 0:1] + class_labels.to(torch.float32)[:, 0:1] # labels is mask * images
                        pde_output = (1-masks[:, 1:2]) * x_next.to(torch.float32)[:, 1:2] + class_labels.to(torch.float32)[:, 1:2]
                        pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, pde_input, pde_output, device=device, training_mode=training_mode).unsqueeze(1)
                    else:
                        pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, x_next[:,0:1].to(torch.float32), x_next[:,1:2].to(torch.float32), device, training_mode).unsqueeze(1)
                    
                # pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, x_next.to(torch.float32), class_labels.to(torch.float32), device).unsqueeze(1)
                # First apply mask if using sparse conditioning
                ### not masking pde residuals 
                # if pass_masks_to_model:
                #     # print("Passing mask to modeler at step ", i)
                #     # print("Mask stats - min:", masks.min().item(), "max:", masks.max().item(), "mean:", masks.mean().item())
                #     pde_residual = pde_residual * masks
                
                # Then apply normalization if requested (same as training)
                if normalize_pde_residual:
                    # print("Normalizing PDE residual at step ", i)
                    pde_residual = get_normalized_pde_residual(pde_residual, sigma_data=0.5)
                if pass_masks_to_model:
                    # print("Passing mask to modeler at step ", i)
                    if test_uncond:
                        # print("testing unconditionally")
                    # print("Running unconditional sampling at step ", i)
                        masks_uncond = torch.zeros_like(masks)
                        class_labels_uncond = torch.zeros_like(class_labels)
                        denoised = net(x_hat, noise_labels, class_labels_uncond, mask=masks_uncond, pde_residual=pde_residual).to(torch.float64)
                    else:
                        denoised = net(x_next, noise_labels_next, class_labels, mask=masks, pde_residual=pde_residual).to(torch.float64)
                else:
                    denoised = net(x_next, noise_labels_next, class_labels, pde_residual=pde_residual).to(torch.float64)
                if training_mode == 'conditional':
                    x_next_denorm = dataset_obj.denorm_output(x_next)
                    class_labels_denorm = dataset_obj.denorm_output(class_labels)
                    pde_res_denorm = compute_pde_loss_fn(pde_loss_function, pde_direction, x_next_denorm.to(torch.float32), class_labels_denorm.to(torch.float32), device).unsqueeze(1)
                if training_mode == 'unified':
                    x_next_denorm = dataset_obj.denorm_tensor(x_next)
                    pde_res_denorm = compute_pde_loss_fn(pde_loss_function, pde_direction, x_next_denorm[:, 0:1].to(torch.float32), x_next_denorm[:, 1:2].to(torch.float32), device, training_mode).unsqueeze(1)
            else:
                denoised = net(x_next, noise_labels_next, class_labels)#.to(torch.float64)

            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

        if re_paint and masks is not None:
            # print("repaint true")
            # This is x_{t-1}_unknown from the denoising step above
            x_next_unknown = x_next
            # Forward process the known data to get x_{t-1}_known.
            # We take the clean ground truth (class_labels) and add the amount of noise
            # appropriate for the *next* timestep (t_next).
            noise_for_known = randn_like(x_next) * t_next
            x_next_known = class_labels.to(torch.float64) + noise_for_known
            # Combine the known and unknown parts using the mask.
            # Where mask == 1, we use the known data.
            # Where mask == 0, we use the model's dreamed-up data.
            x_next = (masks * x_next_known) + ((1 - masks) * x_next_unknown)

        # If tracking losses, compute and store the loss at this step
        # Track losses if requested (do this at the end of each step)
        if track_losses and pde_residual_mode and dataset_obj is not None:
            x_next_denorm = dataset_obj.denorm_tensor(x_next, device) if training_mode == 'unified' else dataset_obj.denorm_output(x_next)
            
            # Calculate and store losses
            # timestep_loss_data = {'timestep': i, 'sigma_t': t_cur.item()}
            
            # Track PDE residual loss
            if training_mode == 'unified':
                pde_res = compute_pde_loss_fn(
                    pde_loss_function, pde_direction,
                    x_next_denorm[:, 0:1].to(torch.float32), 
                    x_next_denorm[:, 1:2].to(torch.float32),
                    device, training_mode
                )
            else:  # conditional
                class_labels_denorm = dataset_obj.denorm_input(class_labels)
                pde_res = compute_pde_loss_fn(
                    pde_loss_function, pde_direction,
                    x_next_denorm.to(torch.float32), 
                    class_labels_denorm.to(torch.float32),
                    device
                )
            
            # # Compute loss value
            # pde_loss = compute_loss(pde_res.unsqueeze(1), torch.zeros_like(pde_res).unsqueeze(1), loss_type, mask=None)
            # pde_loss = pde_loss.sum(dim=1, keepdim=True)
            # timestep_loss_data['pde_loss'] = pde_loss.detach()

            # # Store observation losses if applicable
            # if training_mode == 'unified':
            #     obs_losses = []
            #     for ch in range(2):  # Both channels
            #         mask_ch = masks[:, ch:ch+1] if masks is not None else torch.ones_like(x_next_denorm[:, ch:ch+1])
            #         pred_ch = x_next_denorm[:, ch:ch+1]
            #         target_ch = class_labels[:, ch:ch+1]
            #         masked_diff = (pred_ch - target_ch) * mask_ch
            #         loss_ch = torch.mean(masked_diff**2) if loss_type == 'mse' else torch.mean(torch.abs(masked_diff))
            #         obs_losses.append(loss_ch.detach())
            #     timestep_loss_data['obs_losses'] = obs_losses
            
            # # Placeholder for coefficients (which the EDM sampler doesn't use)
            # timestep_loss_data['obs_coef'] = 0.0
            # timestep_loss_data['pde_coef'] = 0.0
            
            # loss_history.append(timestep_loss_data)
    
        # Track PDE residuals if requested
        # Compute statistics across spatial dimensions for each sample in batch
        if pde_res_tracking:
            batch_size = pde_res_denorm.shape[0]
            timestep_stats = []
            for b in range(batch_size):
                sample_residual = pde_res_denorm[b].flatten()  # Flatten spatial dimensions
                sample_stats = {
                    'mean': sample_residual.mean().item(),
                    'min': sample_residual.min().item(),
                    'max': sample_residual.max().item(),
                    'norm': torch.norm(sample_residual).item()
                }
                timestep_stats.append(sample_stats)
            
            pde_residual_stats['timestep_stats'].append({
                'timestep': i,
                'batch_stats': timestep_stats
            })
            
            # Store the actual PDE residual images for visualization
            pde_residual_stats['timestep_residuals'].append({
                'timestep': i,
                'pde_residual_images': pde_res_denorm.detach().cpu().clone()  # Store on CPU to save GPU memory
            })
            # Store the corresponding predicted images (denormalized) for visualization
            if 'x_next_denorm' in locals():
                pred_images_denorm = x_next_denorm
            else:
                pred_images_denorm = dataset_obj.denorm_output(x_next)
            pde_residual_stats['timestep_predictions'].append({
                'timestep': i,
                'pred_images': pred_images_denorm.detach().cpu().clone()
            })
    return x_next, pde_residual_stats if pde_res_tracking else None, loss_history if track_losses else None

# Add this helper function inside edm_sampler
def get_normalized_pde_residual(residual, sigma_data=0.5):
    """
    Normalize PDE residual to have standard deviation matching sigma_data
    (Same normalization as used in training)
    """
    # Per-sample spatial normalization first (mean=0, std=1)
    # print(residual.shape)
    mean = residual.mean(dim=[2, 3], keepdim=True)
    std = residual.std(dim=[2, 3], keepdim=True) + 1e-8
    normalized = (residual - mean) / std
    
    # Scale to target sigma_data
    normalized = normalized * sigma_data

    return normalized
#----------------------------------------------------------------------------
def edm_dps_sampler(
    net, latents, class_labels=None, masks=None, randn_like=torch.randn_like,
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, training_mode="unified",
    pde_residual_mode=False, pde_loss_function=None, compute_pde_loss_fn=None,
    pde_direction=None, guided_pde_residual_mode=False, device=torch.device('cuda'), dataset_obj=None, 
    normalize_pde_residual=False, sigma_data=0.5, pass_masks_to_model=False, test_uncond=False,
    # DPS specific parameters
    obs_guidance_weight=[10000,10000], pde_guidance_weight=1, loss_type='mse', **kwargs):
    """
    EDM-based sampler with DPS-style guidance specifically for unified models.
    
    This combines EDM sampling with DPS-style gradient-based guidance to solve PDEs.
    """
    # Ensure we're using unified mode
    assert training_mode == "unified", "DPS sampler is only implemented for unified models"
    # print("Using EDM-DPS sampler loss_type:", loss_type)
    # print("obs_guidance_weight:", obs_guidance_weight, "pde_guidance_weight:", pde_guidance_weight)
    
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Initialize PDE residual tracking
    pde_residual_stats = {
        'timestep_stats': [],
        'timestep_residuals': [],
        'timestep_predictions': [],
        'num_timesteps': num_steps
    }

    # Determine which channel is observed based on pde_direction
    if pde_direction == 'forward':
        # a is observed (channel 0), u is predicted (channel 1)
        obs_channel = 0
        pred_channel = 1
    else:  # inverse
        # u is observed (channel 1), a is predicted (channel 0)
        obs_channel = 1
        pred_channel = 0
    
    # Main sampling loop.
    x_next = latents.to(torch.float64) * t_steps[0]
    intermediates = []
    loss_history = []

    # for i, (t_cur, t_next) in enumerate(tqdm.tqdm(zip(t_steps[:-1], t_steps[1:]), total=num_steps)): # 0, ..., N-1
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): 
        x_cur = x_next.detach().clone()
        x_cur.requires_grad_(True)

        sigma_t = net.round_sigma(t_cur)
        noise_labels = sigma_t.expand(latents.shape[0]).to(device=x_cur.device, dtype=x_cur.dtype)
        
        # Euler step
        if pde_residual_mode:
            if guided_pde_residual_mode:
                # print("Computing guided PDE residual at step ", i)
                # print("Forcing PDE residual computation in unified mode")
                pde_input = (1-masks[:, 0:1]) * x_cur.to(torch.float32)[:, 0:1] + class_labels.to(torch.float32)[:, 0:1] # labels is mask * images
                pde_output = (1-masks[:, 1:2]) * x_cur.to(torch.float32)[:, 1:2] + class_labels.to(torch.float32)[:, 1:2]
                pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, pde_input, pde_output, device=device, training_mode=training_mode).unsqueeze(1)
            else:
                pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, x_cur[:,0:1].to(torch.float32), x_cur[:,1:2].to(torch.float32), device, training_mode).unsqueeze(1)
                # Then apply normalization if requested (same as training)
            if normalize_pde_residual:
                # print("Normalizing PDE residual at step ", i)
                pde_residual = get_normalized_pde_residual(pde_residual, sigma_data=0.5)
            if pass_masks_to_model:
                if test_uncond:
                    # print("testing unconditionally")
                # print("Running unconditional sampling at step ", i)
                    masks_uncond = torch.zeros_like(masks)
                    class_labels_uncond = torch.zeros_like(class_labels)
                    x_N = net(x_cur, noise_labels, class_labels_uncond, mask=masks_uncond, pde_residual=pde_residual).to(torch.float64)
                else:
                    x_N = net(x_cur, noise_labels, class_labels, mask=masks, pde_residual=pde_residual).to(torch.float64)
            else:
                x_N = net(x_cur, noise_labels, class_labels, pde_residual=pde_residual).to(torch.float64)
            
            # Create detached copies for visualization/metrics without affecting gradient computation
            with torch.no_grad():
                x_cur_vis = x_cur.detach().clone()
                x_cur_denorm = dataset_obj.denorm_tensor(x_cur_vis)
                pde_res_denorm = compute_pde_loss_fn(pde_loss_function, pde_direction, x_cur_denorm[:, 0:1].to(torch.float32), x_cur_denorm[:, 1:2].to(torch.float32), device, training_mode).unsqueeze(1)
        else: #never used
            if pass_masks_to_model:
                x_N = net(x_cur, noise_labels, class_labels, mask=masks).to(torch.float64)
            else:
                x_N = net(x_cur, noise_labels, class_labels).to(torch.float64)
        d_cur = (x_cur - x_N) / sigma_t
        x_next = x_cur + (t_next - sigma_t) * d_cur

        # 2nd order correction
        if i < num_steps - 1:
            noise_labels_next = t_next.expand(latents.shape[0]).to(device=x_next.device, dtype=x_next.dtype)
            
            if pde_residual_mode:
                if guided_pde_residual_mode:
                    # print("Computing guided PDE residual at step ", i)
                    # print("Forcing PDE residual computation in unified mode")
                    pde_input = (1-masks[:, 0:1]) * x_next.to(torch.float32)[:, 0:1] + class_labels.to(torch.float32)[:, 0:1] # labels is mask * images
                    pde_output = (1-masks[:, 1:2]) * x_next.to(torch.float32)[:, 1:2] + class_labels.to(torch.float32)[:, 1:2]
                    pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, pde_input, pde_output, device=device, training_mode=training_mode).unsqueeze(1)
                else:
                    pde_residual = compute_pde_loss_fn(pde_loss_function, pde_direction, x_next[:,0:1].to(torch.float32), x_next[:,1:2].to(torch.float32), device, training_mode).unsqueeze(1)
                # Then apply normalization if requested (same as training)
                if normalize_pde_residual:
                    # print("Normalizing PDE residual at step ", i)
                    pde_residual = get_normalized_pde_residual(pde_residual, sigma_data=0.5)
                if pass_masks_to_model:
                    # print("Passing mask to modeler at step ", i)
                    if test_uncond:
                        # print("testing unconditionally")
                    # print("Running unconditional sampling at step ", i)
                        masks_uncond = torch.zeros_like(masks)
                        class_labels_uncond = torch.zeros_like(class_labels)
                        x_N = net(x_next, noise_labels, class_labels_uncond, mask=masks_uncond, pde_residual=pde_residual).to(torch.float64)
                    else:
                        x_N = net(x_next, noise_labels_next, class_labels, mask=masks, pde_residual=pde_residual).to(torch.float64)
                else:
                    x_N = net(x_next, noise_labels_next, class_labels, pde_residual=pde_residual).to(torch.float64)
                
                # Create detached copies for visualization/metrics without affecting gradient computation
                with torch.no_grad():
                    x_next_vis = x_next.detach().clone()
                    x_next_denorm = dataset_obj.denorm_tensor(x_next_vis)
                    pde_res_denorm = compute_pde_loss_fn(pde_loss_function, pde_direction, x_next_denorm[:, 0:1].to(torch.float32), x_next_denorm[:, 1:2].to(torch.float32), device, training_mode).unsqueeze(1)
            else: # never used
                if pass_masks_to_model:
                    x_N = net(x_next, noise_labels_next, class_labels, mask=masks).to(torch.float64)
                else:
                    x_N = net(x_next, noise_labels_next, class_labels).to(torch.float64)
                    
            d_prime = (x_next - x_N) / t_next
            x_next = x_cur + (t_next - sigma_t) * (0.5 * d_cur + 0.5 * d_prime)

        # Apply guidance (DPS style)
        # First denormalize predictions for loss calculations
        x_N_denorm = dataset_obj.denorm_tensor(x_N,x_N.device)
        x_next_denorm = dataset_obj.denorm_tensor(x_next.detach(), x_next.device)
        class_labels_denorm = dataset_obj.denorm_tensor(class_labels, x_next.device)
        
        # Initialize gradient update
        update = torch.zeros_like(x_cur)
        
        if i < num_steps - 1:
            step_losses = []
            active_losses = []
            
            # Get sigma-dependent coefficient based on current step
            coef_obs = get_dps_coef(i, "obs", t_steps[i].item())
            coef_pde = get_dps_coef(i, "pde", t_steps[i].item())

            # Calculate observation loss (full or sparse) 
            # its assuming single channel solution, to run on partial obs from both channels, update code
            # pred = x_N_denorm[:, obs_channel:obs_channel+1]
            # target = class_labels[:, obs_channel:obs_channel+1]
            # mask = masks[:, obs_channel:obs_channel+1] # mask should handle full or sparse
            # obs_loss = compute_loss(pred, target, loss_type, mask=mask)
            # step_losses.append(obs_loss.detach())
            
            # if coef_obs != 0 and obs_guidance_weight != 0:
            #     active_losses.append((obs_loss.sum(), coef_obs * obs_guidance_weight))
            
            # Replace the single channel observation loss calculation with a loop over channels
            obs_losses = []
            for ch in range(2):  # 2 channels (0 and 1)
                pred_ch = x_N_denorm[:, ch:ch+1]
                target_ch = class_labels_denorm[:, ch:ch+1]
                # pred_ch = x_N[:, ch:ch+1]
                # target_ch = class_labels[:, ch:ch+1]
                mask_ch = masks[:, ch:ch+1]
                
                loss_ch = compute_loss(pred_ch, target_ch, loss_type, mask=mask_ch)
                obs_losses.append(loss_ch)
                step_losses.append(loss_ch.detach())
                # Get channel-specific coefficient and weight
                ch_weight = obs_guidance_weight[ch] if isinstance(obs_guidance_weight, list) else obs_guidance_weight
                if coef_obs != 0 and ch_weight != 0:
                    active_losses.append((loss_ch.sum(), coef_obs * ch_weight))
            
            # Calculate PDE residual loss if requested
            # Calculate PDE residual between the two channels
            pde_input = (1-masks[:, 0:1]) * x_N_denorm.to(torch.float32)[:, 0:1] + class_labels_denorm.to(torch.float32)[:, 0:1] # labels is mask * images
            pde_output = (1-masks[:, 1:2]) * x_N_denorm.to(torch.float32)[:, 1:2] + class_labels_denorm.to(torch.float32)[:, 1:2]
            # pde_input = (1-masks[:, 0:1]) * x_N.to(torch.float32)[:, 0:1] + class_labels.to(torch.float32)[:, 0:1] # labels is mask * images
            # pde_output = (1-masks[:, 1:2]) * x_N.to(torch.float32)[:, 1:2] + class_labels.to(torch.float32)[:, 1:2]
            pde_res = compute_pde_loss_fn(
                pde_loss_function, pde_direction,
                pde_input.to(torch.float32), 
                pde_output.to(torch.float32),
                device, training_mode
            )
            
            # if normalize_pde_residual:
            #     pde_res = get_normalized_pde_residual(pde_res, sigma_data)
            pde_loss = compute_loss(pde_res.unsqueeze(1), torch.zeros_like(pde_res).unsqueeze(1), loss_type, mask=None)
            # print("PDE residual stats at step ", i, ": mean=", pde_res.mean().item(), "min=", pde_res.min().item(), "max=", pde_res.max().item(), "norm=", torch.norm(pde_res).item())
            # print("pde loss shape: ", pde_loss.shape)
            pde_loss = pde_loss.sum(dim=1, keepdim=True)
            # print("pde loss after sum shape: ", pde_loss.shape)
            step_losses.append(pde_loss.detach())
            
            if coef_pde != 0 and pde_guidance_weight != 0:
                active_losses.append((pde_loss.sum(), coef_pde * pde_guidance_weight))
        
            # Apply gradients
            for idx, (loss, weight) in enumerate(active_losses):
                # print("Applying guidance for loss ", idx, " with weight ", weight)
                flag_retain_graph = idx < len(active_losses) - 1
                grad = torch.autograd.grad(loss, x_cur, retain_graph=flag_retain_graph)[0]
                # grad = grad / (torch.norm(grad) + 1e-8)
                # grad_norm = torch.norm(grad).item()
                # print(f"Step {i}: Observation Gradient Norm = {grad_norm}")
                # if grad_norm == 0.0:
                #     print("--> WARNING: Gradient is ZERO. Computation graph is likely broken.")
                update = update + weight * grad
            
            # Record loss history
            if step_losses:
                # Store more detailed loss information including observation loss by channel and PDE loss
                timestep_loss_data = {
                    'timestep': i,
                    'obs_losses': obs_losses,  # List of per-channel observation losses
                    'pde_loss': pde_loss.detach(),
                    'obs_coef': coef_obs,
                    'pde_coef': coef_pde,
                    'sigma_t': t_steps[i].item()
                }
                loss_history.append(timestep_loss_data)
                    
        # Apply update
        x_next = x_next - update
        if x_next.isnan().any():
                print(f"\nStep {i}: NaN detected!")
                break
        
        # Track PDE residuals for visualization
        if pde_residual_mode:
            # Calculate PDE residual for visualization and statistics
            x_next_denorm = dataset_obj.denorm_tensor(x_next)
            pde_res_denorm = compute_pde_loss_fn(
                pde_loss_function, pde_direction, 
                x_next_denorm[:, 0:1].to(torch.float32), 
                x_next_denorm[:, 1:2].to(torch.float32), 
                device, training_mode
            ).unsqueeze(1)
            
            # Compute statistics across spatial dimensions for each sample in batch
            batch_size = pde_res_denorm.shape[0]
            timestep_stats = []
            for b in range(batch_size):
                sample_residual = pde_res_denorm[b].flatten()
                sample_stats = {
                    'mean': sample_residual.mean().item(),
                    'min': sample_residual.min().item(),
                    'max': sample_residual.max().item(),
                    'norm': torch.norm(sample_residual).item()
                }
                timestep_stats.append(sample_stats)
            
            pde_residual_stats['timestep_stats'].append({
                'timestep': i,
                'batch_stats': timestep_stats
            })
            
            # Store the actual PDE residual images for visualization
            pde_residual_stats['timestep_residuals'].append({
                'timestep': i,
                'pde_residual_images': pde_res_denorm.cpu().clone()
            })
            
            # Store the predicted images for visualization
            pde_residual_stats['timestep_predictions'].append({
                'timestep': i,
                'pred_images': x_next_denorm.detach().cpu().clone()
            })
        
        # Save intermediate results for visualization if requested
        if i % 10 == 0 or i == num_steps - 1:
            intermediates.append(dataset_obj.denorm_tensor(x_next.detach()).cpu())

    # Return final denormalized result
    x_final = x_next.detach()
    pred = dataset_obj.denorm_tensor(x_final)
    
    return x_final, pde_residual_stats, loss_history

# Helper functions
def get_dps_coef(cur_step, obs_type, sigma_t):
    """
    Calculate step coefficient for guidance weight scaling.
    
    Similar to FunDPS's get_coef method:
    - For PDE residuals: use sigma_t when sigma_t <= 1.0, otherwise 0
    - For other observations: always use 1.0 when sigma_t > 1.0, otherwise sigma_t
    """
    if sigma_t > 1.0:
        if obs_type == "pde":
            return 0
        else:
            return 1
    else:
        return sigma_t



def compute_loss(pred, target, loss_type='mse', mask=None):
    """Compute loss between prediction and target using specified loss type."""
    if mask is None:
        diff = pred - target
        n_obs = torch.tensor(pred.shape[-2] * pred.shape[-1], device=pred.device)
    else:
        diff = (pred - target) * mask
        n_obs = mask.sum(dim=[-1, -2]).clamp(min=1.0)

    if loss_type == 'mse':
        return torch.sum(diff**2, dim=[-1, -2]) / n_obs
    elif loss_type == 'l1':
        return torch.sum(torch.abs(diff), dim=[-1, -2]) / n_obs
    elif loss_type == 'l2':
        return torch.sqrt(torch.sum(diff**2, dim=[-1, -2]) / n_obs)
    elif loss_type == 'huber': #pred has the pde error
        abs_error = torch.abs(pred)
        delta = 1.0
        huber_pixel_loss = torch.where(
            abs_error < delta,
            0.5 * abs_error**2,
            delta * (abs_error - 0.5 * delta)
        )
        return torch.sum(huber_pixel_loss, dim=[-1, -2]) / n_obs
    else:
        raise ValueError(f"Unsupported loss type: {loss_type}")

#----------------------------------------------------------------------------
# Generalized ablation sampler, representing the superset of all sampling
# methods discussed in the paper.

def ablation_sampler(
    net, latents, class_labels=None, randn_like=torch.randn_like,
    num_steps=18, sigma_min=None, sigma_max=None, rho=7,
    solver='heun', discretization='edm', schedule='linear', scaling='none',
    epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
    assert solver in ['euler', 'heun']
    assert discretization in ['vp', 've', 'iddpm', 'edm']
    assert schedule in ['vp', 've', 'linear']
    assert scaling in ['vp', 'none']

    # Helper functions for VP & VE noise level schedules.
    vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
    vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
    vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
    ve_sigma = lambda t: t.sqrt()
    ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
    ve_sigma_inv = lambda sigma: sigma ** 2

    # Select default noise level range based on the specified time step discretization.
    if sigma_min is None:
        vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s)
        sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
    if sigma_max is None:
        vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1)
        sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]

    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Compute corresponding betas for VP.
    vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
    vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d

    # Define time steps in terms of noise level.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    if discretization == 'vp':
        orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
        sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
    elif discretization == 've':
        orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
        sigma_steps = ve_sigma(orig_t_steps)
    elif discretization == 'iddpm':
        u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
        alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
        for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
            u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
        u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
        sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
    else:
        assert discretization == 'edm'
        sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho

    # Define noise level schedule.
    if schedule == 'vp':
        sigma = vp_sigma(vp_beta_d, vp_beta_min)
        sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
        sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
    elif schedule == 've':
        sigma = ve_sigma
        sigma_deriv = ve_sigma_deriv
        sigma_inv = ve_sigma_inv
    else:
        assert schedule == 'linear'
        sigma = lambda t: t
        sigma_deriv = lambda t: 1
        sigma_inv = lambda sigma: sigma

    # Define scaling schedule.
    if scaling == 'vp':
        s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
        s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
    else:
        assert scaling == 'none'
        s = lambda t: 1
        s_deriv = lambda t: 0

    # Compute final time steps based on the corresponding noise levels.
    t_steps = sigma_inv(net.round_sigma(sigma_steps))
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    t_next = t_steps[0]
    x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # x_hat = x_cur
        # t_hat = t_cur


        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
        t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
        x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)

        # Euler step.
        h = t_next - t_hat
        denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64)
        d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
        x_prime = x_hat + alpha * h * d_cur
        t_prime = t_hat + alpha * h

        # Apply 2nd order correction.
        if solver == 'euler' or i == num_steps - 1:
            x_next = x_hat + h * d_cur
        else:
            assert solver == 'heun'
            denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64)
            d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
            x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)

    return x_next
#----------------------------------------------------------------------------
# Wrapper for torch.Generator that allows specifying a different random seed
# for each sample in a minibatch.

class StackedRandomGenerator:
    def __init__(self, device, seeds):
        super().__init__()
        self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]

    def randn(self, size, **kwargs):
        assert size[0] == len(self.generators)
        return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])

    def randn_like(self, input):
        return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)

    def randint(self, *args, size, **kwargs):
        assert size[0] == len(self.generators)
        return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])

#----------------------------------------------------------------------------
# Parse a comma separated list of numbers or ranges and return a list of ints.
# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]

def parse_int_list(s):
    if isinstance(s, list): return s
    ranges = []
    range_re = re.compile(r'^(\d+)-(\d+)$')
    for p in s.split(','):
        m = range_re.match(p)
        if m:
            ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
        else:
            ranges.append(int(p))
    return ranges

#----------------------------------------------------------------------------

@click.command()
@click.option('--checkpoint_dir',          help='Where to get the pickle files from', metavar='DIR',                type=str, required=True)
@click.option('--outdir',                  help='Where to save the output images', metavar='DIR',                   type=str, default=None)
@click.option("--resolution",              help="Desired resolution of noise (and therefore generated images",      type=int, default=None)
@click.option('--test_direction',          help='Direction to test for unified model (forward/inverse). Overrides dataset direction.', type=click.Choice(['forward', 'inverse']), default=None)
@click.option('--seeds',                   help='Random seeds (e.g. 1,2,5-10)', metavar='LIST',                     type=parse_int_list, default='0-63', show_default=True)
@click.option('--kimg_intervals',          help='Intervals in KImg to test model checkpoints (e.g. 500,1000,2000)', metavar='LIST', type=parse_int_list, default="500,1000,2000", show_default=True)
@click.option('--data',                    help='Path to the dataset', metavar='STR',                               type=str, required=True)
@click.option('--offset',                  help='Offset index to start evaluation from',                            type=int, default=0)
@click.option('--num',                     help='Number of samples to evaluate',                                    type=int, default=None)
@click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT',                                type=click.IntRange(min=1), default=1, show_default=True)
@click.option('--num_workers',             help='Number of workers for data loader', metavar='INT',                             type=click.IntRange(min=0), default=4)
@click.option('--steps', 'num_steps',      help='Number of sampling steps', metavar='INT',                          type=click.IntRange(min=1), default=18, show_default=True)
@click.option('--sigma_min',               help='Lowest noise level  [default: varies]', metavar='FLOAT',           type=click.FloatRange(min=0, min_open=True))
@click.option('--sigma_max',               help='Highest noise level  [default: varies]', metavar='FLOAT',          type=click.FloatRange(min=0, min_open=True))
@click.option('--rho',                     help='Time step exponent', metavar='FLOAT',                              type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
@click.option('--S_churn', 'S_churn',      help='Stochasticity strength', metavar='FLOAT',                          type=click.FloatRange(min=0), default=0, show_default=True)
@click.option('--S_min', 'S_min',          help='Stoch. min noise level', metavar='FLOAT',                          type=click.FloatRange(min=0), default=0, show_default=True)
@click.option('--S_max', 'S_max',          help='Stoch. max noise level', metavar='FLOAT',                          type=click.FloatRange(min=0), default='inf', show_default=True)
@click.option('--S_noise', 'S_noise',      help='Stoch. noise inflation', metavar='FLOAT',                          type=float, default=1, show_default=True)
@click.option('--viz_samples',             help='Number of sample sto vizualize', metavar='INT',                    type=click.IntRange(min=0), default=10)
@click.option('--solver',                  help='Ablate ODE solver', metavar='euler|heun',                          type=click.Choice(['euler', 'heun']))
@click.option('--disc', 'discretization',  help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm']))
@click.option('--schedule',                help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear',           type=click.Choice(['vp', 've', 'linear']))
@click.option('--scaling',                 help='Ablate signal scaling s(t)', metavar='vp|none',                    type=click.Choice(['vp', 'none']))
@click.option('--test_mode',               help='Run test for full obs or sparse or noisy observations', type=click.Choice(['full', 'sparse', 'noisy']), default='full')
@click.option('--re_paint',               help='Use RePaint-style iterative sampling for sparse observations', type=bool, default=False)
@click.option('--test_uncond',            help='For conditional models, run unconditional sampling during testing', type=bool, default=False)
@click.option('--use_dps',                 help='Use DPS-style gradient-based sampling', type=bool, default=False)
@click.option('--obs_guidance_weight',     help='Weight for observation guidance in DPS, can be a single value or a comma-separated list for channel-specific weights', type=str, default='10000,10000')
@click.option('--pde_guidance_weight',     help='Weight for PDE guidance in DPS', type=float, default=0.0)
@click.option('--loss_type',                help='Loss type for DPS guidance', type=click.Choice(['mse', 'l1', 'l2', 'huber']), default='mse')
@click.option('--sparsity_ratio',          help='Ratio of elements to mask(0.97 for 3% observed)', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.0, show_default=True)
@click.option('--mask_filling_mode',       help='How to fill masked areas: noise, mean, nearest_neighbor, grid_based, or hybrid_grf', type=click.Choice(['noise', 'mean', 'nearest_neighbor', 'grid_based', 'hybrid_grf', 'additive_noise', 'zero_noise','zero','none','gaussian','gaussian_blend']), default='noise')
@click.option('--use_grf_latents',         help='Use GRF-based latents instead of standard Gaussian noise', type=bool, default=False)
@click.option('--rbf_noise', help='Use RBF-based noise', type=bool, default=False)        
@click.option('--save_predictions', help='Save predicted fields at each timestep along with PDE residuals', type=bool, default=False)
@click.option('--noise_magnitude', help='Magnitude of noise to add for noisy observations', metavar='FLOAT', type=click.FloatRange(min=0), default=1.0, show_default=True)
@click.option('--normalizer_path', help='Path to the normalizer', type=str, default=None)

def main(checkpoint_dir, outdir, resolution, test_direction, data, seeds, max_batch_size, num_workers, 
         viz_samples, kimg_intervals, offset, num, num_steps,test_mode,
         sparsity_ratio, mask_filling_mode, use_grf_latents=False, rbf_noise=False, device=torch.device('cuda'),
         save_predictions=False, noise_magnitude=1.0, normalizer_path=None, **sampler_kwargs):
    """Evaluate checkpoints saved at the specified interval and plot results."""
    # dist.init()

    # Load dataset because we need to be able to sample y's to condition on.
    seed = 33
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)


    exp_dir = checkpoint_dir
    config = dnnlib.EasyDict(json.loads(
        open(os.path.join(exp_dir, "training_options.json"), "r").read() 
    ))

    # dist.print0('Loading dataset...')
    print("Loading dataset...")
    # provide test data path, PDE direction specified in previous args already
    # Update path in config.dataset_kwargs
    config.dataset_kwargs['path'] = data 
    if num is not None:
        config.dataset_kwargs['num'] = num
        config.dataset_kwargs['offset'] = offset
    
    if normalizer_path is not None:
        config.dataset_kwargs['normalizer_path'] = normalizer_path
    else:
        normalizer_path = os.path.join(checkpoint_dir)
    
    if resolution is not None:
        if resolution == config.dataset_kwargs['resolution']:
            config.dataset_kwargs['normalizer_path'] = normalizer_path
        test_resolution = resolution
    else:
        test_resolution =config.dataset_kwargs['resolution']
        # config.dataset_kwargs['normalizer_path'] = normalizer_path
    config.dataset_kwargs['resolution'] = test_resolution
    


    dataset_obj = dnnlib.util.construct_class_by_name(**config.dataset_kwargs) # subclass of training.dataset.Dataset
    if hasattr(dataset_obj, 'set_training'):
        dataset_obj.set_training(False)
    dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, seed=33, shuffle=False) 
    data_loader_kwargs = dnnlib.EasyDict(
        pin_memory=True, 
        num_workers=num_workers,
        prefetch_factor=2       # what is this?
    )
    dist.print0("Loading noise sampler...")
    noise_sampler_kwargs = dnnlib.EasyDict(config.sampler_kwargs)
    noise_sampler_kwargs.n_in = dataset_obj.num_channels
    noise_sampler_kwargs.device = device
   
    noise_sampler_kwargs.Ln1 = test_resolution#needto make it work for test case res
    noise_sampler_kwargs.Ln2 = dataset_obj.resolution
    noise_sampler = dnnlib.util.construct_class_by_name(**noise_sampler_kwargs)

    effective_pde_direction = dataset_obj.pde_direction
    print(f"Dataset PDE direction: {effective_pde_direction}")
    training_mode = config.dataset_kwargs.get('training_mode', 'conditional')
    if training_mode == 'unified':
        # assert dataset_obj.pde_direction == 'forward' or dataset_obj.pde_direction == 'unified' # 'Unified model assumes forward direction while calculating residual
        # If test_direction is provided, it overrides dataset's direction for unified models
        effective_pde_direction = test_direction if test_direction else effective_pde_direction
        print(f"Using test direction: {effective_pde_direction} ({'specified by user' if test_direction else 'from dataset'})")
    if training_mode == 'conditional':
        effective_pde_direction = effective_pde_direction
        print(f"Using dataset direction: {effective_pde_direction}")

    normalize_pde_residual = config.loss_kwargs.get('normalize_pde_residual', False)
    sampler_kwargs["normalize_pde_residual"] = normalize_pde_residual
    
    # Process obs_guidance_weight if it's a string
    if sampler_kwargs.get('use_dps', False):
        if 'obs_guidance_weight' in sampler_kwargs and isinstance(sampler_kwargs['obs_guidance_weight'], str):
            try:
                # Parse comma-separated string into list of floats
                weights = [float(w) for w in sampler_kwargs['obs_guidance_weight'].split(',')]
                
                # Ensure we have at least 2 weights (one for each channel in unified model)
                if len(weights) < 2:
                    # If only one weight provided, duplicate it for both channels
                    weights = [weights[0], weights[0]]
                    print(f"Duplicating single weight {weights[0]} for both channels.")
                
                sampler_kwargs['obs_guidance_weight'] = weights
                print(f"Using observation weights: {weights}")
            except ValueError:
                # If parsing fails, use a default value
                print(f"Warning: Could not parse obs_guidance_weight '{sampler_kwargs['obs_guidance_weight']}'. Using default [10000, 10000].")
                sampler_kwargs['obs_guidance_weight'] = [10000, 10000]
        
    if training_mode == 'conditional':
        #pass masks if model trained with a mask. mask can be full if sparse conditiong false
        pass_masks_to_model = config.network_kwargs.get('use_sparse_conditioning', False)
        # if not config.network_kwargs.has('training_mode'):
        #     config.network_kwargs['training_mode'] = 'conditional'
    if training_mode == 'unified':
        # unified model always trained with masked conditioning
        pass_masks_to_model = True
        sampler_kwargs['track_losses'] = True
    sampler_kwargs['pass_masks_to_model'] = pass_masks_to_model
    if training_mode == 'unified' and config.loss_kwargs.get('guided_pde_residual_mode', False):
        sampler_kwargs['guided_pde_residual_mode'] = True
    else:
        sampler_kwargs['guided_pde_residual_mode'] = False
    


    # dist.print0("Sampler kwargs: {}".format(sampler_kwargs))

    if not outdir:
        # print("no outdir specified, using checkpoint_dir")
        outdir = checkpoint_dir
    # print('outdir: ', outdir)
    # breakpoint()
    # outdir = os.path.join(outdir, f"model_validation_steps{num_steps}")
    if offset is not None and num is not None:
        outdir = os.path.join(outdir, f"model_validation_offset{offset}_num{num}")
        outdir = f"{outdir}_steps{num_steps}"
    else:
        outdir = os.path.join(outdir, f"model_validation_steps{num_steps}")
    outdir = os.path.join(outdir, test_mode)
    if use_grf_latents:
        outdir = outdir + "_grf_latents"
    if rbf_noise:
        outdir = outdir + "_rbf_noise"
    use_dps = sampler_kwargs.get('use_dps', False)
    re_paint = sampler_kwargs.get('re_paint', False)
    test_uncond = sampler_kwargs.get('test_uncond', False)
    if re_paint:
        outdir = outdir + "_repaint"
    if test_uncond:
        outdir = outdir + "_uncond"
    if use_dps:
        outdir = outdir + "_dps"
        key = "dps"
        if 'obs_guidance_weight' in sampler_kwargs:
            obs_weights = sampler_kwargs['obs_guidance_weight']
            if isinstance(obs_weights, list):
                weights_str = '_'.join(str(w) for w in obs_weights)
                key = key + f"_obs{weights_str}"
            else:
                key = key + f"_obs{obs_weights}"
        if sampler_kwargs.get('pde_guidance_weight', 0) > 0:
            key = key + f"_pde{sampler_kwargs.get('pde_guidance_weight', 0)}"
        key = key + f"{sampler_kwargs.get('loss_type', 'mse')}_loss"
        outdir = os.path.join(outdir, key)
        # if normalize_pde_residual:
        #     outdir = outdir + "_norm_pde_residual"
    if training_mode == 'unified':
        outdir = os.path.join(outdir, f"unified_{effective_pde_direction}_mode")
    if sparsity_ratio > 0:
            outdir = os.path.join(outdir, f"sparsity_ratio{int(sparsity_ratio * 100)}_{mask_filling_mode}")
    if not os.path.exists(outdir):
        os.makedirs(outdir)
        print("Created output directory: ", outdir)

    checkpoints = sorted(
        [os.path.join(checkpoint_dir, f) for f in os.listdir(checkpoint_dir) if f.endswith('.pkl')],
        key=lambda x: int(x.split('-')[-1].split('.')[0])
    )
    
    kimg_values = []
    norm_relative_errors = []
    denorm_relative_errors = []

    # Extract the KImg value from each checkpoint filename
    kimg_checkpoints = {int(ckpt.split('-')[-1].split('.')[0]): ckpt for ckpt in checkpoints}
    print("Testing at intervals: ", kimg_intervals)

    # Find the closest available checkpoint for each specified KImg interval
    for target_kimg in tqdm.tqdm(kimg_intervals, desc="Evaluating Checkpoints", unit="kimg"):
        closest_kimg = min(kimg_checkpoints.keys(), key=lambda x: abs(x - target_kimg))
        checkpoint_path = kimg_checkpoints[closest_kimg]

        dataset_iterator = iter(
        torch.utils.data.DataLoader(
            dataset=dataset_obj, 
            # sampler=dataset_sampler, # commenting the infinite sampler for validation
            batch_size=max_batch_size,  # only return `t_max` images
            **data_loader_kwargs
        )
    )

        print(f"Evaluating model closest to {target_kimg} KImg, actual checkpoint at {closest_kimg} KImg.")
        norm_mean_relative_error, denorm_mean_relative_error = evaluate_checkpoint(
            network_pkl=checkpoint_path, device=device, batch_size=max_batch_size, training_mode=training_mode,
            dataset_obj=dataset_obj, dataset_iterator=dataset_iterator, effective_pde_direction=effective_pde_direction,
            outdir=outdir, model_kimg=closest_kimg, sampler_kwargs=sampler_kwargs,
            plot=True, output_mat=False, viz_samples=viz_samples, debug=False, test_mode=test_mode,
            sparsity_ratio=sparsity_ratio, mask_filling_mode=mask_filling_mode, noise_magnitude=noise_magnitude,
            use_grf_latents=use_grf_latents, noise_sampler=noise_sampler, rbf_noise=rbf_noise, num_steps=num_steps,
            save_predictions=save_predictions, test_resolution=test_resolution)

        # Store the results
        kimg_values.append(closest_kimg)
        norm_relative_errors.append(norm_mean_relative_error)
        print(f"KImg: {closest_kimg}, Normalized Mean Relative Error: {norm_mean_relative_error}")
        denorm_relative_errors.append(denorm_mean_relative_error)
        print(f"KImg: {closest_kimg}, Denormalized Mean Relative Error: {denorm_mean_relative_error}")


    # Plot relative error over KImg
    plt.figure(figsize=(10, 6))
    plt.plot(kimg_values, norm_relative_errors, label="Normalized Mean Relative Error on Test Set", marker='o')
    plt.xlabel("KImg (Thousands of Images Seen)")
    plt.ylabel("Normalized Mean Relative Error")
    plt.title("Model Evaluation on Test Set Across Checkpoints")
    plt.legend()
    plt.grid()
    plot_name = "norm_relative_error_kimg.png"
    plt.savefig(os.path.join(outdir, plot_name))
    print(f"Plot saved to {os.path.join(outdir, plot_name)}")
    plt.show()

    # Plot relative error over KImg
    plt.figure(figsize=(10, 6))
    plt.plot(kimg_values, denorm_relative_errors, label="Denormalized Mean Relative Error on Test Set", marker='o')
    plt.xlabel("KImg (Thousands of Images Seen)")
    plt.ylabel("Denormalized Mean Relative Error")
    plt.title("Model Evaluation on Test Set Across Checkpoints")
    plt.legend()
    plt.grid()
    plot_name = "denorm_relative_error_kimg.png"
    plt.savefig(os.path.join(outdir, plot_name))
    print(f"Plot saved to {os.path.join(outdir, plot_name)}")
    plt.show()

    # Identify the best model
    best_index = np.argmin(denorm_relative_errors)
    best_kimg = kimg_values[best_index]
    best_error = denorm_relative_errors[best_index]
    print(f"Best model at {best_kimg} KImg with mean relative error: {best_error}")




def evaluate_checkpoint(network_pkl=None, net=None, device=None, batch_size=None, training_mode="conditional", dataset_obj=None, dataset_iterator=None,
                    effective_pde_direction=None, outdir=None, model_kimg=None, sampler_kwargs=None, plot=False, output_mat=False, 
                    viz_samples=None, debug=False, test_mode='full', sparsity_ratio=0.0, noise_magnitude=1.0, mask_filling_mode='noise', use_grf_latents=False,
                    noise_sampler=None, rbf_noise=False, num_steps=18, save_predictions=False, pde_res_tracking=False, test_resolution=None): 
    """Evaluate a checkpoint with sparse observations.
    Args:
        ...existing args...
        sparsity_ratio: Ratio of elements to mask (default: 0.0 for no masking)
        noise_magnitude: Magnitude of noise to add to masked regions (default: 1.0)
    """
    # Configure PyTorch memory management for FFT operations
    use_dps = sampler_kwargs.get('use_dps', False) if sampler_kwargs else False
    if use_dps:
        # Set environment variable to help with memory fragmentation
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
        # Enable PyTorch memory profiling
        torch.cuda.reset_peak_memory_stats()
        
        # Force cudnn benchmark off for more memory stability
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        
        # Print initial memory stats
        print(f"Initial GPU Memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, "
              f"{torch.cuda.memory_reserved() / 1024**3:.2f} GB reserved")
    
    # Rank 0 goes first.
    # if dist.get_rank() != 0:
    #     return

    # Ensure either network_pkl or net is provided
    assert network_pkl is not None or net is not None, "Either 'network_pkl' or 'net' must be provided."

    # If network_pkl is given, load the model
    if network_pkl is not None:
        # dist.print0(f'Loading network from "{network_pkl}"...')
        with dnnlib.util.open_url(network_pkl) as f:
            net = pickle.load(f)['ema'].to(device)

    # Ensure network is properly loaded
    assert net is not None, "Failed to load network. Check 'network_pkl' or pass 'net' explicitly."
    print("Model type: ", net.model_type)
    net.img_resolution = test_resolution

    # dist.print0("Number of params: {}".format(misc.count_parameters(net)))
    torch.cuda.empty_cache()
    total_batches = len(dataset_obj) // batch_size
    # Metric storage for normalized and denormalized values
    metric_keys = [
                    "relative_error_l2", "componentwise_relative_error_l2", "relative_error_l1", 
                    "componentwise_relative_error_l1", "mean_absolute_error", "mean_squared_error", 
                    "pde_residuals_mean", "pde_residuals_norm", "vrmse",
                    # Additional PDE residual metrics
                    "pde_residuals_kurtosis", "pde_residuals_skewness", "pde_residuals_huber",
                    "pde_boundary1_mean", "pde_boundary2_mean",
                  ]
    if dataset_obj.name == "DiffusionPDEDarcyDataset" and effective_pde_direction == "inverse":
        metric_keys += ["coefficient_error_rate"]
    metrics_dict_norm = {key: [] for key in metric_keys}
    metrics_dict_denorm = {key: [] for key in metric_keys}

    random_reservoir, worst_reservoir = [], []
    sample_errors = []  # List to store sample-wise errors for sorting later
    total_samples = 0
    if pde_res_tracking:
        # Initialize PDE residual tracking for plotting
        pde_residual_data = {
            'timestep_means': [],
            'timestep_mins': [],
            'timestep_maxs': [],
            'timestep_norms': [],
            'num_timesteps': None,
            'all_batch_data': [],  # Store data from all batches for median and std calculation
            'sample_timestep_residuals': [],  # Store PDE residual images for each sample across timesteps
            'sample_mean_residuals': []  # Store mean residual for each sample across all timesteps
        }
    pde_predictions_data = {
        'sample_timestep_predictions': [],  # Store PDE predictions images for each sample across timesteps
    }
    
    if save_predictions:
        predictions_dir = os.path.join(outdir, f"{model_kimg}_kimg", "batch_predictions")
        os.makedirs(predictions_dir, exist_ok=True)
        print(f"Will save batch predictions to: {predictions_dir}")

    # for batch_idx, (images_real, labels) in enumerate(dataset_iterator):
    # breakpoint()
    for batch_idx, (images_real, labels,mask) in enumerate(tqdm.tqdm(dataset_iterator, desc="Processing Batches", unit="batch", total=total_batches)):
        batch_start_time = time.time()
        #ignore mask from dataloader for now
        # images_real, labels = next(dataset_iterator)
        # print("images shape: ", images.shape)
        # print("labels shape: ", labels.shape)
        # print(f"Processing batch {batch_idx} with shape:", images_real.shape)
        if training_mode == "conditional":
            labels = rearrange(labels, 'bs h w -> bs 1 h w')
            images_real = rearrange(images_real, 'bs h w -> bs 1 h w')
        # print("Training mode: ", training_mode)
        # print("images shape: ", images_real.shape)
        # print("labels shape: ", labels.shape)

        images_real = images_real.to(device, non_blocking=True).to(torch.float32) # / 127.5 - 1
        labels = labels.to(device, non_blocking=True).to(torch.float32)

        if test_mode in ['full']:
            if training_mode == "conditional":
                assert images_real.shape == labels.shape, f"Image and label shapes must match for conditional models in full observation mode. Got {images_real.shape} and {labels.shape}."
                # all observations available, no masking
                masks = torch.ones_like(labels).to(device)
            if training_mode == "unified":
                assert images_real.shape == labels.shape, f"Image and label shapes must match for unified models in full observation mode. Got {images_real.shape} and {labels.shape}."
                if effective_pde_direction == 'forward':
                    # For forward: a is observed (1), u is predicted (0)
                    m_a = torch.ones_like(images_real[:, 0:1]).to(device)
                    m_u = torch.zeros_like(images_real[:, 0:1]).to(device)
                else:  # inverse
                    # For inverse: u is observed (1), a is predicted (0)
                    m_a = torch.zeros_like(images_real[:, 0:1]).to(device)
                    m_u = torch.ones_like(images_real[:, 0:1]).to(device)
                    
                # Concatenate masks for unified model input
                masks = torch.cat([m_a, m_u], dim=1)
                labels = images_real * masks
            # print("Full observation mode: no masking applied.")
            # print("Labels min/max:", labels.min().item(), labels.max().item())
            # print("Masks min/max:", masks.min().item(), masks.max().item())
            # print("Number of observed elements per sample:", masks.view(masks.size(0), -1).sum(dim=1))
            # print("Masks shape:", masks.shape)
            # print("Masks channel 0 min/max:", masks[:,0:1].min().item(), masks[:,0:1].max().item())
            # print("Masks channel 1 min/max:", masks[:,1:2].min().item(), masks[:,1:2].max().item())
            # print("Number of observed elements per sample channel 0:", masks[:,0:1].view(masks.size(0), -1).sum(dim=1)  )
            # print("Number of observed elements per sample channel 1:", masks[:,1:2].view(masks.size(0), -1).sum(dim=1) )    
            # print("Labels shape:", labels.shape)
            # print("Labels channel 0 min/max:", labels[:,0:1].min().item(), labels[:,0:1].max().item())
            # print("Labels channel 1 min/max:", labels[:,1:2].min().item(), labels[:,1:2].max().item())
            # print("Number of elements per sample channel 0:", labels[:,0:1].view(labels.size(0), -1).sum(dim=1)  )
            # print("Number of elements per sample channel 1:", labels[:,1:2].view(labels.size(0), -1).sum(dim=1) )   

        # Apply sparsification if sparsity_ratio > 0
        elif sparsity_ratio > 0 and test_mode in ['noisy']:
            if training_mode == "conditional":
                # Add noise to a fraction of the labels and send to model
                labels = get_masked_noisy_labels(labels, sparsity_ratio, mask_filling_mode, noise_magnitude)
                # labels_denorm = dataset_obj.denorm_input(labels)
                # labels_denorm_noisy = get_masked_noisy_labels(labels_denorm, sparsity_ratio, mask_filling_mode, noise_magnitude)
                # labels = dataset_obj.norm_input(labels_denorm_noisy)
                masks = torch.ones_like(labels).to(device)
            elif training_mode == "unified":
                # For unified models, apply noise only to the observed component based on direction
                if effective_pde_direction == 'forward':
                    # Forward: Add noise to a (channel 0, observed), leave u (channel 1, predicted) alone
                    a_noisy = get_masked_noisy_labels(images_real[:, 0:1], sparsity_ratio, mask_filling_mode, noise_magnitude)
                    # images_real_denorm = dataset_obj.denorm_tensor(images_real,device)
                    # a_noisy_denorm = get_masked_noisy_labels(images_real_denorm[:, 0:1], sparsity_ratio, mask_filling_mode, noise_magnitude)
                    # a_noisy = dataset_obj.norm_tensor(torch.cat([a_noisy_denorm, u_zeros], dim=1))[:, 0:1]
                    u_zeros = torch.zeros_like(images_real[:, 1:2]).to(device)
                    labels = torch.cat([a_noisy, u_zeros], dim=1)
                    # Create masks with 1s for a (observed) and 0s for u (predicted)
                    m_a = torch.ones_like(images_real[:, 0:1]).to(device)
                    m_u = torch.zeros_like(images_real[:, 1:2]).to(device)
                else:  # inverse
                    # Inverse: Add noise to u (channel 1, observed), leave a (channel 0, predicted) alone
                    u_noisy = get_masked_noisy_labels(images_real[:, 1:2], sparsity_ratio, mask_filling_mode, noise_magnitude)
                    # images_real_denorm = dataset_obj.denorm_tensor(images_real, device=device)
                    # u_noisy_denorm = get_masked_noisy_labels(images_real_denorm[:, 1:2], sparsity_ratio, mask_filling_mode, noise_magnitude)
                    # u_noisy = dataset_obj.norm_tensor(torch.cat([u_noisy_denorm, u_noisy_denorm], dim=1))[:, 1:2]
                    a_zeros = torch.zeros_like(images_real[:, 0:1]).to(device)
                    labels = torch.cat([a_zeros, u_noisy], dim=1)
                    # Create masks with 0s for a (predicted) and 1s for u (observed)
                    m_a = torch.zeros_like(images_real[:, 0:1]).to(device)
                    m_u = torch.ones_like(images_real[:, 1:2]).to(device)
                # Concatenate masks for unified model input
                masks = torch.cat([m_a, m_u], dim=1)
            # print(f"Noisy observation mode: {sparsity_ratio*100}% of labels replaced with noisy values.")
            
        elif sparsity_ratio > 0 and test_mode in ['sparse']:
            if training_mode == "conditional":
                # Mask a fraction of the labels and generate masks and send to model
                masks = get_random_masks(labels, sparsity_ratio).to(device)
                labels = apply_masks(labels, masks, mode=mask_filling_mode).to(device)
                masks = torch.ones_like(labels).to(device)
            elif training_mode == "unified":
                # For unified models, apply sparsity only to the observed component based on direction
                if effective_pde_direction == 'forward':
                    # Forward: Sparsify a (channel 0, observed), leave u (channel 1, predicted) alone
                    a_masks = get_random_masks(images_real[:, 0:1], sparsity_ratio).to(device)
                    a_sparse = apply_masks(images_real[:, 0:1], a_masks, mode=mask_filling_mode).to(device)
                    u_zeros = torch.zeros_like(images_real[:, 1:2]).to(device)
                    u_masks = get_random_masks(images_real[:, 1:2], sparsity_ratio).to(device)
                    u_sparse = apply_masks(images_real[:, 1:2], u_masks, mode=mask_filling_mode).to(device)
                    # For labels: channel 0 has sparse a, channel 1 is all zeros (to be predicted)
                    labels = torch.cat([a_sparse, u_zeros], dim=1)
                    # Create masks with  channel 0 has 1s only where observations exist, channel 1 all zeros
                    m_u = torch.zeros_like(images_real[:, 1:2]).to(device)
                    masks = torch.cat([a_masks, m_u], dim=1)
                else:  # inverse
                    # Inverse: Sparsify u (channel 1, observed), leave a (channel 0, predicted) alone
                    u_masks = get_random_masks(images_real[:, 1:2], sparsity_ratio).to(device)
                    u_sparse = apply_masks(images_real[:, 1:2], u_masks, mode=mask_filling_mode).to(device)
                    a_zeros = torch.zeros_like(images_real[:, 0:1]).to(device)
                    labels = torch.cat([a_zeros, u_sparse], dim=1)
                    # Create masks with 0s for a (predicted) and 1s for u (observed)
                    m_a = torch.zeros_like(images_real[:, 1:2]).to(device)
                    masks = torch.cat([m_a, u_masks], dim=1)

            # print(f"Sparse observation mode: {mask_filling_mode} with {sparsity_ratio*100}% of elements masked.")
            # print("Sparse observation mode: ", mask_filling_mode, f"with {sparsity_ratio*100}% of elements masked.")
            # print("Labels min/max:", labels.min().item(), labels.max().item())
            # print("Masks min/max:", masks.min().item(), masks.max().item())
            # print("Number of observed elements per sample:", masks.view(masks.size(0), -1).sum(dim=1))
            
        if use_grf_latents:
            # Generate GRF-based latents instead of standard Gaussian noise
            latents = generate_grf_noise([batch_size, net.img_channels, test_resolution, test_resolution],
                           device=device, length_scale=0.2)
        elif rbf_noise:
            # print("Using RBF-based noise for latents")
            latents = noise_sampler.sample(batch_size).to(device)
        else:
            latents = torch.randn([batch_size, net.img_channels, test_resolution, test_resolution], device=device)
        
        sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
        have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
        sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler #TODO understand if detereministics or nomal sampler
        use_dps = sampler_kwargs.get('use_dps', False)

        # Set PyTorch memory allocation configuration for FFT operations
        if use_dps:
            # Set environment variable to help with memory fragmentation
            os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
            # Force garbage collection before sampling
            gc.collect()
            torch.cuda.empty_cache()

        if use_dps and training_mode == "unified":
            # print("DPS sampling with unified model")
            sampler_fn = edm_dps_sampler
            print("using dps")
            # print("sampler kwargs for dps:", sampler_kwargs)
            # print("sampler fn:", sampler_fn)
        # breakpoint()
        # print("Using sampler function: ", sampler_fn.__name__)
        # sampler_fn = edm_sampler # for now
        if net.model_type == "SongUNOResidual":
            # print("Updating sampler kwargs for PDE residual mode")
            sampler_kwargs.update({
                "pde_residual_mode": True,
                "pde_loss_function": dataset_obj.pde_loss_function,  # Assuming dataset object contains PDE type info
                "pde_direction": effective_pde_direction,
                "compute_pde_loss_fn": evaluation_utils.compute_pde_loss,  # Ensure function is passed
                "device": device,
                "dataset_obj": dataset_obj,
                "num_steps": num_steps,
                "masks": masks,
                "training_mode": training_mode
            })
        
        # Use context manager based on use_dps flag - this automatically handles enabling/disabling gradients
        with torch.set_grad_enabled(use_dps):
            net.eval()
            images_pred, batch_pde_stats, loss_history = sampler_fn(net, latents, labels, **sampler_kwargs)
            images_pred = images_pred.to(torch.float32)
            gc.collect()
            torch.cuda.empty_cache()
            # print(f"[DEBUG] Gradients enabled: {use_dps}, Torch grad enabled: {torch.is_grad_enabled()}")
            # After running edm_dps_sampler and getting results
            # Create a detached copy for visualization/metrics without affecting gradient computation
        with torch.no_grad():
            images_pred_vis = images_pred.clone()
            # Move batch_pde_stats data to CPU to free GPU memory
            if batch_pde_stats is not None:
                for key in batch_pde_stats:
                    if isinstance(batch_pde_stats[key], list) and key in ['timestep_stats', 'timestep_residuals', 'timestep_predictions']:
                        for item in batch_pde_stats[key]:
                            for k, v in item.items():
                                if isinstance(v, torch.Tensor):
                                    item[k] = v.detach().cpu()
        # if loss_history:
        #     # Plot losses regardless of which sampler was used
        #     loss_plot_path = os.path.join(outdir, f"{sampler_fn.__name__}_losses_sample_{batch_idx}.png")
        #     plot_dps_losses(loss_history, loss_plot_path)
                
        with torch.no_grad():
            # breakpoint()
            # Process and store PDE residual statistics
            if pde_res_tracking:
                if pde_residual_data['num_timesteps'] is None:
                    pde_residual_data['num_timesteps'] = batch_pde_stats['num_timesteps']
                    # Initialize aggregation lists
                    for t in range(pde_residual_data['num_timesteps']):
                        pde_residual_data['timestep_means'].append([])
                        pde_residual_data['timestep_mins'].append([])
                        pde_residual_data['timestep_maxs'].append([])
                        pde_residual_data['timestep_norms'].append([])
            
                # Store batch data for later median and std calculation
                pde_residual_data['all_batch_data'].append(batch_pde_stats)
            
                # Aggregate statistics across timesteps
                for timestep_data in batch_pde_stats['timestep_stats']:
                    t = timestep_data['timestep']
                    if t < pde_residual_data['num_timesteps']:  # Safety check
                        for sample_stats in timestep_data['batch_stats']:
                            pde_residual_data['timestep_means'][t].append(sample_stats['mean'])
                            pde_residual_data['timestep_mins'][t].append(sample_stats['min'])
                            pde_residual_data['timestep_maxs'][t].append(sample_stats['max'])
                            pde_residual_data['timestep_norms'][t].append(sample_stats['norm'])
            
                # Process and store PDE residual images for each sample across timesteps
                for timestep_img_data in batch_pde_stats['timestep_residuals']:
                    t = timestep_img_data['timestep'] 
                    pde_residual_images = timestep_img_data['pde_residual_images']  # Shape: [batch_size, 1, H, W]
                    
                    # Store each sample's PDE residual image at this timestep
                    for sample_idx in range(pde_residual_images.shape[0]):
                        # Calculate mean residual for this sample at this timestep for later ranking
                        sample_image = pde_residual_images[sample_idx:sample_idx+1]  # Keep batch dim
                        sample_mean_residual = sample_image.mean().item()
                        
                        # Store in format for later processing
                        sample_residual_data = {
                            'batch_idx': batch_idx,
                            'sample_idx': sample_idx,
                            'global_sample_idx': total_samples + sample_idx,  # Global index across all batches
                            'timestep': t,
                            'pde_residual_image': sample_image.cpu(),  # Move to CPU to save GPU memory
                            'mean_residual': sample_mean_residual
                        }
                        pde_residual_data['sample_timestep_residuals'].append(sample_residual_data)

            # breakpoint()

                # Process and store PDE predictions images for each sample across timesteps
                for timestep_img_data in batch_pde_stats['timestep_predictions']:
                    t = timestep_img_data['timestep'] 
                    pde_predictions_images = timestep_img_data['pred_images']  # Shape: [batch_size, 1, H, W]
                    # breakpoint()

                    # Store each sample's PDE predictions image at this timestep
                    for sample_idx in range(pde_predictions_images.shape[0]):
                        # Calculate mean prediction for this sample at this timestep for later ranking
                        sample_image = pde_predictions_images[sample_idx:sample_idx+1]  # Keep batch dim
                        sample_mean_prediction = sample_image.mean().item()
                        
                        # Store in format for later processing
                        sample_predictions_data = {
                            'batch_idx': batch_idx,
                            'sample_idx': sample_idx,
                            'global_sample_idx': total_samples + sample_idx,  # Global index across all batches
                            'timestep': t,
                            'pde_predictions_image': sample_image.cpu(),  # Move to CPU to save GPU memory
                            'mean_prediction': sample_mean_prediction
                        }
                        pde_predictions_data['sample_timestep_predictions'].append(sample_predictions_data)
            # breakpoint()

                # Store overall mean residual for each sample in this batch (for ranking samples)
                for sample_idx in range(images_pred.shape[0]):
                    # Calculate mean residual across all timesteps for this sample
                    sample_mean_residuals = []
                    for timestep_img_data in batch_pde_stats['timestep_residuals']:
                        pde_residual_images = timestep_img_data['pde_residual_images']
                        sample_image = pde_residual_images[sample_idx:sample_idx+1]
                        sample_mean_residuals.append(sample_image.mean().item())
                    
                    overall_sample_mean = np.mean(sample_mean_residuals)
                    pde_residual_data['sample_mean_residuals'].append({
                        'global_sample_idx': total_samples + sample_idx,
                        'batch_idx': batch_idx,
                        'sample_idx': sample_idx,
                        'mean_residual': overall_sample_mean
                    })

        sampling_time = time.time() 
        # Offload outputs to CPU ASAP
        images_real = images_real
        labels = labels
        images_pred = images_pred

        # Denormalize images
        if training_mode == "conditional":
            images_real_denorm = dataset_obj.denorm_output(images_real)
            images_pred_denorm = dataset_obj.denorm_output(images_pred)
            labels_denorm = dataset_obj.denorm_input(labels)
        if training_mode == "unified":
            images_real_denorm = dataset_obj.denorm_tensor(images_real)  
            images_pred_denorm = dataset_obj.denorm_tensor(images_pred)  
            labels_denorm = dataset_obj.denorm_tensor(labels)  

        # Save batch predictions if enabled
        if save_predictions:
            # Convert to numpy and move to CPU
            if training_mode == "conditional":
                predictions_np = images_pred_denorm.cpu().numpy()  # Shape: [batch_size, 1, H, W]
                labels_np = labels_denorm.cpu().numpy()           # Shape: [batch_size, 1, H, W]
            
                if effective_pde_direction == "forward":
                    # Forward: input=labels, output=images
                    pde_input_output_pred = np.concatenate([labels_np, predictions_np], axis=1)  # [batch_size, 2, H, W]
                else:  # inverse
                    # Inverse: input=images, output=labels  
                    pde_input_output_pred = np.concatenate([predictions_np, labels_np], axis=1)  # [batch_size, 2, H, W]
            if training_mode == "unified":
                predictions_np = images_pred_denorm.cpu().numpy()  # Shape: [batch_size, 2, H, W]
                pde_input_output_pred = predictions_np  # [batch_size, 2, H, W]
            # Save batch files
            pred_filename = f"batch_{batch_idx:04d}.npy"
            np.save(os.path.join(predictions_dir, pred_filename), pde_input_output_pred)

        # Iterate over each sample in the batch
        for i in range(images_real.shape[0]):
            sample_real = images_real[i].unsqueeze(0)
            sample_pred = images_pred[i].unsqueeze(0)
            sample_label = labels[i].unsqueeze(0)

            # Define the component to evaluate based on the training mode and direction
            if training_mode == "unified":
                # For unified models, we need to choose which component to evaluate based on direction
                if effective_pde_direction == "forward":
                    # Forward: We're predicting u (channel 1)
                    sample_pred_component = sample_pred[:, 1:2]
                    sample_real_component = sample_real[:, 1:2]
                else:  # inverse
                    # Inverse: We're predicting a (channel 0)
                    sample_pred_component = sample_pred[:, 0:1]
                    sample_real_component = sample_real[:, 0:1]
                # Compute metrics for the specific component we're evaluating
                metrics_norm = calculate_metrics(sample_pred_component, sample_real_component)
        
            if training_mode == "conditional":  # conditional mode
                # Use the entire sample for conditional models
                metrics_norm = calculate_metrics(sample_pred, sample_real)
            
            for key in list(metrics_norm.keys()):
                metrics_dict_norm[key].append(metrics_norm[key].mean().item())

            # Compute PDE residuals
            if training_mode == "conditional":
                pde_loss = evaluation_utils.compute_pde_loss(dataset_obj.pde_loss_function, effective_pde_direction,
                                        sample_pred, sample_label, device=device, training_mode=training_mode)
            if training_mode == "unified":
                # pde residual is calculated using predicted input channel and output channel
                pde_loss = evaluation_utils.compute_pde_loss(dataset_obj.pde_loss_function, effective_pde_direction,
                                        sample_pred[:,0:1], sample_pred[:,1:2], device=device, training_mode=training_mode)
            pde_res_norm = torch.norm(pde_loss).item()
            pde_res_mean = torch.mean(torch.abs(pde_loss)).item()

            # Flatten residual for moment computations
            pde_loss_flat = pde_loss.view(-1)
            # Skewness and kurtosis (Fisher definition: kurtosis of normal == 0)
            mean_r = pde_loss_flat.mean()
            std_r = pde_loss_flat.std(unbiased=False) + 1e-8
            centered = pde_loss_flat - mean_r
            skewness = (centered.pow(3).mean() / (std_r ** 3)).item()
            kurtosis = (centered.pow(4).mean() / (std_r ** 4)).item() - 3.0
            # Huber loss (delta=1.0)
            huber = torch.nn.functional.huber_loss(pde_loss, torch.zeros_like(pde_loss), delta=1.0, reduction='mean').item()
            # Boundary means (1-pixel and 2-pixel borders)
            # Extract 2D residual map regardless of dimensionality [H,W]
            if pde_loss.ndim == 4:
                res2d = pde_loss[0, 0]
                # breakpoint()
            elif pde_loss.ndim == 3:
                res2d = pde_loss[0]
                # breakpoint()
            elif pde_loss.ndim == 2:
                res2d = pde_loss
            else:
                res2d = pde_loss.view(1, -1)  # Fallback to avoid crash
                # breakpoint()
            h_r, w_r = res2d.shape[-2], res2d.shape[-1]

            # 1-pixel boundary mask
            b1 = torch.zeros(h_r, w_r, dtype=torch.bool, device=res2d.device)
            b1[0, :] = True; b1[-1, :] = True; b1[:, 0] = True; b1[:, -1] = True
            # 2-pixel boundary mask
            b2 = torch.zeros(h_r, w_r, dtype=torch.bool, device=res2d.device)
            b2[:2, :] = True; b2[-2:, :] = True; b2[:, :2] = True; b2[:, -2:] = True
            boundary1_mean = torch.mean(torch.abs(res2d[b1])).item()
            boundary2_mean = torch.mean(torch.abs(res2d[b2])).item()

            metrics_dict_norm["pde_residuals_norm"].append(pde_res_norm)
            metrics_dict_norm["pde_residuals_mean"].append(pde_res_mean)
            metrics_dict_norm["pde_residuals_skewness"].append(skewness)
            metrics_dict_norm["pde_residuals_kurtosis"].append(kurtosis)
            metrics_dict_norm["pde_residuals_huber"].append(huber)
            metrics_dict_norm["pde_boundary1_mean"].append(boundary1_mean)
            metrics_dict_norm["pde_boundary2_mean"].append(boundary2_mean)

            # Denormalize images
            # sample_real_denorm = dataset_obj.denorm_output(sample_real)
            # sample_pred_denorm = dataset_obj.denorm_output(sample_pred)
            # sample_label_denorm = dataset_obj.denorm_input(sample_label)
            sample_real_denorm = images_real_denorm[i].unsqueeze(0)
            sample_pred_denorm = images_pred_denorm[i].unsqueeze(0)
            sample_label_denorm = labels_denorm[i].unsqueeze(0)

            # Calculate denormalized metrics
            if training_mode == "unified":
                # For unified models, calculate metrics for the specific component we're evaluating
                if effective_pde_direction == "forward":
                    # Forward: We're predicting u (channel 1)
                    sample_pred_denorm_component = sample_pred_denorm[:, 1:2]
                    sample_real_denorm_component = sample_real_denorm[:, 1:2]
                    metrics_denorm = calculate_metrics(sample_pred_denorm_component, sample_real_denorm_component)
                else:  # inverse
                    # Inverse: We're predicting a (channel 0)
                    sample_pred_denorm_component = sample_pred_denorm[:, 0:1]
                    sample_real_denorm_component = sample_real_denorm[:, 0:1]
                    metrics_denorm = calculate_metrics(sample_pred_denorm_component, sample_real_denorm_component)

            if training_mode == "conditional":  # conditional mode
                metrics_denorm = calculate_metrics(sample_pred_denorm, sample_real_denorm)
                
            for key in list(metrics_norm.keys()):
                metrics_dict_denorm[key].append(metrics_denorm[key].mean().item())

            # Compute PDE residuals
            if training_mode == "conditional":
                pde_loss_denorm  = evaluation_utils.compute_pde_loss(dataset_obj.pde_loss_function, effective_pde_direction, 
                                                sample_pred_denorm, sample_label_denorm, device=device)
            if training_mode == "unified":
                # pde residual is calculated using predicted input channel and output channel
                pde_loss_denorm = evaluation_utils.compute_pde_loss(dataset_obj.pde_loss_function, effective_pde_direction,
                                        sample_pred_denorm[:,0:1], sample_pred_denorm[:,1:2], device=device, training_mode=training_mode)                                
            pde_res_norm = torch.norm(pde_loss_denorm).item()
            pde_res_mean = torch.mean(torch.abs(pde_loss_denorm)).item()

            # Flatten and compute moments
            pde_loss_denorm_flat = pde_loss_denorm.view(-1)
            mean_r_d = pde_loss_denorm_flat.mean()
            std_r_d = pde_loss_denorm_flat.std(unbiased=False) + 1e-8
            centered_d = pde_loss_denorm_flat - mean_r_d
            skewness_d = (centered_d.pow(3).mean() / (std_r_d ** 3)).item()
            kurtosis_d = (centered_d.pow(4).mean() / (std_r_d ** 4)).item() - 3.0
            huber_d = torch.nn.functional.huber_loss(pde_loss_denorm, torch.zeros_like(pde_loss_denorm), delta=1.0, reduction='mean').item()
            # Boundary means on denormalized residual map
            if pde_loss_denorm.ndim == 4:
                res2d_d = pde_loss_denorm[0, 0]
            elif pde_loss_denorm.ndim == 3:
                res2d_d = pde_loss_denorm[0]
            elif pde_loss_denorm.ndim == 2:
                res2d_d = pde_loss_denorm
            else:
                res2d_d = pde_loss_denorm.view(1, -1)
            h_d, w_d = res2d_d.shape[-2], res2d_d.shape[-1]
            b1_d = torch.zeros(h_d, w_d, dtype=torch.bool, device=res2d_d.device)
            b1_d[0, :] = True; b1_d[-1, :] = True; b1_d[:, 0] = True; b1_d[:, -1] = True
            b2_d = torch.zeros(h_d, w_d, dtype=torch.bool, device=res2d_d.device)
            b2_d[:2, :] = True; b2_d[-2:, :] = True; b2_d[:, :2] = True; b2_d[:, -2:] = True
            boundary1_mean_d = torch.mean(torch.abs(res2d_d[b1_d])).item()
            boundary2_mean_d = torch.mean(torch.abs(res2d_d[b2_d])).item()

            metrics_dict_denorm["pde_residuals_norm"].append(pde_res_norm)
            metrics_dict_denorm["pde_residuals_mean"].append(pde_res_mean)
            metrics_dict_denorm["pde_residuals_skewness"].append(skewness_d)
            metrics_dict_denorm["pde_residuals_kurtosis"].append(kurtosis_d)
            metrics_dict_denorm["pde_residuals_huber"].append(huber_d)
            metrics_dict_denorm["pde_boundary1_mean"].append(boundary1_mean_d)
            metrics_dict_denorm["pde_boundary2_mean"].append(boundary2_mean_d)
            if training_mode == "unified":
                if dataset_obj.name == "DiffusionPDEDarcyDataset" and effective_pde_direction == "inverse":
                    coefficient_error_rate = torch.mean(compute_coefficient_error_rate(sample_pred_denorm[:, 0:1] , sample_real_denorm[:, 0:1]))
                    metrics_dict_denorm["coefficient_error_rate"].append(coefficient_error_rate)
            else:
                if dataset_obj.name == "DiffusionPDEDarcyDataset" and effective_pde_direction == "inverse":
                    coefficient_error_rate = torch.mean(compute_coefficient_error_rate(sample_pred_denorm, sample_real_denorm))
                    metrics_dict_denorm["coefficient_error_rate"].append(coefficient_error_rate)
            
            metrics_time = time.time()
            denorm_rel_error_l2 = metrics_denorm["relative_error_l2"].mean().item()
            if training_mode == "unified":
                pde_input = sample_real_denorm[:, 0:1] 
                pde_output = sample_real_denorm[:, 1:2] 
                # prediction = sample_pred_denorm[:, 1:2] 
                if effective_pde_direction == "forward":
                    # Forward: input=labels, output=images
                    prediction = sample_pred_denorm[:, 1:2] 
                    difference = torch.abs(pde_output - prediction).squeeze(0).squeeze(0)
                else:  # inverse
                    # Inverse: input=images, output=labels
                    prediction = sample_pred_denorm[:, 0:1] 
                    difference = torch.abs(pde_input - prediction).squeeze(0).squeeze(0)
            if training_mode == "conditional":
                pde_input =  sample_label_denorm if effective_pde_direction == "forward" else sample_real_denorm
                pde_output = sample_real_denorm if effective_pde_direction == "forward" else sample_label_denorm
                prediction = sample_pred_denorm
                difference = torch.abs(sample_real_denorm - prediction).squeeze(0).squeeze(0)
            
            sample_details = {
                "id": total_samples+i,
                "pde_input":  pde_input.detach().cpu().squeeze(0),
                "pde_output": pde_output.detach().cpu().squeeze(0),
                "prediction": prediction.detach().cpu().squeeze(0),
                "difference": difference.detach().cpu(),
                "pde_residual_img": pde_loss_denorm.detach().cpu().squeeze(0),
                "rel_error_l2": denorm_rel_error_l2
            }

            # sample_details = {
            #     "id": total_samples+i,
            #     "pde_input":  sample_label_denorm.cpu().squeeze(0) if effective_pde_direction == "forward" else sample_real_denorm.cpu().squeeze(0),
            #     "pde_output": sample_real_denorm.cpu().squeeze(0) if effective_pde_direction == "forward" else sample_label_denorm.cpu().squeeze(0),
            #     "prediction": sample_pred_denorm.cpu().squeeze(0),
            #     "difference": torch.abs(sample_real_denorm.cpu() - sample_pred_denorm.cpu()).squeeze(0).squeeze(0),
            #     "pde_residual_img": pde_loss_denorm.cpu().squeeze(0).squeeze(0),
            #     "rel_error_l2": denorm_rel_error_l2
            # }
            sample_errors.append((denorm_rel_error_l2, sample_details))

        total_samples += images_real.shape[0]
        if batch_idx==15 and debug:
                print("Debugging mode, breaking after 15 batches")
                break
                
        # More aggressive memory cleanup
        del images_real, images_pred, labels, latents, sample_details
        if use_dps and 'images_pred_vis' in locals():
            del images_pred_vis
        if 'images_real_denorm' in locals():
            del images_real_denorm
        if 'images_pred_denorm' in locals():
            del images_pred_denorm
        if 'labels_denorm' in locals():
            del labels_denorm
        # Add these lines:
        if 'x_cur' in locals(): del x_cur
        if 'x_N' in locals(): del x_N 
        if 'x_next' in locals(): del x_next
        if 'pde_residual' in locals(): del pde_residual
        if 'pde_loss' in locals(): del pde_loss
        if 'pde_loss_denorm' in locals(): del pde_loss_denorm
        
        # Explicitly run garbage collection and empty CUDA cache after each batch when using DPS
        gc.collect()
        torch.cuda.empty_cache()
        
    # Sort by error to get the absolute worst samples
    sample_errors.sort(key=lambda x: x[0], reverse=True)
    worst_reservoir = [x[1] for x in sample_errors[:viz_samples]]

    # Randomly shuffle and select 10 samples for random reservoir
    random.shuffle(sample_errors)
    random_reservoir = [x[1] for x in sample_errors[:viz_samples]]  # Select 10 random samples


    results = {
    "normalized": {key: float(np.mean(metrics_dict_norm[key])) for key in metric_keys},
    "denormalized": {key: float(np.mean(metrics_dict_denorm[key])) for key in metric_keys},
    "total_num_samples": total_samples
}

    norm_rel_error_l2 = np.mean(metrics_dict_norm["relative_error_l2"])
    denorm_rel_error_l2 = np.mean(metrics_dict_denorm["relative_error_l2"])

    # Create a specific output directory for this checkpoint
    outdir = os.path.join(outdir, f"{model_kimg}_kimg")
    os.makedirs(outdir, exist_ok=True)
    output_json = os.path.join(outdir, f"{dataset_obj.print_name}_{effective_pde_direction}_eval_metrics.json")
    with open(output_json, 'w') as f:
        json.dump(results, f, indent=4)
    print(f"Evaluation metrics saved to {output_json}")

    # Save visualizations (PDFs)
    output_pdf = os.path.join(outdir, f"{dataset_obj.print_name}_{effective_pde_direction}_visualization.pdf")
    output_pdf_common = os.path.join(outdir, f"{dataset_obj.print_name}_{effective_pde_direction}_visualization_common.pdf")
    worst_pdf_path = os.path.join(outdir, f"{dataset_obj.print_name}_{effective_pde_direction}_visualization_worst.pdf")
    worst_pdf_path_common = os.path.join(outdir, f"{dataset_obj.print_name}_{effective_pde_direction}_visualization_worst_common.pdf")

    # Ensure we have samples to visualize before attempting to create PDFs
    print(f"Random reservoir size: {len(random_reservoir)}")
    print(f"Worst reservoir size: {len(worst_reservoir)}")
    if len(random_reservoir) > 0:
        print("Checking tensor shapes in random_reservoir[0]:")
        for key, tensor in random_reservoir[0].items():
            if isinstance(tensor, torch.Tensor):
                print(f"  {key}: {tensor.shape}")
            elif isinstance(tensor, np.ndarray):
                print(f"  {key}: {tensor.shape}")
            else:
                print(f"  {key}: {type(tensor)}")
         # Check if any tensor has zero elements in first dimension
        zero_dim_tensors = []
        for sample in random_reservoir:
            for key, tensor in sample.items():
                if isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.shape[0] == 0:
                    zero_dim_tensors.append((key, tensor.shape))
        
        if zero_dim_tensors:
            print("Found tensors with zero elements in first dimension:")
            for key, shape in zero_dim_tensors:
                print(f"  {key}: {shape}")
    

    if len(random_reservoir) > 0:
        save_samples_to_pdf(random_reservoir, pdf_file_path=output_pdf, pde_direction=effective_pde_direction, 
        pde_print_name=dataset_obj.print_name, common_scale=False, transform_error=True, 
        cmap='viridis',  extra_viz=['sobel', 'binary', 'discrete', 'discrete2'])

    if len(worst_reservoir) > 0:
        save_samples_to_pdf(worst_reservoir, pdf_file_path=worst_pdf_path, pde_direction=effective_pde_direction, 
        pde_print_name=dataset_obj.print_name, common_scale=False, transform_error=True, 
        cmap='viridis', extra_viz=['sobel', 'binary', 'discrete', 'discrete2'])

    viz_time = time.time()
    # dist.print0(f'Processing model {model_kimg} complete. All data saved.')
    print(f'Processing model {model_kimg} complete. All data saved.')
    
    # Generate PDE residual evolution plots if data is available
    # if pde_residual_data['num_timesteps'] is not None:
    # breakpoint()
    if pde_res_tracking:
        plot_process(pde_residual_data, 
                    outdir, 
                    dataset_obj.print_name, 
                    effective_pde_direction, 
                    k=6, 
                    n_samples=4, 
                    pde_predictions_data=pde_predictions_data)
        # plot_pde_residual_evolution(pde_residual_data, outdir, model_kimg, 
        #                           dataset_obj.print_name, effective_pde_direction)
        
        # # Generate PDE residual timestep image plots
        # plot_pde_residual_timestep_images(pde_residual_data, outdir, k=5)
        # Generate simple process images: predictions + residuals across time
    
    torch.cuda.empty_cache()

    with open(os.path.join(outdir, "eval_done.txt"), "w") as f:
        f.write("Evaluation completed successfully!\n")
    print("Evaluation Done!")

    print(
        f"[Batch {batch_idx}] Time (s): Sampling = {sampling_time - batch_start_time:.2f}, "
        f"Metrics = {metrics_time - sampling_time:.2f}, "
        f"Visualization+Storage = {viz_time - metrics_time:.2f}, "
        f"Total = {viz_time - batch_start_time:.2f}"
    )




    return norm_rel_error_l2, denorm_rel_error_l2

def generate_grf_noise(shape, device, length_scale=0.1):
    """Generate Gaussian Random Field noise with physical correlation structure."""
    # Get dimensions
    batch_size, channels, height, width = shape
    
    # Create frequency grid
    freq_y = torch.fft.fftfreq(height, d=1/height).to(device)
    freq_x = torch.fft.fftfreq(width, d=1/width).to(device)
    freq_grid_y, freq_grid_x = torch.meshgrid(freq_y, freq_x, indexing='ij')
    
    # Compute squared distance from origin in frequency space
    squared_dist = freq_grid_y.pow(2) + freq_grid_x.pow(2)
    
    # Create power spectrum (defines spatial correlation)
    # Using squared exponential kernel: exp(-k²/(2*α²)) where α=1/length_scale
    power_spectrum = torch.exp(-squared_dist * (length_scale**2)/2)
    
    # Initialize noise tensor
    noise = torch.zeros(shape, device=device)
    
    # Generate GRF for each sample and channel
    for b in range(batch_size):
        for c in range(channels):
            # Create white noise in Fourier space
            white_noise_real = torch.randn(height, width, device=device)
            white_noise_imag = torch.randn(height, width, device=device)
            white_noise_fourier = torch.complex(white_noise_real, white_noise_imag)
            
            # Apply power spectrum to correlate noise
            colored_noise_fourier = white_noise_fourier * power_spectrum.sqrt()
            
            # Inverse FFT to get spatial noise
            colored_noise = torch.fft.ifft2(colored_noise_fourier).real
            
            # Normalize to have unit variance
            colored_noise = colored_noise / colored_noise.std()
            
            # Add to noise tensor
            noise[b, c] = colored_noise
    
    return noise

def get_masked_noisy_labels(labels, sparsity_ratio, mask_filling_mode, noise_magnitude):
    if mask_filling_mode == 'grid_based':
        # Calculate grid step to achieve desired sparsity ratio
        grid_step = max(1, int(1/np.sqrt(1-sparsity_ratio)))
        
        # Create a regular grid mask (much more effective for PDEs)
        mask = torch.zeros_like(labels, device=labels.device)
        mask[:, :, ::grid_step, ::grid_step] = 1.0
        
        # Create sparse labels with regular grid observations
        labels_sparse = labels.clone()
        labels_sparse = labels * mask
        
        # Use max pooling to propagate values from observed points to nearby regions
        filled = torch.nn.functional.max_pool2d(
            labels_sparse,
            kernel_size=grid_step*2+1, 
            stride=1,
            padding=grid_step
        )
        
        # Where original values exist, keep them
        labels = torch.where(mask > 0, labels, filled)
    elif mask_filling_mode == 'hybrid_grf':
        # print(f"Starting hybrid_grf processing for resolution {labels.shape}")
        # Physics-informed hybrid approach using GRF
        
        # Create random mask for sparse observations
        mask = torch.rand_like(labels) > sparsity_ratio
        labels_sparse = labels.clone()
        
        # Add a small random subset of structured points to ensure coverage
        grid_step = max(8, int(4/np.sqrt(1-sparsity_ratio)))  # Coarser grid for structure
        grid_mask = torch.zeros_like(labels, device=labels.device)
        grid_mask[:, :, ::grid_step, ::grid_step] = 1.0
        
        # Combine random and structured sampling
        mask = torch.logical_or(mask, grid_mask)
        
        # For each sample in the batch
        for b in range(labels.shape[0]):
            # print(f"Processing batch {b}/{labels.shape[0]}")
            # Extract observed points and values
            observed = mask[b, 0].cpu().numpy()
            y_obs, x_obs = np.where(observed)
            # print(f"Number of observed points: {len(y_obs)}")
            # Only interpolate if we have enough observations
            if len(y_obs) > 10:
                values = labels[b, 0, y_obs, x_obs].cpu().numpy()
                
                # Use Radial Basis Function interpolation (physics-informed for PDE fields)
                # Set epsilon based on physical domain size and expected correlation length
                # Multiquadric function works well for smooth PDE solutions
                try:
                    # print(f"Starting RBF interpolation with {len(values)} points")
                    rbf = Rbf(x_obs, y_obs, values, function='multiquadric', 
                                epsilon=min(grid_step, 5.0), smooth=0.03)
                    # print("RBF created, generating grid")
                    # Create grid for interpolation
                    y_grid, x_grid = np.mgrid[0:labels.shape[2], 0:labels.shape[3]]
                    # print("Starting interpolation")
                    # Get interpolated values for entire field
                    interp_values = rbf(x_grid, y_grid)
                    # print("Interpolation complete")
                    
                    # Update with interpolated field but keep original observed values
                    interp_tensor = torch.tensor(interp_values, device=labels.device)
                    labels_sparse[b, 0] = torch.where(
                        mask[b, 0] > 0, 
                        labels[b, 0],  # Keep original observed values
                        interp_tensor  # Use interpolated values elsewhere
                    )
                except Exception as e:
                    # Fallback to mean if RBF fails (e.g., if points are collinear)
                    # print(f"RBF interpolation failed, using mean: {e}")
                    mean_value = values.mean()
                    labels_sparse[b, 0] = torch.where(
                        mask[b, 0] > 0,
                        labels[b, 0], 
                        torch.tensor(mean_value, device=labels.device)
                    )
        # Use the GRF interpolated field
        labels = labels_sparse
    else:
        # Original random masking approach
        mask = torch.rand_like(labels) > sparsity_ratio
        
        # Create a copy of labels for sparse version
        labels_sparse = labels.clone()
        
        if mask_filling_mode == 'noise':
            # Generate noise for masked regions
            noise = torch.randn_like(labels) * noise_magnitude
            # Create sparse labels by combining observed values and noise
            labels_sparse = labels * mask + noise * (~mask)

        elif mask_filling_mode == 'additive_noise':
            # Instead of replacing pixels, add noise to a fraction of them
            # Create noise mask where True means "add noise to this pixel"
            noise_mask = torch.rand_like(labels) < sparsity_ratio
            
            # Generate noise scaled by noise_magnitude
            noise = torch.randn_like(labels) * noise_magnitude
            
            # Add noise to selected pixels (rather than replacing them)
            # This preserves the original signal but corrupts it with noise
            labels_sparse = labels + noise * noise_mask
            
        elif mask_filling_mode == 'mean':
            # Use mean value for masked regions
            mean_value = labels.mean().item()
            labels_sparse = labels * mask + mean_value * (~mask)
        
        elif mask_filling_mode == 'nearest_neighbor':
            # For each batch item
            for b in range(labels.shape[0]):
                # Extract the mask for this sample
                batch_mask = mask[b, 0]  # Assuming shape [B, 1, H, W]
                
                # Get indices of observed points
                observed_indices = torch.nonzero(batch_mask)
                
                # Only proceed if we have observed points
                if observed_indices.shape[0] > 0:
                    # Get values at observed locations
                    observed_values = labels[b, 0][batch_mask]
                    
                    # Get indices of masked (unobserved) points
                    masked_indices = torch.nonzero(~batch_mask)
                    
                    # For each masked point, find nearest observed point
                    for i in range(masked_indices.shape[0]):
                        # Calculate distances to all observed points
                        point = masked_indices[i].float()
                        dists = torch.sum((observed_indices.float() - point.unsqueeze(0))**2, dim=1)
                        
                        # Find index of nearest observed point
                        nearest_idx = torch.argmin(dists)
                        
                        # Get coordinates of the masked point
                        y, x = masked_indices[i]
                        
                        # Replace with nearest neighbor value
                        labels_sparse[b, 0, y, x] = observed_values[nearest_idx]

        # Update labels with the sparse version
        labels = labels_sparse

    return labels

def get_random_masks(labels, sparsity_ratio):
    """Generate random binary masks with given sparsity ratio."""
    # Create random mask
    mask = torch.rand_like(labels) > sparsity_ratio
    return mask.float()

def apply_masks(labels, masks, mode='zero_noise', noise_scale=0.01):
    """Apply mask to input with configurable fill strategy"""
    if mode == "zero":
        # Fill masked regions with zeros
        return labels * masks
    elif mode == "mean":
        # Fill masked regions with mean value of observed (unmasked) regions
        observed_values = labels * masks
        # Calculate mean only from observed values (where mask == 1)
        observed_sum = observed_values.sum(dim=(-2, -1), keepdim=True)
        observed_count = masks.sum(dim=(-2, -1), keepdim=True)
        # Avoid division by zero
        mean_value = torch.where(observed_count > 0, 
                               observed_sum / observed_count, 
                               labels.mean())
        return labels * masks + mean_value * (1 - masks)
    elif mode == "noise":
        # Fill masked regions with small noise
        noise = torch.randn_like(labels) * noise_scale
        return labels * masks + noise * (1 - masks)
    elif mode == "zero_noise":
        # Fill masked regions with zeros plus very small noise for stability
        # This helps prevent potential issues during training while keeping values close to zero
        small_noise = torch.randn_like(labels) * (noise_scale * 0.1)  # 10x smaller noise
        return labels * masks + small_noise * (1 - masks)
    else:
        # Default: zero filling
        print(f"Unknown mask filling mode '{mode}', defaulting to no fill.")
        return labels




def _ensure_odd(k):
    return int(k) if int(k) % 2 == 1 else int(k) + 1

def gaussian_blur_torch(x, kernel_size=7, sigma=2.0):
    """
    x: [B, C, H, W] tensor
    Returns blurred x with channel-wise depthwise conv.
    """
    k = _ensure_odd(kernel_size)
    device = x.device
    dtype  = x.dtype

    # 1D Gaussian
    coords = torch.arange(k, device=device, dtype=dtype) - k // 2
    g = torch.exp(-(coords**2) / (2 * (sigma**2)))
    g = g / g.sum()

    # 2D separable kernel
    g2d = (g[:, None] * g[None, :]).to(dtype)
    g2d = (g2d / g2d.sum()).view(1, 1, k, k)            # [1,1,k,k]

    # depthwise conv over channels
    B, C, H, W = x.shape
    kernel = g2d.repeat(C, 1, 1, 1)                     # [C,1,k,k]
    return F.conv2d(x, kernel, padding=k//2, groups=C)

def mean_impute(labels, masks):
    """
    labels, masks: [B, C, H, W], mask ∈ {0,1}
    Impute masked regions with mean of observed values (per-sample, per-channel).
    """
    observed = labels * masks
    obs_sum  = observed.sum(dim=(-2, -1), keepdim=True)
    obs_cnt  = masks.sum(dim=(-2, -1), keepdim=True).clamp_min(1.0)
    mean_val = obs_sum / obs_cnt
    return observed + mean_val * (1.0 - masks)

# --- main ---
def apply_masks(
    labels, masks, mode='zero_noise', noise_scale=0.01,
    kernel_size=7, sigma=2.0
):
    """
    labels, masks: [B, C, H, W]
    modes:
      - "zero", "mean", "noise", "zero_noise" (your originals)
      - "gaussian": mean-impute then Gaussian blur the whole field
      - "gaussian_blend": blur the mask into a soft confidence map and blend
    """
    if mode == "zero":
        return labels * masks

    elif mode == "mean":
        return mean_impute(labels, masks)

    elif mode == "noise":
        noise = torch.randn_like(labels) * noise_scale
        return labels * masks + noise * (1 - masks)

    elif mode == "zero_noise":
        small_noise = torch.randn_like(labels) * (noise_scale * 0.1)
        return labels * masks + small_noise * (1 - masks)

    elif mode == "gaussian":
        # mean-impute, then smooth to remove hard edges/ringing
        filled = mean_impute(labels, masks)
        return gaussian_blur_torch(filled, kernel_size=kernel_size, sigma=sigma)

    elif mode == "gaussian_blend":
        # confidence-weighted feathering: conf ~ blurred mask
        # 1) mean-impute to get a plausible fill
        filled = mean_impute(labels, masks)
        # 2) soft confidence from blurred mask (broadcast to channels)
        conf = gaussian_blur_torch(masks, kernel_size=kernel_size, sigma=sigma).clamp(0, 1)
        # 3) blend observed vs filled using soft conf
        return conf * labels + (1.0 - conf) * filled

    else:
        print(f"Unknown mask filling mode '{mode}', defaulting to no fill.")
        return labels
#----------------------------------------------------------------------------

if __name__ == "__main__":
    main()

#----------------------------------------------------------------------------
