import os
import json
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from scipy.ndimage import sobel
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
# -------------------------------------------------------------------------------
# Error metric functions (using PyTorch)
# -------------------------------------------------------------------------------
def compute_relative_error(predicted, true, p=2):
    """
    Compute the relative error between predicted and true tensors for each sample.
    """
    assert predicted.shape == true.shape, "Predicted and true must have the same shape."
    num_examples = predicted.size(0)
    error_norm = torch.norm(predicted.reshape(num_examples, -1) - true.reshape(num_examples, -1), p=p, dim=1)
    true_norm = torch.norm(true.reshape(num_examples, -1), p=p, dim=1)
    relative_error = error_norm / true_norm
    return relative_error

def compute_relative_error_componentwise(predicted, true, p=2):
    """
    Compute a componentwise relative error.
    The function first computes the ratio of the error to the true value for each component. 
    In other words, each element’s error is normalized by its corresponding true value.
    Then, it computes the norm (across components) of these per-element relative errors.
    """
    assert predicted.shape == true.shape, "Predicted and true must have the same shape."
    num_examples = predicted.size(0)
    error = (predicted.reshape(num_examples, -1) - true.reshape(num_examples, -1)) / true.reshape(num_examples, -1)
    error_norm = torch.norm(error, p=p, dim=1)
    return error_norm

def compute_mean_absolute_error(predicted, true):
    """
    Compute the mean absolute error (MAE) per sample.
    """
    assert predicted.shape == true.shape, "Predicted and true must have the same shape."
    mae = torch.mean(torch.abs(predicted - true), dim=tuple(range(1, predicted.dim())))
    return mae

def compute_mean_squared_error(predicted, true):
    """
    Compute the mean squared error (MSE) per sample.
    """
    assert predicted.shape == true.shape, "Predicted and true must have the same shape."
    mse = torch.mean((predicted - true) ** 2, dim=tuple(range(1, predicted.dim())))
    return mse

def compute_vrmse(predicted, true, eps=1e-7):
    """
    Compute the Variance-Scaled Root Mean Squared Error (VRMSE) per sample.

    Args:
        predicted (torch.Tensor): Predicted tensor of shape (B, ...).
        true (torch.Tensor): Ground truth tensor of shape (B, ...).
        eps (float): Small epsilon to avoid division by zero.

    Returns:
        torch.Tensor: VRMSE per sample (shape: [B]).
    """
    assert predicted.shape == true.shape, "Predicted and true must have the same shape."
    
    # Flatten spatial dimensions: keep batch dimension
    dims = tuple(range(1, predicted.dim()))  # dimensions to reduce over (spatial dims)
    
    # Compute per-sample MSE
    mse = torch.mean((predicted - true) ** 2, dim=dims)
    # Compute per-sample mean of true
    true_mean_std = torch.std(true, dim=dims, keepdim=True) ** 2
    
    # Compute VRMSE
    vrmse = torch.sqrt(mse / (true_mean_std + eps))
    
    return vrmse

# def compute_coefficient_error_rate(predicted, true):
#     """
#     Compute the coefficient error rate for Darcy Flow.
#     """
#     assert predicted.shape == true.shape, "Predicted and true must have the same shape."
#     print("predicted shape:", predicted.shape)
#     predicted = predicted.to('cpu')
#     true = true.to('cpu')
#     predicted[predicted>7.5] = 12 # a is binary
#     predicted[predicted<=7.5] = 3
#     error_rate_a = 1 - torch.sum(predicted==true) / (128*128) #this is for one sample
#     print("Predicted (after thresholding):", predicted)
#     print("Error rate:", error_rate_a)

#     # Compute per-sample error rate
#     total_pixels = predicted[0].numel()  # or H * W
#     print(total_pixels)
#     error_rate_per_sample = 1 - (predicted == true).sum(dim=(1, 2)) / total_pixels
#     print("error_rate per sample shape:", error_rate_per_sample.shape) # BX
#     avg_error_rate = error_rate_per_sample.mean().item()

#     print("Average error rate:", avg_error_rate)
#     return error_rate_per_sample

def compute_coefficient_error_rate(predicted, true):
    """
    Compute the average coefficient error rate for Darcy Flow.
    Supports input shapes (B, H, W) and (B, 1, H, W).
    """
    assert predicted.shape == true.shape, "Predicted and true must have the same shape."

    # Move to CPU and clone
    predicted = predicted.to('cpu').clone()
    true = true.to('cpu')

    # If there's a singleton channel dimension (B, 1, H, W), squeeze it out
    if predicted.ndim == 4 and predicted.shape[1] == 1:
        predicted = predicted.squeeze(1)
        true = true.squeeze(1)

    # Binary thresholding: map to {3, 12}
    predicted[predicted > 7.5] = 12
    predicted[predicted <= 7.5] = 3

    # Total pixels per sample
    total_pixels = predicted.shape[1] * predicted.shape[2]  # H * W

    # Compute per-sample error rate
    correct_per_sample = (predicted == true).sum(dim=(1, 2))
    error_rate_per_sample = 1 - correct_per_sample.float() / total_pixels

    avg_error_rate = error_rate_per_sample.mean().item()

    # print("Predicted shape after squeeze:", predicted.shape)
    # print("Total pixels per sample:", total_pixels)
    # print("Error rate per sample shape:", error_rate_per_sample.shape)
    # print("Error rate per sample:", error_rate_per_sample)
    # print("Average error rate:", avg_error_rate)

    return error_rate_per_sample


def calculate_metrics(predicted, true):
    """
    Calculate several error metrics between predicted and true tensors.
    Returns a dictionary of metrics.
    """
    metrics = {
        "relative_error_l2": compute_relative_error(predicted, true, p=2),
        "componentwise_relative_error_l2": compute_relative_error_componentwise(predicted, true, p=2),
        "relative_error_l1": compute_relative_error(predicted, true, p=1),
        "componentwise_relative_error_l1": compute_relative_error_componentwise(predicted, true, p=1),
        "mean_absolute_error": compute_mean_absolute_error(predicted, true),
        "mean_squared_error": compute_mean_squared_error(predicted, true),
        "vrmse": compute_vrmse(predicted, true)
    }
    return metrics

def accumulate_metrics(metrics_dict, metrics):
    """
    Update the metrics dictionary with new batch results.

    Args:
    - metrics_dict (dict): A dictionary storing accumulated metrics.
    - metrics (dict): A dictionary with current batch metrics.

    Returns:
    - Updated metrics_dict
    """
    for key in metrics:
        metrics_dict[key].append(metrics[key].mean().item())
    return metrics_dict

# -------------------------------------------------------------------------------
# Residual functions (using PyTorch)
# -------------------------------------------------------------------------------

