import os
import math
import argparse
import math

from typing import Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

from models import (
    AnchorGaussianMixtureND,
    SetEncoderND,
)
from train_nflow_arqs import NFlowsConditionalARQS  # AR RQS wrapper from training script


# ---------------------------------------------------------------------------
# Data loading: mirror train_nflow_arqs.make_dataloaders, but no DataLoader
# ---------------------------------------------------------------------------

def load_anchors_from_npz(
    data_path: str,
    traj_start: int = 80,
    n_traj: int = 20,
    time_step=None,
) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
    """
    Load anchor sets and knn distances from a .npz file.

    Expects:
      positions:     (E, T, M, D)
      knn_distances: (E, T, M)

    Args:
      time_step:
        - int: single time step (e.g. -1 for last)
        - list/tuple/np.ndarray of ints: multiple time steps, e.g. [0, 100, -1]

    Returns:
      ## pay attention, input is unnormalized data
      anchors_all: (E, S, M, D)  where S = number of selected time steps
      knn_all:     (E, S, M)
      time_idx:    (S,) numpy array of the resolved time indices (0-based)
    """
    
    input_data = np.load(data_path, allow_pickle=True)
    
    
    ## pay attention, input is unnormalized data


    
    E, T, M, D = input_data.shape

    # Default: last time step (backwards compatible)
    if time_step is None:
        time_step = [i for i in range(T)]

    # Normalize to a 1D numpy array of indices
    if isinstance(time_step, (int, np.integer)):
        time_idx = np.array([int(time_step)], dtype=int)
    else:
        time_idx = np.array(list(time_step), dtype=int)

    # Resolve negative indices and check bounds
    resolved = []
    for ts in time_idx:
        if ts < 0:
            ts_res = T + ts  # e.g. -1 -> T-1
        else:
            ts_res = ts
        if not (0 <= ts_res < T):
            raise IndexError(
                f"time_step {ts} (resolved to {ts_res}) is out of range [0, {T-1}]"
            )
        resolved.append(ts_res)
    time_idx = np.array(resolved, dtype=int)


    anchors_all = input_data[:, time_idx, :, :]  # (E, S, M, D)


    print(f"Using time indices (0-based): {time_idx.tolist()}")
    return anchors_all, time_idx

# ---------------------------------------------------------------------------
# Rebuild model + scale from checkpoint (RQS or ARQS)
# ---------------------------------------------------------------------------

def build_model_from_checkpoint(
    ckpt_path: str,
    device: torch.device,
) -> Tuple[torch.nn.Module, dict, torch.Tensor]:
    """
    Reconstruct the trained model (RQS or ARQS) from the checkpoint and
    also reconstruct the scale normalization used during training.

    Returns:
      model:  nn.Module in eval mode
      ckpt:   checkpoint dict
      scale:  (D,) tensor used for scaling x and anchors
    """
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)

    dim       = ckpt["dim"]
    model_type = ckpt.get("model_type", "ARQS").upper()
    z_dim     = ckpt.get("z_dim", ckpt.get("args", {}).get("z_dim", 4))
    hidden    = ckpt.get("hidden_dim", ckpt.get("args", {}).get("hidden_dim", 256))
    n_layers  = ckpt.get("n_layers", ckpt.get("args", {}).get("n_layers", 8))
    rqs_K     = ckpt.get("rqs_K", ckpt.get("args", {}).get("rqs_K", 16))
    rqs_B     = ckpt.get("rqs_B", ckpt.get("args", {}).get("rqs_B", 5.0))

    print(f"Checkpoint loaded from: {ckpt_path}")
    print(f"dim={dim}, z_dim={z_dim}, hidden_dim={hidden}, n_layers={n_layers}, "
          f"rqs_K={rqs_K}, rqs_B={rqs_B}, model_type={model_type}")

    # Shared DeepSet encoder
    set_encoder = SetEncoderND(
        in_dim=dim,
        hidden_dim=hidden,
        z_dim=z_dim,
    ).to(device)

    

    # nflows-based autoregressive RQS flow (same wrapper as training)
    model = NFlowsConditionalARQS(
        dim=dim,
        z_dim=z_dim,
        hidden_dim=hidden,
        n_layers=n_layers,
        K=rqs_K,
        B=rqs_B,
        set_encoder=set_encoder,
    ).to(device)


    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    # Reconstruct scale normalization used in training:
    # scale = [rqs_B / x_range, rqs_B / y_range]  (for D=2), from CLI args.
    normalization_info = ckpt.get("normalization_info", ckpt.get("args", {}).get("normalization_info", None))
    # print(f"\nNormalization info from checkpoint args: {normalization_info}")    

    # print(f"Using evaluation scale: {scale.cpu().numpy()}")
    return model, ckpt, normalization_info


