import os, sys
# os.environ["JAX_PLATFORMS"] = "cpu"
# os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ["JAX_LOG_COMPILES"] = "0"
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_analytical_sol_latency_estimator=false"
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from soap_jax import soap
import optax
import shutil
from tqdm import tqdm

from nnx_models.utils_lora import (add_lora_to_model, merge_lora_params, reset_lora_params) 
from nnx_models.mlp import MLP 
import flax.nnx as nnx

import orbax.checkpoint as ocp
from typing import Any, Sequence, Union, Callable, Tuple
from flax.core.frozen_dict import FrozenDict
from pathlib import Path
import absl.logging
from collections import deque
import numpy as np 
absl.logging.set_verbosity(absl.logging.FATAL)
jax.config.update("jax_default_matmul_precision", "highest")
from collections import deque

absl.logging.set_verbosity(absl.logging.FATAL)

# %%%%%%%%%%%%%%%% Jax CFD imports %%%%%%%%%%%%%%%%%
import jax_cfd.base as cfd
import jax_cfd.base.grids as grids
import jax_cfd.spectral as spectral
#####################################################

Array = Union[np.ndarray, jax.Array]
IntOrSequence = Union[int, Sequence[int]]
PyTree = Any
Array = Union[np.ndarray, jax.Array]
IntOrSequence = Union[int, Sequence[int]]
PyTree = Any