def get_darcy_loss_fno(pde_input, pde_output, device=torch.device('cpu'), D=1):
    """
    a, u  ->  [B, 1, H, W]  (batched version)
    Returns the PDE residual for each sample in the batch, shape [B, H, W].
    """
    a = pde_input
    u = pde_output

    batchsize = u.size(0)
    size = u.size(2)
    u = u.squeeze(1)  # (B, H, W)
    a = a.squeeze(1)

    dx = D / (size - 1)
    dy = dx

    # ux: (batch, size-2, size-2)
    ux = (u[:, 2:, 1:-1] - u[:, :-2, 1:-1]) / (2 * dx)
    uy = (u[:, 1:-1, 2:] - u[:, 1:-1, :-2]) / (2 * dy)

    # ax = (a[:, 2:, 1:-1] - a[:, :-2, 1:-1]) / (2 * dx)
    # ay = (a[:, 1:-1, 2:] - a[:, 1:-1, :-2]) / (2 * dy)
    # uxx = (u[:, 2:, 1:-1] -2*u[:,1:-1,1:-1] +u[:, :-2, 1:-1]) / (dx**2)
    # uyy = (u[:, 1:-1, 2:] -2*u[:,1:-1,1:-1] +u[:, 1:-1, :-2]) / (dy**2)

    a = a[:, 1:-1, 1:-1]
    # u = u[:, 1:-1, 1:-1]
    # Du = -(ax*ux + ay*uy + a*uxx + a*uyy)

    # inner1 = torch.mean(a*(ux**2 + uy**2), dim=[1,2])
    # inner2 = torch.mean(f*u, dim=[1,2])
    # return 0.5*inner1 - inner2

    aux = a * ux
    auy = a * uy
    auxx = (aux[:, 2:, 1:-1] - aux[:, :-2, 1:-1]) / (2 * dx)
    auyy = (auy[:, 1:-1, 2:] - auy[:, 1:-1, :-2]) / (2 * dy)
    Du = - (auxx + auyy)  # [B, H-4, W-4]

    f = torch.ones_like(Du) 
    residual = Du - f
    # pad back to [B, H, W]
    pad = (2, 2, 2, 2)  # (left, right, top, bottom)
    residual_padded = F.pad(residual, pad=pad, mode='constant', value=0)  # [B, H, W]

    return residual_padded

def get_darcy_loss_diffusion_pde(pde_input, pde_output, device=torch.device('cuda')):
    """
    a, u  ->  [B, 1, H, W]  (batched version)
    Returns the PDE residual for each sample in the batch, shape [B, H, W].
    """
    """Return the loss of the Darcy Flow equation"""
    # Build derivative filters just once, on the correct device
    a = pde_input.to(torch.float)
    u = pde_output.to(torch.float)
    u = u.to(device)
    a = a.to(device)
    deriv_x = torch.tensor([[-1, 0, 1]], dtype=torch.float, device=device).view(1, 1, 1, 3) / 2
    deriv_y = torch.tensor([[-1], [0], [1]], dtype=torch.float, device=device).view(1, 1, 3, 1) / 2

    # Perform conv2d for horizontal and vertical derivatives
    grad_x_next_x = F.conv2d(u, deriv_x, padding=(0, 1))
    grad_x_next_y = F.conv2d(u, deriv_y, padding=(1, 0))
    # Multiply gradient by 'a'
    grad_x_next_x = a * grad_x_next_x
    grad_x_next_y = a * grad_x_next_y
    # Second derivatives
    result = F.conv2d(grad_x_next_x, deriv_x, padding=(0, 1)) + F.conv2d(grad_x_next_y, deriv_y, padding=(1, 0))
    pde_loss = result + 1.0
    pde_loss = pde_loss.squeeze(1)
    # print(pde_loss.shape)
    return pde_loss

def get_burger_loss_diffusion_pde(pde_output, device=torch.device('cuda')):
    """Return the loss of the Burgers' equation"""
    
    u = pde_output.to(torch.float)
    
    u = u.view(1, 1, 128, 128)
    deriv_t = torch.tensor([[-1], [0], [1]], dtype=torch.float, device=device).view(1, 1, 3, 1) / 2 
    deriv_x = torch.tensor([[-1, 0, 1]], dtype=torch.float, device=device).view(1, 1, 1, 3) / 2 
    u_t = F.conv2d(u, deriv_t, padding=(1, 0)) 
    u_x = F.conv2d(u, deriv_x, padding=(0, 1)) 
    u_xx = F.conv2d(u_x, deriv_x, padding=(0, 1))

    pde_loss = u_t + u * u_x - 0.01 * u_xx
    pde_loss = pde_loss.squeeze()

    return pde_loss

def get_helmholtz_loss_diffusion_pde(pde_input, pde_output, device=torch.device('cuda')):
    """Return the loss of the Helmholtz equation."""
    a = pde_input
    u = pde_output

    # print("pde_input shape:", a.shape)
    # print("pde_output shape:", u.shape)

    S = u.size(2)
    h = 1 / (S - 1)
    a = a.view(-1, 1, S, S)
    u = u.view(-1, 1, S, S) ## added to fix dimension error
    u_padded = torch.nn.functional.pad(u, (1, 1, 1, 1), 'constant', 0)
    d2u = (u_padded[:, :, :-2, 1:-1] + u_padded[:, :, 2:, 1:-1] +
           u_padded[:, :, 1:-1, :-2] + u_padded[:, :, 1:-1, 2:] - 4 * u[:, :, :, :]) / h**2
    pde_loss = d2u + u - a
    pde_loss = pde_loss.squeeze(1)

    # Uncomment the following lines if you want to set boundary conditions to zero
    # pde_loss[0, :] = 0
    # pde_loss[-1, :] = 0
    # pde_loss[:, 0] = 0
    # pde_loss[:, -1] = 0
    
    return pde_loss

def get_poisson_loss_diffusion_pde(pde_input, pde_output, device=torch.device('cuda')):
    """Return the loss of the Poisson equation"""
    a = pde_input
    u = pde_output

    S = u.size(2)
    h = 1 / (S - 1)
    a = a.view(-1, 1, S, S)
    u = u.view(-1, 1, S, S) ## added to fix dimension error
    u_padded = torch.nn.functional.pad(u, (1, 1, 1, 1), 'constant', 0)
    d2u = (u_padded[:, :, :-2, 1:-1] + u_padded[:, :, 2:, 1:-1] +
           u_padded[:, :, 1:-1, :-2] + u_padded[:, :, 1:-1, 2:] - 4 * u[:, :, :, :]) / h**2
    pde_loss = d2u - a
    pde_loss = pde_loss.squeeze(1)
    # pde_loss[0, :] = 0
    # pde_loss[-1, :] = 0
    # pde_loss[:, 0] = 0
    # pde_loss[:, -1] = 0
    
    return pde_loss
    
def get_ns_nonbounded_loss_diffusion_pde(pde_input, pde_output, device=torch.device('cuda')):
    """Return the loss of the non-bounded NS equation and the observation loss."""
    a = pde_input.to(device=device, dtype=torch.float)
    u = pde_output.to(device=device, dtype=torch.float)

    deriv_x = torch.tensor([[-1, 0, 1]], dtype=torch.float, device=device).view(1, 1, 1, 3) / 2
    deriv_y = torch.tensor([[-1], [0], [1]], dtype=torch.float, device=device).view(1, 1, 3, 1) / 2
    grad_x_next_x = F.conv2d(u, deriv_x, padding=(0, 1))
    grad_x_next_y = F.conv2d(u, deriv_y, padding=(1, 0))
    pde_loss = grad_x_next_x + grad_x_next_y
    pde_loss = pde_loss.squeeze(1)
    # pde_loss[0, :] = 0
    # pde_loss[-1, :] = 0
    # pde_loss[:, 0] = 0
    # pde_loss[:, -1] = 0
    
    return pde_loss

def get_ns_bounded_loss_diffusion_pde(pde_input, pde_output, device=torch.device('cuda')):
    """Return the loss of the bounded NS equation"""
    a = pde_input.to(device=device, dtype=torch.float)
    u = pde_output.to(device=device, dtype=torch.float)

    deriv_x = torch.tensor([[-1, 0, 1]], dtype=torch.float, device=device).view(1, 1, 1, 3) / 2
    deriv_y = torch.tensor([[-1], [0], [1]], dtype=torch.float, device=device).view(1, 1, 3, 1) / 2
    grad_x_next_x = F.conv2d(u, deriv_x, padding=(0, 1))
    grad_x_next_y  = F.conv2d(u, deriv_y, padding=(1, 0))
    pde_loss = grad_x_next_x + grad_x_next_y
    # Only squeeze the channel dimension (dim=1), preserving batch dimension
    pde_loss = pde_loss.squeeze(1)

    # Uncomment the following lines if you want to set boundary conditions to zero and ignore them
    # pde_loss[0, :] = 0
    # pde_loss[-1, :] = 0
    # pde_loss[:, 0] = 0
    # pde_loss[:, -1] = 0
    
    return pde_loss