# ---------------------------------------------------------------------------
# Global KL evaluation + 2D viz (similar spirit to test_old.py)
# ---------------------------------------------------------------------------
def pairwise_energy_step_force(X, a=4, b=0.1):
    """Vectorized pairwise interaction energy over the last two axes."""

    def _logcosh(x):
        ax = np.abs(x)
        return ax + np.log1p(np.exp(-2.0 * ax)) - np.log(2.0)

    X = np.asarray(X)
    if X.ndim < 2:
        raise ValueError("X must have shape (..., N, D) with at least 2 dimensions")

    *batch_shape, N, D = X.shape

    diff = X[..., :, None, :] - X[..., None, :, :]
    r = np.linalg.norm(diff, axis=-1)

    iu = np.triu_indices(N, k=1)
    rij = r[..., iu[0], iu[1]]

    x = a * (1.0 - rij)
    P = (1.0 / a) * _logcosh(x) + b * (1.0 - rij)
    energy = np.sum(P, axis=-1)

    if not batch_shape:
        return float(energy)
    return energy


def cmpt_energy(
    ckpt_path: str,
    data_path: str,
    exp_id: int = 0,
    time_steps: list = None,
    n_samples: int = 300,
    output_dir: str = "evaluation_nflow",
    device_id: int = 0,
):
    if time_steps is None:
        time_steps = [0, 10, 20, 50, -1]

    device = torch.device(
        f"cuda:{device_id}"
        if (device_id >= 0 and torch.cuda.is_available())
        else "cpu"
    )

    model, ckpt, normalization_info = build_model_from_checkpoint(ckpt_path, device)
    dim = ckpt["dim"]
    epsilon = ckpt.get("epsilon", ckpt.get("args", {}).get("epsilon", None))

    anchors_all, time_idx = load_anchors_from_npz(
        data_path=data_path,
        time_step=time_steps,
    )

    E, S, M, D = anchors_all.shape
    assert D == dim, f"Data dim={D} but model dim={dim}"
    if not (0 <= exp_id < E):
        raise ValueError(f"exp_id must be in [0, {E-1}], got {exp_id}")

    mean = normalization_info.get("mean").squeeze()
    std = normalization_info.get("std").squeeze()

    anchors_scaled = (anchors_all - mean) / std
    anchors_scaled = torch.from_numpy(anchors_scaled).float().to(device)  # (E,S,M,D)
    anchors_scaled = anchors_scaled[exp_id]  # (S,M,D)

    model.eval()
    target = AnchorGaussianMixtureND(
        anchors=anchors_scaled,
        epsilon=epsilon,
    )

    with torch.no_grad():
        samples_from_target = target.sample(n_samples)  # (S, n_samples, D)
        samples_from_flow = model.sample(
            n_samples,
            anchors_scaled.contiguous(),
        )  # (S, n_samples, D)

    anchors_scaled = anchors_scaled.cpu().numpy()
    samples_from_target = samples_from_target.cpu().numpy()
    samples_from_flow = samples_from_flow.cpu().numpy()

    input_denorm = anchors_scaled * std + mean
    samples_target_denorm = samples_from_target * std + mean
    samples_flow_denorm = samples_from_flow * std + mean

    input_energy = pairwise_energy_step_force(input_denorm)
    target_energy = pairwise_energy_step_force(samples_target_denorm)
    flow_energy = pairwise_energy_step_force(samples_flow_denorm)

    z_dim = ckpt.get("z_dim", ckpt.get("args", {}).get("z_dim", 4))
    deepset_pool = ckpt.get("deepset_pool", ckpt.get("args", {}).get("deepset_pool", None))
    n_traj = ckpt.get("n_traj", ckpt.get("args", {}).get("n_traj", None))
    out_dir = os.path.join(
        output_dir,
        f"epsilon{epsilon}",
        f"deepset_pool_{deepset_pool}",
        f"Z{z_dim}",
        f"n_traj_{n_traj}"
    )
    os.makedirs(out_dir, exist_ok=True)

    t_axis = np.array(time_idx, dtype=int)
    plt.figure(figsize=(6, 4))
    plt.plot(t_axis, input_energy, label="Simulation", marker="o", markersize=2)
    plt.plot(t_axis, target_energy, label="Target samples", marker="o", markersize=2)
    plt.plot(t_axis, flow_energy, label="Flow samples", marker="o", markersize=2)
    plt.xlabel("Time step")
    plt.ylabel("Pairwise interaction energy")
    plt.title(f"Energy comparison (exp {exp_id}, epsilon={epsilon})")
    plt.legend()

    ts_str = "_".join(str(int(t)) for t in time_idx)
    out_path = os.path.join(
        out_dir,
        f"energy_comparison_exp{exp_id}.png",
    )
    plt.savefig(out_path, dpi=300)
    plt.close()
    print(f"[Energy] Saved energy comparison plot to {out_path}")



