# %%
import h5py
import numpy as np
from sklearn.preprocessing import StandardScaler
import warnings
import torch
import snntorch as snn
from snntorch import surrogate
import numpy as np
import torch
import h5py
from sklearn.preprocessing import StandardScaler
import warnings
import os
import random
import pandas as pd
import matplotlib.pyplot as plt
import json
import time
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from filterpy.kalman import KalmanFilter
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from scipy.signal import butter, sosfiltfilt
from sklearn.linear_model import Ridge
import math
from scipy.signal import correlate, correlation_lags
import torch.nn.functional as F

# =============================================================
# Evaluation configuration
# -------------------------------------------------------------
# If EVAL_MOVEMENT_ONLY is True, metrics are computed only on bins
# where cursor speed exceeds MOVEMENT_THRESHOLD_NORM (value in the
# *normalised* velocity space produced by StandardScaler).  If False,
# all time bins contribute to the reported correlation and R².
# -------------------------------------------------------------
EVAL_MOVEMENT_ONLY = False          # Set to False for full-trajectory evaluation
MOVEMENT_THRESHOLD_NORM = 0.2      # ≈ 2 cm s⁻¹ after inverse transform
MOVEMENT_THRESHOLD_CM_S = 2.0      # threshold used directly on de-normalised velocities for plots
# =============================================================

def load_data(
    mat_file_path='zenodo_dataset/indy_20160419_01.mat',
    bin_width_s=0.064,   # 64 ms window to match Makin et al.
    stride_s=0.064,      # 64 ms stride for non-overlapping bins
    test_split_ratio=0.2,
    spike_processing='rate',
    normalize_spikes=False,
    verbose=True
):
    if verbose:
        print("=== load_data parameters ===")
        print(f" mat_file_path    = {mat_file_path}")
        print(f" bin_width_s      = {bin_width_s}")
        print(f" stride_s         = {stride_s}")
        print("  (Processing is now zero-lag with non-overlapping bins)")
        print(f" test_split_ratio = {test_split_ratio}")
        print(f" spike_processing = {spike_processing}")
        print(f" normalize_spikes = {normalize_spikes}")
        print("============================\n")

    with h5py.File(mat_file_path, 'r') as f:
        t_vec = f['t'][()]
        if t_vec.ndim == 1:
            t_vec = t_vec[:, None]
        elif t_vec.shape[0] < t_vec.shape[1]:
            t_vec = t_vec.T
        fs = 1.0 / np.mean(np.diff(t_vec.squeeze()))

        ds = f['spikes']
        if h5py.check_dtype(ref=ds.dtype) is not None:
            refs = ds[()]
            if refs.ndim == 2 and refs.shape[1] != 1:
                n_units, get_ref = refs.shape[1], lambda i: refs[0,i]
            elif refs.ndim == 2 and refs.shape[0] != 1:
                n_units, get_ref = refs.shape[0], lambda i: refs[i,0]
            else:
                refs = refs.ravel()
                n_units, get_ref = refs.shape[0], lambda i: refs[i]
            T_orig = t_vec.shape[0]
            spikes = np.zeros((T_orig, n_units), dtype=np.float32)
            for i in range(n_units):
                r = get_ref(i)
                if isinstance(r, h5py.Reference):
                    times = f[r][()].squeeze()
                    idx = np.searchsorted(t_vec.squeeze(), times, 'left')
                    idx = idx[idx < T_orig]
                    spikes[np.unique(idx), i] = 1.0
        else:
            spikes = ds[()]
            if spikes.shape[0] < spikes.shape[1]:
                spikes = spikes.T
            spikes = spikes.astype(np.float32)

        cp = f['cursor_pos'][()]
        if cp.shape[0] < cp.shape[1]:
            cp = cp.T
        
        # --- Filter cursor position before calculating velocity ---
        # Per Makin et al. 2019, apply a 4th-order, 10Hz, non-causal (zero-phase)
        # Butterworth filter to the cursor position data before differentiation.
        sos = butter(4, 10, 'low', fs=fs, output='sos')
        cp_filtered = sosfiltfilt(sos, cp, axis=0)
        vel = np.gradient(cp_filtered, t_vec.squeeze(), axis=0)

    # --- Binning --- 
    bin_width = int(round(bin_width_s * fs))
    stride = int(round(stride_s * fs))
    T = spikes.shape[0]
    num_bins = (T - bin_width) // stride + 1

    # +++ Handle Case of Zero Bins +++
    if num_bins <= 0:
        warnings.warn(f"Not enough data points ({T}) for the given bin ({bin_width}) and stride ({stride}) in file {mat_file_path}. Returning empty data structure.", RuntimeWarning)
        # Return a dictionary indicating failure and providing empty structures where possible
        n_neurons = spikes.shape[1] # Get neuron count even if no bins
        return {
            'error': 'Insufficient data for binning',
            'X_binned_raw': np.empty((0, n_neurons), dtype=np.float32),
            'y_binned_raw': np.empty((0, vel.shape[1]), dtype=np.float32), # Use original vel shape for columns
            't_binned_raw': np.empty((0,), dtype=np.float64),
            # Include other keys expected by caller, filled with None or empty equivalents
            'X_train': torch.empty((0, n_neurons), dtype=torch.float32),
            'y_train': torch.empty((0, vel.shape[1]), dtype=torch.float32),
            'X_test':  torch.empty((0, n_neurons), dtype=torch.float32),
            'y_test':  torch.empty((0, vel.shape[1]), dtype=torch.float32),
            't_train': np.empty((0,), dtype=np.float64),
            't_test':  np.empty((0,), dtype=np.float64),
            'X_train_norm': torch.empty((0, n_neurons), dtype=torch.float32),
            'y_train_norm': torch.empty((0, vel.shape[1]), dtype=torch.float32),
            'X_test_norm':  torch.empty((0, n_neurons), dtype=torch.float32),
            'y_test_norm':  torch.empty((0, vel.shape[1]), dtype=torch.float32),
            'spike_scaler': None,
            'velocity_scaler': None,
            'bin_width_s': bin_width_s,
            'stride_s': stride_s
        }
    # +++ End Handle Case of Zero Bins +++

    spike_windows = np.stack([spikes[i*stride : i*stride+bin_width] for i in range(num_bins)])
    spikes_b_raw = spike_windows.sum(axis=1).astype(np.float32) # Raw spike counts

    vel_windows = np.stack([vel[i*stride : i*stride+bin_width] for i in range(num_bins)])
    vel_b_raw = vel_windows.mean(axis=1).astype(np.float32)    # Raw mean velocity in bins

    t_binned = np.array([t_vec[i*stride][0] for i in range(num_bins)])

    # --- Spike Processing (Applied After Binning) ---
    if spike_processing == 'binary':
        X_processed = (spikes_b_raw > 0).astype(np.float32)
    elif spike_processing == 'count':
        X_processed = spikes_b_raw # Already counts
    elif spike_processing == 'rate':
        X_processed = (spikes_b_raw / bin_width_s).astype(np.float32)
    elif spike_processing == 'poisson':
        lam = spikes_b_raw / bin_width_s
        X_processed = np.random.poisson(lam).astype(np.float32)
    else:
        raise ValueError(f"Unknown spike_processing: {spike_processing}")

    # --- Spike Normalization (Optional, applied AFTER processing and lag) ---
    if normalize_spikes:
        # Note: Normalizing here before splitting might leak test info
        # It's generally better to normalize after splitting, fitting scaler only on train
        mn = X_processed.min(axis=0, keepdims=True)
        mx = X_processed.max(axis=0, keepdims=True)
        rng = np.where(mx - mn > 0, mx - mn, 1.0)
        X_processed = (X_processed - mn) / rng
        if verbose:
            print("Applied spike normalization (min-max scaling). Consider normalizing after splitting.")

    # --- Internal Train/Test Split (if ratio > 0) --- 
    y_processed = vel_b_raw # Use the zero-lag velocity
    B_total = X_processed.shape[0]
    n_test = 0
    if test_split_ratio > 0:
        n_test = max(int(B_total * test_split_ratio), 1)
    n_train = B_total - n_test

    X_tr, X_te = X_processed[:n_train], X_processed[n_train:]
    y_tr, y_te = y_processed[:n_train], y_processed[n_train:]
    t_tr, t_te = t_binned[:n_train], t_binned[n_train:]

    # --- Fit Scalers (Only if train data exists) --- 
    sx = None
    sy = None
    X_tr_norm = X_tr # Default to unnormalized if no scaling
    y_tr_norm = y_tr
    X_te_norm = X_te
    y_te_norm = y_te

    if n_train > 0:
        try:
            sx = StandardScaler().fit(X_tr)
            X_tr_norm = sx.transform(X_tr)
            if n_test > 0:
                X_te_norm = sx.transform(X_te) # Only transform test if it exists
        except ValueError as e:
            warnings.warn(f"Could not fit/transform spike scaler: {e}", RuntimeWarning)
            sx = None # Ensure scaler is None if fitting failed

        try:
            sy = StandardScaler().fit(y_tr)
            y_tr_norm = sy.transform(y_tr)
            if n_test > 0:
                y_te_norm = sy.transform(y_te) # Only transform test if it exists
        except ValueError as e:
            warnings.warn(f"Could not fit/transform velocity scaler: {e}", RuntimeWarning)
            sy = None # Ensure scaler is None if fitting failed
    # --- End Fit Scalers --- 

    # --- Return Dictionary --- 
    return_dict = {
        # Data after lag, processing, and internal split
        'X_train': torch.tensor(X_tr, dtype=torch.float32),
        'y_train': torch.tensor(y_tr, dtype=torch.float32),
        'X_test':  torch.tensor(X_te, dtype=torch.float32),
        'y_test':  torch.tensor(y_te, dtype=torch.float32),
        't_train': t_tr,
        't_test':  t_te,
        # Data after lag, processing, split, AND normalization fit on train
        'X_train_norm': torch.tensor(X_tr_norm, dtype=torch.float32),
        'y_train_norm': torch.tensor(y_tr_norm, dtype=torch.float32),
        'X_test_norm':  torch.tensor(X_te_norm, dtype=torch.float32),
        'y_test_norm':  torch.tensor(y_te_norm, dtype=torch.float32),
        # Raw binned data (before lag or processing) - Useful for aggregation
        'X_binned_raw': spikes_b_raw, # Raw counts
        'y_binned_raw': vel_b_raw,    # Raw velocity
        't_binned_raw': t_binned,     # Timestamps for raw binned data
        # Scalers fit on the internal training split
        'spike_scaler': sx,
        'velocity_scaler': sy,
        # Original parameters for reference
        'bin_width_s': bin_width_s,
        'stride_s': stride_s
    }

    # Add the full processed (but un-split) data if no split was done
    if test_split_ratio == 0:
        return_dict['X_processed_full'] = torch.tensor(X_processed, dtype=torch.float32)
        return_dict['y_processed_full'] = torch.tensor(y_processed, dtype=torch.float32)
        return_dict['t_processed_full'] = t_binned
        # Provide raw counts after lag as well
        return_dict['X_raw_counts_lagged'] = spikes_b_raw

    return return_dict

# %%
import random
import time
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from scipy.interpolate import interp1d
import numpy as np
import pandas as pd