# def get_pde_loss_function(pde_name):
#     """
#     Retrieve the appropriate PDE loss function based on the PDE name.
    
#     Args:
#         pde_name (str): The name of the PDE (e.g., 'darcy', 'poisson', 'helmholtz', etc.)
    
#     Returns:
#         function: The corresponding PDE loss function.
#     """
#     # Function mapping for different PDE solvers
#     pde_loss_functions = {
#         'burgers': get_burger_loss,
#         'darcy': get_darcy_loss,
#         'poisson': get_poisson_loss,
#         'helmholtz': get_helmholtz_loss,
#         'ns-nonbounded': get_ns_nonbounded_loss,
#         'ns-bounded': get_ns_bounded_loss
#     }

#     # Convert PDE name to lowercase and retrieve function
#     pde_name = pde_name.lower()
#     if pde_name in pde_loss_functions:
#         return pde_loss_functions[pde_name]
#     else:
#         raise ValueError(f"Unknown PDE type: {pde_name}. Available options: {list(pde_loss_functions.keys())}")

# def dataset_to_type(dataset):
#     map = {"darcy_flow":"Darcy",
#             "helmholtz":"Helmholtz",
#             "poisson":"Poisson",}
#     return map[dataset]

def compute_pde_loss(pde_loss_fn, pde_direction, images_pred_denorm, labels_denorm, device=torch.device("cpu"), training_mode="conditional"):
    """
    Compute the PDE residual loss based on the PDE type and direction.

    Args:
        images_real_denorm (torch.Tensor): Ground-truth denormalized output.
        images_pred_denorm (torch.Tensor): Model-predicted denormalized output.
        labels_denorm (torch.Tensor): Denormalized input (conditioning variables).
        device (torch.device): Device for computation.

    Returns:
        torch.Tensor: Computed PDE loss.
    """

    # # Retrieve the PDE type and direction
    # pde_type = pde_type.lower()
    # pde_direction = pde_direction.lower()

    # # Get the appropriate PDE loss function
    # pde_loss_fn = get_pde_loss_function(pde_type)

    # Assign PDE input and output based on direction
    if training_mode == "conditional":
        if pde_direction == "forward":
            pde_input = labels_denorm   # Conditioning variable (e.g., permeability in Darcy flow)
            pde_output = images_pred_denorm  # Model output (predicted field)
        elif pde_direction == "inverse":
            pde_input = images_pred_denorm  # Model prediction serves as input
            pde_output = labels_denorm  # Ground truth serves as output
        else:
            raise ValueError(f"Unknown PDE direction: {pde_direction}. Must be 'forward' or 'inverse'.")
    
    if training_mode == "unified":
        # print("Unified training mode: using predicted image as PDE input and label as PDE output.")
        pde_input = images_pred_denorm  # first argument is the pde input
        pde_output = labels_denorm  # second argument is the pde output
    # Compute the PDE residual loss
    return pde_loss_fn(pde_input=pde_input, pde_output=pde_output, device=device)