def evaluate_KL(
    ckpt_path: str,
    data_path: str,
    time_step=None,    
    kl_samples: int = 2000,
    output_dir: str = "evaluation_nflow",
    device_id: int = 0,
):
    """
    Evaluate a trained conditional flow on a dataset of anchor sets using:

      - Approximate global KL(target || model) over ALL (experiment, timestep)
        pairs in the NPZ.
      - Optional 2D density plots for a few experiments.

    The target is AnchorKNNGaussianMixtureND with per-component epsilons from
    knn_distances and scale factor c from the checkpoint.
    """
    # Device
    device = torch.device(
        f"cuda:{device_id}"
        if (device_id >= 0 and torch.cuda.is_available())
        else "cpu"
    )

    # Load model + scale
    model, ckpt, normalization_info = build_model_from_checkpoint(ckpt_path, device)
    dim = ckpt["dim"]
    epsilon = ckpt.get("epsilon", ckpt.get("args", {}).get("epsilon", None))

    # Load anchors & KNN distances (E_total, M, D) and (E_total, M)
    
    anchors_all, time_idx = load_anchors_from_npz(
        data_path=data_path,
        time_step=time_step,
    )

    # Apply same scaling as in training, standardization
    mean = normalization_info.get("mean").squeeze()
    std = normalization_info.get("std").squeeze()
    anchors_all = (anchors_all - mean) / std
    
    scale_np = 1 / std

    
    scale = torch.from_numpy(scale_np).float().to(device)

    anchors_all = torch.from_numpy(anchors_all).float().to(device)  # (E, S, M, D)

    E, S, M, D = anchors_all.shape
    assert D == dim, f"Data dim={D} but model dim={dim}"


    # 
    z_dim     = ckpt.get("z_dim", ckpt.get("args", {}).get("z_dim", 4))
    epsilon = ckpt.get("epsilon", ckpt.get("args", {}).get("epsilon", None))
    deepset_pool = ckpt.get("deepset_pool", ckpt.get("args", {}).get("deepset_pool", None))
    n_traj = ckpt.get("n_traj", ckpt.get("args", {}).get("n_traj", None))
    output_dir = os.path.join(output_dir, f"epsilon{epsilon}", f"deepset_pool_{deepset_pool}", f"Z{z_dim}", f"n_traj_{n_traj}")
    os.makedirs(output_dir, exist_ok=True)

    # ----------------------------------------------------------------------
    # Estimate global KL(target || model) via Monte Carlo
    # ----------------------------------------------------------------------
    # The model is trained on scaled coordinates x_scaled = x * scale.
    # For a sample x ~ p(x) in ORIGINAL coordinates, the model's density
    # in original space is:
    #
    #   log q(x) = log q_s(scale * x) + sum_d log(scale_d)
    #
    # where q_s is the density in the scaled space (what model.log_prob returns).
    # --> change-of-variables formula with the Jacobian of a diagonal scaling.
    # We compute:
    #
    #   KL(p || q) ≈ E_p[log p(x) - log q(x)].
    #
    logdet_scale = scale.log().sum().item()

    

    model.eval()
    # Save KL to a text file alongside plots
    with open(os.path.join(output_dir, "kl_summary.txt"), "w") as f:
        f.write(f"evalute KL for {S} time steps\n")
    with torch.no_grad():
        for s in range(S):
            total_diff = 0.0
            total_count = 0
            for e in tqdm(range(E), desc="Estimating global KL"):
                anchors_e = anchors_all[e, s]         # (M,D)

                target = AnchorGaussianMixtureND(
                    anchors=anchors_e.unsqueeze(0),  # (1,M,D)
                    epsilon=epsilon
                )

                # Sample in original coordinates: x ~ p(x)
                samples = target.sample(kl_samples, exp_idx=0)  # (kl_samples, D)

                # Log p(x) under the target
                log_p = target.log_prob(samples, exp_idx=0)     # (kl_samples,)

                # # Log q(x) from the flow (original space)
                

                log_q_scaled = model.log_prob(
                    samples.unsqueeze(0),        # (1,B,D)
                    anchors_e.unsqueeze(0),  # (1,M,D)
                ).squeeze(0)                      # (B,)

                log_q = log_q_scaled + logdet_scale

                total_diff += (log_p - log_q).sum().item()
                total_count += samples.shape[0]

            mean_kl = total_diff / total_count
            print(f"\n time step: {time_idx[s]}, KL(target || model): {mean_kl:.6f}\n")
            with open(os.path.join(output_dir, "kl_summary.txt"), "a") as f:
                f.write(f"Time step {time_idx[s]}: KL(target || model): {mean_kl:.6f}\n")

    

    