class CausalBatcher:
    """
    Creates batches of contiguous sequences from time-series data.
    Supports overlapping windows for training data augmentation.
    """
    def __init__(self, spike_data, targets, batch_size=64, sequence_length=20, shuffle=False, stride=None):
        self.spike_data = torch.tensor(spike_data, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.shuffle = shuffle
        self.stride = stride if stride is not None else sequence_length # Non-overlapping by default

        # Calculate how many full sequences we can make
        num_total_timesteps = self.spike_data.shape[0]
        if num_total_timesteps < self.sequence_length:
            self.num_sequences = 0
        else:
            self.num_sequences = (num_total_timesteps - self.sequence_length) // self.stride + 1

        # Calculate how many full batches we can make
        self.num_batches = self.num_sequences // self.batch_size

        if self.num_batches == 0:
            raise ValueError(
                f"Not enough data to create a single batch. "
                f"Have {self.num_sequences} sequences for seq_len={sequence_length} and stride={self.stride}, but batch_size is {self.batch_size}."
            )

        self.indices = np.arange(self.num_sequences)
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __len__(self):
        return self.num_batches

    def __iter__(self):
        self.current_batch = 0
        if self.shuffle: # Reshuffle for each epoch
            np.random.shuffle(self.indices)
        return self

    def __next__(self):
        if self.current_batch >= self.num_batches:
            raise StopIteration

        # Get the indices for the sequences in this batch
        start_idx = self.current_batch * self.batch_size
        end_idx = start_idx + self.batch_size
        batch_seq_indices = self.indices[start_idx:end_idx]

        # Get the actual start timestep for each sequence in the batch
        batch_timestep_starts = batch_seq_indices * self.stride

        # Build the batch of sequences
        batch_spikes = torch.stack([
            self.spike_data[ts:ts + self.sequence_length] for ts in batch_timestep_starts
        ])
        batch_targets = torch.stack([
            self.targets[ts:ts + self.sequence_length] for ts in batch_timestep_starts
        ])

        self.current_batch += 1
        return batch_spikes, batch_targets

class OnlineDataStream:
    """
    Creates a stream of single timestep samples for true online learning.
    Maintains proper chronological order and supports trial boundaries.
    """
    def __init__(self, spike_data, targets, reset_every_n=None):
        self.spike_data = torch.tensor(spike_data, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.reset_every_n = reset_every_n  # Reset hidden states every N timesteps (e.g., trial boundaries)
        self.current_idx = 0
        self.steps_since_reset = 0
        
    def __len__(self):
        return self.spike_data.shape[0]
    
    def __iter__(self):
        self.current_idx = 0
        self.steps_since_reset = 0
        return self
    
    def __next__(self):
        if self.current_idx >= len(self.spike_data):
            raise StopIteration
            
        x_t = self.spike_data[self.current_idx]
        y_t = self.targets[self.current_idx]
        
        should_reset = False
        if self.reset_every_n and self.steps_since_reset >= self.reset_every_n:
            should_reset = True
            self.steps_since_reset = 0
        
        self.current_idx += 1
        self.steps_since_reset += 1
        
        return x_t, y_t, should_reset

class SNNRegression(nn.Module):
    def __init__(self, input_size, hidden_size=128, output_size=2):
        super(SNNRegression, self).__init__()
        spike_grad = surrogate.fast_sigmoid()

        # Feedforward & Recurrent Layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc_rec = nn.Linear(hidden_size, hidden_size, bias=False) # RESTORED
        self.lif1 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)

        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.lif2 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)

        self.fc3 = nn.Linear(hidden_size // 2, output_size)
        self.lif3 = snn.Leaky(
            beta=0.5,
            spike_grad=spike_grad,
            init_hidden=False,
            threshold=1.0,
            reset_mechanism="none"
        )
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                # give fc1 / fc2 a small negative bias; keep others at 0
                if module in (self.fc1, self.fc2):
                    nn.init.constant_(module.bias, -0.1)
                else:
                    nn.init.zeros_(module.bias)

    def forward(self, x, spk1_rec, mem1, mem2, mem3, need_traces: bool = False):
        """
        Args
        ----
        x           : [batch, T, features]
        spk1_rec …  : recurrent + membrane states (same as before)
        need_traces : if True, also returns the "pre / post" tensors
                      and next membrane potentials required by the
                      Hebbian updater.
        """
        batch, T, _ = x.shape
        outputs = []
        
        # Buffers only filled when learning
        if need_traces:
            pre_ff_buf, pre_rec_buf, spk1_buf, spk2_buf = [], [], [], []
            mem1_next_buf, mem2_next_buf, mem3_next_buf = [], [], []

        for t in range(x.size(1)):
            inp = x[:, t, :]
            
            # Detach states for BPTT-like behavior within the sequence
            mem1 = mem1.detach()
            mem2 = mem2.detach()
            spk1_rec = spk1_rec.detach()
            mem3 = mem3.detach()
            
            pre_rec_t_for_trace = spk1_rec

            # Combine feedforward and recurrent input for the first layer
            cur1 = self.fc1(inp) + self.fc_rec(spk1_rec)
            spk1, mem1 = self.lif1(cur1, mem1)

            # Update the recurrent state for the next time step
            spk1_rec = spk1
            
            # Second hidden layer
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            # Output layer
            cur3 = self.fc3(spk2)
            out, mem3 = self.lif3(cur3, mem3)
            outputs.append(mem3)

            # --- Collect traces if asked ---
            if need_traces:
                pre_ff_buf.append(inp.detach())
                pre_rec_buf.append(pre_rec_t_for_trace.detach())
                spk1_buf.append(spk1.detach())
                spk2_buf.append(spk2.detach())
                mem1_next_buf.append(mem1.detach())
                mem2_next_buf.append(mem2.detach())
                mem3_next_buf.append(mem3.detach())

        out_seq = torch.stack(outputs, dim=1)
        final_states = (spk1_rec, mem1, mem2, mem3)

        if need_traces:
            traces = (
                torch.stack(pre_ff_buf, dim=1),
                torch.stack(pre_rec_buf, dim=1),
                torch.stack(spk1_buf, dim=1),
                torch.stack(spk2_buf, dim=1),
                torch.stack(mem1_next_buf, dim=1),
                torch.stack(mem2_next_buf, dim=1),
                torch.stack(mem3_next_buf, dim=1)
            )
            return out_seq, final_states, traces
        else:
            return out_seq, final_states


def compute_correlation(pred, target):
    """
    Compute Pearson correlation between predicted and target values.
    Handles both tensor and numpy inputs.
    """
    # Convert to numpy if tensors
    if torch.is_tensor(pred):
        pred = pred.detach().cpu().numpy()
    if torch.is_tensor(target):
        target = target.detach().cpu().numpy()
    
    # Handle multi-dimensional arrays by flattening
    pred = pred.flatten()
    target = target.flatten()
    
    # Check for zero variance
    if np.std(pred) < 1e-10 or np.std(target) < 1e-10:
        return 0.0
    
    # Compute correlation safely
    try:
        corr_matrix = np.corrcoef(pred, target)
        # Handle case when corrcoef returns a scalar or 2x2 matrix
        if corr_matrix.size > 1:
            return corr_matrix[0, 1]
        else:
            # If identical arrays, correlation is 1.0
            return 1.0
    except (IndexError, ValueError):
        # Fallback for any errors
        return 0.0

def compute_fvaf(y_true, y_pred):
    """
    Computes the Fraction of Variance Accounted For (FVAF), equivalent to R^2 score.
    Handles both tensor and numpy inputs. Expects 1D arrays.
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()

    if y_true.ndim > 1 or y_pred.ndim > 1:
        warnings.warn(f"FVAF expects 1D arrays, but got shapes {y_true.shape} and {y_pred.shape}. Flattening.", RuntimeWarning)
        y_true = y_true.flatten()
        y_pred = y_pred.flatten()

    ss_res = np.sum((y_true - y_pred)**2)
    ss_tot = np.sum((y_true - np.mean(y_true))**2)

    if ss_tot < 1e-10:
        return 1.0 if ss_res < 1e-10 else 0.0

    return 1 - (ss_res / ss_tot)

def butter_lowpass(data, fs=1/0.05, cutoff=8, order=4):
    sos = butter(order, cutoff, 'low', fs=fs, output='sos')
    return sosfiltfilt(sos, data, axis=0)

# %%
##########################################
# TWO-TIMESCALE META RL UPDATER WITH META-LEARNING
##########################################
class TwoScaleMetaRLWeightUpdaterFull:
    """
    This updater adapts all layers using two timescales and incorporates meta-learning.
    It uses a multi-timescale eligibility trace to integrate credit assignment over time, 
    providing both short-term and long-term memory of recent activity without 
    acausally looking into the future. The learning rate is modulated by a
    biologically-inspired reward signal based on performance.
    
    This approach integrates Hebbian principles with the needs of the BCI task.
    """
    def __init__(self, model, base_fast_lr=1e-3, base_slow_lr=1e-2, window_size=180, 
                 meta_lr=1e-3, tau_e_fast=0.12, tau_e_slow=0.7, online_mode=False, **kwargs):
        self.model = model
        self.device = next(model.parameters()).device
        self.base_fast_lr = base_fast_lr
        self.base_slow_lr = base_slow_lr
        self.meta_lr = meta_lr
        self.grad_scale = 0
        self.online_mode = online_mode

        # RMS accumulators for error and spike normalization (Tweaks A & B)
        # Use per-unit, integer-based EMAs for hardware friendliness
        self.err2_sq_ema = torch.ones(model.fc2.out_features, dtype=torch.int32, device=self.device)
        self.err1_sq_ema = torch.ones(model.fc1.out_features, dtype=torch.int32, device=self.device)
        self.spk1_rms    = torch.ones(model.fc1.out_features, dtype=torch.int32, device=self.device)
        
        self.meta_params = {
            'plasticity': 1.0, 
            'sensitivity': 1.0
        } 
        self.fast_lr = self.base_fast_lr * self.meta_params['plasticity']
        self.slow_lr = self.base_slow_lr * self.meta_params['sensitivity']

        # --- Hardware-friendly LUTs ---
        # Passed in from the main training script
        self.inv_sqrt_LUT = kwargs.get('inv_sqrt_LUT', None)
        if self.inv_sqrt_LUT is None:
            raise ValueError("inv_sqrt_LUT must be provided to the updater.")
        # LUT for error-bucket based reward scaling
        self.reward_LUT = torch.tensor(
       [15,14,13,12,11,10,9,8,8,8,7,7,6,5,4,3], device=self.device)
        self.out_rms = torch.ones(2, dtype=torch.int32, device=self.device)  # vx, vy
        self.fc3_row_cap_q12 = torch.full((2, 1),
                                     int(6.0*4096),
                                     dtype=torch.int32,
                                     device=self.device)

        # --- Multi-Timescale Eligibility Trace Parameters ---
        # Adjust for online mode: faster decay for single-timestep updates
        if online_mode:
            # For online mode, use more aggressive mixing and faster decay
            self.decay_fast = math.exp(-0.064 / (tau_e_fast * 0.5))  # Faster decay
            self.decay_slow = math.exp(-0.064 / (tau_e_slow * 0.8))  # Slightly faster decay
            self.trace_mix_a = 0.8  # More weight on fast traces for responsiveness
        else:
            self.decay_fast = math.exp(-0.064 / tau_e_fast)
            self.decay_slow = math.exp(-0.064 / tau_e_slow)
            self.trace_mix_a = 0.5  # Fixed mixing parameter for fast and slow traces

        # Fast traces
        self.e_fast_fc1  = torch.zeros_like(model.fc1.weight, device=self.device)
        self.e_fast_fc2  = torch.zeros_like(model.fc2.weight, device=self.device)
        self.e_fast_fc3  = torch.zeros_like(model.fc3.weight, device=self.device)
        self.e_fast_rec  = torch.zeros_like(model.fc_rec.weight, device=self.device)
        
        # Slow traces
        self.e_slow_fc1  = torch.zeros_like(model.fc1.weight, device=self.device)
        self.e_slow_fc2  = torch.zeros_like(model.fc2.weight, device=self.device)
        self.e_slow_fc3  = torch.zeros_like(model.fc3.weight, device=self.device)
        self.e_slow_rec  = torch.zeros_like(model.fc_rec.weight, device=self.device)
        # --- End Eligibility Trace ---

        self.window_size = window_size
        self.step_counter = 0
        self.cumulative_error = 0
        self.prev_cumulative_error = None
        
        self.max_loss_value = 5.0
        
        # Initialize gradient averages with zeros matching model layer shapes and device
        self.grad_fc1_avg = torch.zeros_like(self.model.fc1.weight.data, device=self.device)
        self.grad_fc2_avg = torch.zeros_like(self.model.fc2.weight.data, device=self.device)
        self.grad_fc3_avg = torch.zeros_like(self.model.fc3.weight.data, device=self.device)
        self.grad_fc_rec_avg = torch.zeros_like(self.model.fc_rec.weight.data, device=self.device) # RESTORED
        
        # Initialize momentum for gradient accumulation
        self.momentum = 0.9

    def reset_weights_with_xavier(self):
        """
        Reset weights using Xavier/Glorot initialization for better convergence
        after catastrophic errors.
        """
        with torch.no_grad():
            for module in [self.model.fc1, self.model.fc2, self.model.fc3, self.model.fc_rec]: # RESTORED self.model.fc_rec
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
        print("Weights reset using Xavier/Glorot initialization for better convergence")

    def warm_start_rms_emas(self, warmup_data_x, warmup_data_y, num_warmup_steps=100):
        """
        Warm-start the RMS EMAs with some initial data to avoid cold-start issues in online mode.
        """
        if not self.online_mode:
            return
            
        print(f"Warm-starting RMS EMAs with {num_warmup_steps} steps...")
        with torch.no_grad():
            for i in range(min(num_warmup_steps, len(warmup_data_x))):
                x_t = warmup_data_x[i:i+1].to(self.device)  # [1, features]
                y_t = warmup_data_y[i:i+1].to(self.device)  # [1, 2]
                
                # Quick forward pass to get some initial error statistics
                x_t_seq = x_t.unsqueeze(1)  # [1, 1, features]
                y_t_seq = y_t.unsqueeze(1)  # [1, 1, 2]
                
                # Initialize temporary states
                batch_size = 1
                spk1_rec = torch.zeros(batch_size, self.model.fc1.out_features, device=self.device)
                mem1 = torch.zeros(batch_size, self.model.fc1.out_features, device=self.device)
                mem2 = torch.zeros(batch_size, self.model.fc2.out_features, device=self.device)
                mem3 = torch.zeros(batch_size, self.model.fc3.out_features, device=self.device)
                
                pred_seq, _, _ = self.model(x_t_seq, spk1_rec, mem1, mem2, mem3, need_traces=True)
                output_error = y_t_seq - pred_seq
                
                # Update output RMS - need to sum over batch AND sequence dims to get [2] shape
                k_out = 2  # Faster adaptation for warmup
                sq_out = ((output_error**2) * 4096).sum(dim=(0,1)).to(torch.int32)  # Sum over batch and seq dims
                self.out_rms -= self.out_rms >> k_out
                self.out_rms += sq_out >> k_out
                
        print("RMS EMA warm-start complete.")

    def fast_update_single_timestep(self, x_t, y_t, states):
        """
        Online single-timestep update optimized for real-time learning.
        Maintains persistent hidden states and applies scaled learning rates.
        """
        if not self.online_mode:
            raise ValueError("Single timestep update only available in online mode")
            
        # Reshape for model forward pass: [batch=1, seq=1, features]
        x_t_seq = x_t.unsqueeze(0).unsqueeze(0)  # [1, 1, features]
        y_t_seq = y_t.unsqueeze(0).unsqueeze(0)  # [1, 1, 2]
        
        device = self.device
        spk1_rec, mem1, mem2, mem3 = states

        # Single forward pass to get prediction and traces
        pred_seq, final_states, traces = self.model(
            x_t_seq, spk1_rec, mem1, mem2, mem3, need_traces=True
        )
        (pre_ff_all, pre_rec_all, spk1_all, spk2_all, 
         mem1_next_all, mem2_next_all, mem3_next_all) = traces
        
        # Calculate loss for single timestep
        mse_loss = F.mse_loss(pred_seq, y_t_seq)
        combined_loss = torch.clamp(mse_loss, 0.0, 10.0)
        
        # Hebbian update calculation for single timestep
        with torch.no_grad():
            # Get single timestep data
            pre_ff_t = pre_ff_all[:, 0, :]    # [1, features]
            pre_rec_t = pre_rec_all[:, 0, :]  # [1, hidden]
            spk1_t = spk1_all[:, 0, :]        # [1, hidden]
            spk2_t = spk2_all[:, 0, :]        # [1, hidden//2]
            mem1_next = mem1_next_all[:, 0, :] # [1, hidden]
            mem2_next = mem2_next_all[:, 0, :] # [1, hidden//2]
            mem3_next = mem3_next_all[:, 0, :] # [1, 2]
            
            # Single timestep error
            output_error_t = y_t_seq[:, 0, :] - pred_seq[:, 0, :]  # [1, 2]
            
            # Calculate bucket-based reward with online-adapted EMAs
            abs_err_xy = torch.abs(output_error_t).mean(dim=0)  # [2]
            bucket_xy = (abs_err_xy * 8).clamp(0, 15).to(torch.int64)
            bucket = torch.max(bucket_xy)
            lr_scale = self.reward_LUT[bucket]
            
            # # Scale learning rate for single timestep (was calibrated for seq_len=10)
            effective_lr = self.fast_lr * lr_scale / 16 * 0.1  # 0.1 scale factor for single timestep
            # effective_lr = self.fast_lr * 0.1
            # Online-adapted EMA constants (faster adaptation)
            k_out = 2  # Faster than batch mode (was 4)
            k = 3      # Faster than batch mode (was 5)
            
            # Normalize output error per-channel
            sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)
            self.out_rms -= self.out_rms >> k_out
            self.out_rms += sq_out >> k_out
            idx_out = (self.out_rms >> 8).clamp_(0, 255)
            norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]
            
            # Backpropagate error through layers with online-adapted normalization
            err2_raw = torch.matmul(norm_err, self.model.fc3.weight)
            sq_err2 = ((err2_raw**2) * 4096).sum(dim=0).to(torch.int32)
            self.err2_sq_ema -= self.err2_sq_ema >> k
            self.err2_sq_ema += sq_err2 >> k
            idx = (self.err2_sq_ema >> 8).clamp_(0, 255)
            hidden2_error_t = err2_raw * self.inv_sqrt_LUT[idx]

            err1_raw = torch.matmul(hidden2_error_t, self.model.fc2.weight)
            sq_err1 = ((err1_raw**2) * 4096).sum(dim=0).to(torch.int32)
            self.err1_sq_ema -= self.err1_sq_ema >> k
            self.err1_sq_ema += sq_err1 >> k
            idx = (self.err1_sq_ema >> 8).clamp_(0, 255)
            hidden1_error_t = err1_raw * self.inv_sqrt_LUT[idx]
            
            # Local sensitivities
            d_lif3_t = self.model.lif3.spike_grad(mem3_next)
            d_lif2_t = self.model.lif2.spike_grad(mem2_next)
            d_lif1_t = self.model.lif1.spike_grad(mem1_next)

            # Normalize spike activities
            sq_spk1 = ((spk1_t**2) * 4096).sum(dim=0).to(torch.int32)
            self.spk1_rms += -(self.spk1_rms >> 2) + (sq_spk1 >> 2)  # Faster adaptation (k=2)
            idx = (self.spk1_rms >> 8).clamp_(0, 255)
            spk1_t = spk1_t * self.inv_sqrt_LUT[idx]

            # Update eligibility traces with single timestep Hebbian terms
            hebb_fc3 = torch.matmul((norm_err * d_lif3_t).t(), spk2_t)
            self.e_fast_fc3.mul_(self.decay_fast).add_(hebb_fc3)
            self.e_slow_fc3.mul_(self.decay_slow).add_(hebb_fc3)
            
            hebb_fc2 = torch.matmul((hidden2_error_t * d_lif2_t).t(), spk1_t)
            self.e_fast_fc2.mul_(self.decay_fast).add_(hebb_fc2)
            self.e_slow_fc2.mul_(self.decay_slow).add_(hebb_fc2)

            hebb_fc1 = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_ff_t)
            self.e_fast_fc1.mul_(self.decay_fast).add_(hebb_fc1)
            self.e_slow_fc1.mul_(self.decay_slow).add_(hebb_fc1)
            
            hebb_rec = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_rec_t)
            self.e_fast_rec.mul_(self.decay_fast).add_(hebb_rec)
            self.e_slow_rec.mul_(self.decay_slow).add_(hebb_rec)

            # Apply updates using combined traces
            e_comb_fc1 = self.trace_mix_a * self.e_fast_fc1 + (1 - self.trace_mix_a) * self.e_slow_fc1
            e_comb_fc2 = self.trace_mix_a * self.e_fast_fc2 + (1 - self.trace_mix_a) * self.e_slow_fc2
            e_comb_fc3 = self.trace_mix_a * self.e_fast_fc3 + (1 - self.trace_mix_a) * self.e_slow_fc3
            e_comb_rec = self.trace_mix_a * self.e_fast_rec + (1 - self.trace_mix_a) * self.e_slow_rec

            self.model.fc1.weight.data += effective_lr * e_comb_fc1
            self.model.fc2.weight.data += effective_lr * e_comb_fc2
            self.model.fc3.weight.data += effective_lr * e_comb_fc3
            self.model.fc_rec.weight.data += effective_lr * e_comb_rec
            
            # Accumulate for slow updates
            self.grad_fc1_avg = self.momentum * self.grad_fc1_avg + (1 - self.momentum) * e_comb_fc1
            self.grad_fc2_avg = self.momentum * self.grad_fc2_avg + (1 - self.momentum) * e_comb_fc2
            self.grad_fc3_avg = self.momentum * self.grad_fc3_avg + (1 - self.momentum) * e_comb_fc3
            self.grad_fc_rec_avg = self.momentum * self.grad_fc_rec_avg + (1 - self.momentum) * e_comb_rec
            self.grad_scale += 1

            # Renormalize weights
            for w in [self.model.fc1.weight, self.model.fc2.weight, self.model.fc3.weight, self.model.fc_rec.weight]:
                self.renorm_row_(w)
        
        return combined_loss.item(), final_states
    
    def update_single_timestep(self, x_t, y_t, states):
        """
        Main interface for online single-timestep updates.
        """
        x_t, y_t = x_t.to(self.device), y_t.to(self.device)
        
        loss, new_states = self.fast_update_single_timestep(x_t, y_t, states)
        
        bounded_loss = min(loss, self.max_loss_value)
        
        if bounded_loss > self.max_loss_value * 1.5:
            print(f"CATASTROPHIC ERROR DETECTED: {bounded_loss:.4f}. Performing Xavier weight reset.")
            self.reset_weights_with_xavier()
            return bounded_loss, new_states
            
        self.cumulative_error += bounded_loss
        self.step_counter += 1
        
        if self.step_counter >= self.window_size:
            self.slow_update()
        
        return bounded_loss, new_states
              
    def fast_update(self, input_sequence, desired_sequence, initial_states):
        """
        Novel biologically plausible learning rule for a SEQUENCE using an eligibility trace.
        It integrates Hebbian updates over the entire window before applying.
        """
        batch_size, seq_len, _ = input_sequence.shape
        device = self.device

        # Unpack initial states for the sequence
        spk1_rec, mem1, mem2, mem3 = initial_states

        # NOTE: Eligibility traces are NOT reset here. They persist and decay across batches
        # to maintain a continuous memory of activity, as per best practices.

        # SINGLE forward pass to get predictions AND traces for Hebbian update
        pred_sequence, final_states, traces = self.model(
            input_sequence, *initial_states, need_traces=True
        )
        (pre_ff_all, pre_rec_all, spk1_all, spk2_all, 
         mem1_next_all, mem2_next_all, mem3_next_all) = traces
        
        # Calculate loss over the whole sequence
        mse_loss = F.mse_loss(pred_sequence, desired_sequence)
        combined_loss = torch.clamp(mse_loss, 0.0, 10.0)
        
        # Hebbian update calculation
        with torch.no_grad():
            # Calculate a 4-bit error bucket reward signal to modulate plasticity.
            output_error_t_for_reward = desired_sequence - pred_sequence
            abs_err_xy = torch.mean(torch.abs(output_error_t_for_reward), dim=(0,1))  # [2]
            bucket_xy  = (abs_err_xy * 8).clamp(0,15).to(torch.int64)      # 4-bit each
            bucket     = torch.max(bucket_xy)                              # int64 scalar
            lr_scale   = self.reward_LUT[bucket]
            effective_lr = self.fast_lr * lr_scale / 16 # Use integer division

            # NO second pass needed. Use the returned traces.
            for t in range(seq_len):
                # Get activations from the pre-computed traces
                pre_ff_t = pre_ff_all[:, t, :]
                pre_rec_t = pre_rec_all[:, t, :]
                spk1_t = spk1_all[:, t, :]
                spk2_t = spk2_all[:, t, :]
                mem1_next = mem1_next_all[:, t, :]
                mem2_next = mem2_next_all[:, t, :]
                mem3_next = mem3_next_all[:, t, :]
                
                # Propagate error backward from this time step's prediction
                output_error_t = desired_sequence[:, t, :] - pred_sequence[:, t, :]
                
                # Patch A: Normalize output error per-channel before backprop
                k_out = 4           # EMA 1/16
                sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)  # [2]
                self.out_rms -= self.out_rms >> k_out
                self.out_rms += sq_out       >> k_out
                idx_out = (self.out_rms >> 8).clamp_(0,255)
                norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]   # element-wise
                
                # Tweak A: RMS-Rescale the local errors (Hardware-Friendly Version)
                k = 5 # EMA constant: 2^-5 = 0.03125
                
                err2_raw = torch.matmul(norm_err, self.model.fc3.weight)     # [B,H2]
                sq_err2 = ((err2_raw**2) * 4096).sum(dim=0).to(torch.int32) # Sum over batch
                self.err2_sq_ema -= self.err2_sq_ema >> k
                self.err2_sq_ema += sq_err2 >> k
                idx = (self.err2_sq_ema >> 8).clamp_(0,255)
                hidden2_error_t  = err2_raw * self.inv_sqrt_LUT[idx]

                err1_raw = torch.matmul(hidden2_error_t, self.model.fc2.weight)    # [B,H1]
                sq_err1 = ((err1_raw**2) * 4096).sum(dim=0).to(torch.int32)
                self.err1_sq_ema -= self.err1_sq_ema >> k
                self.err1_sq_ema += sq_err1 >> k
                idx = (self.err1_sq_ema >> 8).clamp_(0,255)
                hidden1_error_t  = err1_raw * self.inv_sqrt_LUT[idx]
                
                # Get local sensitivities (surrogate gradients)
                d_lif3_t = self.model.lif3.spike_grad(mem3_next)
                d_lif2_t = self.model.lif2.spike_grad(mem2_next)
                d_lif1_t = self.model.lif1.spike_grad(mem1_next)

                # Tweak B: RMS-Rescale spike activities (Hardware-Friendly Version)
                # Note: spk1_t is already detached from previous ops
                sq_spk1 = ((spk1_t**2) * 4096).sum(dim=0).to(torch.int32)
                self.spk1_rms += - (self.spk1_rms >> 3) + (sq_spk1 >> 3) # k=3 for spikes
                idx = (self.spk1_rms >> 8).clamp_(0,255)
                spk1_t = spk1_t * self.inv_sqrt_LUT[idx]

                # Update both fast and slow eligibility traces with the new instantaneous Hebbian term
                hebb_fc3 = torch.matmul((norm_err * d_lif3_t).t(), spk2_t)
                self.e_fast_fc3.mul_(self.decay_fast).add_(hebb_fc3)
                self.e_slow_fc3.mul_(self.decay_slow).add_(hebb_fc3)
                
                hebb_fc2 = torch.matmul((hidden2_error_t * d_lif2_t).t(), spk1_t)
                self.e_fast_fc2.mul_(self.decay_fast).add_(hebb_fc2)
                self.e_slow_fc2.mul_(self.decay_slow).add_(hebb_fc2)

                hebb_fc1 = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_ff_t)
                self.e_fast_fc1.mul_(self.decay_fast).add_(hebb_fc1)
                self.e_slow_fc1.mul_(self.decay_slow).add_(hebb_fc1)
                
                hebb_rec = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_rec_t)
                self.e_fast_rec.mul_(self.decay_fast).add_(hebb_rec)
                self.e_slow_rec.mul_(self.decay_slow).add_(hebb_rec)

                # NO state update at end of loop

            # Apply updates using the final state of the eligibility traces
            # The effective_lr is now calculated once for the whole sequence at the top
            
            # Combine fast and slow traces for the update
            e_comb_fc1 = self.trace_mix_a * self.e_fast_fc1 + (1 - self.trace_mix_a) * self.e_slow_fc1
            e_comb_fc2 = self.trace_mix_a * self.e_fast_fc2 + (1 - self.trace_mix_a) * self.e_slow_fc2
            e_comb_fc3 = self.trace_mix_a * self.e_fast_fc3 + (1 - self.trace_mix_a) * self.e_slow_fc3
            e_comb_rec = self.trace_mix_a * self.e_fast_rec + (1 - self.trace_mix_a) * self.e_slow_rec

            self.model.fc1.weight.data += effective_lr * e_comb_fc1
            self.model.fc2.weight.data += effective_lr * e_comb_fc2
            self.model.fc3.weight.data += effective_lr * e_comb_fc3
            self.model.fc_rec.weight.data += effective_lr * e_comb_rec
            
            # Accumulate combined traces for the slow update path
            self.grad_fc1_avg = self.momentum * self.grad_fc1_avg + (1 - self.momentum) * e_comb_fc1
            self.grad_fc2_avg = self.momentum * self.grad_fc2_avg + (1 - self.momentum) * e_comb_fc2
            self.grad_fc3_avg = self.momentum * self.grad_fc3_avg + (1 - self.momentum) * e_comb_fc3
            self.grad_fc_rec_avg = self.momentum * self.grad_fc_rec_avg + (1 - self.momentum) * e_comb_rec
            self.grad_scale += 1

            # Renormalize all weights after every fast step to prevent explosion/drift
            for w in [self.model.fc1.weight, self.model.fc2.weight, self.model.fc3.weight, self.model.fc_rec.weight]:
                self.renorm_row_(w)
            
        x_corr=0
        y_corr=0
        return combined_loss.item(), x_corr, y_corr

    def _rms(self, x, eps=1e-5):
        return x / (torch.sqrt(torch.mean(x**2)) + eps)

    def renorm_row_(self, w):
       l2 = torch.sum((w.to(torch.int32) * w), dim=1, keepdim=True)          # Q12
       caps = (self.fc3_row_cap_q12 if w.shape[0] == 2
               else torch.full_like(l2, int(6.0 * 4096)))                    # scalar cap
       mask = l2 > caps
       if mask.any():
           shift = (torch.log2(l2[mask].float()) - torch.log2(caps[mask].float())
                   ).ceil().to(torch.int32)
           w_int = w.data.to(torch.int32)
           w_int[mask.squeeze(1)] >>= shift.squeeze(1)                       # power-of-two divide
           w.data.copy_(w_int.to(torch.float32))                                  # back to fp32 for PyTorch

    def normalize_weights(self, weights, max_norm=1.0):
        # This function is no longer used by the fast update path.
        # Kept for potential other uses, but can be removed.
        norm = torch.norm(weights, p=2, dim=1, keepdim=True)
        scale = torch.clamp(norm, max=max_norm)
        return weights * (scale / (norm + 1e-8))

    def slow_update(self):
        with torch.no_grad():
            if self.grad_scale > 0:
                self.grad_fc1_avg /= self.grad_scale
                self.grad_fc2_avg /= self.grad_scale
                self.grad_fc3_avg /= self.grad_scale
                self.grad_fc_rec_avg /= self.grad_scale

                self.model.fc1.weight.data += self.slow_lr * self._rms(self.grad_fc1_avg)
                self.model.fc2.weight.data += self.slow_lr * self._rms(self.grad_fc2_avg)
                self.model.fc3.weight.data += self.slow_lr * self._rms(self.grad_fc3_avg)
                self.model.fc_rec.weight.data += self.slow_lr * self._rms(self.grad_fc_rec_avg)

        # Apply a light weight decay to all layers during the slow update
        with torch.no_grad():
            for w in [self.model.fc1.weight, self.model.fc2.weight,
                      self.model.fc3.weight, self.model.fc_rec.weight]:
                w.mul_(0.999995) # ≈ 1e-5 weight-decay per slow step
        
        current_avg_loss = self.cumulative_error / self.step_counter if self.step_counter > 0 else float('inf')
        current_avg_loss = min(current_avg_loss, self.max_loss_value)
        
        if self.prev_cumulative_error is not None:
            # pass
            prev_loss = min(self.prev_cumulative_error, self.max_loss_value)
            if current_avg_loss < prev_loss:
                self.meta_params['plasticity'] *= (1 + self.meta_lr)
                self.meta_params['sensitivity'] *= (1 + self.meta_lr)
            else:
                self.meta_params['plasticity'] *= (1 - self.meta_lr)
                self.meta_params['sensitivity'] *= (1 - self.meta_lr)
            
            self.meta_params['plasticity'] = np.clip(self.meta_params['plasticity'], 0.2, 1.5)
            self.meta_params['sensitivity'] = np.clip(self.meta_params['sensitivity'], 0.2, 1.5)
        
        self.fast_lr = self.base_fast_lr * self.meta_params['plasticity']
        self.slow_lr = self.base_slow_lr * self.meta_params['sensitivity']
        
        self.prev_cumulative_error = current_avg_loss
        
        self.grad_fc1_avg.zero_() 
        self.grad_fc2_avg.zero_()
        self.grad_fc3_avg.zero_()
        self.grad_fc_rec_avg.zero_() # RESTORED
        self.grad_scale = 0
        self.cumulative_error = 0
        self.step_counter = 0

        # Reset eligibility traces to prevent saturation over long runs
        for e in [self.e_fast_fc1, self.e_fast_fc2, self.e_fast_fc3,
                  self.e_fast_rec, self.e_slow_fc1, self.e_slow_fc2,
                  self.e_slow_fc3, self.e_slow_rec]:
            e.zero_()

    def update(self, input_sequence, desired_sequence, initial_states):
        input_sequence, desired_sequence = input_sequence.to(self.device), desired_sequence.to(self.device)
        
        combined_loss, _, _ = self.fast_update(input_sequence, desired_sequence, initial_states)
        
        bounded_loss = min(combined_loss, self.max_loss_value)
        
        if bounded_loss > self.max_loss_value * 1.5:
            print(f"CATASTROPHIC ERROR DETECTED: {bounded_loss:.4f}. Performing Xavier weight reset.")
            self.reset_weights_with_xavier()
            return bounded_loss
            
        self.cumulative_error += bounded_loss
        self.step_counter += 1
        
        if self.step_counter >= self.window_size:
            self.slow_update()
        
        return bounded_loss


def train_snn_windowed(model, rl_updater, train_loader, val_loader, num_epochs, patience=10):
    device = next(model.parameters()).device
    print(f"Starting SNN Causal Windowed Training for {num_epochs} epochs on {device}...")
    
    best_val_metric = -np.inf # Track best Pearson correlation (can be negative)
    best_model_state = None
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        
        for i, (input_seq, target_seq) in enumerate(train_loader):
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)

            batch_size = input_seq.size(0)
            spk1_rec = torch.zeros(batch_size, model.fc1.out_features, device=device) # RESTORED
            mem1 = torch.zeros(batch_size, model.fc1.out_features, device=device)
            mem2 = torch.zeros(batch_size, model.fc2.out_features, device=device)
            mem3 = torch.zeros(batch_size, model.fc3.out_features, device=device)
            initial_states = (spk1_rec, mem1, mem2, mem3)
            
            loss = rl_updater.update(input_seq, target_seq, initial_states)
            running_train_loss += loss

        avg_train_loss = running_train_loss / len(train_loader)
        
        model.eval()
        val_preds, val_targets = [], []
        with torch.no_grad():
            for input_seq, target_seq in val_loader:
                input_seq, target_seq = input_seq.to(device), target_seq.to(device)
                batch_size = input_seq.size(0)
                val_spk1_rec = torch.zeros(batch_size, model.fc1.out_features, device=device) # RESTORED
                val_mem1 = torch.zeros(batch_size, model.fc1.out_features, device=device)
                val_mem2 = torch.zeros(batch_size, model.fc2.out_features, device=device)
                val_mem3 = torch.zeros(batch_size, model.fc3.out_features, device=device)
                
                pred_seq, _ = model(input_seq, val_spk1_rec, val_mem1, val_mem2, val_mem3)
                # Use final time-step to align with read-out
                val_preds.append(pred_seq[:, -1, :].cpu())
                val_targets.append(target_seq[:, -1, :].cpu())
            
            if not val_preds:
                print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Avg Corr: N/A (empty val set)")
                continue

            val_preds = torch.cat(val_preds)
            val_targets = torch.cat(val_targets)
            val_corr_x = compute_correlation(val_targets[:, 0], val_preds[:, 0])
            val_corr_y = compute_correlation(val_targets[:, 1], val_preds[:, 1])
            avg_val_corr = (val_corr_x + val_corr_y) / 2.0
            
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Avg Corr: {avg_val_corr:.4f}")

        if avg_val_corr > best_val_metric:
            best_val_metric = avg_val_corr
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            print(f"  >>> New best validation model saved with Avg Corr: {best_val_metric:.4f} <<<")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break

    if best_model_state:
        model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
        print(f"Loaded best model with validation Corr: {best_val_metric:.4f}")

    return model

def train_snn_online(model, rl_updater, train_stream, val_stream, num_epochs, bin_width_s, correlation_csv_path, patience=10, 
                     eval_every_n=500, reset_states_every_n=None):
    """
    Train SNN with true online learning - one timestep at a time.
    Maintains persistent hidden states and applies proper online adaptations.
    MODIFIED: Calculates and saves timestep-wise sliding window correlation against wall-clock time.
    
    Args:
        model: SNN model to train
        rl_updater: Online-enabled updater (must have online_mode=True)
        train_stream: OnlineDataStream for training data
        val_stream: OnlineDataStream for validation data
        num_epochs: Number of epochs to train
        bin_width_s: The width of each time bin in seconds.
        correlation_csv_path: Path to save the correlation results.
        patience: Early stopping patience (number of evaluation intervals without improvement)
        eval_every_n: Evaluate on validation every N steps
        reset_states_every_n: Reset hidden states every N steps (None = never reset)
    """
    if not rl_updater.online_mode:
        raise ValueError("rl_updater must be in online_mode for online training")
        
    device = next(model.parameters()).device
    print(f"Starting SNN Online Training for {num_epochs} epochs on {device}...")
    print(f"Evaluation every {eval_every_n} steps, reset states every {reset_states_every_n or 'never'} steps")

    import time
    from collections import deque
    
    # For timestep-wise correlation plot
    sliding_window_duration_s = 60
    sliding_window_size = int(round(sliding_window_duration_s / bin_width_s))
    print(f"Using a sliding window of {sliding_window_duration_s}s ({sliding_window_size} timesteps) for correlation.")
    preds_history = deque(maxlen=sliding_window_size)
    targets_history = deque(maxlen=sliding_window_size)
    correlation_results = []

    best_val_metric = -np.inf
    best_model_state = None
    epochs_no_improve = 0
    
    # Initialize persistent hidden states - these persist across timesteps!
    spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
    mem1 = torch.zeros(1, model.fc1.out_features, device=device)
    mem2 = torch.zeros(1, model.fc2.out_features, device=device)
    mem3 = torch.zeros(1, model.fc3.out_features, device=device)
    states = (spk1_rec, mem1, mem2, mem3)
    
    # Warm-start RMS EMAs if we have training data
    if hasattr(train_stream, 'spike_data') and len(train_stream.spike_data) > 100:
        rl_updater.warm_start_rms_emas(
            train_stream.spike_data[:100], 
            train_stream.targets[:100], 
            num_warmup_steps=50
        )
    
    global_step_count = 0
    start_time = time.time()  # Start wall clock timer
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        
        for x_t, y_t, should_reset in train_stream:
            global_step_count += 1
            # Reset states if requested (e.g., trial boundaries)
            if should_reset or (reset_states_every_n and (global_step_count-1) % reset_states_every_n == 0):
                spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                states = (spk1_rec, mem1, mem2, mem3)

            # Get prediction BEFORE updating the model for this timestep
            model.eval()
            with torch.no_grad():
                x_t_seq = x_t.unsqueeze(0).unsqueeze(0).to(device)
                pred_t_seq, _ = model(x_t_seq, *states) 
                pred_t = pred_t_seq[0, 0, :].cpu()
            model.train()

            preds_history.append(pred_t.numpy())
            targets_history.append(y_t.cpu().numpy())
            
            # Single timestep update with persistent states
            loss, states = rl_updater.update_single_timestep(x_t, y_t, states)
            running_train_loss += loss
            
            # Calculate and store sliding window correlation
            if len(preds_history) > 10: # Start calculating after a few steps
                preds_arr = np.array(preds_history)
                targets_arr = np.array(targets_history)
                corr_x = compute_correlation(preds_arr[:, 0], targets_arr[:, 0])
                corr_y = compute_correlation(preds_arr[:, 1], targets_arr[:, 1])
                avg_corr = (corr_x + corr_y) / 2.0
            else:
                corr_x, corr_y, avg_corr = 0.0, 0.0, 0.0

            elapsed_time = time.time() - start_time
            correlation_results.append({
                'timestep': global_step_count,
                'time_s': elapsed_time,
                'avg_corr': avg_corr,
                'corr_x': corr_x,
                'corr_y': corr_y,
                'loss': loss
            })

            # Periodic validation
            if global_step_count % eval_every_n == 0:
                model.eval()
                val_preds, val_targets = [], []
                
                # Initialize fresh states for validation
                val_spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                val_mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                val_mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                val_mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                val_states = (val_spk1_rec, val_mem1, val_mem2, val_mem3)
                
                with torch.no_grad():
                    val_step_count = 0
                    for x_val, y_val, should_reset_val in val_stream:
                        if should_reset_val or (reset_states_every_n and val_step_count % reset_states_every_n == 0):
                            val_spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                            val_mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                            val_mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                            val_mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                            val_states = (val_spk1_rec, val_mem1, val_mem2, val_mem3)
                            
                        x_val_seq = x_val.unsqueeze(0).unsqueeze(0).to(device)  # [1, 1, features]
                        pred_seq, val_states = model(x_val_seq, *val_states)
                        val_preds.append(pred_seq[0, 0, :].cpu())  # [2]
                        val_targets.append(y_val.cpu())
                        val_step_count += 1
                        
                        if val_step_count >= 1000:
                            break
                
                if val_preds:
                    val_preds = torch.stack(val_preds)
                    val_targets = torch.stack(val_targets)
                    
                    val_corr_x = compute_correlation(val_targets[:, 0], val_preds[:, 0])
                    val_corr_y = compute_correlation(val_targets[:, 1], val_preds[:, 1])
                    avg_val_corr = (val_corr_x + val_corr_y) / 2.0
                    
                    avg_train_loss = running_train_loss / eval_every_n
                    print(f"Epoch [{epoch+1}/{num_epochs}], Step {global_step_count}, "
                          f"Train Loss: {avg_train_loss:.4f}, Val Avg Corr: {avg_val_corr:.4f}")
                    
                    if avg_val_corr > best_val_metric:
                        best_val_metric = avg_val_corr
                        best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                        print(f"  >>> New best validation model saved with Avg Corr: {best_val_metric:.4f} <<<")
                        epochs_no_improve = 0
                    else:
                        epochs_no_improve += 1
                        if epochs_no_improve >= patience:
                            print(f"Early stopping triggered after {global_step_count} steps in epoch {epoch+1}.")
                            print(f"  (Patience: {patience} eval intervals = {patience * eval_every_n} steps)")
                            break
                    
                    running_train_loss = 0.0
                    
                model.train()
                
        if epochs_no_improve >= patience:
            break
    
    if best_model_state:
        model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
        print(f"Loaded best model with validation Corr: {best_val_metric:.4f}")

    if correlation_results:
        results_df = pd.DataFrame(correlation_results)
        results_df.to_csv(correlation_csv_path, index=False)
        print(f"Saved timestep-wise correlations to {correlation_csv_path}")
    
    return model

# %%
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from filterpy.kalman import KalmanFilter
def train_test_kalman_filter(train_rates, train_vel, test_rates, test_vel):
    """
    Trains and applies a Kalman Filter for neural decoding.
    
    Returns:
        predicted_velocities (np.array): KF predicted velocities for test data.
        kf (KalmanFilter): The trained Kalman Filter object.
    """
    print("\n--- Setting up Kalman Filter ---")
    # Ensure inputs are numpy arrays
    if torch.is_tensor(train_rates):
        train_rates = train_rates.detach().cpu().numpy()
    if torch.is_tensor(train_vel):
        train_vel   = train_vel.detach().cpu().numpy()
    if torch.is_tensor(test_rates):
        test_rates  = test_rates.detach().cpu().numpy()
    # test_vel is not used for KF training/prediction, only evaluation later

    n_timesteps_train, n_neurons = train_rates.shape
    n_timesteps_test = test_rates.shape[0]
    n_outputs = train_vel.shape[1]  # Should be 2 for [vx, vy]

    if n_outputs != 2:
        raise ValueError("This KF implementation assumes 2D velocity (vx, vy)")
    
    # --- 1. Train linear decoder velocity = W * spikes + b ---
    print("Fitting linear velocity decoder (Ridge)")
    lin_decoder = Ridge(alpha=1e-3, fit_intercept=True)
    lin_decoder.fit(train_rates, train_vel)
    W = lin_decoder.coef_.T            # shape (2, n_neurons)
    b = lin_decoder.intercept_         # shape (2,)

    # Use decoder to obtain noisy velocity measurements
    vel_meas_train = lin_decoder.predict(train_rates)
    vel_meas_test  = lin_decoder.predict(test_rates)

    # --- 1b. Remove systematic bias between decoder output and true velocity ---
    bias = vel_meas_train.mean(axis=0) - train_vel.mean(axis=0)
    vel_meas_train -= bias
    vel_meas_test  -= bias

    # --- 2. Dimensions and Dynamics ---
    dt =  0.050  # fall back to 50-ms if not defined
    dim_x = 4           # [px, py, vx, vy]
    dim_z = 2           # measurement provides velocity only

    # Constant-velocity model
    F = np.array([[1,0,dt,0],
                  [0,1,0 ,dt],
                  [0,0,1 ,0 ],
                  [0,0,0 ,1 ]], dtype=np.float32)

    H = np.array([[0,0,1,0],
                  [0,0,0,1]], dtype=np.float32)  # we observe velocity

    # --- 3. Measurement noise R from training residuals ---
    print("Estimating R from decoder residuals")
    residuals_meas = train_vel - vel_meas_train
    R = np.cov(residuals_meas.T) + np.eye(dim_z)*1e-6
    
    # --- 4. Process noise already set via decoder residuals; skip old 2×2 estimation ---

    # --- 6. Initialize Kalman Filter ---
    print("Initializing Kalman Filter...")
    kf = KalmanFilter(dim_x=dim_x, dim_z=dim_z)
    # Initialise state vector [px, py, vx, vy].
    # We set the starting position to 0 and initialise the velocity with the
    # first sample from the training data.  This yields a 4-element state
    # vector, as required by the 4×4 transition matrix F.
    initial_state = np.zeros(dim_x, dtype=np.float32)
    initial_state[2:] = train_vel[0]
    kf.x = initial_state
    kf.P = np.eye(dim_x) * 500.     # High initial uncertainty
    kf.F = F
    kf.H = H
    kf.R = R
    # Define process noise Q based on residual velocity variance
    q_vel = np.var(residuals_meas, axis=0).mean()
    q_pos = (dt**2) * q_vel
    Q = np.diag([q_pos, q_pos, q_vel, q_vel]) + 1e-6*np.eye(dim_x)
    kf.Q = Q
    
    # --- 7. Run the filter on test data ---
    print(f"Running Kalman Filter for {n_timesteps_test} test steps...")
    predicted_velocities = np.zeros((n_timesteps_test, 2))
    b = b.reshape(-1, 1) # Ensure intercept is a column vector
    start_time_kf = time.time()
    for t in range(n_timesteps_test):
        z = vel_meas_test[t]  # 1-D vector length 2
        kf.predict()
        kf.update(z)
        predicted_velocities[t] = kf.x[2:].copy()  # store vx, vy only
    end_time_kf = time.time()
    print(f"Kalman Filter processing time: {end_time_kf - start_time_kf:.2f} seconds")
    print("--- Kalman Filter Setup and Run Complete ---")
    return predicted_velocities, kf


# %%
# +++ Re-inserting LSTM Definitions +++
class LSTMRegression(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.2):
        super(LSTMRegression, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # Ensure dropout is only applied if num_layers > 1
        lstm_dropout = dropout if num_layers > 1 else 0
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=lstm_dropout)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x shape: (batch, seq_len, input_size)
        # Initialize hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        # LSTM forward pass
        out, _ = self.lstm(x, (h0, c0))
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

def train_lstm_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=50, patience=5):
    """Basic training loop for the LSTM model with early stopping (handles val_loader=None)."""
    print(f"Starting LSTM training for {num_epochs} epochs on {device}...")
    best_val_loss = float('inf')
    epochs_no_improve = 0
    train_losses = []
    val_losses = []
    best_model_state = model.state_dict() # Start with initial state as best if no validation
    best_epoch = 0

    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            if inputs.dim() == 2:
                inputs = inputs.unsqueeze(1) # Add sequence dimension if needed
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item()
        avg_train_loss = running_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        avg_val_loss = None
        # Validation phase only if val_loader is provided and not empty
        if val_loader and len(val_loader.dataset) > 0:
            model.eval()
            running_val_loss = 0.0
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    if inputs.dim() == 2:
                        inputs = inputs.unsqueeze(1)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    running_val_loss += loss.item()
                avg_val_loss = running_val_loss / len(val_loader)
            val_losses.append(avg_val_loss)
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")

            # Check for improvement and early stopping based on validation loss
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                epochs_no_improve = 0
                best_model_state = model.state_dict()
                best_epoch = epoch + 1
                print(f"  New best validation loss: {best_val_loss:.6f}")
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"Early stopping triggered after {epoch+1} epochs.")
                    break
        else:
            # If no validation, just print training loss
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}")
            # Optionally save the model from the last epoch if no validation is done
            # best_model_state = model.state_dict()
            # best_epoch = epoch + 1

    print("Finished LSTM Training.")
    # Load the best model state before returning
    if best_model_state:
        model.load_state_dict(best_model_state)
        if val_loader and len(val_loader.dataset) > 0:
            print(f"Loaded best model state from epoch {best_epoch}.")
        # else:
            # print("Loaded model state from last epoch (no validation performed).")

    return model, train_losses, val_losses
# +++ End Re-inserted LSTM Definitions +++


def create_sequences(data_x, data_y, seq_len=10):
    """Creates overlapping sequences for LSTM."""
    xs, ys = [], []
    for i in range(len(data_x) - seq_len + 1):
        x = data_x[i:(i + seq_len)]
        y = data_y[i + seq_len - 1] # Target is the last velocity in the window
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)


def run_within_session_comparison(num_sessions=10, use_online_snn=False):
    """
    Performs a more robust, within-session decoder comparison.

    This function iterates through a number of recording sessions (.mat files).
    For each session, it performs a standard chronological train/test split
    (e.g., first 80% for training, last 20% for testing). This avoids the
    confound of trying to decode across different days with mismatched neurons.
    
    Args:
        num_sessions: Number of sessions to test
        use_online_snn: If True, runs online SNN training instead of windowed training
    """
    mode_str = "Online" if use_online_snn else "Windowed"
    print(f"\n===== Running Within-Session Comparison ({num_sessions} Sessions, {mode_str} SNN) =====")
    basepath = 'zenodo_dataset'
    # basepath = os.path.expanduser('~/scratch/zenodo_dataset')
    neurons_to_use = 96  # Limit number of neurons
    snn_epochs = 20       # Number of passes over train data for SNN
    snn_patience = 10    # Early stopping patience for SNN
    lstm_epochs = 50     # Number of epochs for LSTM
    lstm_patience = 10   # Early stopping patience for LSTM

    all_results = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    try:
        all_mat_files = [f for f in os.listdir(basepath) if f.endswith('.mat')]
        if len(all_mat_files) < num_sessions:
            print(f"Error: Found only {len(all_mat_files)} files in {basepath}, need at least {num_sessions}.")
            return
        selected_files = random.sample(all_mat_files, num_sessions)
        # selected_files = ['indy_20161212_02.mat']
    except FileNotFoundError:
        print(f"Error: Basepath not found: {basepath}")
        return
    except Exception as e:
        print(f"Error listing files: {e}")
        return

    for session_idx, filename in enumerate(selected_files):
        print(f"\n--- Session {session_idx + 1}/{num_sessions} ({filename}) ---")
        
        # --- 1. Set fixed data processing parameters based on best practices ---
        print("  Using fixed parameters: Bin Width=0.050s, Stride=0.050s (no overlap)")
        best_params = {
            'bin_width': 0.050, # 50ms bin
            'lag_s': 0.100,     # Default causal lag, will be tuned per session
            'stride_s': 0.050   # No overlap in data loading
        }
        
        # --- 2. Load Data with Optimal Parameters ---
        try:
            data = load_data(
                mat_file_path=os.path.join(basepath, filename),
                bin_width_s=best_params['bin_width'],
                stride_s=best_params['stride_s'], # Use non-overlapping windows
                spike_processing='count',
                test_split_ratio=0, # Load full session, split manually
                verbose=False
            )
            if 'error' in data:
                print(f"  Skipping {filename} due to load error: {data['error']}")
                continue
            
            X_full = data.get('X_processed_full').numpy()
            y_full = data.get('y_processed_full').numpy()
            
            if X_full.shape[1] > neurons_to_use:
                X_full = X_full[:, :neurons_to_use]
            
            input_size = X_full.shape[1]
            
        except Exception as e:
            print(f"  Error loading session file {filename}: {e}")
            continue

        # --- 3a. Fixed Zero-Lag Setting ---
        best_params['lag_s'] = 0.0
        print("  Using zero causal lag (0 ms) - skipping lag search.")

        # --- 3b. Preprocessing & Alignment (with best lag) ---
        lag_s = best_params['lag_s']
        bin_width_s = best_params['bin_width']
        lag_bins = int(round(lag_s / bin_width_s))
        
        X_lagged_counts = X_full[:-lag_bins] if lag_bins > 0 else X_full
        y_lagged = y_full[lag_bins:] if lag_bins > 0 else y_full
        
        # b) For binary input we keep the 0/1 spike indicator - no division by bin width
        X_rate_lagged = X_lagged_counts.astype(np.float32)

        # c) Manual chronological train/val/test split on all data types (70/15/15)
        n_total = X_rate_lagged.shape[0]
        n_train = int(n_total * 0.70)
        n_val   = int(n_total * 0.15)
        # Remainder goes to test
        n_test  = n_total - n_train - n_val

        X_train_rate = X_rate_lagged[:n_train]
        y_train_raw  = y_lagged[:n_train]

        X_val_rate   = X_rate_lagged[n_train:n_train + n_val]
        y_val_raw    = y_lagged[n_train:n_train + n_val]

        X_test_rate  = X_rate_lagged[n_train + n_val:]
        y_test_raw   = y_lagged[n_train + n_val:]

        # d) Do NOT normalise spikes - use raw spike counts per 50 ms bin
        print("  Using raw spike counts (no normalisation)...")
        spike_scaler = None  # Placeholder so later code can check if needed
        X_train_norm = X_train_rate.astype(np.float32)
        X_val_norm   = X_val_rate.astype(np.float32)
        X_test_norm  = X_test_rate.astype(np.float32)

        # --- Prepare NORMALISED spike RATES (Hz) for KF & LSTM ---
        bin_width = bin_width_s  #  e.g., 0.050 s
        X_train_rate_hz = X_train_rate.astype(np.float32) / bin_width
        X_val_rate_hz   = X_val_rate.astype(np.float32)   / bin_width
        X_test_rate_hz  = X_test_rate.astype(np.float32)  / bin_width

        spike_rate_scaler = StandardScaler().fit(X_train_rate_hz)
        X_train_rates_norm = spike_rate_scaler.transform(X_train_rate_hz)
        X_val_rates_norm   = spike_rate_scaler.transform(X_val_rate_hz)
        X_test_rates_norm  = spike_rate_scaler.transform(X_test_rate_hz)

        # --- Sanity check: show a few raw spike-count rows before batching ---
        print("    Sample spike-count matrix (first 4 bins, first 12 neurons):")
        print(X_train_rate[:4, :12].astype(int))
        print("    Counts min / max in this session:", int(X_train_rate.min()), int(X_train_rate.max()))
        # histogram of counts over ALL bins / neurons (0-6 spikes)
        hist = np.bincount(X_train_rate.astype(int).ravel(), minlength=7)[:7]
        print("Global count histogram 0-6:", hist)

        # Show a quick sanity sample of the rate-normalised matrix used by KF / LSTM
        print("    Sample rate-NORMALISED matrix (first 4 bins, first 12 neurons):")
        print(np.round(X_train_rates_norm[:4, :12], 3))

        # e) Normalize velocity targets (for all models, fit on train)
        velocity_scaler = StandardScaler().fit(y_train_raw)
        y_train_norm = velocity_scaler.transform(y_train_raw)
        y_val_norm   = velocity_scaler.transform(y_val_raw)
        y_test_norm  = velocity_scaler.transform(y_test_raw)
        
        # DEBUG: Check for problematic dataset characteristics
        vel_std_raw = np.std(y_train_raw, axis=0)
        vel_mean_raw = np.mean(y_train_raw, axis=0)
        vel_max_raw = np.max(np.abs(y_train_raw), axis=0)
        spike_std = np.std(X_train_norm, axis=0).mean()
        spike_max = np.max(X_train_norm)
        
        print(f"  Dataset diagnostics for {filename}:")
        print(f"    Velocity std: X={vel_std_raw[0]:.4f}, Y={vel_std_raw[1]:.4f}")
        print(f"    Velocity max: X={vel_max_raw[0]:.4f}, Y={vel_max_raw[1]:.4f}")
        print(f"    Spike mean std: {spike_std:.4f}, max: {spike_max:.1f}")
        print(f"    Train samples: {len(y_train_raw)}, Val: {len(y_val_raw)}, Test: {len(y_test_raw)}")
        
        if vel_std_raw[0] < 0.1 or vel_std_raw[1] < 0.1:
            print(f"    WARNING: Very low velocity variance detected!")
        if spike_max < 0.1:
            print(f"    WARNING: Very low spike activity detected!")

        # --- 4. Run Models ---
        session_results = {'session': filename}

        # --- SNN ---
        training_mode = "Online" if use_online_snn else "Windowed"
        print(f"\n  --- Running SNN {training_mode} Training ---")
        try:
            snn_val_split_ratio = 0.15
            n_snn_train_full = X_train_norm.shape[0] # Use normalized rate data
            n_snn_val = int(n_snn_train_full * snn_val_split_ratio)
            n_snn_train_split = n_snn_train_full - n_snn_val
            
            X_train_snn_final = X_train_norm[:n_snn_train_split]
            y_train_snn_norm_final = y_train_norm[:n_snn_train_split]
            X_val_snn = X_train_norm[n_snn_train_split:]
            y_val_snn_norm = y_train_norm[n_snn_train_split:]

            # --- SETUP FOR HARDWARE-FRIENDLY UPDATER ---
            # 1. Create the reciprocal square root LUT with no flat tail
            vals = torch.linspace(0.25, 8.0, 256, device=device)
            inv_sqrt_LUT = 1.0 / torch.sqrt(vals)
            # Sanity check: ensure LUT is monotonically decreasing
            assert (inv_sqrt_LUT[1:] < inv_sqrt_LUT[:-1]).all()
            
            snn_model = SNNRegression(input_size=input_size, hidden_size=256, output_size=2).to(device)
            
            # Define common parameters for both training modes
            seq_len = 10
            batch_size = 32
            overlap_stride = max(1, seq_len // 2)  # 50% overlap
            
            if use_online_snn:
                # Adaptive learning rate based on dataset characteristics
                spike_activity_factor = min(spike_max / 2.0, 1.0)  # Scale by max spike activity
                velocity_variance_factor = min((vel_std_raw[0] + vel_std_raw[1]) / 100.0, 2.0)  # Scale by velocity variance
                adaptive_lr_scale = max(spike_activity_factor * velocity_variance_factor, 0.1)
                
                # Use 2e-3 as base instead of 3e-4 (learned from dead neuron recovery)
                base_online_lr = 2e-3
                base_fast_lr_adaptive = base_online_lr * adaptive_lr_scale
                print(f"    Adaptive LR scaling: spike_factor={spike_activity_factor:.3f}, vel_factor={velocity_variance_factor:.3f}")
                print(f"    Base fast LR: {base_online_lr:.1e} → {base_fast_lr_adaptive:.1e}")
                
                # Online SNN training
                snn_updater = TwoScaleMetaRLWeightUpdaterFull(
                    snn_model, 
                    base_fast_lr=base_fast_lr_adaptive,  # Adaptive LR based on dataset
                    base_slow_lr=2e-4,  # Keep slow LR fixed
                    window_size=50, 
                    meta_lr=0.01,
                    online_mode=True,   # Enable online mode
                    inv_sqrt_LUT=inv_sqrt_LUT
                )
                
                # Create online data streams (reset every 200 timesteps to simulate trials)
                train_stream = OnlineDataStream(X_train_snn_final, y_train_snn_norm_final, reset_every_n=200)
                val_stream = OnlineDataStream(X_val_snn, y_val_snn_norm, reset_every_n=200)
                
                # For online training, patience should be based on evaluation intervals (steps)
                # eval_every_n=500, so patience=20 means 20*500=10,000 steps without improvement
                online_patience = 20  # More generous for online mode since we evaluate more frequently
                
                trained_snn_model = train_snn_online(
                    snn_model, snn_updater, train_stream, val_stream, 
                    num_epochs=snn_epochs,
                    bin_width_s=bin_width_s,
                    correlation_csv_path=f"online_correlation_{filename.replace('.mat', '.csv')}",
                    patience=online_patience, eval_every_n=500, reset_states_every_n=200
                )
            else:
                # Windowed SNN training (original)
                # Create sequence datasets (50% overlap for more training samples)
                try:
                    train_loader = CausalBatcher(
                        X_train_snn_final, y_train_snn_norm_final,
                        batch_size=batch_size, sequence_length=seq_len,
                        shuffle=True, stride=overlap_stride)
                    val_loader = CausalBatcher(
                        X_val_snn, y_val_snn_norm,
                        batch_size=batch_size, sequence_length=seq_len,
                        shuffle=False, stride=overlap_stride)
                except ValueError as e:
                    print(f"  Could not create data loaders, skipping SNN for this session: {e}")
                    session_results['snn_corr_x'] = None
                    session_results['snn_corr_y'] = None
                    continue # Skip to the next model in the session

                snn_updater = TwoScaleMetaRLWeightUpdaterFull(
                    snn_model, 
                    base_fast_lr=3e-3, 
                    base_slow_lr=1e-3, 
                    window_size=50, 
                    meta_lr=0.01,
                    online_mode=False,  # Windowed mode
                    inv_sqrt_LUT=inv_sqrt_LUT
                )
                # Tweak learning rate for stability with new updates
                snn_updater.fast_lr *= 0.7

                trained_snn_model = train_snn_windowed(
                    snn_model, snn_updater, train_loader, val_loader, num_epochs=snn_epochs, patience=snn_patience
                )
            

            print("  Re-evaluating SNN with movement mask for fair comparison...")
            trained_snn_model.eval()
            
            try:
                # Use non-overlapping windows for testing
                test_loader = CausalBatcher(
                    X_test_norm, y_test_norm,
                    batch_size=batch_size, sequence_length=seq_len,
                    shuffle=False, stride=None)
            except ValueError as e:
                print(f"  Could not create test loader, skipping SNN evaluation: {e}")
                session_results['snn_corr_x'] = None
                session_results['snn_corr_y'] = None
                continue
            
            snn_preds_norm_list, snn_targets_norm_list = [], []
            
            with torch.no_grad():
                for input_seq, target_seq in test_loader:
                    input_seq, target_seq = input_seq.to(device), target_seq.to(device)
                    batch_size_test = input_seq.size(0)
                    test_spk1_rec = torch.zeros(batch_size_test, trained_snn_model.fc1.out_features, device=device) # RESTORED
                    test_mem1 = torch.zeros(batch_size_test, trained_snn_model.fc1.out_features, device=device)
                    test_mem2 = torch.zeros(batch_size_test, trained_snn_model.fc2.out_features, device=device)
                    test_mem3 = torch.zeros(batch_size_test, trained_snn_model.fc3.out_features, device=device)
                    
                    pred_seq, _ = trained_snn_model(input_seq, test_spk1_rec, test_mem1, test_mem2, test_mem3)
                    # Keep only the last step to align with read-out training
                    snn_preds_norm_list.append(pred_seq[:, -1, :].cpu())
                    snn_targets_norm_list.append(target_seq[:, -1, :].cpu())

            snn_preds_normalized = torch.cat(snn_preds_norm_list).numpy()
            snn_targets_normalized = torch.cat(snn_targets_norm_list).numpy()
            
            # --- Option (c): apply an 8 Hz low-pass filter to both signals ---
            fs_lp = 1.0 / data['bin_width_s']  # sampling frequency in Hz (e.g. 20 Hz for 50 ms bins)
            snn_preds_normalized   = butter_lowpass(snn_preds_normalized,   fs=fs_lp, cutoff=8, order=4)
            snn_targets_normalized = butter_lowpass(snn_targets_normalized, fs=fs_lp, cutoff=8, order=4)

            # ----- Compute metrics on NORMALIZED data (robust to gain/offset) -----
            vel_norm_snn = np.linalg.norm(snn_targets_normalized, axis=1)
            if EVAL_MOVEMENT_ONLY:
                move_mask_snn = vel_norm_snn > MOVEMENT_THRESHOLD_NORM
            else:
                # Full-trajectory evaluation: keep every bin
                move_mask_snn = np.ones_like(vel_norm_snn, dtype=bool)
            if np.sum(move_mask_snn) > 10:
                snn_corr_x = compute_correlation(snn_targets_normalized[move_mask_snn, 0], snn_preds_normalized[move_mask_snn, 0])
                snn_corr_y = compute_correlation(snn_targets_normalized[move_mask_snn, 1], snn_preds_normalized[move_mask_snn, 1])
                snn_r2_x   = compute_fvaf(  snn_targets_normalized[move_mask_snn, 0], snn_preds_normalized[move_mask_snn, 0])
                snn_r2_y   = compute_fvaf(  snn_targets_normalized[move_mask_snn, 1], snn_preds_normalized[move_mask_snn, 1])
            else:
                snn_corr_x = snn_corr_y = snn_r2_x = snn_r2_y = 0

            # ---- Denormalise ONLY for plots ----
            snn_preds_denorm   = velocity_scaler.inverse_transform(snn_preds_normalized)
            snn_targets_denorm = velocity_scaler.inverse_transform(snn_targets_normalized)

       
            # Recompute movement mask in *denormalised* units for plotting only
            vel_norm_snn = np.linalg.norm(snn_targets_denorm, axis=1)
            if EVAL_MOVEMENT_ONLY:
                move_mask_snn = vel_norm_snn > MOVEMENT_THRESHOLD_NORM
            else:
                # Full-trajectory evaluation: keep every bin
                move_mask_snn = np.ones_like(vel_norm_snn, dtype=bool)
            # Do NOT overwrite snn_corr_x / snn_r2_x etc.; those are from the
            # earlier evaluation step.

            if np.sum(move_mask_snn) > 10:
                snn_corr_x = compute_correlation(snn_targets_denorm[move_mask_snn, 0], snn_preds_denorm[move_mask_snn, 0])
                snn_corr_y = compute_correlation(snn_targets_denorm[move_mask_snn, 1], snn_preds_denorm[move_mask_snn, 1])
                # Compute R^2 (FVAF) for the same movement-only bins
                snn_r2_x   = compute_fvaf(  snn_targets_denorm[move_mask_snn, 0], snn_preds_denorm[move_mask_snn, 0])
                snn_r2_y   = compute_fvaf(  snn_targets_denorm[move_mask_snn, 1], snn_preds_denorm[move_mask_snn, 1])
            else:
                snn_corr_x = snn_corr_y = snn_r2_x = snn_r2_y = 0
                snn_corr_x_norm, snn_corr_y_norm = 0, 0
                snn_r2_x_norm,  snn_r2_y_norm  = 0, 0

            # Store a copy so they aren't overwritten later
            snn_corr_x_norm, snn_corr_y_norm = snn_corr_x, snn_corr_y
            snn_r2_x_norm,  snn_r2_y_norm  = snn_r2_x,  snn_r2_y

            session_results['snn_corr_x'] = snn_corr_x_norm
            session_results['snn_corr_y'] = snn_corr_y_norm
            session_results['snn_r2_x']   = snn_r2_x_norm
            session_results['snn_r2_y']   = snn_r2_y_norm

            eval_label = "Movement Only" if EVAL_MOVEMENT_ONLY else "All Bins"
            print(f"  SNN Test Corr ({eval_label}): X={snn_corr_x_norm:.4f}, Y={snn_corr_y_norm:.4f}")
            print(f"  SNN Test  R²  ({eval_label}): X={snn_r2_x_norm:.4f}, Y={snn_r2_y_norm:.4f}")
            
            
            # -------------------------------------------------
            # 0. Parameters
            # -------------------------------------------------
            bins_per_sec    = int(round(1.0 / data['bin_width_s']))   # ≈15-16 bins/s
            seconds_to_plot = 30
            subset_len      = seconds_to_plot * bins_per_sec
            
            # -------------------------------------------------
            # 1. Choose a contiguous 30-s window that includes movement
            # -------------------------------------------------
            move_idx = np.where(move_mask_snn)[0]                     # indices where |v| > 1 cm/s
            
            if len(move_idx) == 0:
                raise RuntimeError("No movement bins found (‖v‖ > 1 cm/s).")
            
            # Pick the first movement bin that leaves enough room for a full window;
            # otherwise fall back to the last possible window.
            start_bin = 0
            for idx in move_idx:
                if idx + subset_len < len(snn_targets_denorm):
                    start_bin = idx
                    break
            else:
                start_bin = max(0, len(snn_targets_denorm) - subset_len)
            
            end_bin = start_bin + subset_len
            
            # -------------------------------------------------
            # 2. Slice data
            # -------------------------------------------------
            gt_slice   = snn_targets_denorm[start_bin:end_bin]
            pred_slice = snn_preds_denorm[start_bin:end_bin]
            hold_slice = ~move_mask_snn[start_bin:end_bin]
            t          = np.arange(start_bin, end_bin) * data['bin_width_s']
            
            # Ensure all vectors are the same length to avoid plotting errors
            min_len = min(len(gt_slice), len(pred_slice), len(t))
            gt_slice   = gt_slice[:min_len]
            pred_slice = pred_slice[:min_len]
            hold_slice = hold_slice[:min_len]
            t          = t[:min_len]
            
            # -------------------------------------------------
            # 3. Plot with hold periods shaded in gray
            # -------------------------------------------------
            import matplotlib.pyplot as plt
            
            fig, ax = plt.subplots(2, 1, figsize=(12, 7), sharex=True)
            
            labels = ['v_x', 'v_y']
            for i in range(2):
                ax[i].plot(t, gt_slice[:, i],  label=f'True {labels[i]}', color='royalblue', linewidth=2)
                ax[i].plot(t, pred_slice[:, i], label=f'Pred {labels[i]}', color='firebrick', linewidth=1.6, alpha=0.8)
            
                # Shade hold periods
                ax[i].fill_between(t, -1e3, 1e3, where=hold_slice,
                                   facecolor='lightgray', alpha=0.25, step='mid', label='Hold')
            
                ax[i].set_ylabel(f'{labels[i]} (cm/s)')
                ax[i].grid(True)
                ax[i].legend(loc='upper right')
            
            ax[1].set_xlabel('Time (s)')
            
            fig.suptitle(f'SNN velocity decoding - {seconds_to_plot} s window starting at bin {start_bin}')
            fig.tight_layout(rect=[0, 0, 1, 0.95])
            
            # -------------------------------------------------
            # 4. Save the figure
            # -------------------------------------------------
            out_path = f'snn_velocity_timeseries_slice_{filename}.png'
            fig.savefig(out_path, dpi=300)
            print(f'Time-series velocity figure saved --> {out_path}')

        except Exception as e:
            import traceback
            traceback.print_exc()
            print(f"  ERROR during SNN run: {e}")
            session_results['snn_corr_x'] = None
            session_results['snn_corr_y'] = None

        # --- Kalman Filter ---
        # session_results['kf_corr_x'] = None
        # # session_results['kf_corr_y'] = None
        # print("\n  --- Running Kalman Filter ---")
        # try:
        #     # KF uses NORMALIZED rates but UNNORMALIZED velocities for training H
        #     kf_pred, _ = train_test_kalman_filter(
        #         X_train_rates_norm, y_train_raw,
        #         X_test_rates_norm, y_test_raw
        #     )
        #     # Mask out hold periods for evaluation
        #     vel_norm = np.linalg.norm(y_test_raw, axis=1)
        #     move_mask = vel_norm > 1.0 # 1 cm/s threshold
        #     if np.sum(move_mask) > 10:
        #         kf_corr_x = compute_correlation(y_test_raw[move_mask, 0], kf_pred[move_mask, 0])
        #         kf_corr_y = compute_correlation(y_test_raw[move_mask, 1], kf_pred[move_mask, 1])
        #     else:
        #         kf_corr_x, kf_corr_y = 0, 0

        #     # Compute R² for completeness
        #     if np.sum(move_mask) > 10:
        #         kf_r2_x = compute_fvaf(y_test_raw[move_mask, 0], kf_pred[move_mask, 0])
        #         kf_r2_y = compute_fvaf(y_test_raw[move_mask, 1], kf_pred[move_mask, 1])
        #     else:
        #         kf_r2_x = kf_r2_y = 0

        #     session_results['kf_corr_x'] = kf_corr_x
        #     session_results['kf_corr_y'] = kf_corr_y
        #     session_results['kf_r2_x']   = kf_r2_x
        #     session_results['kf_r2_y']   = kf_r2_y
        #     print(f"  KF Test Corr (Movement Only): X={kf_corr_x:.4f}, Y={kf_corr_y:.4f}")
        # except Exception as e:
        #     print(f"  ERROR during KF run: {e}")
        #     session_results['kf_corr_x'] = None
        #     session_results['kf_corr_y'] = None

        # --- LSTM (ENABLED) ---
        # print("\n  --- Running LSTM ---")
        # try:
        #     seq_len = 10
        #     # Use the explicit train / val splits prepared earlier (rate-normalised)
        #     X_train_lstm_norm = X_train_rates_norm
        #     X_val_lstm_norm   = X_val_rates_norm
        #     y_train_lstm_norm = y_train_norm
        #     y_val_lstm_norm   = y_val_norm

        #     X_train_seq, y_train_seq = create_sequences(X_train_lstm_norm, y_train_lstm_norm, seq_len)
        #     X_val_seq,   y_val_seq   = create_sequences(X_val_lstm_norm,   y_val_lstm_norm,   seq_len)
        #     X_test_seq,  y_test_seq  = create_sequences(X_test_rates_norm,      y_test_norm,       seq_len)

        #     lstm_train_dataset = TensorDataset(torch.tensor(X_train_seq, dtype=torch.float32), torch.tensor(y_train_seq, dtype=torch.float32))
        #     lstm_val_dataset   = TensorDataset(torch.tensor(X_val_seq,   dtype=torch.float32), torch.tensor(y_val_seq,   dtype=torch.float32))
        #     lstm_test_dataset  = TensorDataset(torch.tensor(X_test_seq,  dtype=torch.float32), torch.tensor(y_test_seq,  dtype=torch.float32))

        #     batch_size_lstm = 64
        #     lstm_train_loader = DataLoader(lstm_train_dataset, batch_size=batch_size_lstm, shuffle=True)
        #     lstm_val_loader   = DataLoader(lstm_val_dataset,   batch_size=batch_size_lstm, shuffle=False)
        #     lstm_test_loader  = DataLoader(lstm_test_dataset,  batch_size=batch_size_lstm, shuffle=False)

        #     lstm_model = LSTMRegression(input_size=input_size, hidden_size=256, num_layers=2, output_size=2, dropout=0.3).to(device)
        #     criterion_lstm = nn.MSELoss()
        #     optimizer_lstm = torch.optim.Adam(lstm_model.parameters(), lr=1e-3, weight_decay=1e-5)

        #     trained_lstm_model, _, _ = train_lstm_model(
        #         lstm_model, lstm_train_loader, lstm_val_loader, criterion_lstm,
        #         optimizer_lstm, device, num_epochs=lstm_epochs, patience=lstm_patience
        #     )

        #     trained_lstm_model.eval()
        #     lstm_preds_norm = []
        #     with torch.no_grad():
        #         for inputs, _ in lstm_test_loader:
        #             inputs = inputs.to(device)
        #             outputs = trained_lstm_model(inputs)
        #             lstm_preds_norm.append(outputs.cpu().numpy())

        #     if lstm_preds_norm:
        #         lstm_preds_normalized = np.concatenate(lstm_preds_norm, axis=0)
        #         lstm_preds_denorm = velocity_scaler.inverse_transform(lstm_preds_normalized)

        #         true_y_for_eval = y_test_raw[seq_len-1:][:len(lstm_preds_denorm)]
        #         vel_norm_lstm = np.linalg.norm(true_y_for_eval, axis=1)
        #         move_mask_lstm = vel_norm_lstm > 1.0  # cm/s threshold
        #         if np.sum(move_mask_lstm) > 10:
        #             lstm_corr_x = compute_correlation(true_y_for_eval[move_mask_lstm, 0], lstm_preds_denorm[move_mask_lstm, 0])
        #             lstm_corr_y = compute_correlation(true_y_for_eval[move_mask_lstm, 1], lstm_preds_denorm[move_mask_lstm, 1])
        #         else:
        #             lstm_corr_x = lstm_corr_y = 0
        #     else:
        #         lstm_corr_x = lstm_corr_y = 0

        #     lstm_r2_x = compute_fvaf(true_y_for_eval[:,0], lstm_preds_denorm[:,0]) if lstm_preds_norm else 0
        #     lstm_r2_y = compute_fvaf(true_y_for_eval[:,1], lstm_preds_denorm[:,1]) if lstm_preds_norm else 0

        #     session_results['lstm_corr_x'] = lstm_corr_x
        #     session_results['lstm_corr_y'] = lstm_corr_y
        #     session_results['lstm_r2_x']   = lstm_r2_x
        #     session_results['lstm_r2_y']   = lstm_r2_y
        #     print(f"  LSTM Test Corr (Movement Only): X={lstm_corr_x:.4f}, Y={lstm_corr_y:.4f}")
        # except Exception as e:
        #     print(f"  ERROR during LSTM run: {e}")
        #     session_results['lstm_corr_x'] = None
        #     session_results['lstm_corr_y'] = None

        all_results.append(session_results)
        print(f"--- Finished Session {session_idx + 1} ---")

    # --- 4. Aggregate and Report Results ---
    print("\n===== Within-Session Comparison Final Results =====")
    if not all_results:
        print("No sessions completed successfully.")
        return
        
    results_df = pd.DataFrame(all_results)
    results_df.to_csv("results_within_session_comparison.csv", index=False)
    print("Full results saved to results_within_session_comparison.csv")

    print("\n--- Average Test Corr & R² --- ")
    avg_results = {}
    for model_key in ['snn', 'kf', 'lstm']:
        x_key = f'{model_key}_corr_x'
        y_key = f'{model_key}_corr_y'
        x_r2_key = f'{model_key}_r2_x'
        y_r2_key = f'{model_key}_r2_y'
        if (x_key in results_df.columns and y_key in results_df.columns and
            x_r2_key in results_df.columns and y_r2_key in results_df.columns and
            results_df[x_key].notna().any()):
             mean_x = results_df[x_key].mean(); std_x = results_df[x_key].std()
             mean_y = results_df[y_key].mean(); std_y = results_df[y_key].std()
             mean_r2x = results_df[x_r2_key].mean(); std_r2x = results_df[x_r2_key].std()
             mean_r2y = results_df[y_r2_key].mean(); std_r2y = results_df[y_r2_key].std()
             print(f"Model: {model_key.upper()}")
             print(f"  X-Corr: Mean={mean_x:.4f}, Std={std_x:.4f}")
             print(f"  Y-Corr: Mean={mean_y:.4f}, Std={std_y:.4f}")
             print(f"  X-R² : Mean={mean_r2x:.4f}, Std={std_r2x:.4f}")
             print(f"  Y-R² : Mean={mean_r2y:.4f}, Std={std_r2y:.4f}")
             avg_results[f'{model_key}_x'] = mean_x
             avg_results[f'{model_key}_y'] = mean_y
        else:
             print(f"Model: {model_key.upper()} - Results missing or all NaN.")

    # --- Plotting Comparison ---
    try:
        labels = []
        means = []
        if 'snn_x' in avg_results: labels.extend(['SNN_X', 'SNN_Y']); means.extend([avg_results['snn_x'], avg_results['snn_y']])
        if 'kf_x' in avg_results: labels.extend(['KF_X', 'KF_Y']); means.extend([avg_results['kf_x'], avg_results['kf_y']])
        if 'lstm_x' in avg_results: labels.extend(['LSTM_X', 'LSTM_Y']); means.extend([avg_results['lstm_x'], avg_results['lstm_y']])
        
        if means:
            x = np.arange(len(labels))
            width = 0.6
            fig, ax = plt.subplots(figsize=(10, 6))
            rects = ax.bar(x, means, width, label='Mean Corr')
            ax.set_ylabel('Corr')
            ax.set_title(f'Average Model Performance Across {len(all_results)} Sessions')
            ax.set_xticks(x)
            ax.set_xticklabels(labels)
            ax.legend()
            ax.bar_label(rects, padding=3, fmt='%.3f')
            ax.set_ylim(min(0, min(means) * 1.15 if means else 0), max(means) * 1.15 if means else 1)
            fig.tight_layout()
            plt.savefig("within_session_comparison_plot.png")
            plt.show()
            print("Comparison plot saved to within_session_comparison_plot.png")
        else:
            print("Skipping plot generation due to missing results.")
    except Exception as e:
        print(f"Error generating comparison plot: {e}")

    print("===== Experiment Complete =====")


# Update the main execution block
if __name__ == "__main__":
    # Demonstrate both windowed and online training
    print("=" * 80)
    print("ONLINE vs WINDOWED SNN COMPARISON")
    print("=" * 80)
    
    # Test with a smaller number of sessions for demonstration
    test_sessions = 1
    seed = 42
    
    for mode_name, use_online in [("WINDOWED", False)]:
        print(f"\n{'='*50}")
        print(f"Testing {mode_name} SNN Training")
        print(f"{'='*50}")
        
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        
        try:
            run_within_session_comparison(num_sessions=test_sessions, use_online_snn=use_online)
            
            # Load and display results
            df = pd.read_csv("results_within_session_comparison.csv")
            if len(df) > 0 and 'snn_corr_x' in df.columns:
                mean_corr_x = df['snn_corr_x'].mean()
                mean_corr_y = df['snn_corr_y'].mean()
                print(f"\n{mode_name} SNN Results:")
                print(f"  Mean X Correlation: {mean_corr_x:.4f}")
                print(f"  Mean Y Correlation: {mean_corr_y:.4f}")
                print(f"  Average: {(mean_corr_x + mean_corr_y)/2:.4f}")
            else:
                print(f"No valid results for {mode_name} mode")
        except Exception as e:
            print(f"Error running {mode_name} mode: {e}")
    
    print(f"\n{'='*80}")
    print("COMPARISON COMPLETE")
    print("Key differences between modes:")
    print("- WINDOWED: Uses 10-timestep sequences, batch updates")
    print("- ONLINE: Uses single timesteps, persistent hidden states")
    print("- Online mode should show lower but more stable learning")
    print("- Online mode simulates real-time BCI operation")
    print(f"{'='*80}")
    
    # Original multi-seed run (commented out for now to focus on demo)
    # combined_dfs = []
    # for seed in range(10):
    #     print(f"\n##############################\n# Seed {seed} run\n##############################")
    #     np.random.seed(seed)
    #     random.seed(seed)
    #     torch.manual_seed(seed)
    #     run_within_session_comparison(num_sessions=10)

    #     try:
    #         df_seed = pd.read_csv("results_within_session_comparison.csv")
    #         df_seed['seed'] = seed
    #         corr_cols = [c for c in df_seed.columns if 'corr_' in c]
    #         df_seed = df_seed[['session', 'seed'] + corr_cols]
    #         combined_dfs.append(df_seed)
    #     except Exception as e:
    #         print(f"Could not load results for seed {seed}: {e}")

    # if combined_dfs:
    #     big_df = pd.concat(combined_dfs, ignore_index=True)
    #     big_df.to_csv("results_all_seeds_correlations.csv", index=False)
    #     print("Aggregated correlations saved to results_all_seeds_correlations.csv")
    # else:
    #     print("No aggregated results were produced.")