# -------------------------------------------------------------------------------
# Visualisation Functions
# -------------------------------------------------------------------------------
def save_samples_to_pdf_old(samples, pdf_file_path, pde_direction='forward', pde_print_name="None", common_scale=False, transform_error=False, cmap='viridis', rows_per_page=10):
    """
    Save a list of sample dictionaries to a multi-page PDF.

    Each sample dictionary must have the following keys:
        - 'id'
        - 'pde_input'
        - 'pde_output'
        - 'prediction'
        - 'difference'
        - 'pde_residual_img'
        - 'rel_error_l2'

    Column ordering (5 columns):
    For forward:
        [PDE Input, PDE Output, Prediction, Difference, PDE Residual]
    For inverse:
        [PDE Output, PDE Input, Prediction, Difference, PDE Residual]

    The function always computes a common color scale for the two columns that are compared:
        - For forward: the scale for PDE Output (col 1) and Prediction (col 2) is computed per sample.
        - For inverse: the scale for PDE Input (col 1) and Prediction (col 2) is computed per sample.
    This is done irrespective of the value of common_scale.

    Additionally, if common_scale is True, a global common scale is computed for the
    "Difference" (col 3) and "PDE Residual" (col 4) columns over all samples.

    A text annotation (showing sample ID, relative L2 error, PDE type, and direction with first letters in caps)
    is placed below each row.

    The pdf_file_path should be the full path (including filename) where the PDF will be saved.
    """

    num_samples = len(samples)
    num_cols = 5  # 5 columns: [col0, col1, col2, col3, col4]
    error_cmap = 'inferno' if transform_error else cmap

    if pde_direction.lower() == "forward":
        col_titles = ["PDE Input", "PDE Output", "Prediction", "Difference", "PDE Residual"]
    else:
        col_titles = ["PDE Output", "PDE Input", "Prediction", "Difference", "PDE Residual"]
            # Transform error tries to apply a tanh transform to highligh high errors
    # If transform_error is True, add two more columns for transformed difference and residual.
    if transform_error:
        num_cols+=2
        col_titles +=["Transformed Difference", "Transformed PDE Residual"]
    


    # If common_scale is True, compute global min/max for "Difference" and "PDE Residual"
    if common_scale:
        diff_vals = []
        res_vals = []
        for sample in samples:
            diff = sample["difference"].cpu().numpy() if isinstance(sample["difference"], torch.Tensor) else np.array(sample["difference"])
            res = sample["pde_residual_img"].cpu().numpy() if isinstance(sample["pde_residual_img"], torch.Tensor) else np.array(sample["pde_residual_img"])
            diff_vals.append(diff)
            res_vals.append(res)
        common_diff_vmin = min(np.min(arr) for arr in diff_vals)
        common_diff_vmax = max(np.max(arr) for arr in diff_vals)
        common_res_vmin = min(np.min(arr) for arr in res_vals)
        common_res_vmax = max(np.max(arr) for arr in res_vals)

        # If transform_error is True, also compute scales for the transformed error maps.
        if transform_error:
            transformed_diff_vals = []
            transformed_res_vals = []
            for sample in samples:
                diff = (sample["difference"].cpu().numpy() if isinstance(sample["difference"], torch.Tensor)
                        else np.array(sample["difference"]))
                res = (sample["pde_residual_img"].cpu().numpy() if isinstance(sample["pde_residual_img"], torch.Tensor)
                        else np.array(sample["pde_residual_img"]))
                transformed_diff = np.log1p(np.abs(diff))
                transformed_res = np.log1p(np.abs(res))
                transformed_diff_vals.append(transformed_diff)
                transformed_res_vals.append(transformed_res)
            common_trans_diff_vmin = min(np.min(arr) for arr in transformed_diff_vals)
            common_trans_diff_vmax = max(np.max(arr) for arr in transformed_diff_vals)
            common_trans_res_vmin = min(np.min(arr) for arr in transformed_res_vals)
            common_trans_res_vmax = max(np.max(arr) for arr in transformed_res_vals)
        else:
            common_trans_diff_vmin = common_trans_diff_vmax = common_trans_res_vmin = common_trans_res_vmax = None
    else:
        common_diff_vmin = common_diff_vmax = common_res_vmin = common_res_vmax = None
        common_trans_diff_vmin = common_trans_diff_vmax = common_trans_res_vmin = common_trans_res_vmax = None


    with PdfPages(pdf_file_path) as pdf:
        for start_idx in range(0, num_samples, rows_per_page):
            end_idx = min(start_idx + rows_per_page, num_samples)
            current_samples = samples[start_idx:end_idx]
            nrows = len(current_samples)
            fig, axs = plt.subplots(nrows, num_cols, figsize=(num_cols * 3, nrows * 3))
            fig.tight_layout(pad=3.0)

            # Ensure axs is a 2D list even if only one row.
            if nrows == 1:
                axs = [axs]

            for row, sample in enumerate(current_samples):
                # Add text annotation below each row.
                annotation = (f"ID: {sample['id']} | Relative L2 Error: {sample['rel_error_l2']:.4f} | "
                            f"PDE: {pde_print_name.title()} | Direction: {pde_direction.title()}")
                
                # Compute per-sample common scale for the pair of images that must share scale.
                # For forward: use PDE Output and Prediction.
                # For inverse: use PDE Input and Prediction.
                if pde_direction.lower() == "forward":
                    img_for_common_1 = sample["pde_output"]
                else:
                    img_for_common_1 = sample["pde_input"]
                img_for_common_2 = sample["prediction"]
                # Convert to numpy arrays if necessary.
                if isinstance(img_for_common_1, torch.Tensor):
                    if img_for_common_1.ndim == 3 and img_for_common_1.shape[0] == 1:
                        img_for_common_1 = img_for_common_1.squeeze(0)
                    img_for_common_1 = img_for_common_1.cpu().numpy()
                else:
                    img_for_common_1 = np.array(img_for_common_1)
                if isinstance(img_for_common_2, torch.Tensor):
                    if img_for_common_2.ndim == 3 and img_for_common_2.shape[0] == 1:
                        img_for_common_2 = img_for_common_2.squeeze(0)
                    img_for_common_2 = img_for_common_2.cpu().numpy()
                else:
                    img_for_common_2 = np.array(img_for_common_2)
                common_pair_vmin = min(np.min(img_for_common_1), np.min(img_for_common_2))
                common_pair_vmax = max(np.max(img_for_common_1), np.max(img_for_common_2))
                
                # Determine ordering of images based on PDE direction.
                if pde_direction.lower() == "forward":
                    # Order: [PDE Input, PDE Output, Prediction, Difference, PDE Residual]
                    images = [sample["pde_input"], sample["pde_output"], sample["prediction"],
                            sample["difference"], sample["pde_residual_img"]]
                else:
                    # Order: [PDE Output, PDE Input, Prediction, Difference, PDE Residual]
                    images = [sample["pde_output"], sample["pde_input"], sample["prediction"],
                            sample["difference"], sample["pde_residual_img"]]


                if transform_error:
                    diff_val = sample["difference"]
                    if isinstance(diff_val, torch.Tensor):
                        diff_val = diff_val.cpu().numpy()
                    transformed_diff = np.log1p(np.abs(diff_val))
                    res_val = sample["pde_residual_img"]
                    if isinstance(res_val, torch.Tensor):
                        res_val = res_val.cpu().numpy()
                    transformed_res = np.log1p(np.abs(res_val))
                    images.append(transformed_diff)
                    images.append(transformed_res)

                for col in range(num_cols):
                    ax = axs[row][col]
                    img = images[col]
                    # For each column, fetch the appropriate image.
                    # For columns 0-4, use the "images" list (which is set based on pde_direction)
                    # If transform_error is True, then for col 5 use transformed difference and for col 6 use transformed residual.
                    
                    if col < 5:
                        img = images[col]
                    else:
                        # For col 5 (transformed difference) and col 6 (transformed residual)
                        if col == 5:
                            img = sample["difference"]
                            img = np.log1p(np.abs(img.cpu().numpy())) if isinstance(img, torch.Tensor) else np.log1p(np.abs(np.array(img)))
                        elif col == 6:
                            img = sample["pde_residual_img"]
                            img = np.log1p(np.abs(img.cpu().numpy())) if isinstance(img, torch.Tensor) else np.log1p(np.abs(np.array(img)))

                    # Convert torch.Tensor to numpy if necessary.
                    if col < 5 and isinstance(img, torch.Tensor):
                        if img.ndim == 3 and img.shape[0] == 1:
                            img = img.squeeze(0)
                        img = img.cpu().numpy()
                    # For columns 1 and 2, use the per-sample common scale computed above.
                    if col in [1, 2]:
                        im = ax.imshow(img, cmap=cmap, vmin=common_pair_vmin, vmax=common_pair_vmax)
                    # For the Difference column (col 3), use the global common scale if enabled.
                    elif col == 3:
                        #  if transform_error:
                        #     img = np.log1p(np.abs(img))
                        if common_scale and common_diff_vmin is not None and common_diff_vmax is not None:
                            im = ax.imshow(img, cmap=cmap, vmin=common_diff_vmin, vmax=common_diff_vmax)
                        else:
                            im = ax.imshow(img, cmap=cmap)
                    # For the PDE Residual column (col 4), use the global common scale if enabled.
                    elif col == 4:
                        # if transform_error:
                        #     img = np.log1p(np.abs(img))
                        if common_scale and common_res_vmin is not None and common_res_vmax is not None:
                            im = ax.imshow(img, cmap=cmap, vmin=common_res_vmin, vmax=common_res_vmax)
                        else:
                            im = ax.imshow(img, cmap=cmap)
                    elif transform_error and col == 5:
                        if common_scale and common_trans_diff_vmin is not None and common_trans_diff_vmax is not None:
                            im = ax.imshow(img, cmap=error_cmap, vmin=common_trans_diff_vmin, vmax=common_trans_diff_vmax)
                        else:
                            im = ax.imshow(img, cmap=error_cmap)
                    elif transform_error and col == 6:
                        if common_scale and common_trans_res_vmin is not None and common_trans_res_vmax is not None:
                            im = ax.imshow(img, cmap=error_cmap, vmin=common_trans_res_vmin, vmax=common_trans_res_vmax)
                        else:
                            im = ax.imshow(img, cmap=error_cmap)
                    else:
                        im = ax.imshow(img, cmap=cmap)
                    ax.axis('off')
                    # Add a colorbar below each image.
                    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
                    cbar.ax.tick_params(labelsize=8)
                    # Set column titles for the first row.
                    if row == 0:
                        ax.set_title(col_titles[col], fontsize=10)
                
                # Add text annotation below the row (centered under column 2).
                axs[row][2].text(0.5, -0.25, annotation, transform=axs[row][2].transAxes,
                                ha='center', va='top', fontsize=12, color='black')
            pdf.savefig(fig)
            plt.close(fig)
    print(f"PDF saved to: {pdf_file_path}")


def prepare_image(img, transform=False, sobelize=False, binarize=False, grad=False, highpass=False, zoom=False):
    if isinstance(img, torch.Tensor):
        img = img.squeeze(0).cpu().numpy() if img.ndim == 3 and img.shape[0] == 1 else img.cpu().numpy()
    if transform:
        img = np.log1p(np.abs(img))
    if sobelize:
        img = np.hypot(sobel(img, axis=0), sobel(img, axis=1))
    if grad:
        gx, gy = np.gradient(img)
        img = np.hypot(gx, gy)
    if highpass:
        smoothed = gaussian_filter(img, sigma=2)
        img = img - smoothed
    if binarize:
        img = (np.abs(img) > np.mean(np.abs(img))).astype(float)
    if zoom:
        # Extract a band from all sides (top, bottom, left, right)
        band = 8
        top = img[:band, :]
        bottom = img[-band:, :]
        left = img[:, :band]
        right = img[:, -band:]

        # Combine into a single image by tiling side by side
        img = np.block([
            [top, top],
            [bottom, bottom],
            [left.T, right.T]  # transpose for horizontal alignment
        ])
    return img
  