def viz_samples(
    ckpt_path: str,
    data_path: str,
    exp_id: int = 0,    
    time_steps: list = [0, 10, 20, 50, -1],
    kl_samples: int = 2000,
    output_dir: str = "evaluation_nflow",
    device_id: int = 0,
):
    """
    Evaluate a trained conditional flow on a dataset of anchor sets using:

      - Approximate global KL(target || model) over ALL (experiment, timestep)
        pairs in the NPZ.
      - Optional 2D density plots for a few experiments.

    The target is AnchorGaussianMixtureND with per-component epsilons from
    knn_distances and scale factor c from the checkpoint.
    """
    # Device
    device = torch.device(
        f"cuda:{device_id}"
        if (device_id >= 0 and torch.cuda.is_available())
        else "cpu"
    )

    # Load model + scale
    model, ckpt, normalization_info = build_model_from_checkpoint(ckpt_path, device)
    dim = ckpt["dim"]
    epsilon = ckpt.get("epsilon", ckpt.get("args", {}).get("epsilon", None))

    # Load anchors & KNN distances (E_total, M, D) and (E_total, M)
    anchors_all, time_idx = load_anchors_from_npz(
        data_path=data_path,
    )

    # Apply same scaling as in training
    mean = normalization_info.get("mean").squeeze()
    std = normalization_info.get("std").squeeze()
    anchors_all = (anchors_all - mean) / std

    anchors_all = torch.from_numpy(anchors_all).float().to(device)  # (E, S, M, D)

    E, S, M, D = anchors_all.shape
    assert D == dim, f"Data dim={D} but model dim={dim}"


    # 
    z_dim     = ckpt.get("z_dim", ckpt.get("args", {}).get("z_dim", 4))
    epsilon = ckpt.get("epsilon", ckpt.get("args", {}).get("epsilon", None))
    n_layers = ckpt.get("n_layers", ckpt.get("args", {}).get("n_layers", None))
    deepset_pool = ckpt.get("deepset_pool", ckpt.get("args", {}).get("deepset_pool", None))
    n_traj = ckpt.get("n_traj", ckpt.get("args", {}).get("n_traj", None))
    output_dir = os.path.join(output_dir, f"epsilon{epsilon}", f"deepset_pool_{deepset_pool}", f"Z{z_dim}", f"n_traj_{n_traj}")
    os.makedirs(output_dir, exist_ok=True)


    # ----------------------------------------------------------------------
    # 2D visualization for a few experiments 
    # ----------------------------------------------------------------------

    # anchors_all: (E, S, M, D)
    print()
    print(f"anchors_all shape: {anchors_all.shape}")
    
    
    model.eval()
    
    anchors_all = anchors_all[exp_id, time_steps] # (T, M, D)
    print(f"anchors_all shape: {anchors_all.shape}")

    target = AnchorGaussianMixtureND(
        anchors=anchors_all, # (T,M,D)
        epsilon=epsilon
    )

    samples_from_target = target.sample(kl_samples)  # (T, kl_samples, D)
    samples_from_flow = model.sample(kl_samples, anchors_all.contiguous() )  # (T, kl_samples, D)

    # mask_oob = ((samples_from_target < -1.2) | (samples_from_target > 1.2)).any(dim=-1)
    # frac_oob = mask_oob.float().mean().item()
    # print(frac_oob)

    # denormalize back to original coordinates for plotting
    anchors_all = anchors_all.cpu().numpy()
    # anchors_all = (anchors_all + 1.0) * ranges / 2.0 + mins


    samples_from_target = samples_from_target.cpu().numpy()
    # samples_from_target = (samples_from_target + 1.0) * ranges / 2.0 + mins

    samples_from_flow = samples_from_flow.cpu().numpy()
    # samples_from_flow = (samples_from_flow + 1.0) * ranges / 2.0 + mins
    
    predict_samples = samples_from_flow

    # first row: xy
    # middle row: vxvy (Target)
    # second row: vxvy (Flow)
    n_cols = len(time_steps)
    n_rows = 1 # only position, no velocity
    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(4 * n_cols, 4 * n_rows),
        constrained_layout=True,
        sharex=False,
    )
    axes = np.array(axes).reshape(n_rows, n_cols)

    # Global y-range per row:
    # Row 1 (position y across all columns)
    pos_y_min = min(anchors_all[:, :, 1].min(), predict_samples[:, :, 1].min())
    pos_y_max = max(anchors_all[:, :, 1].max(), predict_samples[:, :, 1].max())
    # Rows 2 & 3 share velocity y-range (vy) across all columns
    

    for s_idx, t in enumerate(time_steps):
        
        # first row: xy
        ax = axes[0, s_idx]        
        ax.scatter(
            predict_samples[s_idx, :, 0],
            predict_samples[s_idx, :, 1],
            label="Flow",
            alpha=0.5,
            color="gray",
            s=0.5,
        )
        ax.scatter(
            anchors_all[s_idx, :, 0],
            anchors_all[s_idx, :, 1],
            label="simulation",
            alpha=1.0,
            color='black',
            s=5,
        )
        ax.set_title(f"Exp {exp_id}, t={t}: position")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.legend(fontsize='small', ncol=2)
        ax.grid(alpha=0.3)
        ax.set_aspect('equal', adjustable='box')
        ax.set_ylim(pos_y_min, pos_y_max)

    # title for the entire figure
    fig.suptitle(f"Flow distribution samples (epsilon={epsilon})", fontsize=12)

    out_path = os.path.join(output_dir, f"flow_distribution.png")
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

        