@jax.jit
def max_abs(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.max(jnp.abs(x - y))

@jax.jit
def l2_norm(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.linalg.norm(x - y)

@jax.jit
def mae(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.mean(jnp.abs(x - y))

@jax.jit
def mse(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.mean((x - y) ** 2)

@jax.jit
def pearson_corr(x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.corrcoef(x.flatten(), y.flatten())[0, 1]


class StreamingTemporalSelector:
    """
    Maintains the state of maxx_diff and curr_diff_sum across 
    multiple simulation windows to ensure global adaptive logic.
    """
    def __init__(self, corr_threshold=0.9, window_size=5):
        self.corr_threshold = corr_threshold
        self.window_size = window_size
        self.maxx_diff = 0.0
        self.curr_diff_sum = 0.0
        self.count = 0
        self.initialized = False

    def select_jump(self, window_real):
        # window_real[0] is current state, [1:] are candidates
        T = window_real.shape[0]
        
        # Initialization logic (Warmup baseline)
        if not self.initialized:
            diff = max_abs(window_real[1], window_real[0])
            self.maxx_diff = diff
            self.curr_diff_sum = diff
            self.count = 1
            self.initialized = True

        # Calculate the adaptive factor based on GLOBAL history
        curr_mean = self.curr_diff_sum / self.count
        factor = jnp.sqrt(self.maxx_diff / (curr_mean + 1e-8)).item()
        
        # Reference frame for this window selection
        ref_frame = window_real[0]
        diff0 = max_abs(window_real[1], window_real[0])
        
        best_idx = 1 # Minimum jump
        for i in range(1, T):
            norm_val = max_abs(window_real[i], ref_frame)
            corr_val = pearson_corr(window_real[i], ref_frame)
            
            # Use your logic: Norm ratio and Correlation threshold
            if (norm_val / (diff0 + 1e-8)) <= factor and corr_val >= self.corr_threshold:
                best_idx = i
            else:
                break
        
        # Update Global Metrics for the next window
        # We update based on the step we are actually taking
        actual_diff = max_abs(window_real[best_idx], window_real[best_idx-1])
        self.maxx_diff = max(self.maxx_diff, actual_diff)
        self.curr_diff_sum += actual_diff
        self.count += 1
        
        return best_idx

def loss(model, x, y):
    preds = model(x)
    return jnp.mean((preds - y)**2)

@nnx.jit(static_argnames='filter')
def train_step(model, optimizer, x, y, filter=nnx.Param):
    loss_step, grads = nnx.value_and_grad(loss, argnums=nnx.DiffState(0, filter=filter))(model, x, y)
    # loss_step, grads = nnx.value_and_grad(loss, argnums=nnx.DiffState(0, filter=nnx.LoRAParam))(model, x, y)
    optimizer.update(grads)
    return loss_step

# @nnx.jit(static_argnames='scheduler')
def define_optim(model, scheduler, filter_param=nnx.Param):
    # optim = nnx.Optimizer(model, optax.adamw(scheduler, weight_decay=1e-5), wrt=nnx.Param)
    optim = nnx.Optimizer(model, soap(learning_rate=scheduler, precondition_frequency=1), wrt=filter_param)
    return optim

@nnx.jit
def eval_step(pred, vorticity_step, min_v, max_v):
    fact = max_v - min_v
    mse_norm = jnp.mean((pred - (vorticity_step))**2)
    l2_rel_norm = jnp.linalg.norm(pred - vorticity_step) / jnp.linalg.norm(vorticity_step)
    l2_rel_original = jnp.linalg.norm(pred * fact - (vorticity_step * fact)) / jnp.linalg.norm(vorticity_step * fact + min_v)
    l2_error = jnp.linalg.norm(pred * fact - (vorticity_step * fact))
    max_abs = jnp.max(jnp.abs(pred * fact - (vorticity_step * fact)))
    psnr = 10 * jnp.log10(jnp.max(vorticity_step * fact + min_v)**2 / jnp.mean((pred * fact - (vorticity_step * fact))**2))
    return {
        "mse_norm": mse_norm,
        "l2_rel_norm": l2_rel_norm,
        "l2_rel_original": l2_rel_original,
        "l2_error": l2_error,
        "max_abs": max_abs,
        "psnr": psnr
    }

def define_scheduler(init_lr, len_coords, batch_size, epochs):
    scheduler = optax.cosine_decay_schedule(
        init_value=init_lr,
        decay_steps=(len_coords // batch_size) * epochs,
        alpha=1e-2
    )
    return scheduler

@nnx.jit 
def compiled_merge_and_reset(model):
    merge_lora_params(model)
    reset_lora_params(model)

########################### In-situ Enstrophy Flux based Temporal Selector #################################
def temporal_selection_enstrophy_flux(trajectory: jax.Array,
                               queue_size: int,
                               enstrophy_weight: jax.Array, # Now Mandatory
                               corr_metric: Callable[[jax.Array, jax.Array], jax.Array] = pearson_corr,
                               corr_treshold: float = 0.9) -> jax.Array:
    """
    Selects time steps where the primary criterion is the accumulation of 
    Enstrophy Flux rather than raw state deviation.
    """
    start_idx = 0
    queue = deque(maxlen=queue_size)
    T = trajectory.shape[0]
    queue.append(1)
    
    # NEW: maxx_diff is now based on the maximum change in enstrophy between steps
    # This identifies the most 'physically violent' transition in the current window
    def e_diff(idx1, idx2):
        return jnp.abs(enstrophy_weight[idx1] - enstrophy_weight[idx2])

    maxx_e_flux = e_diff(1, 0)
    curr_e_flux_sum = e_diff(1, 0)
    selected_timesteps = [0]
    last_idx = 1

    while start_idx < T - 1:
        if len(queue) == queue_size or queue[-1] == T - 1:
            # The 'budget' for how much enstrophy can change before we need a new snapshot
            avg_e_flux = curr_e_flux_sum / len(queue) # [commented for now]
            
            # The more 'unstable' the enstrophy (high max vs avg), the smaller the factor
            factor = jnp.sqrt(maxx_e_flux/(avg_e_flux + 1e-8)).item() # [commented for now] 
                
            best_idx = start_idx + 1
            # Initial enstrophy change from the last selected point
            e_diff0 = e_diff(best_idx, start_idx)
            
            for idx in queue:
                # 1. Physical Criterion: Has enstrophy changed too much?
                e_change = e_diff(idx, start_idx)
                
                # 2. Statistical Criterion: Is the field still correlated?
                corr_val = corr_metric(trajectory[idx], trajectory[start_idx])
                
                # We stop when the enstrophy has shifted significantly 
                # relative to the local flux 'factor'
                if e_change / (e_diff0 + 1e-8) <= factor and corr_val >= corr_treshold:
                    best_idx = idx
                else:
                    break
            
            start_idx = best_idx
            selected_timesteps.append(best_idx)
            
            # Reset queue stats for the next window
            while len(queue) > 0 and queue[0] <= start_idx:
                prev_idx = queue.popleft()
                curr_e_flux_sum -= e_diff(prev_idx, max(0, prev_idx - 1))

        if last_idx + 1 < T:
            queue.append(last_idx + 1)
            last_idx += 1
            step_flux = e_diff(queue[-1], queue[-1] - 1)
            maxx_e_flux = max(maxx_e_flux, step_flux)
            curr_e_flux_sum += step_flux

    return jnp.array(selected_timesteps)
######################################################################################################

########################### Static Regulator In-situ Enstrophy Flux based Temporal Selector #################################

def temporal_selection_enstrophy_flux_regulator_ablation(trajectory: jax.Array,
                               queue_size: int,
                               enstrophy_weight: jax.Array,
                               corr_metric: Callable = pearson_corr,
                               corr_treshold: float = 0.9) -> jax.Array:
    start_idx = 0
    queue = deque(maxlen=queue_size)
    T = trajectory.shape[0]
    queue.append(1)
    
    def e_diff(idx1, idx2):
        return jnp.abs(enstrophy_weight[idx1] - enstrophy_weight[idx2])

    maxx_e_flux = e_diff(1, 0)
    curr_e_flux_sum = e_diff(1, 0)
    selected_timesteps = [0]
    last_idx = 1

    while start_idx < T - 1:
        if len(queue) == queue_size or queue[-1] == T - 1:
            avg_e_flux = curr_e_flux_sum / len(queue)
            factor = jnp.sqrt(maxx_e_flux/(avg_e_flux + 1e-8)).item() 
                
            # --- BINARY SELECTION LOGIC ---
            # Default to the immediate next step (Stride 1)
            best_idx = start_idx + 1
            e_diff0 = e_diff(best_idx, start_idx)
            
            # Check only the very last element in the current queue (The Window Edge)
            edge_idx = queue[-1]
            e_change_edge = e_diff(edge_idx, start_idx)
            corr_edge = corr_metric(trajectory[edge_idx], trajectory[start_idx])
            
            # If the full window jump satisfies the physics/stats criteria, 
            # we take the big jump. Otherwise, best_idx stays as start_idx + 1.
            if e_change_edge / (e_diff0 + 1e-8) <= factor and corr_edge >= corr_treshold:
                best_idx = edge_idx
            # ------------------------------
            
            start_idx = best_idx
            selected_timesteps.append(best_idx)
            
            while len(queue) > 0 and queue[0] <= start_idx:
                prev_idx = queue.popleft()
                curr_e_flux_sum -= e_diff(prev_idx, max(0, prev_idx - 1))

        if last_idx + 1 < T:
            queue.append(last_idx + 1)
            last_idx += 1
            step_flux = e_diff(queue[-1], queue[-1] - 1)
            maxx_e_flux = max(maxx_e_flux, step_flux)
            curr_e_flux_sum += step_flux

    return jnp.array(selected_timesteps)
###################################################################################################################

############################# In-Situ Adaptive Temporal Neural Field Training for 2D Kolmogorov flows #################################
def generate_streaming_adaptive_grid_with_nf(config: dict, train_config: dict, nf_model=None):
    # --- 1. SETUP ---
    ns = config.get("ns", 128)
    tf = config.get("tf", 25.0)
    outer_steps = config.get("outer_steps", 1000) 
    return_real_space = config.get("return_real_space", True)
    W = config.get("window_size", 5)
    
    grid = grids.Grid((ns, ns), domain=config.get("domain", ((0, 2*np.pi), (0, 2*np.pi))))
    dt_sim = cfd.equations.stable_time_step(config.get('max_velocity'), 0.5, config.get('viscosity'), grid)
    
    equation = spectral.equations.NavierStokes2D(
        viscosity=config.get('viscosity'), grid=grid, smooth=config.get('anti_aliasing', True))
    step_fn = config.get('mol')(equation, dt_sim)

    inner_steps = int((tf / dt_sim) / outer_steps)
    outer_step_fn = cfd.funcutils.repeated(step_fn, inner_steps)
    rollout_fn = jax.jit(cfd.funcutils.trajectory(outer_step_fn, 1, start_with_input=False))

    # Grid coordinates for Neural Field training (x, y)
    coords = jnp.stack((grid.rfft_mesh()[0], grid.rfft_mesh()[1]), axis=-1) # remains in spectral space with shape (ns, ns//2 + 1, 2)
    if return_real_space:
        coords = jnp.stack((jnp.fft.irfftn(grid.rfft_mesh()[0]), jnp.fft.irfftn(grid.rfft_mesh()[1])), axis=-1).reshape(-1, 2)
        
    # --- 2. INITIALIZATION ---
    v0 = cfd.initial_conditions.filtered_velocity_field(
        jax.random.PRNGKey(config.get("seed")), grid, config.get('max_velocity'), 4)
    current_vort_hat = jnp.fft.rfftn(cfd.finite_differences.curl_2d(v0).data)
    v0_real = jnp.fft.irfftn(current_vort_hat, s=(ns, ns))
    min_v = jnp.min(v0_real)
    max_v = jnp.max(v0_real)
    v1_hat = outer_step_fn(current_vort_hat)
    v1_real = jnp.fft.irfftn(v1_hat, s=(ns, ns))
    
    # Global metrics for the selector
    GLOBAL_DIFF0 = max_abs(v1_real, v0_real)
    maxx_diff = GLOBAL_DIFF0
    curr_diff_sum = GLOBAL_DIFF0
    count = 1

    final_time_indices = [0]
    current_t_idx = 0
    last_selected_vort = 0

    print(f"Starting In-Situ Adaptive Neural Field Training...")
    # optim = define_lr_optim(nf_model)

    checkpointer = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
    json_checkpointer = ocp.AsyncCheckpointer(ocp.JsonCheckpointHandler())

    # Train first snapshot (t=0)

    coords = jnp.stack(
        jnp.meshgrid(
            jnp.arange(config["ns"]) * 2 * jnp.pi / config["ns"],
            jnp.arange(config["ns"]) * 2 * jnp.pi / config["ns"],
            indexing='ij'
        )
    , axis=-1).reshape(-1, 2)

    last_selected_vort = v0_real
    scheduler = define_scheduler(init_lr=1e-3, 
                                 len_coords=coords.shape[0], 
                                 batch_size=train_config["batch_size"], 
                                 epochs=train_config["first_slice_epochs"])
    
    permute_idx = jax.jit(lambda coords, key: jax.random.permutation(key, coords.shape[0]))

    optim = define_optim(nf_model, scheduler, filter_param=nnx.Param)
    key = jax.random.PRNGKey(0)
    norm_data = (v0_real.reshape(-1, 1) - min_v) / (max_v - min_v)
    for epoch in tqdm(range(0, train_config["first_slice_epochs"]), desc='Lora compression (Warmup)'):
         # I have applied one global normalization coming from the initial condition
        key, _ = jax.random.split(key)
        perm_idx = permute_idx(coords, key)
        loss_avg = 0.0
        for batch in range(0, coords.shape[0], train_config["batch_size"]):
            if batch + train_config["batch_size"] > coords.shape[0]:
                break
            batch_idx = perm_idx[batch:batch + train_config["batch_size"]]
            loss_step = train_step(nf_model, optim, coords[batch_idx], norm_data[batch_idx], nnx.Param)
            loss_avg += loss_step.item()
        
        if epoch % 50 == 0:
            print(f"Epoch {epoch}, loss: {loss_avg / (coords.shape[0] // train_config["batch_size"]):.2e}")
        
    loss_stats = eval_step(nf_model(coords), norm_data, min_v, max_v)
    loss_stats = {k: v.item() for k, v in loss_stats.items()}
    print(f"Warmup Loss Stats at t=0: MSE: {loss_stats['mse_norm']:.2e}")
    print(f"Warmup Loss Stats at t=0: L2 Rel normalized slice: {loss_stats['l2_rel_norm']:.2e}")
    print(f"Warmup Loss Stats at t=0: L2 Rel original: {loss_stats['l2_rel_original']:.2e}")
    print(f"Warmup Loss Stats at t=0: Max Abs Error: {loss_stats['max_abs']:.2e}")
    print(f"Warmup Loss Stats at t=0: PSNR: {loss_stats['psnr']:.2f} dB")

    if train_config["save_file"] is not None:
        if os.path.exists(os.path.join(train_config["save_file"], "full_model_t=0")):
            shutil.rmtree(os.path.join(train_config["save_file"], "full_model_t=0"))
        ckpt_dir = os.path.join(train_config["save_file"], "full_model_t=0", "state")
        loss_stats_dir = os.path.join(train_config["save_file"], "full_model_t=0", "loss_stats")
        json_checkpointer.save(loss_stats_dir, loss_stats)
        checkpointer.save(ckpt_dir, nnx.state(nf_model))
        checkpointer.wait_until_finished()
        json_checkpointer.wait_until_finished()
        print(f"Saved initial model checkpoint at: {train_config["save_file"]}/full_model_t=0")

    prev_vort = v0_real

    scheduler = define_scheduler(init_lr=1e-2, 
                                 len_coords=coords.shape[0], 
                                 batch_size=train_config["batch_size"], 
                                 epochs=train_config["lora_epochs"])

    add_lora_to_model(nf_model, lora_rank=train_config["lora_rank"])

    def_opt = lambda model: define_optim(model, scheduler, filter_param=nnx.LoRAParam)

    def_opt = nnx.jit(def_opt)

    current_t_idx = 1

    curr_stats = {}

    curr_stats["curr_diff_sum"] = 0.0
    curr_stats["step_diff"] = 0.0
    curr_stats["count"] = 0

    prev_stats = curr_stats.copy()

    prev_stats["count"] = 1
    prev_stats["curr_diff_sum"] = GLOBAL_DIFF0

    last_selected_vort = prev_vort
    diff0 = 0.0
    # --- 3. ADAPTIVE IN-SITU LOOP ---
    for current_t_idx in tqdm(range(current_t_idx, outer_steps), desc='Adaptive Temporal Selection with NF Training'):
        # A. Look-ahead windowing
        _, traj_hat = rollout_fn(current_vort_hat)
        traj_real = jnp.fft.irfftn(traj_hat, s=(ns, ns), axes=(1, 2))
        
        # B. Temporal Selection
        curr_mean = prev_stats["curr_diff_sum"] / prev_stats["count"]
        factor = jnp.sqrt(maxx_diff / (curr_mean + 1e-8)).item()
        
        norm_val = max_abs(traj_real[0], last_selected_vort)
        corr_val = pearson_corr(traj_real[0], last_selected_vort)

        curr_stats["curr_diff_sum"] += max_abs(traj_real[0], prev_vort)
        maxx_diff = max(float(maxx_diff), float(max_abs(traj_real[0], prev_vort)))
        curr_stats["count"] += 1
        if curr_stats["count"] == 1:
            diff0 = curr_stats["curr_diff_sum"]

        apply_lora = None

        if (norm_val / (diff0 + 1e-8)) > factor or corr_val < config["corr_threshold"]:
            last_selected_vort = prev_vort
            prev_stats = curr_stats.copy()
            curr_stats["count"] = 1
            curr_stats["curr_diff_sum"] = max_abs(traj_real[0], prev_vort)
            diff0 = max_abs(traj_real[0], prev_vort)
            final_time_indices.append(current_t_idx - 1)
            apply_lora = prev_vort
        elif curr_stats["count"] == W or current_t_idx == outer_steps - 1:
            last_selected_vort = traj_real[0]
            prev_stats = curr_stats.copy()
            curr_stats["count"] = 0
            curr_stats["curr_diff_sum"] = 0.0
            final_time_indices.append(current_t_idx)
            apply_lora = traj_real[0]
    
        # --- C. IN-SITU NEURAL FIELD TRAINING ---
        if apply_lora is not None:
            optim = def_opt(nf_model)
            
            norm_data = (apply_lora.reshape(-1, 1) - min_v) / (max_v - min_v)
            for epoch in tqdm(range(0, train_config["lora_epochs"]), desc='Lora compression'):
                key, _ = jax.random.split(key)
                perm_idx = permute_idx(coords, key)
                for batch in range(0, coords.shape[0], train_config["batch_size"]):
                    batch_idx = perm_idx[batch:batch + train_config["batch_size"]]
                    loss_step = train_step(nf_model, optim, coords[batch_idx], norm_data[batch_idx], nnx.LoRAParam)
            
            loss_stats = eval_step(nf_model(coords), norm_data, min_v, max_v)
            loss_stats = {k: v.item() for k, v in loss_stats.items()}
            
            print(f"Lora Loss Stats at t={current_t_idx}: MSE: {loss_stats['mse_norm']:.2e}")
            print(f"Lora Loss Stats at t={current_t_idx}: L2 Rel normalized slice: {loss_stats['l2_rel_norm']:.2e}")
            print(f"Lora Loss Stats at t={current_t_idx}: L2 Rel original: {loss_stats['l2_rel_original']:.2e}")
            print(f"Lora Loss Stats at t={current_t_idx}: Max Abs Error: {loss_stats['max_abs']:.2e}")
            print(f"Lora Loss Stats at t={current_t_idx}: PSNR: {loss_stats['psnr']:.2f} dB")

            if train_config["save_file"] is not None:
                if os.path.exists(os.path.join(train_config["save_file"], f"lora_model_t={current_t_idx}")):
                    shutil.rmtree(os.path.join(train_config["save_file"], f"lora_model_t={current_t_idx}"))
                ckpt_dir = os.path.join(train_config["save_file"], f"lora_model_t={current_t_idx}", "state")
                loss_stats_dir = os.path.join(train_config["save_file"], f"lora_model_t={current_t_idx}", "loss_stats")
                json_checkpointer.save(loss_stats_dir, loss_stats)
                checkpointer.save(ckpt_dir, nnx.state(nf_model, nnx.LoRAParam))
                print(f"Saved model checkpoint at: {train_config['save_file']}/lora_model_t={current_t_idx}")
        # E. Sync Simulation Carry
        prev_vort = traj_real[0]
        current_vort_hat = traj_hat[0]
        if current_t_idx % 100 == 0 or current_t_idx >= outer_steps:
            print(f"Time: {current_t_idx}/{outer_steps} | Model Trained on {len(final_time_indices)} snapshots")

    checkpointer.close()
    json_checkpointer.close()
    return final_time_indices


if __name__ == '__main__': 
    # --- 4. EXECUTION ---
    ns = 1024
    domain = ((0, 2 * jnp.pi), (0, 2 * jnp.pi))
    eta = 1e-3
    v = 5
    tf = 25.0
    out_steps = 1000
    snapshots = 1000

    config = {
        "ns": ns,
        "domain": domain,
        "viscosity": eta,
        "max_velocity": v,
        "tf": tf,
        "outer_steps": out_steps,
        "seed": 42,
        'temporal_snapshots': snapshots,
        "vorticity_temporal_slice": slice(0, out_steps, out_steps // snapshots),
        "anti_aliasing": True,
        "return_real_space": True,
        "mol": spectral.time_stepping.crank_nicolson_rk4,
        "precision": jnp.float32,
        "data_save_dir": "..",
        "window_size": 5,
        "corr_threshold": 0.9,
    }

    training_config = {
        "first_slice_epochs": 401,
        "lora_epochs": 50,
        "lora_rank": 16,
        "batch_size": 100000,
        "save_file": None,
    }
    
    model = MLP(
        input_dim=2,
        output_dim=1,
        hidden_dim=256,
        fourier_emb_scale=7.0,
        num_hidden_layers=6
    )

    snapshots = generate_streaming_adaptive_grid_with_nf(config, training_config, model) ## should pick up around 50% of snapshots with window size 5

    # %%%%%%%%%%%%%%%% PLOTTER (Incase needed) %%%%%%%%%%%%%%%%% 
    # import matplotlib.pyplot as plt
    # import seaborn as sns
    # def plot_combined_ns_analysis(vorticity, time_full, selected_indices, snapshot_indices, window_size=5):
    #     """
    #     Combines Ground Truth snapshots with Temporal Selection analysis.
    #     JAX-compatible indexing version.
    #     """
    #     nx = vorticity.shape[1]
        
    #     # Ensure indices are JAX/NumPy arrays for indexing
    #     sel_idx_arr = jnp.array(selected_indices)
    #     snap_idx_arr = jnp.array(snapshot_indices)
        
    #     # Calculate the strides
    #     strides = np.diff(selected_indices)
    #     time_midpoints = time_full[sel_idx_arr[:-1]]
    #     num_snaps = len(snapshot_indices)
        
    #     fig = plt.figure(figsize=(18, 12), dpi=200)
    #     gs = fig.add_gridspec(3, num_snaps, height_ratios=[2, 2.0, 3.5], hspace=0.3)

    #     # --- TOP ROW: GT Snapshots ---
    #     for i, idx in enumerate(snapshot_indices):
    #         ax_snap = fig.add_subplot(gs[0, i])
            
    #         # Real space conversion
    #         real_vort = vorticity[idx]
            
    #         im = ax_snap.imshow(
    #             real_vort, 
    #             aspect='equal', origin='lower', cmap=sns.cm.icefire,
    #             extent=[0, 2*np.pi, 0, 2*np.pi]
    #         )
    #         ax_snap.set_title(f"$t = {time_full[idx]:.2f}$s", fontsize=18, fontweight='bold')
    #         ax_snap.axis('off')
            
    #         # Visual cue lines
    #         ax_pos = ax_snap.get_position()
    #         x_center = (ax_pos.x0 + ax_pos.x1) / 2
    #         fig.add_artist(plt.Line2D(
    #             [x_center, x_center], 
    #             [ax_pos.y0 - 0.01, 0.45], 
    #             color='black', transform=fig.transFigure, alpha=0.15, linestyle='--'
    #         ))

    #     # --- MIDDLE ROW: Rug Plot ---
    #     ax_rug = fig.add_subplot(gs[1, :])
    #     # FIX: Use jnp.array for indexing
    #     ax_rug.vlines(time_full[sel_idx_arr], 0, 12, color='tab:blue', alpha=0.3, linewidth=0.3)
        
    #     # FIX: Use jnp.array for indexing
    #     # ax_rug.scatter(time_full[snap_idx_arr], [1.1]*num_snaps, 
    #     #                color='black', marker='v', s=100, zorder=5, clip_on=False)
        
    #     ax_rug.set_xlim(float(time_full[0]), float(time_full[-1]))
    #     ax_rug.set_yticks([])
    #     ax_rug.spines['top'].set_visible(False)
    #     ax_rug.spines['right'].set_visible(False)
    #     ax_rug.spines['left'].set_visible(False)
    #     ax_rug.set_title("Adaptive Temporal Sampling", loc='center', fontsize=28, pad=15)

    #     # --- BOTTOM ROW: Stride Evolution ---
    #     ax_stride = fig.add_subplot(gs[2, :])
    #     ax_stride.scatter(time_midpoints, strides, alpha=0.8, s=40, c=strides, cmap='viridis_r')
        
    #     window_avg = 50
    #     if len(strides) > window_avg:
    #         rolling_stride = np.convolve(strides, np.ones(window_avg)/window_avg, mode='valid')
    #         ax_stride.plot(time_midpoints[window_avg-1:], rolling_stride, color='tab:red', lw=2, label='Mean Stride')

    #     ax_stride.set_ylabel("Step Jump ($\Delta t$)", fontsize=28)
    #     ax_stride.set_xlabel("Simulation Time ($t$)", fontsize=28)
    #     ax_stride.set_xlim(float(time_full[0]), float(time_full[-1]))
    #     ax_stride.set_yticks([1, window_size])
    #     ax_stride.set_yticklabels(['1 (Dense)', f'{window_size} (Sparse)'], fontsize=22)
    #     ax_stride.tick_params(axis='x', labelsize=20)
    #     ax_stride.grid(axis='y', linestyle='--', alpha=0.3)
    #     ax_stride.legend(fontsize=22, loc='center right', frameon=False)

    #     # Colorbar
    #     cbar_ax = fig.add_axes([0.93, 0.65, 0.012, 0.23]) 
    #     fig.colorbar(im, cax=cbar_ax).set_label('Vorticity ($\omega$)', fontsize=26)

    #     return fig
    
    # indices = [200, 400, 600, 1000, 3500, 4200, -200, -100, -1]
    # #indices = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    # time_full = time_info[0] * jnp.arange(time_info[1]) * time_info[2] # dt_sim * jnp.arange(outer_steps) * inner_steps
    # fig = plot_combined_ns_analysis(snapshots, time_full, times, indices, window_size=5)
    # plt.show()