def get_col_titles(pde_direction, transform_error=False, extra_viz=None):
    titles = (["PDE Input", "PDE Output"] if pde_direction.lower() == "forward"
                else ["PDE Output", "PDE Input"])
    titles += ["Prediction", "Difference", "PDE Residual"]
    if transform_error:
        titles += ["Transformed Difference", "Transformed PDE Residual"]
    if extra_viz:
         if 'sobel' in extra_viz:
             titles += ['Sobel Difference', 'Sobel PDE Residual']
         if 'binary' in extra_viz:
             titles += ['Binary Difference', 'Binary PDE Residual']
         if 'discrete' in extra_viz:
             titles += ['Discrete Difference', 'Discrete PDE Residual']
         if 'discrete2' in extra_viz:
             titles += ['Discrete2 PDE Residual']
    return titles

def get_common_scale(values):
    return min(np.min(v) for v in values), max(np.max(v) for v in values)


def compute_common_scales(samples, transform_error=False, extra_viz=None):
    def collect(key):
        return [prepare_image(s[key]) for s in samples]

    scales = {
        "diff": get_common_scale(collect("difference")),
        "res": get_common_scale(collect("pde_residual_img")),
    }

    if transform_error:
        transformed_diff = [prepare_image(s["difference"], transform=True) for s in samples]
        transformed_res = [prepare_image(s["pde_residual_img"], transform=True) for s in samples]
        scales.update({
            "trans_diff": get_common_scale(transformed_diff),
            "trans_res": get_common_scale(transformed_res),
        })

    if extra_viz:
         if 'sobel' in extra_viz:
             sobel_diff = [prepare_image(s['difference'], sobelize=True) for s in samples]
             sobel_res = [prepare_image(s['pde_residual_img'], sobelize=True) for s in samples]
             scales.update({
                 'sobel_diff': get_common_scale(sobel_diff),
                 'sobel_res': get_common_scale(sobel_res),
             })
         if 'binary' in extra_viz:
             bin_diff = [prepare_image(s['difference'], binarize=True) for s in samples]
             bin_res = [prepare_image(s['pde_residual_img'], binarize=True) for s in samples]
             scales.update({
                 'binary_diff': get_common_scale(bin_diff),
                 'binary_res': get_common_scale(bin_res),
             })
         if 'discrete' in extra_viz:
             discrete_diff = [prepare_image(s['difference']) for s in samples]
             discrete_res = [prepare_image(s['pde_residual_img']) for s in samples]
             scales.update({
                 'discrete_diff': get_common_scale(discrete_diff),
                 'discrete_res': get_common_scale(discrete_res),
             })

    return scales

def extract_column_images(sample, pde_direction, transform_error, extra_viz):
    # Determine ordering of images based on PDE direction.
    direction = pde_direction.lower()
    base_keys = ['pde_input', 'pde_output'] if direction == 'forward' else ['pde_output', 'pde_input']
    base_keys += ['prediction', 'difference', 'pde_residual_img']
    base = [prepare_image(sample[k]) for k in base_keys]

    extra = []
    if transform_error:
        extra += [
            prepare_image(sample["difference"], transform=True),
            prepare_image(sample["pde_residual_img"], transform=True)
        ]
    if extra_viz:
         if 'sobel' in extra_viz:
             extra += [
                 prepare_image(sample['difference'], sobelize=True),
                 prepare_image(sample['pde_residual_img'], sobelize=True)
             ]
         if 'binary' in extra_viz:
             extra += [
                 prepare_image(sample['difference'], binarize=True),
                 prepare_image(sample['pde_residual_img'], binarize=True)
             ]
         if 'discrete' in extra_viz:
             extra += [
                 prepare_image(sample['difference']),
                 prepare_image(sample['pde_residual_img'])
             ]
         if 'discrete2' in extra_viz:
             extra += [ prepare_image(sample['pde_residual_img']) ]
    return base + extra

def save_samples_to_pdf(samples, pdf_file_path, pde_direction='forward', pde_print_name="None",
                        common_scale=False, transform_error=False, cmap='viridis',
                        error_cmap='inferno', extra_viz=None, rows_per_page=10):
    
    """
    Saves a multi-page PDF visualizing a batch of PDE-related samples with multiple image columns.

    Each row in the PDF corresponds to one sample. Columns can include:
        - PDE input and output
        - Model prediction
        - Absolute difference and PDE residual
        - Optionally, transformed versions of difference/residual (e.g., log1p(abs))
        - Optionally, Sobel-filtered versions of difference/residual for edge detection

    Args:
        samples (List[Dict]): A list of sample dictionaries. Each dict must include:
            'id', 'pde_input', 'pde_output', 'prediction', 'difference', 'pde_residual_img', 'rel_error_l2'
        pdf_file_path (str): Full output path for the saved PDF file.
        pde_direction (str): Either 'forward' or 'inverse'. Affects column ordering and labels.
        pde_print_name (str): Human-readable name of the PDE to show in annotations.
        common_scale (bool): Whether to compute global colorbar scale across all samples (True) or per-image scaling (False).
        transform_error (bool): Whether to include columns showing log-transformed versions of difference and residual.
        cmap (str): Colormap for standard heatmaps (e.g., input, output, prediction).
        error_cmap (str): Colormap for transformed or Sobel maps (e.g., 'inferno', 'binary').
        extra_viz: list of optional visualizations (e.g., ['sobel', 'binary'])
            - "sobel": adds Sobel edge maps of difference and residual as additional columns.
        rows_per_page (int): Number of sample rows per PDF page.

    Behavior:
        - The layout is dynamically determined based on selected options (direction, transform, sobel).
        - Shared vmin/vmax ranges are applied to columns 1-2 (GT and prediction), and optionally across other types.
        - A summary line with ID, rel_error, PDE name, and direction is included for each sample.
        - Each image includes a colorbar and optional common scaling for consistency.

    Output:
        A multipage PDF file saved at `pdf_file_path`, with all configured visualizations.

    Example:
        save_samples_to_pdf_modularize(samples, 'output.pdf', pde_direction='inverse', transform_error=True, extra_viz='sobel')
    """

    col_titles = get_col_titles(pde_direction, transform_error, extra_viz)
    num_cols = len(col_titles)
    num_samples = len(samples)

    common_scales = compute_common_scales(samples, transform_error, extra_viz) if common_scale else {}

    with PdfPages(pdf_file_path) as pdf:
        for start_idx in range(0, num_samples, rows_per_page):
            end_idx = min(start_idx + rows_per_page, num_samples)
            current_samples = samples[start_idx:end_idx]
            nrows = len(current_samples)
            fig, axs = plt.subplots(nrows, num_cols, figsize=(num_cols * 3, nrows * 3))
            fig.tight_layout(pad=3.0)
            axs = axs if isinstance(axs[0], (list, np.ndarray)) else [axs]

            for row_idx, sample in enumerate(current_samples):
                row_images = extract_column_images(sample, pde_direction, transform_error, extra_viz)

                # Common range for prediction + GT
                img1 = prepare_image(sample["pde_output" if pde_direction == "forward" else "pde_input"])
                img2 = prepare_image(sample["prediction"])
                common_vmin, common_vmax = get_common_scale([img1, img2])

                for col_idx, img in enumerate(row_images):
                    ax = axs[row_idx][col_idx]
                    key = col_titles[col_idx].lower().replace(" ", "_")
                    scale_key = None
                    vmin = vmax = None
                    current_cmap = cmap

                    if col_idx in [1, 2]:  # prediction, GT comparison
                        # For columns 1 and 2, use the per-sample common scale computed above.
                        vmin, vmax = common_vmin, common_vmax
                    elif 'diff' in key:
                         if 'sobel' in key:
                             scale_key = 'sobel_diff'
                             current_cmap = error_cmap
                         elif 'transformed' in key:
                             scale_key = 'trans_diff'
                             current_cmap = error_cmap
                         elif 'binary' in key:
                             scale_key = 'binary_diff'
                             current_cmap = 'binary'
                         elif 'discrete' in key:
                             scale_key = 'discrete_diff'
                             current_cmap = 'Accent'  # <-- Discrete color map
                         else:
                             scale_key = 'diff'
                    elif 'residual' in key:
                         if 'sobel' in key:
                             scale_key = 'sobel_res'
                             current_cmap = error_cmap
                         elif 'transformed' in key:
                             scale_key = 'trans_res'
                             current_cmap = error_cmap
                         elif 'binary' in key:
                             scale_key = 'binary_res'
                             current_cmap = 'binary'
                         elif 'discrete2' in key:
                             scale_key = 'discrete_res'
                             current_cmap = 'Set1' 
                         elif 'discrete' in key:
                             scale_key = 'discrete_res'
                             current_cmap = 'Accent'  # Or any other discrete colormap
                         else:
                             scale_key = 'res'
 
                    if common_scale and scale_key in common_scales:
                        vmin, vmax = common_scales[scale_key]

                    im = ax.imshow(img, cmap=current_cmap, vmin=vmin, vmax=vmax)
                    ax.axis('off')
                    if row_idx == 0:
                        ax.set_title(col_titles[col_idx], fontsize=10)
                    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
                    cbar.ax.tick_params(labelsize=8)

                annotation = f"ID: {sample['id']} | Relative L2 Error: {sample['rel_error_l2']:.4f} | PDE: {pde_print_name.title()} | Direction: {pde_direction.title()}"
                axs[row_idx][2].text(0.5, -0.25, annotation, transform=axs[row_idx][2].transAxes,
                                        ha='center', va='top', fontsize=12, color='black')
            pdf.savefig(fig)
            plt.close(fig)

    print(f"PDF saved to: {pdf_file_path}")