def viz_flow_density_heatmap(
    ckpt_path: str,
    data_path: str,
    exp_id: int = 0,
    time_steps: list = None,
    grid_size: int = 200,
    output_dir: str = "evaluation_nflow",
    device_id: int = 0,
):
    """Visualize the learned flow density as 2D heatmaps over multiple time steps.

    Uses the same time indexing convention as viz_samples: a list of time steps
    (which may include negative indices). For the selected experiment and all
    requested time steps, this function:

      - loads the corresponding anchors in original coordinates,
      - applies the same standardization used during training,
      - evaluates the flow log-density on a common 2D grid,
      - plots one heatmap per time step in a single figure.
    """

    # Device
    device = torch.device(
        f"cuda:{device_id}" if (device_id >= 0 and torch.cuda.is_available()) else "cpu"
    )

    # Load model + normalization
    model, ckpt, normalization_info = build_model_from_checkpoint(ckpt_path, device)
    dim = ckpt["dim"]
    if dim != 2:
        raise NotImplementedError(
            f"viz_flow_density_heatmap currently supports dim=2 only, got dim={dim}"
        )

    if time_steps is None:
        # Default to same style as viz_samples
        time_steps = [0, 10, 20, 50, -1]

    z_dim = ckpt.get("z_dim", ckpt.get("args", {}).get("z_dim", 4))
    epsilon = ckpt.get("epsilon", ckpt.get("args", {}).get("epsilon", None))
    deepset_pool = ckpt.get("deepset_pool", ckpt.get("args", {}).get("deepset_pool", None))
    n_traj = ckpt.get("n_traj", ckpt.get("args", {}).get("n_traj", None))
    # Output directory consistent with other evaluation utilities
    out_dir = os.path.join(
        output_dir,
        f"epsilon{epsilon}",
        f"deepset_pool_{deepset_pool}",
        f"Z{z_dim}",
        f"n_traj_{n_traj}"
    )
    os.makedirs(out_dir, exist_ok=True)

    # Load anchors for the requested time steps
    anchors_all, time_idx = load_anchors_from_npz(
        data_path=data_path,
        time_step=time_steps,
    )  # anchors_all: (E, S, M, D)

    E, S, M, D = anchors_all.shape
    assert D == dim, f"Data dim={D} but model dim={dim}"

    if not (0 <= exp_id < E):
        raise ValueError(f"exp_id must be in [0, {E-1}], got {exp_id}")

    # Select original (unstandardized) anchors for this experiment: (S, M, D)
    anchors_orig = anchors_all[exp_id]  # (S, M, D)

    # Apply same standardization as during training
    mean = normalization_info.get("mean").squeeze()  # (D,)
    std = normalization_info.get("std").squeeze()    # (D,)

    # Determine global plotting bounds from original anchors across all selected times
    x_min = anchors_orig[:, :, 0].min()
    x_max = anchors_orig[:, :, 0].max()
    y_min = anchors_orig[:, :, 1].min()
    y_max = anchors_orig[:, :, 1].max()
    dx = x_max - x_min
    dy = y_max - y_min
    margin_x = 0.1 * dx if dx > 0 else 0.1
    margin_y = 0.1 * dy if dy > 0 else 0.1
    x_min -= margin_x
    x_max += margin_x
    y_min -= margin_y
    y_max += margin_y

    xs = np.linspace(x_min, x_max, grid_size)
    ys = np.linspace(y_min, y_max, grid_size)
    Xg, Yg = np.meshgrid(xs, ys)
    grid_points = np.stack([Xg.ravel(), Yg.ravel()], axis=-1)  # (G, 2)

    # Standardize grid for the model, same as training
    grid_scaled = (grid_points - mean) / std  # (G, 2)
    grid_scaled_t = torch.from_numpy(grid_scaled).float().to(device).unsqueeze(0)  # (1, G, 2)

    # Precompute change-of-variables term to convert to original-space log-density
    scale_np = 1.0 / std
    logdet_scale = float(np.log(scale_np).sum())

    model.eval()

    # Evaluate log-density for each selected time step
    log_q_grids = []
    with torch.no_grad():
        for s in range(S):
            anchors_s = anchors_orig[s]  # (M, D) in original coords
            anchors_scaled_s = (anchors_s - mean) / std  # (M, D)
            anchors_scaled_t = (
                torch.from_numpy(anchors_scaled_s).float().to(device).unsqueeze(0)
            )  # (1, M, D)

            log_q_scaled = model.log_prob(
                grid_scaled_t,            # (1, G, 2)
                anchors_scaled_t,         # (1, M, 2)
            ).squeeze(0)                  # (G,)

            log_q = log_q_scaled.cpu().numpy() + logdet_scale  # (G,)
            log_q_grid = log_q.reshape(grid_size, grid_size)   # (grid_size, grid_size)
            log_q_grids.append(log_q_grid)

    # Use common color scale across time steps for comparability
    all_vals = np.stack(log_q_grids, axis=0)
    vmin = all_vals.min()
    vmax = all_vals.max()

    n_cols = S
    fig, axes = plt.subplots(
        1,
        n_cols,
        figsize=(4 * n_cols, 4),
        constrained_layout=True,
    )
    axes = np.array(axes).reshape(1, n_cols)

    for s in range(S):
        ax = axes[0, s]
        im = ax.imshow(
            log_q_grids[s],
            origin="lower",
            extent=[x_min, x_max, y_min, y_max],
            aspect="equal",
            cmap="viridis",
            vmin=vmin,
            vmax=vmax,
        )
        ts_resolved = int(time_idx[s])
        ax.set_title(f"t={ts_resolved}")
        ax.set_xlabel("x")
        if s == 0:
            ax.set_ylabel("y")

    # Add a single colorbar for the whole figure
    fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8, label="log density")
    fig.suptitle(f"Flow log-density heatmaps (exp {exp_id})", fontsize=12)

    # Use the raw time_steps argument for filename, but resolved indices for clarity
    ts_str = "_".join(str(int(t)) for t in time_idx)
    out_path = os.path.join(
        out_dir,
        f"flow_logdensity_heatmaps_exp{exp_id}.png",
    )
    plt.savefig(out_path, dpi=300)
    plt.close()

    print(f"[Heatmap] Saved flow log-density heatmaps to {out_path}")




