import os
import math
import argparse

from typing import Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

from models import (
    AnchorGaussianMixtureND,
    SetEncoderND,
)
from train_nflow_arqs import NFlowsConditionalARQS  # AR RQS wrapper from training script


# ---------------------------------------------------------------------------
# Data loading: mirror train.make_dataloaders, but no DataLoader
# ---------------------------------------------------------------------------

def _should_record(time_step: int) -> bool:
    return (time_step + 1) <= 300


def _load_trajectories_from_npz(npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
    data = np.load(npz_path, allow_pickle=True)
    trajectories = data["trajectories"]
    types = data["types"]
    data.close()
    return trajectories, types


def load_anchors_from_npz(
    data_path: str,
    time_step=None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Load anchor sets and particle types from the NPZ file used in training.

    Args:
      time_step:
        - int: single time step (e.g. -1 for last)
        - list/tuple/np.ndarray of ints: multiple time steps, e.g. [0, 100, -1]

    Returns:
      anchors_all: (E, S, M, D) where S = number of selected time steps
      types:       (E, M) int array with values {1,2}
      time_idx:    (S,) numpy array of the resolved time indices (0-based)
    """
    anchors, types = _load_trajectories_from_npz(data_path)  # (E, T, M, D), (E, M)
    E, T, M, D = anchors.shape
    # print(f"Loaded anchors from {data_path} with shape: {anchors.shape}"); exit()

    # Default: use the same recording policy as training.
    if time_step is None:
        time_step = [i for i in range(T) if _should_record(i)]
        if not time_step:
            raise ValueError("No time steps selected by _should_record().")

    # Normalize to a 1D numpy array of indices
    if isinstance(time_step, (int, np.integer)):
        time_idx = np.array([int(time_step)], dtype=int)
    else:
        time_idx = np.array(list(time_step), dtype=int)

    # Resolve negative indices and check bounds
    resolved = []
    for ts in time_idx:
        if ts < 0:
            ts_res = T + ts  # e.g. -1 -> T-1
        else:
            ts_res = ts
        if not (0 <= ts_res < T):
            raise IndexError(
                f"time_step {ts} (resolved to {ts_res}) is out of range [0, {T-1}]"
            )
        resolved.append(ts_res)
    time_idx = np.array(resolved, dtype=int)


    anchors_all = anchors[:, time_idx, :, :]  # (E, S, M, D)


    # print(f"Using time indices (0-based): {time_idx.tolist()}")
    return anchors_all, types, time_idx


def normalize_anchors(
    anchors: np.ndarray,
    normalization_info: dict,
) -> np.ndarray:
    min_val = _reshape_norm_values(normalization_info["min"], anchors.ndim)
    max_val = _reshape_norm_values(normalization_info["max"], anchors.ndim)
    scale = max_val - min_val
    scale[scale == 0] = 1.0
    return 2.0 * (anchors - min_val) / scale - 1.0


def denormalize_anchors(
    anchors: np.ndarray,
    normalization_info: dict,
) -> np.ndarray:
    min_val = _reshape_norm_values(normalization_info["min"], anchors.ndim)
    max_val = _reshape_norm_values(normalization_info["max"], anchors.ndim)
    scale = max_val - min_val
    scale[scale == 0] = 1.0
    return (anchors + 1.0) * scale / 2.0 + min_val


def _reshape_norm_values(values: np.ndarray, ndim: int) -> np.ndarray:
    flat = np.asarray(values).reshape(-1)
    shape = (1,) * (ndim - 1) + (flat.shape[0],)
    return flat.reshape(shape)

# ---------------------------------------------------------------------------
# Rebuild model + normalization from checkpoint
# ---------------------------------------------------------------------------

def build_model_from_checkpoint(
    ckpt_path: str,
    device: torch.device,
) -> Tuple[torch.nn.Module, dict, dict]:
    """
    Reconstruct the trained model from the checkpoint and
    also reconstruct the min/max normalization used during training.

    Returns:
      model:  nn.Module in eval mode
      ckpt:   checkpoint dict
      normalization_info: dict with min/max used for training normalization
    """
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)

    dim       = ckpt["dim"]
    model_type = ckpt.get("model_type", "ARQS").upper()
    z_dim     = ckpt.get("z_dim", ckpt.get("args", {}).get("z_dim", 4))
    hidden    = ckpt.get("hidden_dim", ckpt.get("args", {}).get("hidden_dim", 256))
    n_layers  = ckpt.get("n_layers", ckpt.get("args", {}).get("n_layers", 8))
    rqs_K     = ckpt.get("rqs_K", ckpt.get("args", {}).get("rqs_K", 16))
    rqs_B     = ckpt.get("rqs_B", ckpt.get("args", {}).get("rqs_B", 5.0))
    deepset_pool = ckpt.get("deepset_pool", ckpt.get("args", {}).get("deepset_pool", "mean"))

    print(f"Checkpoint loaded from: {ckpt_path}")
    print(f"dim={dim}, z_dim={z_dim}, hidden_dim={hidden}, n_layers={n_layers}, "
          f"rqs_K={rqs_K}, rqs_B={rqs_B}, model_type={model_type}")

    # Shared DeepSet encoder
    set_encoder = SetEncoderND(
        in_dim=dim,
        hidden_dim=hidden,
        z_dim=z_dim,
        pool=deepset_pool,
    ).to(device)

    

    # nflows-based autoregressive RQS flow (same wrapper as training)
    model = NFlowsConditionalARQS(
        dim=dim,
        z_dim=z_dim,
        hidden_dim=hidden,
        n_layers=n_layers,
        K=rqs_K,
        B=rqs_B,
        set_encoder=set_encoder,
    ).to(device)


    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    # Reconstruct min/max normalization used in training.
    normalization_info = ckpt.get("normalization_info", ckpt.get("args", {}).get("normalization_info", None))
    # print(f"\nNormalization info from checkpoint args: {normalization_info}")    

    # print(f"Normalization info from checkpoint args: {normalization_info}")
    return model, ckpt, normalization_info


# ---------------------------------------------------------------------------
# Global KL evaluation + 2D viz (similar spirit to test_old.py)
# ---------------------------------------------------------------------------



    
def viz_samples(
    # ckpt_path: str,
    model, ckpt, normalization_info,
    data_path: str,
    exp_id: int = 0,    
    time_steps: list = [0, 10, 20, 50, -1],
    kl_samples: int = 2000,
    output_dir: str = "evaluation_nflow",
    device_id: int = 0,
    particle_types: tuple[int, int] = (1, 2),
):
    """
    Evaluate a trained conditional flow using a global KL estimate over all time steps.

    The target is AnchorGaussianMixtureND with the same epsilon used in training.
    """
    # Device
    # device = torch.device(
    #     f"cuda:{device_id}"
    #     if (device_id >= 0 and torch.cuda.is_available())
    #     else "cpu"
    # )

    # Load model + normalization
    # model, ckpt, normalization_info = build_model_from_checkpoint(ckpt_path, device)
    dim = ckpt["dim"]
    epsilon = ckpt.get("epsilon", ckpt.get("args", {}).get("epsilon", None))
    anchors_all, types, time_idx = load_anchors_from_npz(
        data_path=data_path,
        time_step=time_steps,
    )

    anchors_scaled = normalize_anchors(anchors_all, normalization_info)
    anchors_scaled_t = torch.from_numpy(anchors_scaled).float().to(device)  # (E, S, M, D)

    E, S, M, D = anchors_scaled_t.shape
    assert D == dim, f"Data dim={D} but model dim={dim}"


    # 
    # z_dim     = ckpt.get("z_dim", ckpt.get("args", {}).get("z_dim", 4))
    # epsilon = ckpt.get("epsilon", ckpt.get("args", {}).get("epsilon", None))
    # n_layers = ckpt.get("n_layers", ckpt.get("args", {}).get("n_layers", None))
    # deepset_pool = ckpt.get("deepset_pool", ckpt.get("args", {}).get("deepset_pool", None))
    # output_dir = os.path.join(output_dir, f"epsilon{epsilon}", f"deepsetPool_{deepset_pool}", f"Z{z_dim}")
    # os.makedirs(output_dir, exist_ok=True)


    # ----------------------------------------------------------------------
    # 2D visualization for a few experiments 
    # ----------------------------------------------------------------------

    # anchors_all: (E, S, M, D)
    print()
    print(f"anchors_all shape: {anchors_scaled_t.shape}")
    
    
    model.eval()
    
    anchors_scaled_t = anchors_scaled_t[exp_id]  # (S, M, D)
    anchors_orig = anchors_all[exp_id]  # (S, M, D) in original coordinates
    types_exp = types[exp_id]  # (M,)
    print(f"anchors_all shape: {anchors_scaled_t.shape}")

    type_colors = {
        "particles": {
            1: "#1f77b4",
            2: "#ff7f0e",
        },
        "samples": {
            1: "#bde19e",
            2: "#fd84cf",
        },
    }
    type_data = {}
    for particle_type in particle_types:
        mask_np = types_exp == particle_type
        if not np.any(mask_np):
            raise ValueError(f"No particles of type {particle_type} in experiment {exp_id}.")
        particle_color = type_colors["particles"].get(particle_type, "black")
        sample_color = type_colors["samples"].get(particle_type, "gray")

        mask_t = torch.from_numpy(mask_np).to(device)
        anchors_type_scaled_t = anchors_scaled_t[:, mask_t, :]  # (S, M_t, D)
        anchors_type_orig = anchors_orig[:, mask_np, :]         # (S, M_t, D)

        target = AnchorGaussianMixtureND(
            anchors=anchors_type_scaled_t, # (S, M_t, D)
            epsilon=epsilon
        )

        samples_from_target = target.sample(kl_samples)  # (T, kl_samples, D)
        z_ctx = model.set_encoder(anchors_type_scaled_t)
        samples_from_flow = model.flow.sample(kl_samples, context=z_ctx)  # (S, kl_samples, D)

        # denormalize back to original coordinates for plotting
        samples_from_target = denormalize_anchors(
            samples_from_target.detach().cpu().numpy(),
            normalization_info,
        )

        samples_from_flow = denormalize_anchors(
            samples_from_flow.detach().cpu().numpy(),
            normalization_info,
        )
        
        predict_samples = samples_from_flow
        type_data[particle_type] = {
            "anchors": anchors_type_orig,
            "samples": predict_samples,
            "particle_color": particle_color,
            "sample_color": sample_color,
        }

    # first row: xy (samples + anchors)
    # second row: xy (samples only)
    n_cols = len(time_idx)
    n_rows = 2  # only position, no velocity
    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(4 * n_cols, 4 * n_rows),
        constrained_layout=True,
        sharex=False,
    )
    axes = np.array(axes).reshape(n_rows, n_cols)

    # Global y-range per row:
    # Rows 1-2 (position y across all columns)
    pos_y_min = min(
        min(v["anchors"][:, :, 1].min() for v in type_data.values()),
        min(v["samples"][:, :, 1].min() for v in type_data.values()),
    )
    pos_y_max = max(
        max(v["anchors"][:, :, 1].max() for v in type_data.values()),
        max(v["samples"][:, :, 1].max() for v in type_data.values()),
    )
    # Rows 2 & 3 share velocity y-range (vy) across all columns
    

    for s_idx, t in enumerate(time_idx):
        # first row: xy (samples + anchors)
        ax = axes[0, s_idx]
        for particle_type in particle_types:
            data = type_data[particle_type]
            ax.scatter(
                data["samples"][s_idx, :, 0],
                data["samples"][s_idx, :, 1],
                label=f"Flow type {particle_type}",
                alpha=1.0,
                color=data["sample_color"],
                s=0.5,
            )
            ax.scatter(
                data["anchors"][s_idx, :, 0],
                data["anchors"][s_idx, :, 1],
                label=f"simulation type {particle_type}",
                alpha=1.0,
                color=data["particle_color"],
                s=5,
            )
        ax.set_title(f"Exp {exp_id}, t={t}: position")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        # ax.legend(fontsize='small', ncol=2)
        ax.grid(alpha=0.3)
        ax.set_aspect('equal', adjustable='box')
        ax.set_ylim(pos_y_min, pos_y_max)

        # second row: xy (samples only)
        ax = axes[1, s_idx]
        for particle_type in particle_types:
            data = type_data[particle_type]
            ax.scatter(
                data["samples"][s_idx, :, 0],
                data["samples"][s_idx, :, 1],
                label=f"Flow type {particle_type}",
                alpha=1.0,
                color=data["sample_color"],
                s=0.5,
            )
        ax.set_title(f"Exp {exp_id}, t={t}: samples only")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.grid(alpha=0.3)
        ax.set_aspect('equal', adjustable='box')
        ax.set_ylim(pos_y_min, pos_y_max)

    # title for the entire figure
    fig.suptitle(f"Flow distribution samples (epsilon={epsilon})", fontsize=12)

    out_path = os.path.join(output_dir, "flow_distribution_overlay.png")
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

        



def viz_flow_density_heatmap(
    # ckpt_path: str,
    model, ckpt, normalization_info,
    data_path: str,
    exp_id: int = 0,
    time_steps: list = None,
    grid_size: int = 200,
    output_dir: str = "evaluation_nflow",
    device_id: int = 0,
    particle_types: tuple[int, int] = (1, 2),
):
    """Visualize the learned flow density as 2D heatmaps over multiple time steps.

    Uses the same time indexing convention as viz_samples: a list of time steps
    (which may include negative indices). For the selected experiment and all
    requested time steps, this function:

      - loads the corresponding anchors in original coordinates,
      - applies the same standardization used during training,
      - evaluates the flow log-density on a common 2D grid,
      - plots one heatmap per time step in a single figure.
    """

    # Device
    # device = torch.device(
    #     f"cuda:{device_id}" if (device_id >= 0 and torch.cuda.is_available()) else "cpu"
    # )

    # Load model + normalization
    # model, ckpt, normalization_info = build_model_from_checkpoint(ckpt_path, device)
    dim = ckpt["dim"]
    if dim != 2:
        raise NotImplementedError(
            f"viz_flow_density_heatmap currently supports dim=2 only, got dim={dim}"
        )

    if time_steps is None:
        # Default to same style as viz_samples
        time_steps = [0, 10, 20, 50, -1]

    z_dim = ckpt.get("z_dim", ckpt.get("args", {}).get("z_dim", 4))
    epsilon = ckpt.get("epsilon", ckpt.get("args", {}).get("epsilon", None))
    deepset_pool = ckpt.get("deepset_pool", ckpt.get("args", {}).get("deepset_pool", None))

    # # Output directory consistent with other evaluation utilities
    # out_dir = os.path.join(
    #     output_dir,
    #     f"epsilon{epsilon}",
    #     f"deepsetPool_{deepset_pool}",
    #     f"Z{z_dim}",
    # )
    # os.makedirs(out_dir, exist_ok=True)

    # Load anchors for the requested time steps
    anchors_all, types, time_idx = load_anchors_from_npz(
        data_path=data_path,
        time_step=time_steps,
    )  # anchors_all: (E, S, M, D)

    E, S, M, D = anchors_all.shape
    assert D == dim, f"Data dim={D} but model dim={dim}"

    if not (0 <= exp_id < E):
        raise ValueError(f"exp_id must be in [0, {E-1}], got {exp_id}")

    # Select original (unstandardized) anchors for this experiment: (S, M, D)
    anchors_orig = anchors_all[exp_id]  # (S, M, D)
    types_exp = types[exp_id]  # (M,)

    min_val = np.asarray(normalization_info["min"]).reshape(-1)
    max_val = np.asarray(normalization_info["max"]).reshape(-1)
    scale = max_val - min_val
    scale[scale == 0] = 1.0

    # Determine global plotting bounds from original anchors across all selected times
    x_min = anchors_orig[:, :, 0].min()
    x_max = anchors_orig[:, :, 0].max()
    y_min = anchors_orig[:, :, 1].min()
    y_max = anchors_orig[:, :, 1].max()
    dx = x_max - x_min
    dy = y_max - y_min
    margin_x = 0.1 * dx if dx > 0 else 0.1
    margin_y = 0.1 * dy if dy > 0 else 0.1
    x_min -= margin_x
    x_max += margin_x
    y_min -= margin_y
    y_max += margin_y

    xs = np.linspace(x_min, x_max, grid_size)
    ys = np.linspace(y_min, y_max, grid_size)
    Xg, Yg = np.meshgrid(xs, ys)
    grid_points = np.stack([Xg.ravel(), Yg.ravel()], axis=-1)  # (G, 2)

    # Scale grid for the model, same as training
    grid_scaled = 2.0 * (grid_points - min_val) / scale - 1.0  # (G, 2)
    grid_scaled_t = torch.from_numpy(grid_scaled).float().to(device).unsqueeze(0)  # (1, G, 2)

    # Precompute change-of-variables term to convert to original-space log-density
    scale_np = 2.0 / scale
    logdet_scale = float(np.log(scale_np).sum())

    model.eval()

    for particle_type in particle_types:
        mask_np = types_exp == particle_type
        if not np.any(mask_np):
            raise ValueError(f"No particles of type {particle_type} in experiment {exp_id}.")

        # Evaluate log-density for each selected time step
        log_q_grids = []
        with torch.no_grad():
            for s in range(S):
                anchors_s = anchors_orig[s]  # (M, D) in original coords
                anchors_s = anchors_s[mask_np]  # (M_t, D)
                anchors_scaled_s = 2.0 * (anchors_s - min_val) / scale - 1.0  # (M_t, D)
                anchors_scaled_t = (
                    torch.from_numpy(anchors_scaled_s).float().to(device).unsqueeze(0)
                )  # (1, M_t, D)

                z_ctx = model.set_encoder(anchors_scaled_t)
                log_q_scaled = model.log_prob_with_context(
                    grid_scaled_t,            # (1, G, 2)
                    z_ctx,                    # (1, z_dim)
                ).squeeze(0)                  # (G,)

                log_q = log_q_scaled.cpu().numpy() + logdet_scale  # (G,)
                log_q_grid = log_q.reshape(grid_size, grid_size)   # (grid_size, grid_size)
                log_q_grids.append(log_q_grid)

        # Use common color scale across time steps for comparability
        all_vals = np.stack(log_q_grids, axis=0)
        vmin = all_vals.min()
        vmax = all_vals.max()

        n_cols = S
        fig, axes = plt.subplots(
            1,
            n_cols,
            figsize=(4 * n_cols, 4),
            constrained_layout=True,
        )
        axes = np.array(axes).reshape(1, n_cols)

        for s in range(S):
            ax = axes[0, s]
            im = ax.imshow(
                log_q_grids[s],
                origin="lower",
                extent=[x_min, x_max, y_min, y_max],
                aspect="equal",
                cmap="viridis",
                vmin=vmin,
                vmax=vmax,
            )
            ts_resolved = int(time_idx[s])
            ax.set_title(f"t={ts_resolved}")
            ax.set_xlabel("x")
            if s == 0:
                ax.set_ylabel("y")

        # Add a single colorbar for the whole figure
        fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8, label="log density")
        fig.suptitle(
            f"Flow log-density heatmaps (exp {exp_id}, type {particle_type})",
            fontsize=12,
        )

        out_path = os.path.join(
            output_dir,
            f"flow_logdensity_heatmaps_exp{exp_id}_type{particle_type}.png",
        )
        plt.savefig(out_path, dpi=300)
        plt.close()

        print(f"[Heatmap] Saved flow log-density heatmaps to {out_path}")




def plot_deepset_z_over_time(
    # ckpt_path: str,
    model, ckpt, normalization_info,
    data_path: str,
    num_experiments_to_plot: int = 3,
    max_dims_to_plot: int = None,
    output_dir: str = "evaluation_nflow",
    device_id: int = 0,
    particle_types: tuple[int, int] = (1, 2),
):
    """
    Plot DeepSet latent Z over time, per dimension, for several experiments.
    Also appends per-type macro features from generate_data/dataset/macro_feature.npy.

    Layout:
      - Up to 4 columns of dimensions
      - n_rows = ceil(num_dims / 4)
      - Each subplot: one dimension, lines = experiments over time
    """
    # Device
    # device = torch.device(
    #     f"cuda:{device_id}"
    #     if (device_id >= 0 and torch.cuda.is_available())
    #     else "cpu"
    # )

    # Load model + normalization
    # model, ckpt, normalization_info = build_model_from_checkpoint(ckpt_path, device)
    dim = ckpt["dim"]
    print(f"normalization_info: {normalization_info}")

    # z_dim     = ckpt.get("z_dim", ckpt.get("args", {}).get("z_dim", 4))
    # epsilon = ckpt.get("epsilon", ckpt.get("args", {}).get("epsilon", None))
    # n_layers = ckpt.get("n_layers", ckpt.get("args", {}).get("n_layers", None))
    # deepset_pool = ckpt.get("deepset_pool", ckpt.get("args", {}).get("deepset_pool", None))
    # output_dir = os.path.join(output_dir, f"epsilon{epsilon}", f"deepsetPool_{deepset_pool}", f"Z{z_dim}")
    # os.makedirs(output_dir, exist_ok=True)

    # ------------------------------------------------------------------
    # Load full trajectories (all timesteps)
    # ------------------------------------------------------------------
    anchors, types, time_idx = load_anchors_from_npz(data_path=data_path, time_step=None)
    print(f"Loaded time_idx shape: {time_idx.shape}")

    E, T, M, D = anchors.shape
    assert D == dim, f"Data dim={D} but model dim={dim}"

    print(f"[Z-plot] Loaded positions from {data_path} with shape {anchors.shape}")

    

    # Apply same scaling as in training
    anchors_scaled = normalize_anchors(anchors, normalization_info)
    

    anchors_scaled = torch.from_numpy(anchors_scaled).float().to(device)  # (E, T, M, D)

    macro_path = os.path.join(
        os.path.dirname(__file__),
        "generate_data",
        "dataset",
        "macro_feature.npy",
    )
    macro_feature = np.load(macro_path, allow_pickle=True)
    if macro_feature.ndim != 3 or macro_feature.shape[2] != 2:
        raise ValueError(
            f"macro_feature must have shape (n_traj, T, 2), got {macro_feature.shape}"
        )
    if macro_feature.shape[0] < E:
        raise ValueError(
            f"macro_feature has {macro_feature.shape[0]} trajectories, but data has {E}"
        )
    if macro_feature.shape[1] <= np.max(time_idx):
        raise ValueError(
            "macro_feature has fewer timesteps than requested by time_idx."
        )
    macro_feature = macro_feature[:E, time_idx, :]

    set_encoder = model.set_encoder
    set_encoder.eval()
    model.eval()

    for particle_type in particle_types:
        if particle_type not in (1, 2):
            raise ValueError(f"macro_feature only supports types 1 and 2, got {particle_type}")
        # ------------------------------------------------------------------
        # Run DeepSet encoder over time: Z(e,t,:) = set_encoder(anchors_e_t_type)
        # ------------------------------------------------------------------
        Z = []
        with torch.no_grad():
            for e_idx in range(E):
                mask_np = types[e_idx] == particle_type
                if not np.any(mask_np):
                    raise ValueError(
                        f"No particles of type {particle_type} in experiment {e_idx}."
                    )
                mask_t = torch.from_numpy(mask_np).to(device)
                anchors_e = anchors_scaled[e_idx].contiguous()  # (T, M, D)
                anchors_e = anchors_e[:, mask_t, :]             # (T, M_t, D)
                Z_e = set_encoder(anchors_e)                    # (T, z_dim)
                Z.append(Z_e.cpu().numpy())

        Z = np.stack(Z, axis=0)  # (E, T, z_dim)
        z_dim = Z.shape[-1]
        macro_idx = 0 if particle_type == 1 else 1
        macro_type = macro_feature[..., macro_idx][..., None]
        Z = np.concatenate([Z, macro_type], axis=-1)
        z_dim_total = Z.shape[-1]
        
    

        # ------------------------------------------------------------------
        # Decide which dimensions to plot
        # ------------------------------------------------------------------
        if max_dims_to_plot is None:
            num_dims = z_dim_total
        else:
            num_dims = min(z_dim_total, max_dims_to_plot)

        num_lines = min(num_experiments_to_plot, E)
        t_axis = time_idx

        # ------------------------------------------------------------------
        # Subplot layout: n_rows x 4 columns
        # ------------------------------------------------------------------
        n_cols = 4
        n_rows = math.ceil(num_dims / n_cols)

        fig, axes = plt.subplots(
            n_rows,
            n_cols,
            figsize=(4 * n_cols, 4 * n_rows),
            sharex=True,
            constrained_layout=True,
        )

        # Make axes a 2D array
        axes = np.array(axes)
        if n_rows == 1:
            axes = axes.reshape(1, n_cols)

        # Plot each dimension in its own subplot
        for d_idx in range(num_dims):
            row = d_idx // n_cols
            col = d_idx % n_cols
            ax = axes[row, col]

            for e_idx in range(num_lines):
                y = Z[e_idx, :, d_idx]  # (T,)
                ax.plot(t_axis, y, label=f"Exp {e_idx}", linewidth=1)

            if d_idx < z_dim:
                ax.set_title(f"z[{d_idx}]")
            else:
                ax.set_title("macro")
            ax.grid(alpha=0.3)

        # Turn off any unused axes (if num_dims is not a multiple of 4)
        total_slots = n_rows * n_cols
        for empty_idx in range(num_dims, total_slots):
            row = empty_idx // n_cols
            col = empty_idx % n_cols
            axes[row, col].axis("off")

        # Label only bottom row with x-axis
        for col in range(n_cols):
            axes[n_rows - 1, col].set_xlabel("Time step")

        out_path = os.path.join(
            output_dir,
            f"deepset_z_Z{z_dim}_type{particle_type}.png",
        )
        plt.savefig(out_path, dpi=300)
        plt.close(fig)

        print(f"[Z-plot] Saved DeepSet Z per-dimension trajectory grid to {out_path}")

# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate trained conditional NF (RQS/ARQS)")

    
    parser.add_argument(
        "--base_path",
        type=str,
        required=True,
        help="Directory to checkpoint", # e.g., trained_nflow/ARQS/epsilon0.01/deepsetPool_mean/n_layers_8/rqs_K_16/traj_200/
    )
    parser.add_argument("--z_dim", type=int, default=1, help="output dimension of deepset")

    parser.add_argument(
        "--data_path", type=str,
        default="generate_data/dataset/trajectories.npz",
        help="Path to the NPZ file with 'trajectories' and 'types'.",
    )
    parser.add_argument(
        "--time_step", type=int, nargs="+",
        default=[0, 10, 20, 50, -1],
        help="List of time steps to evaluate, e.g. --time_step 0 100 -1",
    )
    

    parser.add_argument("--num_experiments_to_plot", type=int, default=30)
    parser.add_argument("--kl_samples", type=int, default=1000)
    parser.add_argument("--device", type=int, default=0)

    args = parser.parse_args()

    ckpt_path = os.path.join(
        args.base_path,
        f"best_model_Z{args.z_dim}.pth",
    )

    output_dir = os.path.join(args.base_path, f"evaluation_Z{args.z_dim}")
    os.makedirs(output_dir, exist_ok=True)


    device = torch.device(
        f"cuda:{args.device}"
        if (args.device >= 0 and torch.cuda.is_available())
        else "cpu"
    )
    model, ckpt, normalization_info = build_model_from_checkpoint(ckpt_path, device)


    
    
    
    
    viz_samples(
        model, ckpt, normalization_info,
        data_path=args.data_path,
        exp_id=9,
        time_steps=args.time_step,
        kl_samples=args.kl_samples,
        output_dir=output_dir,
        device_id=args.device,
    )


    # # Example heatmap visualization for the same time steps as viz_samples
    # viz_flow_density_heatmap(
    #     model, ckpt, normalization_info,
    #     data_path=args.data_path,
    #     exp_id=9,
    #     time_steps=args.time_step,
    #     grid_size=200,
    #     output_dir=output_dir,
    #     device_id=args.device,
    # )



    

    plot_deepset_z_over_time(
        model, ckpt, normalization_info,
        data_path=args.data_path,
        num_experiments_to_plot=args.num_experiments_to_plot,
        max_dims_to_plot=None,  # or e.g. 4 if you only want first 4 dims
        output_dir=output_dir,
        device_id=args.device,
        # use_norm=args.plot_z_use_norm,
    )