def plot_pde_residual_evolution(pde_residual_data, outdir, model_kimg, dataset_name, pde_direction):
    """Plot PDE residual evolution across timesteps."""
    if pde_residual_data['num_timesteps'] is None:
        print("No PDE residual data to plot")
        return
    
    num_timesteps = pde_residual_data['num_timesteps']
    timesteps = list(range(num_timesteps))
    
    # Calculate aggregated statistics across all samples
    mean_values = []
    min_values = []
    max_values = []
    norm_values = []
    median_values = []
    std_values = []
    
    for t in range(num_timesteps):
        if len(pde_residual_data['timestep_means'][t]) > 0:
            # Mean of means across all samples at this timestep
            mean_values.append(np.mean(pde_residual_data['timestep_means'][t]))
            # Min of mins across all samples at this timestep
            min_values.append(np.min(pde_residual_data['timestep_mins'][t]))
            # Max of maxs across all samples at this timestep
            max_values.append(np.max(pde_residual_data['timestep_maxs'][t]))
            # Mean of norms across all samples at this timestep
            norm_values.append(np.mean(pde_residual_data['timestep_norms'][t]))
            # Median of means across all samples at this timestep
            median_values.append(np.median(pde_residual_data['timestep_means'][t]))
            # Std of means across all samples at this timestep
            std_values.append(np.std(pde_residual_data['timestep_means'][t]))
        else:
            # Handle case where no data is available for this timestep
            mean_values.append(0)
            min_values.append(0)
            max_values.append(0)
            norm_values.append(0)
            median_values.append(0)
            std_values.append(0)
    
    # Create the plot
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle(f'PDE Residual Evolution - {dataset_name} {pde_direction} - {model_kimg} KImg', fontsize=16)
    
    # Plot 1: Mean
    axes[0, 0].plot(timesteps, mean_values, 'b-', marker='o', linewidth=2, markersize=4)
    axes[0, 0].set_title('Mean PDE Residual')
    axes[0, 0].set_xlabel('Timestep')
    axes[0, 0].set_ylabel('Mean Value')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Min and Max
    axes[0, 1].plot(timesteps, min_values, 'g-', marker='s', linewidth=2, markersize=4, label='Min')
    axes[0, 1].plot(timesteps, max_values, 'r-', marker='^', linewidth=2, markersize=4, label='Max')
    axes[0, 1].set_title('Min/Max PDE Residual')
    axes[0, 1].set_xlabel('Timestep')
    axes[0, 1].set_ylabel('Value')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot 3: Norm
    axes[0, 2].plot(timesteps, norm_values, 'm-', marker='d', linewidth=2, markersize=4)
    axes[0, 2].set_title('Mean Norm of PDE Residual')
    axes[0, 2].set_xlabel('Timestep')
    axes[0, 2].set_ylabel('Norm Value')
    axes[0, 2].grid(True, alpha=0.3)
    
    # Plot 4: Median
    axes[1, 0].plot(timesteps, median_values, 'c-', marker='p', linewidth=2, markersize=4)
    axes[1, 0].set_title('Median PDE Residual')
    axes[1, 0].set_xlabel('Timestep')
    axes[1, 0].set_ylabel('Median Value')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Plot 5: Standard Deviation
    axes[1, 1].plot(timesteps, std_values, 'orange', marker='h', linewidth=2, markersize=4)
    axes[1, 1].set_title('Std Dev of PDE Residual')
    axes[1, 1].set_xlabel('Timestep')
    axes[1, 1].set_ylabel('Std Dev Value')
    axes[1, 1].grid(True, alpha=0.3)
    
    # Plot 6: Combined view (Mean with error bars)
    axes[1, 2].errorbar(timesteps, mean_values, yerr=std_values, 
                       fmt='b-', marker='o', linewidth=2, markersize=4, capsize=3)
    axes[1, 2].fill_between(timesteps, min_values, max_values, alpha=0.2, color='gray', label='Min-Max Range')
    axes[1, 2].set_title('Mean ± Std Dev (with Min-Max Range)')
    axes[1, 2].set_xlabel('Timestep')
    axes[1, 2].set_ylabel('Value')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save the plot
    plot_filename = f"pde_residual_evolution_{dataset_name}_{pde_direction}_{model_kimg}kimg.pdf"
    plot_path = os.path.join(outdir, plot_filename)
    plt.savefig(plot_path, format='pdf', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"PDE residual evolution plot saved to {plot_path}")
    
    # Also save the raw data as JSON for further analysis
    summary_data = {
        "timesteps": timesteps,
        "mean_values": mean_values,
        "min_values": min_values,
        "max_values": max_values,
        "norm_values": norm_values,
        "median_values": median_values,
        "std_values": std_values,
        "total_samples": sum(len(pde_residual_data['timestep_means'][t]) for t in range(num_timesteps)),
        "model_kimg": model_kimg
    }
    
    json_filename = f"pde_residual_data_{dataset_name}_{pde_direction}_{model_kimg}kimg.json"
    json_path = os.path.join(outdir, json_filename)
    with open(json_path, 'w') as f:
        json.dump(summary_data, f, indent=4)
    
    print(f"PDE residual data saved to {json_path}")