def plot_deepset_z_over_time(
    ckpt_path: str,
    data_path: str,
    num_experiments_to_plot: int = 3,
    max_dims_to_plot: int = None,
    output_dir: str = "evaluation_nflow",
    device_id: int = 0,
):
    """
    Plot DeepSet latent Z over time, per dimension, for several experiments.

    Layout:
      - Up to 4 columns of dimensions
      - n_rows = ceil(num_dims / 4)
      - Each subplot: one dimension, lines = experiments over time
    """
    # Device
    device = torch.device(
        f"cuda:{device_id}"
        if (device_id >= 0 and torch.cuda.is_available())
        else "cpu"
    )

    # Load model + scale
    model, ckpt, normalization_info = build_model_from_checkpoint(ckpt_path, device)
    dim = ckpt["dim"]
    print(f"normalization_info: {normalization_info}")

    z_dim     = ckpt.get("z_dim", ckpt.get("args", {}).get("z_dim", 4))
    epsilon = ckpt.get("epsilon", ckpt.get("args", {}).get("epsilon", None))
    n_layers = ckpt.get("n_layers", ckpt.get("args", {}).get("n_layers", None))
    deepset_pool = ckpt.get("deepset_pool", ckpt.get("args", {}).get("deepset_pool", None))
    n_traj = ckpt.get("n_traj", ckpt.get("args", {}).get("n_traj", None))
    output_dir = os.path.join(output_dir, f"epsilon{epsilon}", f"deepset_pool_{deepset_pool}", f"Z{z_dim}", f"n_traj_{n_traj}")
    os.makedirs(output_dir, exist_ok=True)

    # ------------------------------------------------------------------
    # Load full trajectories (all timesteps)
    # ------------------------------------------------------------------
    anchors, time_idx = load_anchors_from_npz(data_path=data_path, time_step=None)
    print(f"Loaded time_idx shape: {time_idx.shape}")

    E, T, M, D = anchors.shape
    assert D == dim, f"Data dim={D} but model dim={dim}"

    print(f"[Z-plot] Loaded positions from {data_path} with shape {anchors.shape}")

    

    # Apply same scaling as in training
    mean = normalization_info.get("mean").squeeze()
    std = normalization_info.get("std").squeeze()
    anchors_scaled = (anchors - mean) / std
    

    anchors_scaled = torch.from_numpy(anchors_scaled).float().to(device)  # (E, T, M, D)

    # ------------------------------------------------------------------
    # Run DeepSet encoder over time: Z(e,t,:) = set_encoder(anchors_e_t)
    # ------------------------------------------------------------------
    set_encoder = model.set_encoder
    set_encoder.eval()
    model.eval()
    Z = []
    temperature = []
    with torch.no_grad():
        for e_idx in range(E):
            anchors_e = anchors_scaled[e_idx].contiguous()  # (T, M, D)
            Z_e = set_encoder(anchors_e)       # (T, z_dim)
            Z.append(Z_e.cpu().numpy())

    Z = np.stack(Z, axis=0)  # (E, T, z_dim)
    
    z_dim = Z.shape[-1]
        
    

    # ------------------------------------------------------------------
    # Decide which dimensions to plot
    # ------------------------------------------------------------------
    if max_dims_to_plot is None:
        num_dims = z_dim
    else:
        num_dims = min(z_dim, max_dims_to_plot)

    num_lines = min(num_experiments_to_plot, E)
    t_axis = np.arange(T)

    

    # ------------------------------------------------------------------
    # Subplot layout: n_rows x 4 columns
    # ------------------------------------------------------------------
    n_cols = 4
    n_rows = math.ceil(num_dims / n_cols)

    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(4 * n_cols, 4 * n_rows),
        sharex=True,
        constrained_layout=True,
    )

    # Make axes a 2D array
    axes = np.array(axes)
    if n_rows == 1:
        axes = axes.reshape(1, n_cols)

    # Plot each dimension in its own subplot
    for d_idx in range(num_dims):
        row = d_idx // n_cols
        col = d_idx % n_cols
        ax = axes[row, col]

        for e_idx in range(num_lines):
            y = Z[e_idx, :, d_idx]  # (T,)
            # no line, only scatter
            # ax.scatter(t_axis, y, label=f"Exp {e_idx}", s=3)
            ax.plot(t_axis, y, label=f"Exp {e_idx}", linewidth=1)

        ax.set_title(f"z[{d_idx}]")
        ax.grid(alpha=0.3)

        # Add legend only once to avoid clutter
        # if d_idx == 0:
        #     ax.legend()

    # Turn off any unused axes (if num_dims is not a multiple of 4)
    total_slots = n_rows * n_cols
    for empty_idx in range(num_dims, total_slots):
        row = empty_idx // n_cols
        col = empty_idx % n_cols
        axes[row, col].axis("off")

    # Label only bottom row with x-axis
    for col in range(n_cols):
        axes[n_rows - 1, col].set_xlabel("Time step")

    out_path = os.path.join(output_dir, f"deepset_z_Z{z_dim}.png")
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

    print(f"[Z-plot] Saved DeepSet Z per-dimension trajectory grid to {out_path}")

# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate trained conditional NF (RQS/ARQS)")

    parser.add_argument(
        "--ckpt_path",
        type=str,
        required=True,
        help="Path to checkpoint (e.g. trained_nflow/ARQS/epsilon0.1/deepsetPool_mean/200_traj/best_model_Z4.pth)",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default="generate_dataset/data/trajectories.npy",
        help="NPZ with positions + velocities",
    )
    parser.add_argument(
        "--time_step",
        type=int,
        nargs="+",
        default=[0, 10, 20, 50, -1],
        help="List of time steps to evaluate, e.g. --time_step 0 100 -1",
    )
    

    parser.add_argument("--num_experiments_to_plot", type=int, default=10)
    parser.add_argument("--kl_samples", type=int, default=300)
    parser.add_argument("--output_dir", type=str, default="evaluation_nflow_gmm2")
    parser.add_argument("--device", type=int, default=3)

    args = parser.parse_args()

    # evaluate_KL(
    #     ckpt_path=args.ckpt_path,
    #     data_path=args.data_path,
    #     time_step=np.arange(0, 119, 5),
    #     kl_samples=args.kl_samples,
    #     output_dir=args.output_dir,
    #     device_id=args.device,
    # )
    
    
    
    # viz_samples(
    #     ckpt_path=args.ckpt_path,
    #     data_path=args.data_path,
    #     exp_id=9,
    #     time_steps = np.arange(0, 119, 5),
    #     kl_samples=args.kl_samples,
    #     output_dir=args.output_dir,
    #     device_id=args.device,
    # )


    # # Example heatmap visualization for the same time steps as viz_samples
    # viz_flow_density_heatmap(
    #     ckpt_path=args.ckpt_path,
    #     data_path=args.data_path,
    #     exp_id=9,
    #     time_steps=np.arange(0, 119, 5),
    #     grid_size=200,
    #     output_dir=args.output_dir,
    #     device_id=args.device,
    # )

    # cmpt_energy(
    #     ckpt_path=args.ckpt_path,
    #     data_path=args.data_path,
    #     exp_id=9,
    #     time_steps=np.arange(0, 119, 1),
    #     n_samples=args.kl_samples,
    #     output_dir=args.output_dir,
    #     device_id=args.device,
    # )


    

    plot_deepset_z_over_time(
        ckpt_path=args.ckpt_path,
        data_path=args.data_path,
        num_experiments_to_plot=args.num_experiments_to_plot,
        max_dims_to_plot=None,  # or e.g. 4 if you only want first 4 dims
        output_dir=args.output_dir,
        device_id=args.device,
        # use_norm=args.plot_z_use_norm,
    )