def plot_pde_residual_timestep_images(pde_residual_data, output_dir, k=5, timesteps_to_plot=None):
    """
    Plot PDE residual images across timesteps for best and worst performing samples.
    
    Args:
        pde_residual_data: Dictionary containing PDE residual tracking data
        output_dir: Directory to save the plots
        k: Number of best/worst samples to visualize (default: 5)
        timesteps_to_plot: List of timesteps to visualize. If None, select evenly spaced timesteps
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Check if we have residual image data
    if not pde_residual_data['sample_timestep_images'] or not pde_residual_data['sample_mean_residuals']:
        print("No PDE residual image data available for plotting")
        return
    
    # Sort samples by their mean residual performance
    sample_rankings = sorted(pde_residual_data['sample_mean_residuals'], 
                           key=lambda x: x['mean_residual'])
    
    # Select k best (lowest residual) and k worst (highest residual) samples
    best_samples = sample_rankings[:k]
    worst_samples = sample_rankings[-k:]
    
    # Determine timesteps to plot
    num_timesteps = pde_residual_data['num_timesteps']
    if timesteps_to_plot is None:
        # Select evenly spaced timesteps (e.g., every 10th timestep)
        timesteps_to_plot = list(range(0, num_timesteps, max(1, num_timesteps // 8)))
        if num_timesteps - 1 not in timesteps_to_plot:
            timesteps_to_plot.append(num_timesteps - 1)  # Always include the last timestep
    
    def create_sample_plots(samples_list, title_prefix, filename_suffix):
        """Helper function to create plots for a list of samples"""
        
        # Calculate figure size based on number of samples and timesteps
        n_samples = len(samples_list)
        n_timesteps = len(timesteps_to_plot)
        fig_width = max(12, n_timesteps * 2)
        fig_height = max(8, n_samples * 1.5)
        
        fig, axes = plt.subplots(n_samples, n_timesteps, 
                               figsize=(fig_width, fig_height))
        
        # Handle case where we have only one sample or one timestep
        if n_samples == 1:
            axes = axes.reshape(1, -1)
        if n_timesteps == 1:
            axes = axes.reshape(-1, 1)
        if n_samples == 1 and n_timesteps == 1:
            axes = np.array([[axes]])
        
        for sample_idx, sample_info in enumerate(samples_list):
            global_sample_idx = sample_info['global_sample_idx']
            mean_residual = sample_info['mean_residual']
            
            # First pass: collect all images for this sample to determine individual sample min/max
            sample_images = []
            for timestep in timesteps_to_plot:
                matching_images = [img_data for img_data in pde_residual_data['sample_timestep_images']
                                 if img_data['global_sample_idx'] == global_sample_idx 
                                 and img_data['timestep'] == timestep]
                if matching_images:
                    pde_residual_image = matching_images[0]['pde_residual_image'].squeeze().numpy()
                    sample_images.append(pde_residual_image)
            
            # Calculate individual min/max for this sample across all its timesteps
            if sample_images:
                sample_vmin = min(img.min() for img in sample_images)
                sample_vmax = max(img.max() for img in sample_images)
                # Add small buffer to avoid edge cases
                vrange = sample_vmax - sample_vmin
                if vrange > 0:
                    sample_vmin -= vrange * 0.05
                    sample_vmax += vrange * 0.05
                else:
                    # If all values are the same, create a small range
                    sample_vmin -= 0.01
                    sample_vmax += 0.01
            else:
                sample_vmin, sample_vmax = 0, 1  # Default range if no images
            
            for timestep_idx, timestep in enumerate(timesteps_to_plot):
                # Find the corresponding image data for this sample and timestep
                matching_images = [img_data for img_data in pde_residual_data['sample_timestep_images']
                                 if img_data['global_sample_idx'] == global_sample_idx 
                                 and img_data['timestep'] == timestep]
                
                if matching_images:
                    img_data = matching_images[0]
                    pde_residual_image = img_data['pde_residual_image'].squeeze().numpy()
                    
                    # Plot the PDE residual image with individual sample colormap range
                    im = axes[sample_idx, timestep_idx].imshow(pde_residual_image, 
                                                             cmap='viridis', 
                                                             aspect='auto',
                                                             vmin=sample_vmin,
                                                             vmax=sample_vmax)
                    axes[sample_idx, timestep_idx].set_title(f't={timestep}\nres={img_data["mean_residual"]:.4f}', 
                                                           fontsize=10)
                    axes[sample_idx, timestep_idx].axis('off')
                    
                    # Store the image handle for colorbar (use the last timestep image for colorbar)
                    if timestep_idx == len(timesteps_to_plot) - 1:
                        # Add colorbar for this sample row at the end
                        cbar = plt.colorbar(im, ax=axes[sample_idx, :], fraction=0.046, pad=0.04, aspect=20)
                        cbar.set_label(f'PDE Residual [{sample_vmin:.3f}, {sample_vmax:.3f}]', rotation=270, labelpad=15)
                else:
                    # No data available for this combination
                    axes[sample_idx, timestep_idx].text(0.5, 0.5, 'No Data', 
                                                       ha='center', va='center',
                                                       transform=axes[sample_idx, timestep_idx].transAxes)
                    axes[sample_idx, timestep_idx].axis('off')
            
            # Add y-axis label for each sample
            axes[sample_idx, 0].set_ylabel(f'Sample {global_sample_idx}\n(mean={mean_residual:.4f})', 
                                         fontsize=10, rotation=90, va='center')
        
        plt.suptitle(f'{title_prefix} Samples - PDE Residual Evolution Across Timesteps', 
                    fontsize=16, y=0.98)
        plt.tight_layout()
        
        # Save the plot
        output_path = os.path.join(output_dir, f'pde_residual_timesteps_{filename_suffix}.pdf')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Saved {title_prefix.lower()} samples timestep plot to {output_path}")
    
    # Create plots for best and worst samples
    create_sample_plots(best_samples, "Best", "best")
    create_sample_plots(worst_samples, "Worst", "worst")
    
    # Print summary statistics
    print(f"PDE Residual Timestep Visualization Summary:")
    print(f"- Total samples analyzed: {len(pde_residual_data['sample_mean_residuals'])}")
    print(f"- Timesteps visualized: {timesteps_to_plot}")
    print(f"- Best {k} samples mean residuals: {[s['mean_residual'] for s in best_samples]}")
    print(f"- Worst {k} samples mean residuals: {[s['mean_residual'] for s in worst_samples]}")


def plot_process(pde_residual_data, 
                 output_dir, 
                 dataset_name, 
                 pde_direction, 
                 k=5, 
                 n_samples=4, 
                 timesteps_to_plot=None, 
                 pde_predictions_data=None):
    """
    Plot process images: multiple rows (prediction and residual per selected sample) across timesteps.

    Rows = 2 * n_samples (pred_i, res_i), Columns = selected timesteps (k).
    Uses pde_residual_data collected during sampling: expects keys 'timestep_predictions' and 'timestep_images'.
    """
    # breakpoint()
    if not pde_residual_data.get('sample_timestep_residuals') or not pde_predictions_data.get('sample_timestep_predictions'):
        print("No prediction/residual timestep data available for plotting")
        return

    num_timesteps = pde_residual_data.get('num_timesteps') or len(pde_residual_data['sample_timestep_residuals'])
    if timesteps_to_plot is None:
        k = max(1, int(k))
        if k >= num_timesteps:
            # Use percentage-based selection for better interpretability
            percentages = np.linspace(0, 100, num_timesteps)
            timesteps_to_plot = list(range(num_timesteps))
        else:
            # Select k timesteps at evenly spaced percentages
            percentages = np.linspace(0, 100, k)
            indices = np.round(percentages / 100 * (num_timesteps - 1)).astype(int)
            timesteps_to_plot = list(sorted(set(indices.tolist())))

    # Debug: Print basic data info
    print(f"Plotting process: {len(pde_residual_data.get('sample_timestep_residuals', []))} residuals, {len(pde_predictions_data.get('sample_timestep_predictions', []))} predictions")

    # Get available sample indices from residual data
    available_samples = set()
    for item in pde_residual_data['sample_timestep_residuals']:
        available_samples.add(item['global_sample_idx'])
    
    # Calculate percentages for the selected timesteps
    percentages = [(t / (num_timesteps - 1)) * 100 if num_timesteps > 1 else 0 for t in timesteps_to_plot]
    timestep_info = [f"{p:.0f}% (t={t})" for p, t in zip(percentages, timesteps_to_plot)]
    # print(f"Available samples: {sorted(list(available_samples))}, plotting timesteps: {timestep_info}")
    
    # Select samples to plot
    if pde_residual_data.get('sample_mean_residuals'):
        ranked = sorted(pde_residual_data['sample_mean_residuals'], key=lambda x: x['mean_residual'])
        sample_indices = [s['global_sample_idx'] for s in ranked[:n_samples] if s['global_sample_idx'] in available_samples]
    else:
        sample_indices = sorted(list(available_samples))[:n_samples]

    # Assemble data by rows: [ [step_img_0, step_img_1, ...], ... ]
    row_titles = []
    row_data = []  # list of length 2*n_samples; each element is list of images over steps

    for s_idx in sample_indices:
        # Prediction row for sample
        preds_over_steps = []
        for t in timesteps_to_plot:
            # Find prediction for this sample at this timestep
            pred_items = [item for item in pde_predictions_data['sample_timestep_predictions'] 
                         if item['global_sample_idx'] == s_idx and item['timestep'] == t]
            if pred_items:
                p = pred_items[0]['pde_predictions_image'].squeeze()  # Remove batch dim
                if p.ndim == 3:  # If still has channel dim, take first channel
                    p = p[0]
                p_np = p.cpu().numpy() if isinstance(p, torch.Tensor) else p
                preds_over_steps.append(p_np)
        if preds_over_steps:
            row_data.append(preds_over_steps)
            row_titles.append(f"pred sample {s_idx}")

        # Residual row for sample  
        res_over_steps = []
        for t in timesteps_to_plot:
            # Find residual for this sample at this timestep
            res_items = [item for item in pde_residual_data['sample_timestep_residuals']
                        if item['global_sample_idx'] == s_idx and item['timestep'] == t]
            if res_items:
                r = res_items[0]['pde_residual_image'].squeeze()  # Remove batch dim
                if r.ndim == 3:  # If still has channel dim, take first channel
                    r = r[0]
                r_np = r.cpu().numpy() if isinstance(r, torch.Tensor) else r
                res_over_steps.append(r_np)
        if res_over_steps:
            row_data.append(res_over_steps)
            row_titles.append(f"res sample {s_idx}")

    if not row_data:
        print("No row data assembled for process plot")
        return

    # print(f"Assembled {len(row_data)} rows for plotting: {[title for title in row_titles]}")

    n_rows = len(row_data)
    n_cols = len(timesteps_to_plot)
    fig = plt.figure(figsize=(3 * n_cols, 2.5 * n_rows))
    gs = fig.add_gridspec(n_rows, n_cols + 1, width_ratios=[1] * n_cols + [0.05], hspace=0.3, wspace=0.1, top=0.95, bottom=0.05, left=0.05, right=0.95)
    fig.suptitle(f"Process evolution - {dataset_name} {pde_direction}")

    # Plot row by row with appropriate color scales
    for r_idx in range(n_rows):
        imgs = row_data[r_idx]
        vmin = min(img.min() for img in imgs)
        vmax = max(img.max() for img in imgs)
        
        # Choose colormap and scaling based on data type
        is_residual = 'res' in row_titles[r_idx].lower()
        if is_residual:
            # For residuals: use diverging colormap centered at 0
            vmax_abs = max(abs(vmin), abs(vmax))
            vmin_plot, vmax_plot = -vmax_abs, vmax_abs
            cmap = 'RdBu_r'  # Red-Blue diverging, reversed so red=positive
        else:
            # For predictions: use regular colormap with actual range
            vmin_plot, vmax_plot = vmin, vmax
            cmap = 'viridis'
            
        # print(f"DEBUG: Row {r_idx} ({row_titles[r_idx]}): plotting {len(imgs)} images with range [{vmin:.4f}, {vmax:.4f}] -> plot range [{vmin_plot:.4f}, {vmax_plot:.4f}], cmap={cmap}")
        
        axs = [fig.add_subplot(gs[r_idx, c]) for c in range(n_cols)]
        last_im = None
        for c in range(n_cols):
            if c < len(imgs):
                im = axs[c].imshow(imgs[c], cmap=cmap, vmin=vmin_plot, vmax=vmax_plot)
                # Show percentage completion instead of raw timestep
                percentage = (timesteps_to_plot[c] / (num_timesteps - 1)) * 100 if num_timesteps > 1 else 0
                axs[c].set_title(f"{percentage:.0f}%\n(t={timesteps_to_plot[c]})")
                last_im = im
            axs[c].axis('off')
        cax = fig.add_subplot(gs[r_idx, -1])
        if last_im is not None:
            cbar = plt.colorbar(last_im, cax=cax)
            cbar.set_label(f"{row_titles[r_idx]}\n[{vmin:.1e}, {vmax:.1e}]", fontsize=8)

    out_path = os.path.join(output_dir, f"process_grid.png")
    plt.savefig(out_path, dpi=200, bbox_inches='tight', pad_inches=0.2)
    plt.close()
    print(f"Saved process grid to {out_path}")
def plot_dps_losses(loss_history, save_path):
    """
    Plot loss curves for DPS sampling.
    
    Args:
        loss_history: List of dictionaries containing loss data for each timestep
        save_path: Path to save the plot
    """
    if not loss_history:
        print("No loss history to plot")
        return
    
    # Create figure
    fig, axs = plt.subplots(3, 1, figsize=(12, 18), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1]})
    
    # Extract data
    timesteps = [data['timestep'] for data in loss_history]
    sigma_values = [data['sigma_t'] for data in loss_history]
    
    # Plot 1: Observation losses by channel
    if 'obs_losses' in loss_history[0] and loss_history[0]['obs_losses']:
        num_channels = len(loss_history[0]['obs_losses'])
        for ch in range(num_channels):
            # Extract mean loss across batch for this channel
            channel_losses = [data['obs_losses'][ch].mean().item() for data in loss_history]
            axs[0].semilogy(timesteps, channel_losses, label=f'Channel {ch} Loss', marker='o')
    
        axs[0].set_title('Observation Losses by Channel')
        axs[0].set_ylabel('Loss (log scale)')
        axs[0].legend()
        axs[0].grid(True)
    
    # Plot 2: PDE loss
    if 'pde_loss' in loss_history[0]:
        pde_losses = [data['pde_loss'].mean().item() for data in loss_history]
        axs[1].semilogy(timesteps, pde_losses, label='PDE Residual Loss', marker='o', color='red')
        axs[1].set_title('PDE Residual Loss')
        axs[1].set_ylabel('Loss (log scale)')
        axs[1].legend()
        axs[1].grid(True)
    
    # Plot 3: Coefficient values
    if 'obs_coef' in loss_history[0] and 'pde_coef' in loss_history[0]:
        obs_coefs = [data['obs_coef'] for data in loss_history]
        pde_coefs = [data['pde_coef'] for data in loss_history]
        
        ax3 = axs[2]
        ax3.plot(timesteps, obs_coefs, label='Observation Coefficient', marker='s')
        ax3.plot(timesteps, pde_coefs, label='PDE Coefficient', marker='d')
        ax3.plot(timesteps, sigma_values, label='Sigma Value', marker='.', linestyle='--')
        ax3.set_title('Guidance Coefficients and Sigma Values')
        ax3.set_ylabel('Value')
        ax3.set_xlabel('Timestep')
        ax3.legend()
        ax3.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Loss plot saved to {save_path}")