# %%
import h5py
import numpy as np
from sklearn.preprocessing import StandardScaler
import warnings
import torch
import snntorch as snn
from snntorch import surrogate
import numpy as np
import torch
import h5py
from sklearn.preprocessing import StandardScaler
import warnings
import os
import random
import pandas as pd
import matplotlib.pyplot as plt
import json
import time
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from filterpy.kalman import KalmanFilter
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from scipy.signal import butter, sosfiltfilt
from sklearn.linear_model import Ridge
import math
from scipy.signal import correlate, correlation_lags
import torch.nn.functional as F

# =============================================================
# Evaluation configuration
# -------------------------------------------------------------
# If EVAL_MOVEMENT_ONLY is True, metrics are computed only on bins
# where cursor speed exceeds MOVEMENT_THRESHOLD_NORM (value in the
# *normalised* velocity space produced by StandardScaler).  If False,
# all time bins contribute to the reported correlation and R².
# -------------------------------------------------------------
EVAL_MOVEMENT_ONLY = False          # Set to False for full-trajectory evaluation
MOVEMENT_THRESHOLD_NORM = 0.2      # ≈ 2 cm s⁻¹ after inverse transform
MOVEMENT_THRESHOLD_CM_S = 2.0      # threshold used directly on de-normalised velocities for plots
# =============================================================

def load_data(
    mat_file_path='zenodo_dataset/indy_20160419_01.mat',
    bin_width_s=0.064,   # 64 ms window to match Makin et al.
    stride_s=0.064,      # 64 ms stride for non-overlapping bins
    test_split_ratio=0.2,
    spike_processing='rate',
    normalize_spikes=False,
    verbose=True
):
    if verbose:
        print("=== load_data parameters ===")
        print(f" mat_file_path    = {mat_file_path}")
        print(f" bin_width_s      = {bin_width_s}")
        print(f" stride_s         = {stride_s}")
        print("  (Processing is now zero-lag with non-overlapping bins)")
        print(f" test_split_ratio = {test_split_ratio}")
        print(f" spike_processing = {spike_processing}")
        print(f" normalize_spikes = {normalize_spikes}")
        print("============================\n")

    with h5py.File(mat_file_path, 'r') as f:
        t_vec = f['t'][()]
        if t_vec.ndim == 1:
            t_vec = t_vec[:, None]
        elif t_vec.shape[0] < t_vec.shape[1]:
            t_vec = t_vec.T
        fs = 1.0 / np.mean(np.diff(t_vec.squeeze()))

        ds = f['spikes']
        if h5py.check_dtype(ref=ds.dtype) is not None:
            refs = ds[()]
            if refs.ndim == 2 and refs.shape[1] != 1:
                n_units, get_ref = refs.shape[1], lambda i: refs[0,i]
            elif refs.ndim == 2 and refs.shape[0] != 1:
                n_units, get_ref = refs.shape[0], lambda i: refs[i,0]
            else:
                refs = refs.ravel()
                n_units, get_ref = refs.shape[0], lambda i: refs[i]
            T_orig = t_vec.shape[0]
            spikes = np.zeros((T_orig, n_units), dtype=np.float32)
            for i in range(n_units):
                r = get_ref(i)
                if isinstance(r, h5py.Reference):
                    times = f[r][()].squeeze()
                    idx = np.searchsorted(t_vec.squeeze(), times, 'left')
                    idx = idx[idx < T_orig]
                    spikes[np.unique(idx), i] = 1.0
        else:
            spikes = ds[()]
            if spikes.shape[0] < spikes.shape[1]:
                spikes = spikes.T
            spikes = spikes.astype(np.float32)

        cp = f['cursor_pos'][()]
        if cp.shape[0] < cp.shape[1]:
            cp = cp.T
        
        # --- Filter cursor position before calculating velocity ---
        # Per Makin et al. 2019, apply a 4th-order, 10Hz, non-causal (zero-phase)
        # Butterworth filter to the cursor position data before differentiation.
        sos = butter(4, 10, 'low', fs=fs, output='sos')
        cp_filtered = sosfiltfilt(sos, cp, axis=0)
        vel = np.gradient(cp_filtered, t_vec.squeeze(), axis=0)

    # --- Binning --- 
    bin_width = int(round(bin_width_s * fs))
    stride = int(round(stride_s * fs))
    T = spikes.shape[0]
    num_bins = (T - bin_width) // stride + 1

    # +++ Handle Case of Zero Bins +++
    if num_bins <= 0:
        warnings.warn(f"Not enough data points ({T}) for the given bin ({bin_width}) and stride ({stride}) in file {mat_file_path}. Returning empty data structure.", RuntimeWarning)
        # Return a dictionary indicating failure and providing empty structures where possible
        n_neurons = spikes.shape[1] # Get neuron count even if no bins
        return {
            'error': 'Insufficient data for binning',
            'X_binned_raw': np.empty((0, n_neurons), dtype=np.float32),
            'y_binned_raw': np.empty((0, vel.shape[1]), dtype=np.float32), # Use original vel shape for columns
            't_binned_raw': np.empty((0,), dtype=np.float64),
            # Include other keys expected by caller, filled with None or empty equivalents
            'X_train': torch.empty((0, n_neurons), dtype=torch.float32),
            'y_train': torch.empty((0, vel.shape[1]), dtype=torch.float32),
            'X_test':  torch.empty((0, n_neurons), dtype=torch.float32),
            'y_test':  torch.empty((0, vel.shape[1]), dtype=torch.float32),
            't_train': np.empty((0,), dtype=np.float64),
            't_test':  np.empty((0,), dtype=np.float64),
            'X_train_norm': torch.empty((0, n_neurons), dtype=torch.float32),
            'y_train_norm': torch.empty((0, vel.shape[1]), dtype=torch.float32),
            'X_test_norm':  torch.empty((0, n_neurons), dtype=torch.float32),
            'y_test_norm':  torch.empty((0, vel.shape[1]), dtype=torch.float32),
            'spike_scaler': None,
            'velocity_scaler': None,
            'bin_width_s': bin_width_s,
            'stride_s': stride_s
        }
    # +++ End Handle Case of Zero Bins +++

    spike_windows = np.stack([spikes[i*stride : i*stride+bin_width] for i in range(num_bins)])
    spikes_b_raw = spike_windows.sum(axis=1).astype(np.float32) # Raw spike counts

    vel_windows = np.stack([vel[i*stride : i*stride+bin_width] for i in range(num_bins)])
    vel_b_raw = vel_windows.mean(axis=1).astype(np.float32)    # Raw mean velocity in bins

    t_binned = np.array([t_vec[i*stride][0] for i in range(num_bins)])

    # --- Spike Processing (Applied After Binning) ---
    if spike_processing == 'binary':
        X_processed = (spikes_b_raw > 0).astype(np.float32)
    elif spike_processing == 'count':
        X_processed = spikes_b_raw # Already counts
    elif spike_processing == 'rate':
        X_processed = (spikes_b_raw / bin_width_s).astype(np.float32)
    elif spike_processing == 'poisson':
        lam = spikes_b_raw / bin_width_s
        X_processed = np.random.poisson(lam).astype(np.float32)
    else:
        raise ValueError(f"Unknown spike_processing: {spike_processing}")

    # --- Spike Normalization (Optional, applied AFTER processing and lag) ---
    if normalize_spikes:
        # Note: Normalizing here before splitting might leak test info
        # It's generally better to normalize after splitting, fitting scaler only on train
        mn = X_processed.min(axis=0, keepdims=True)
        mx = X_processed.max(axis=0, keepdims=True)
        rng = np.where(mx - mn > 0, mx - mn, 1.0)
        X_processed = (X_processed - mn) / rng
        if verbose:
            print("Applied spike normalization (min-max scaling). Consider normalizing after splitting.")

    # --- Internal Train/Test Split (if ratio > 0) --- 
    y_processed = vel_b_raw # Use the zero-lag velocity
    B_total = X_processed.shape[0]
    n_test = 0
    if test_split_ratio > 0:
        n_test = max(int(B_total * test_split_ratio), 1)
    n_train = B_total - n_test

    X_tr, X_te = X_processed[:n_train], X_processed[n_train:]
    y_tr, y_te = y_processed[:n_train], y_processed[n_train:]
    t_tr, t_te = t_binned[:n_train], t_binned[n_train:]

    # --- Fit Scalers (Only if train data exists) --- 
    sx = None
    sy = None
    X_tr_norm = X_tr # Default to unnormalized if no scaling
    y_tr_norm = y_tr
    X_te_norm = X_te
    y_te_norm = y_te

    if n_train > 0:
        try:
            sx = StandardScaler().fit(X_tr)
            X_tr_norm = sx.transform(X_tr)
            if n_test > 0:
                X_te_norm = sx.transform(X_te) # Only transform test if it exists
        except ValueError as e:
            warnings.warn(f"Could not fit/transform spike scaler: {e}", RuntimeWarning)
            sx = None # Ensure scaler is None if fitting failed

        try:
            sy = StandardScaler().fit(y_tr)
            y_tr_norm = sy.transform(y_tr)
            if n_test > 0:
                y_te_norm = sy.transform(y_te) # Only transform test if it exists
        except ValueError as e:
            warnings.warn(f"Could not fit/transform velocity scaler: {e}", RuntimeWarning)
            sy = None # Ensure scaler is None if fitting failed
    # --- End Fit Scalers --- 

    # --- Return Dictionary --- 
    return_dict = {
        # Data after lag, processing, and internal split
        'X_train': torch.tensor(X_tr, dtype=torch.float32),
        'y_train': torch.tensor(y_tr, dtype=torch.float32),
        'X_test':  torch.tensor(X_te, dtype=torch.float32),
        'y_test':  torch.tensor(y_te, dtype=torch.float32),
        't_train': t_tr,
        't_test':  t_te,
        # Data after lag, processing, split, AND normalization fit on train
        'X_train_norm': torch.tensor(X_tr_norm, dtype=torch.float32),
        'y_train_norm': torch.tensor(y_tr_norm, dtype=torch.float32),
        'X_test_norm':  torch.tensor(X_te_norm, dtype=torch.float32),
        'y_test_norm':  torch.tensor(y_te_norm, dtype=torch.float32),
        # Raw binned data (before lag or processing) - Useful for aggregation
        'X_binned_raw': spikes_b_raw, # Raw counts
        'y_binned_raw': vel_b_raw,    # Raw velocity
        't_binned_raw': t_binned,     # Timestamps for raw binned data
        # Scalers fit on the internal training split
        'spike_scaler': sx,
        'velocity_scaler': sy,
        # Original parameters for reference
        'bin_width_s': bin_width_s,
        'stride_s': stride_s
    }

    # Add the full processed (but un-split) data if no split was done
    if test_split_ratio == 0:
        return_dict['X_processed_full'] = torch.tensor(X_processed, dtype=torch.float32)
        return_dict['y_processed_full'] = torch.tensor(y_processed, dtype=torch.float32)
        return_dict['t_processed_full'] = t_binned
        # Provide raw counts after lag as well
        return_dict['X_raw_counts_lagged'] = spikes_b_raw

    return return_dict

# %%
import random
import time
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from scipy.interpolate import interp1d
import numpy as np
import pandas as pd

class CausalBatcher:
    """
    Creates batches of contiguous sequences from time-series data.
    Supports overlapping windows for training data augmentation.
    """
    def __init__(self, spike_data, targets, batch_size=64, sequence_length=20, shuffle=False, stride=None):
        self.spike_data = torch.tensor(spike_data, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.shuffle = shuffle
        self.stride = stride if stride is not None else sequence_length # Non-overlapping by default

        # Calculate how many full sequences we can make
        num_total_timesteps = self.spike_data.shape[0]
        if num_total_timesteps < self.sequence_length:
            self.num_sequences = 0
        else:
            self.num_sequences = (num_total_timesteps - self.sequence_length) // self.stride + 1

        # Calculate how many full batches we can make
        self.num_batches = self.num_sequences // self.batch_size

        if self.num_batches == 0:
            raise ValueError(
                f"Not enough data to create a single batch. "
                f"Have {self.num_sequences} sequences for seq_len={sequence_length} and stride={self.stride}, but batch_size is {self.batch_size}."
            )

        self.indices = np.arange(self.num_sequences)
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __len__(self):
        return self.num_batches

    def __iter__(self):
        self.current_batch = 0
        if self.shuffle: # Reshuffle for each epoch
            np.random.shuffle(self.indices)
        return self

    def __next__(self):
        if self.current_batch >= self.num_batches:
            raise StopIteration

        # Get the indices for the sequences in this batch
        start_idx = self.current_batch * self.batch_size
        end_idx = start_idx + self.batch_size
        batch_seq_indices = self.indices[start_idx:end_idx]

        # Get the actual start timestep for each sequence in the batch
        batch_timestep_starts = batch_seq_indices * self.stride

        # Build the batch of sequences
        batch_spikes = torch.stack([
            self.spike_data[ts:ts + self.sequence_length] for ts in batch_timestep_starts
        ])
        batch_targets = torch.stack([
            self.targets[ts:ts + self.sequence_length] for ts in batch_timestep_starts
        ])

        self.current_batch += 1
        return batch_spikes, batch_targets

class OnlineDataStream:
    """
    Creates a stream of single timestep samples for true online learning.
    Maintains proper chronological order and supports trial boundaries.
    """
    def __init__(self, spike_data, targets, reset_every_n=None):
        self.spike_data = torch.tensor(spike_data, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.reset_every_n = reset_every_n  # Reset hidden states every N timesteps (e.g., trial boundaries)
        self.current_idx = 0
        self.steps_since_reset = 0
        
    def __len__(self):
        return self.spike_data.shape[0]
    
    def __iter__(self):
        self.current_idx = 0
        self.steps_since_reset = 0
        return self
    
    def __next__(self):
        if self.current_idx >= len(self.spike_data):
            raise StopIteration
            
        x_t = self.spike_data[self.current_idx]
        y_t = self.targets[self.current_idx]
        
        should_reset = False
        if self.reset_every_n and self.steps_since_reset >= self.reset_every_n:
            should_reset = True
            self.steps_since_reset = 0
        
        self.current_idx += 1
        self.steps_since_reset += 1
        
        return x_t, y_t, should_reset

class SNNRegression(nn.Module):
    def __init__(self, input_size, hidden_size=128, output_size=2):
        super(SNNRegression, self).__init__()
        spike_grad = surrogate.fast_sigmoid()

        # Feedforward & Recurrent Layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc_rec = nn.Linear(hidden_size, hidden_size, bias=False) # RESTORED
        self.lif1 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)

        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.lif2 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)

        self.fc3 = nn.Linear(hidden_size // 2, output_size)
        self.lif3 = snn.Leaky(
            beta=0.5,
            spike_grad=spike_grad,
            init_hidden=False,
            threshold=1.0,
            reset_mechanism="none"
        )
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                # give fc1 / fc2 a small negative bias; keep others at 0
                if module in (self.fc1, self.fc2):
                    nn.init.constant_(module.bias, -0.1)
                else:
                    nn.init.zeros_(module.bias)

    def forward(self, x, spk1_rec, mem1, mem2, mem3, need_traces: bool = False):
        """
        Args
        ----
        x           : [batch, T, features]
        spk1_rec …  : recurrent + membrane states (same as before)
        need_traces : if True, also returns the "pre / post" tensors
                      and next membrane potentials required by the
                      Hebbian updater.
        """
        batch, T, _ = x.shape
        outputs = []
        
        # Buffers only filled when learning
        if need_traces:
            pre_ff_buf, pre_rec_buf, spk1_buf, spk2_buf = [], [], [], []
            mem1_next_buf, mem2_next_buf, mem3_next_buf = [], [], []

        for t in range(x.size(1)):
            inp = x[:, t, :]
            
            # Detach states for BPTT-like behavior within the sequence
            mem1 = mem1.detach()
            mem2 = mem2.detach()
            spk1_rec = spk1_rec.detach()
            mem3 = mem3.detach()
            
            pre_rec_t_for_trace = spk1_rec

            # Combine feedforward and recurrent input for the first layer
            cur1 = self.fc1(inp) + self.fc_rec(spk1_rec)
            spk1, mem1 = self.lif1(cur1, mem1)

            # Update the recurrent state for the next time step
            spk1_rec = spk1
            
            # Second hidden layer
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            # Output layer
            cur3 = self.fc3(spk2)
            out, mem3 = self.lif3(cur3, mem3)
            outputs.append(mem3)

            # --- Collect traces if asked ---
            if need_traces:
                pre_ff_buf.append(inp.detach())
                pre_rec_buf.append(pre_rec_t_for_trace.detach())
                spk1_buf.append(spk1.detach())
                spk2_buf.append(spk2.detach())
                mem1_next_buf.append(mem1.detach())
                mem2_next_buf.append(mem2.detach())
                mem3_next_buf.append(mem3.detach())

        out_seq = torch.stack(outputs, dim=1)
        final_states = (spk1_rec, mem1, mem2, mem3)

        if need_traces:
            traces = (
                torch.stack(pre_ff_buf, dim=1),
                torch.stack(pre_rec_buf, dim=1),
                torch.stack(spk1_buf, dim=1),
                torch.stack(spk2_buf, dim=1),
                torch.stack(mem1_next_buf, dim=1),
                torch.stack(mem2_next_buf, dim=1),
                torch.stack(mem3_next_buf, dim=1)
            )
            return out_seq, final_states, traces
        else:
            return out_seq, final_states


def compute_correlation(pred, target):
    """
    Compute Pearson correlation between predicted and target values.
    Handles both tensor and numpy inputs.
    """
    # Convert to numpy if tensors
    if torch.is_tensor(pred):
        pred = pred.detach().cpu().numpy()
    if torch.is_tensor(target):
        target = target.detach().cpu().numpy()
    
    # Handle multi-dimensional arrays by flattening
    pred = pred.flatten()
    target = target.flatten()
    
    # Check for zero variance
    if np.std(pred) < 1e-10 or np.std(target) < 1e-10:
        return 0.0
    
    # Compute correlation safely
    try:
        corr_matrix = np.corrcoef(pred, target)
        # Handle case when corrcoef returns a scalar or 2x2 matrix
        if corr_matrix.size > 1:
            return corr_matrix[0, 1]
        else:
            # If identical arrays, correlation is 1.0
            return 1.0
    except (IndexError, ValueError):
        # Fallback for any errors
        return 0.0

def compute_fvaf(y_true, y_pred):
    """
    Computes the Fraction of Variance Accounted For (FVAF), equivalent to R^2 score.
    Handles both tensor and numpy inputs. Expects 1D arrays.
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()

    if y_true.ndim > 1 or y_pred.ndim > 1:
        warnings.warn(f"FVAF expects 1D arrays, but got shapes {y_true.shape} and {y_pred.shape}. Flattening.", RuntimeWarning)
        y_true = y_true.flatten()
        y_pred = y_pred.flatten()

    ss_res = np.sum((y_true - y_pred)**2)
    ss_tot = np.sum((y_true - np.mean(y_true))**2)

    if ss_tot < 1e-10:
        return 1.0 if ss_res < 1e-10 else 0.0

    return 1 - (ss_res / ss_tot)

def butter_lowpass(data, fs=1/0.05, cutoff=8, order=4):
    sos = butter(order, cutoff, 'low', fs=fs, output='sos')
    return sosfiltfilt(sos, data, axis=0)

# %%
##########################################
# TWO-TIMESCALE META RL UPDATER WITH META-LEARNING
##########################################
class TwoScaleMetaRLWeightUpdaterFull:
    """
    This updater adapts all layers using configurable timescales and incorporates meta-learning.
    It uses eligibility traces to integrate credit assignment over time, 
    providing memory of recent activity without acausally looking into the future. 
    The learning rate is modulated by a biologically-inspired reward signal based on performance.
    
    This approach integrates Hebbian principles with the needs of the BCI task.
    """
    def __init__(self, model, base_fast_lr=1e-3, base_slow_lr=1e-2, window_size=180, 
                 meta_lr=1e-3, tau_e_fast=0.12, tau_e_slow=0.7, online_mode=False,
                 use_hebbian=True,  # ABLATION: Enable/disable Hebbian learning
                 use_meta_learning=True,  # ABLATION: Enable/disable meta-learning
                 rms_mode='with_rms',  # NEW ABLATION: 'with_rms', 'without_rms', 'partial_rms'
                 **kwargs):
        self.model = model
        self.device = next(model.parameters()).device
        self.base_fast_lr = base_fast_lr
        self.base_slow_lr = base_slow_lr
        self.meta_lr = meta_lr
        self.grad_scale = 0
        self.online_mode = online_mode
        self.use_hebbian = use_hebbian  # FIX: Use the parameter, not hardcoded False!
        self.use_meta_learning = use_meta_learning  # Meta-learning ablation parameter
        self.rms_mode = rms_mode  # NEW: RMS normalization ablation parameter
        
        # Print ablation status for clarity
        learning_type = "Hebbian (error × d_lif × activity)" if self.use_hebbian else "Delta Rule (error × activity)"
        meta_status = "WITH meta-learning" if self.use_meta_learning else "WITHOUT meta-learning (fixed LR)"
        rms_status = f"RMS mode: {rms_mode.upper()}"
        print(f"Initializing updater with {learning_type} {meta_status} {rms_status}")

        # RMS accumulators for error and spike normalization (Tweaks A & B)
        # Use per-unit, integer-based EMAs for hardware friendliness
        # Initialize based on RMS mode
        if self.rms_mode in ['with_rms', 'partial_rms']:
            self.err2_sq_ema = torch.ones(model.fc2.out_features, dtype=torch.int32, device=self.device)
            self.err1_sq_ema = torch.ones(model.fc1.out_features, dtype=torch.int32, device=self.device)
            self.spk1_rms    = torch.ones(model.fc1.out_features, dtype=torch.int32, device=self.device)
        else:  # without_rms mode
            self.err2_sq_ema = None
            self.err1_sq_ema = None
            self.spk1_rms = None
        
        # Meta-learning parameters - only used if use_meta_learning=True
        self.meta_params = {
            'plasticity': 1.0, 
            'sensitivity': 1.0
        } 
        
        # Set learning rates based on meta-learning mode
        if self.use_meta_learning:
            self.fast_lr = self.base_fast_lr * self.meta_params['plasticity']
            self.slow_lr = self.base_slow_lr * self.meta_params['sensitivity']
        else:
            # Fixed learning rates when meta-learning is disabled
            self.fast_lr = self.base_fast_lr
            self.slow_lr = self.base_slow_lr

        # --- Hardware-friendly LUTs ---
        # Conditional initialization based on RMS mode
        if self.rms_mode in ['with_rms', 'partial_rms']:
            # Passed in from the main training script
            self.inv_sqrt_LUT = kwargs.get('inv_sqrt_LUT', None)
            if self.inv_sqrt_LUT is None:
                raise ValueError("inv_sqrt_LUT must be provided to the updater when using RMS normalization.")
            self.out_rms = torch.ones(2, dtype=torch.int32, device=self.device)  # vx, vy
        else:  # without_rms mode
            self.inv_sqrt_LUT = None
            self.out_rms = None
            
        # LUT for error-bucket based reward scaling (always used)
        self.reward_LUT = torch.tensor(
       [15,14,13,12,11,10,9,8,8,8,7,7,6,5,4,3], device=self.device)
        self.fc3_row_cap_q12 = torch.full((2, 1),
                                     int(6.0*4096),
                                     dtype=torch.int32,
                                     device=self.device)

        # --- Standard Dual-Timescale Eligibility Trace Parameters ---
        # Use standard dual timescale system (not being ablated in this study)
        if online_mode:
            self.decay_fast = math.exp(-0.064 / (tau_e_fast * 0.5))  # Faster decay
            self.decay_slow = math.exp(-0.064 / (tau_e_slow * 0.8))  # Slightly faster decay
            self.trace_mix_a = 0.8  # More weight on fast traces for responsiveness
        else:
            self.decay_fast = math.exp(-0.064 / tau_e_fast)
            self.decay_slow = math.exp(-0.064 / tau_e_slow)
            self.trace_mix_a = 0.5  # Fixed mixing parameter for fast and slow traces

        # Initialize eligibility traces (dual-timescale)
        self.e_fast_fc1  = torch.zeros_like(model.fc1.weight, device=self.device)
        self.e_fast_fc2  = torch.zeros_like(model.fc2.weight, device=self.device)
        self.e_fast_fc3  = torch.zeros_like(model.fc3.weight, device=self.device)
        self.e_fast_rec  = torch.zeros_like(model.fc_rec.weight, device=self.device)
        
        self.e_slow_fc1  = torch.zeros_like(model.fc1.weight, device=self.device)
        self.e_slow_fc2  = torch.zeros_like(model.fc2.weight, device=self.device)
        self.e_slow_fc3  = torch.zeros_like(model.fc3.weight, device=self.device)
        self.e_slow_rec  = torch.zeros_like(model.fc_rec.weight, device=self.device)
        # --- End Eligibility Trace ---

        self.window_size = window_size
        self.step_counter = 0
        self.cumulative_error = 0
        self.prev_cumulative_error = None
        
        self.max_loss_value = 5.0
        
        # Initialize gradient averages with zeros matching model layer shapes and device
        self.grad_fc1_avg = torch.zeros_like(self.model.fc1.weight.data, device=self.device)
        self.grad_fc2_avg = torch.zeros_like(self.model.fc2.weight.data, device=self.device)
        self.grad_fc3_avg = torch.zeros_like(self.model.fc3.weight.data, device=self.device)
        self.grad_fc_rec_avg = torch.zeros_like(self.model.fc_rec.weight.data, device=self.device) # RESTORED
        
        # Initialize momentum for gradient accumulation
        self.momentum = 0.9

    def reset_weights_with_xavier(self):
        """
        Reset weights using Xavier/Glorot initialization for better convergence
        after catastrophic errors.
        """
        with torch.no_grad():
            for module in [self.model.fc1, self.model.fc2, self.model.fc3, self.model.fc_rec]: # RESTORED self.model.fc_rec
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
        print("Weights reset using Xavier/Glorot initialization for better convergence")

    def warm_start_rms_emas(self, warmup_data_x, warmup_data_y, num_warmup_steps=100):
        """
        Warm-start the RMS EMAs with some initial data to avoid cold-start issues in online mode.
        """
        if not self.online_mode:
            return
            
        print(f"Warm-starting RMS EMAs with {num_warmup_steps} steps...")
        with torch.no_grad():
            for i in range(min(num_warmup_steps, len(warmup_data_x))):
                x_t = warmup_data_x[i:i+1].to(self.device)  # [1, features]
                y_t = warmup_data_y[i:i+1].to(self.device)  # [1, 2]
                
                # Quick forward pass to get some initial error statistics
                x_t_seq = x_t.unsqueeze(1)  # [1, 1, features]
                y_t_seq = y_t.unsqueeze(1)  # [1, 1, 2]
                
                # Initialize temporary states
                batch_size = 1
                spk1_rec = torch.zeros(batch_size, self.model.fc1.out_features, device=self.device)
                mem1 = torch.zeros(batch_size, self.model.fc1.out_features, device=self.device)
                mem2 = torch.zeros(batch_size, self.model.fc2.out_features, device=self.device)
                mem3 = torch.zeros(batch_size, self.model.fc3.out_features, device=self.device)
                
                pred_seq, _, _ = self.model(x_t_seq, spk1_rec, mem1, mem2, mem3, need_traces=True)
                output_error = y_t_seq - pred_seq
                
                # Update output RMS - need to sum over batch AND sequence dims to get [2] shape
                k_out = 2  # Faster adaptation for warmup
                sq_out = ((output_error**2) * 4096).sum(dim=(0,1)).to(torch.int32)  # Sum over batch and seq dims
                self.out_rms -= self.out_rms >> k_out
                self.out_rms += sq_out >> k_out
                
        print("RMS EMA warm-start complete.")

    def fast_update_single_timestep(self, x_t, y_t, states):
        """
        Online single-timestep update optimized for real-time learning.
        Maintains persistent hidden states and applies scaled learning rates.
        """
        if not self.online_mode:
            raise ValueError("Single timestep update only available in online mode")
            
        # Reshape for model forward pass: [batch=1, seq=1, features]
        x_t_seq = x_t.unsqueeze(0).unsqueeze(0)  # [1, 1, features]
        y_t_seq = y_t.unsqueeze(0).unsqueeze(0)  # [1, 1, 2]
        
        device = self.device
        spk1_rec, mem1, mem2, mem3 = states

        # Single forward pass to get prediction and traces
        pred_seq, final_states, traces = self.model(
            x_t_seq, spk1_rec, mem1, mem2, mem3, need_traces=True
        )
        (pre_ff_all, pre_rec_all, spk1_all, spk2_all, 
         mem1_next_all, mem2_next_all, mem3_next_all) = traces
        
        # Calculate loss for single timestep
        mse_loss = F.mse_loss(pred_seq, y_t_seq)
        combined_loss = torch.clamp(mse_loss, 0.0, 10.0)
        
        # Hebbian update calculation for single timestep
        with torch.no_grad():
            # Get single timestep data
            pre_ff_t = pre_ff_all[:, 0, :]    # [1, features]
            pre_rec_t = pre_rec_all[:, 0, :]  # [1, hidden]
            spk1_t = spk1_all[:, 0, :]        # [1, hidden]
            spk2_t = spk2_all[:, 0, :]        # [1, hidden//2]
            mem1_next = mem1_next_all[:, 0, :] # [1, hidden]
            mem2_next = mem2_next_all[:, 0, :] # [1, hidden//2]
            mem3_next = mem3_next_all[:, 0, :] # [1, 2]
            
            # Single timestep error
            output_error_t = y_t_seq[:, 0, :] - pred_seq[:, 0, :]  # [1, 2]
            
            # Calculate bucket-based reward with online-adapted EMAs
            abs_err_xy = torch.abs(output_error_t).mean(dim=0)  # [2]
            bucket_xy = (abs_err_xy * 8).clamp(0, 15).to(torch.int64)
            bucket = torch.max(bucket_xy)
            lr_scale = self.reward_LUT[bucket]
            
            # Scale learning rate for single timestep (was calibrated for seq_len=10)
            effective_lr = self.fast_lr * lr_scale / 16 * 0.1  # 0.1 scale factor for single timestep
            
            # RMS normalization (conditional on rms_mode)
            if self.rms_mode == 'with_rms':
                # Online-adapted EMA constants (faster adaptation)
                k_out = 2  # Faster than batch mode (was 4)
                k = 3      # Faster than batch mode (was 5)
                
                # Normalize output error per-channel
                sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)
                self.out_rms -= self.out_rms >> k_out
                self.out_rms += sq_out >> k_out
                idx_out = (self.out_rms >> 8).clamp_(0, 255)
                norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]
                
                # Backpropagate error through layers with online-adapted normalization
                err2_raw = torch.matmul(norm_err, self.model.fc3.weight)
                sq_err2 = ((err2_raw**2) * 4096).sum(dim=0).to(torch.int32)
                self.err2_sq_ema -= self.err2_sq_ema >> k
                self.err2_sq_ema += sq_err2 >> k
                idx = (self.err2_sq_ema >> 8).clamp_(0, 255)
                hidden2_error_t = err2_raw * self.inv_sqrt_LUT[idx]

                err1_raw = torch.matmul(hidden2_error_t, self.model.fc2.weight)
                sq_err1 = ((err1_raw**2) * 4096).sum(dim=0).to(torch.int32)
                self.err1_sq_ema -= self.err1_sq_ema >> k
                self.err1_sq_ema += sq_err1 >> k
                idx = (self.err1_sq_ema >> 8).clamp_(0, 255)
                hidden1_error_t = err1_raw * self.inv_sqrt_LUT[idx]
            elif self.rms_mode == 'partial_rms':
                # Only normalize output error, not hidden errors
                k_out = 2
                sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)
                self.out_rms -= self.out_rms >> k_out
                self.out_rms += sq_out >> k_out
                idx_out = (self.out_rms >> 8).clamp_(0, 255)
                norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]
                
                # Backpropagate without normalization
                hidden2_error_t = torch.matmul(norm_err, self.model.fc3.weight)
                hidden1_error_t = torch.matmul(hidden2_error_t, self.model.fc2.weight)
            else:  # without_rms
                # No normalization at all
                norm_err = output_error_t
                hidden2_error_t = torch.matmul(norm_err, self.model.fc3.weight)
                hidden1_error_t = torch.matmul(hidden2_error_t, self.model.fc2.weight)
            
            # Local sensitivities
            d_lif3_t = self.model.lif3.spike_grad(mem3_next)
            d_lif2_t = self.model.lif2.spike_grad(mem2_next)
            d_lif1_t = self.model.lif1.spike_grad(mem1_next)

            # Normalize spike activities (conditional on RMS mode)
            if self.rms_mode in ['with_rms', 'partial_rms']:
                sq_spk1 = ((spk1_t**2) * 4096).sum(dim=0).to(torch.int32)
                self.spk1_rms += -(self.spk1_rms >> 2) + (sq_spk1 >> 2)  # Faster adaptation (k=2)
                idx = (self.spk1_rms >> 8).clamp_(0, 255)
                spk1_t = spk1_t * self.inv_sqrt_LUT[idx]
            # else: no spike normalization in without_rms mode

            # Compute weight updates using either Hebbian or Delta rule
            if self.use_hebbian:
                # Original Hebbian updates with surrogate gradients
                hebb_fc3 = torch.matmul((norm_err * d_lif3_t).t(), spk2_t)
                hebb_fc2 = torch.matmul((hidden2_error_t * d_lif2_t).t(), spk1_t)
                hebb_fc1 = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_ff_t)
                hebb_rec = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_rec_t)
            else:
                # Delta rule updates (no surrogate gradients)
                hebb_fc3 = torch.matmul(norm_err.t(), spk2_t)
                hebb_fc2 = torch.matmul(hidden2_error_t.t(), spk1_t)
                hebb_fc1 = torch.matmul(hidden1_error_t.t(), pre_ff_t)
                hebb_rec = torch.matmul(hidden1_error_t.t(), pre_rec_t)

            # Update eligibility traces with computed terms (dual-timescale)
            self.e_fast_fc3.mul_(self.decay_fast).add_(hebb_fc3)
            self.e_fast_fc2.mul_(self.decay_fast).add_(hebb_fc2)
            self.e_fast_fc1.mul_(self.decay_fast).add_(hebb_fc1)
            self.e_fast_rec.mul_(self.decay_fast).add_(hebb_rec)
            
            self.e_slow_fc3.mul_(self.decay_slow).add_(hebb_fc3)
            self.e_slow_fc2.mul_(self.decay_slow).add_(hebb_fc2)
            self.e_slow_fc1.mul_(self.decay_slow).add_(hebb_fc1)
            self.e_slow_rec.mul_(self.decay_slow).add_(hebb_rec)

            # Apply updates using combined dual-timescale traces
            e_comb_fc1 = self.trace_mix_a * self.e_fast_fc1 + (1 - self.trace_mix_a) * self.e_slow_fc1
            e_comb_fc2 = self.trace_mix_a * self.e_fast_fc2 + (1 - self.trace_mix_a) * self.e_slow_fc2
            e_comb_fc3 = self.trace_mix_a * self.e_fast_fc3 + (1 - self.trace_mix_a) * self.e_slow_fc3
            e_comb_rec = self.trace_mix_a * self.e_fast_rec + (1 - self.trace_mix_a) * self.e_slow_rec

            self.model.fc1.weight.data += effective_lr * e_comb_fc1
            self.model.fc2.weight.data += effective_lr * e_comb_fc2
            self.model.fc3.weight.data += effective_lr * e_comb_fc3
            self.model.fc_rec.weight.data += effective_lr * e_comb_rec
            
            # Accumulate for slow updates
            self.grad_fc1_avg = self.momentum * self.grad_fc1_avg + (1 - self.momentum) * e_comb_fc1
            self.grad_fc2_avg = self.momentum * self.grad_fc2_avg + (1 - self.momentum) * e_comb_fc2
            self.grad_fc3_avg = self.momentum * self.grad_fc3_avg + (1 - self.momentum) * e_comb_fc3
            self.grad_fc_rec_avg = self.momentum * self.grad_fc_rec_avg + (1 - self.momentum) * e_comb_rec
            self.grad_scale += 1

            # Renormalize weights
            for w in [self.model.fc1.weight, self.model.fc2.weight, self.model.fc3.weight, self.model.fc_rec.weight]:
                self.renorm_row_(w)
        
        return combined_loss.item(), final_states
    
    def update_single_timestep(self, x_t, y_t, states):
        """
        Main interface for online single-timestep updates.
        """
        x_t, y_t = x_t.to(self.device), y_t.to(self.device)
        
        loss, new_states = self.fast_update_single_timestep(x_t, y_t, states)
        
        bounded_loss = min(loss, self.max_loss_value)
        
        if bounded_loss > self.max_loss_value * 1.5:
            print(f"CATASTROPHIC ERROR DETECTED: {bounded_loss:.4f}. Performing Xavier weight reset.")
            self.reset_weights_with_xavier()
            return bounded_loss, new_states
            
        self.cumulative_error += bounded_loss
        self.step_counter += 1
        
        if self.step_counter >= self.window_size:
            self.slow_update()
        
        return bounded_loss, new_states
              
    def fast_update(self, input_sequence, desired_sequence, initial_states):
        """
        Novel biologically plausible learning rule for a SEQUENCE using an eligibility trace.
        It integrates Hebbian updates over the entire window before applying.
        """
        batch_size, seq_len, _ = input_sequence.shape
        device = self.device

        # Unpack initial states for the sequence
        spk1_rec, mem1, mem2, mem3 = initial_states

        # NOTE: Eligibility traces are NOT reset here. They persist and decay across batches
        # to maintain a continuous memory of activity, as per best practices.

        # SINGLE forward pass to get predictions AND traces for Hebbian update
        pred_sequence, final_states, traces = self.model(
            input_sequence, *initial_states, need_traces=True
        )
        (pre_ff_all, pre_rec_all, spk1_all, spk2_all, 
         mem1_next_all, mem2_next_all, mem3_next_all) = traces
        
        # Calculate loss over the whole sequence
        mse_loss = F.mse_loss(pred_sequence, desired_sequence)
        combined_loss = torch.clamp(mse_loss, 0.0, 10.0)
        
        # Hebbian update calculation
        with torch.no_grad():
            # Calculate a 4-bit error bucket reward signal to modulate plasticity.
            output_error_t_for_reward = desired_sequence - pred_sequence
            abs_err_xy = torch.mean(torch.abs(output_error_t_for_reward), dim=(0,1))  # [2]
            bucket_xy  = (abs_err_xy * 8).clamp(0,15).to(torch.int64)      # 4-bit each
            bucket     = torch.max(bucket_xy)                              # int64 scalar
            lr_scale   = self.reward_LUT[bucket]
            effective_lr = self.fast_lr * lr_scale / 16 # Use integer division

            # NO second pass needed. Use the returned traces.
            for t in range(seq_len):
                # Get activations from the pre-computed traces
                pre_ff_t = pre_ff_all[:, t, :]
                pre_rec_t = pre_rec_all[:, t, :]
                spk1_t = spk1_all[:, t, :]
                spk2_t = spk2_all[:, t, :]
                mem1_next = mem1_next_all[:, t, :]
                mem2_next = mem2_next_all[:, t, :]
                mem3_next = mem3_next_all[:, t, :]
                
                # Propagate error backward from this time step's prediction
                output_error_t = desired_sequence[:, t, :] - pred_sequence[:, t, :]
                
                # RMS normalization (conditional on rms_mode)
                if self.rms_mode == 'with_rms':
                    # Patch A: Normalize output error per-channel before backprop
                    k_out = 4           # EMA 1/16
                    sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)  # [2]
                    self.out_rms -= self.out_rms >> k_out
                    self.out_rms += sq_out       >> k_out
                    idx_out = (self.out_rms >> 8).clamp_(0,255)
                    norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]   # element-wise
                    
                    # Tweak A: RMS-Rescale the local errors (Hardware-Friendly Version)
                    k = 5 # EMA constant: 2^-5 = 0.03125
                    
                    err2_raw = torch.matmul(norm_err, self.model.fc3.weight)     # [B,H2]
                    sq_err2 = ((err2_raw**2) * 4096).sum(dim=0).to(torch.int32) # Sum over batch
                    self.err2_sq_ema -= self.err2_sq_ema >> k
                    self.err2_sq_ema += sq_err2 >> k
                    idx = (self.err2_sq_ema >> 8).clamp_(0,255)
                    hidden2_error_t  = err2_raw * self.inv_sqrt_LUT[idx]

                    err1_raw = torch.matmul(hidden2_error_t, self.model.fc2.weight)    # [B,H1]
                    sq_err1 = ((err1_raw**2) * 4096).sum(dim=0).to(torch.int32)
                    self.err1_sq_ema -= self.err1_sq_ema >> k
                    self.err1_sq_ema += sq_err1 >> k
                    idx = (self.err1_sq_ema >> 8).clamp_(0,255)
                    hidden1_error_t  = err1_raw * self.inv_sqrt_LUT[idx]
                elif self.rms_mode == 'partial_rms':
                    # Only normalize output error, not hidden errors
                    k_out = 4
                    sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)
                    self.out_rms -= self.out_rms >> k_out
                    self.out_rms += sq_out >> k_out
                    idx_out = (self.out_rms >> 8).clamp_(0, 255)
                    norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]
                    
                    # Backpropagate without normalization
                    hidden2_error_t = torch.matmul(norm_err, self.model.fc3.weight)
                    hidden1_error_t = torch.matmul(hidden2_error_t, self.model.fc2.weight)
                else:  # without_rms
                    # No normalization at all
                    norm_err = output_error_t
                    hidden2_error_t = torch.matmul(norm_err, self.model.fc3.weight)
                    hidden1_error_t = torch.matmul(hidden2_error_t, self.model.fc2.weight)
                
                # Get local sensitivities (surrogate gradients)
                d_lif3_t = self.model.lif3.spike_grad(mem3_next)
                d_lif2_t = self.model.lif2.spike_grad(mem2_next)
                d_lif1_t = self.model.lif1.spike_grad(mem1_next)

                # Tweak B: RMS-Rescale spike activities (conditional on RMS mode)
                if self.rms_mode in ['with_rms', 'partial_rms']:
                    # Note: spk1_t is already detached from previous ops
                    sq_spk1 = ((spk1_t**2) * 4096).sum(dim=0).to(torch.int32)
                    self.spk1_rms += - (self.spk1_rms >> 3) + (sq_spk1 >> 3) # k=3 for spikes
                    idx = (self.spk1_rms >> 8).clamp_(0,255)
                    spk1_t = spk1_t * self.inv_sqrt_LUT[idx]
                # else: no spike normalization in without_rms mode

                # Update both fast and slow eligibility traces with the new instantaneous Hebbian term
                if self.use_hebbian:
                    # Hebbian updates with surrogate gradients
                    hebb_fc3 = torch.matmul((norm_err * d_lif3_t).t(), spk2_t)
                    hebb_fc2 = torch.matmul((hidden2_error_t * d_lif2_t).t(), spk1_t)
                    hebb_fc1 = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_ff_t)
                    hebb_rec = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_rec_t)
                else:
                    # Delta rule updates (no surrogate gradients)
                    hebb_fc3 = torch.matmul(norm_err.t(), spk2_t)
                    hebb_fc2 = torch.matmul(hidden2_error_t.t(), spk1_t)
                    hebb_fc1 = torch.matmul(hidden1_error_t.t(), pre_ff_t)
                    hebb_rec = torch.matmul(hidden1_error_t.t(), pre_rec_t)
                
                # Update eligibility traces (dual-timescale)
                self.e_fast_fc3.mul_(self.decay_fast).add_(hebb_fc3)
                self.e_fast_fc2.mul_(self.decay_fast).add_(hebb_fc2)
                self.e_fast_fc1.mul_(self.decay_fast).add_(hebb_fc1)
                self.e_fast_rec.mul_(self.decay_fast).add_(hebb_rec)
                
                self.e_slow_fc3.mul_(self.decay_slow).add_(hebb_fc3)
                self.e_slow_fc2.mul_(self.decay_slow).add_(hebb_fc2)
                self.e_slow_fc1.mul_(self.decay_slow).add_(hebb_fc1)
                self.e_slow_rec.mul_(self.decay_slow).add_(hebb_rec)

                # NO state update at end of loop

            # Apply updates using the final state of the eligibility traces
            # The effective_lr is now calculated once for the whole sequence at the top
            
            # Combine traces for the update (dual-timescale)
            e_comb_fc1 = self.trace_mix_a * self.e_fast_fc1 + (1 - self.trace_mix_a) * self.e_slow_fc1
            e_comb_fc2 = self.trace_mix_a * self.e_fast_fc2 + (1 - self.trace_mix_a) * self.e_slow_fc2
            e_comb_fc3 = self.trace_mix_a * self.e_fast_fc3 + (1 - self.trace_mix_a) * self.e_slow_fc3
            e_comb_rec = self.trace_mix_a * self.e_fast_rec + (1 - self.trace_mix_a) * self.e_slow_rec

            self.model.fc1.weight.data += effective_lr * e_comb_fc1
            self.model.fc2.weight.data += effective_lr * e_comb_fc2
            self.model.fc3.weight.data += effective_lr * e_comb_fc3
            self.model.fc_rec.weight.data += effective_lr * e_comb_rec
            
            # Accumulate combined traces for the slow update path
            self.grad_fc1_avg = self.momentum * self.grad_fc1_avg + (1 - self.momentum) * e_comb_fc1
            self.grad_fc2_avg = self.momentum * self.grad_fc2_avg + (1 - self.momentum) * e_comb_fc2
            self.grad_fc3_avg = self.momentum * self.grad_fc3_avg + (1 - self.momentum) * e_comb_fc3
            self.grad_fc_rec_avg = self.momentum * self.grad_fc_rec_avg + (1 - self.momentum) * e_comb_rec
            self.grad_scale += 1

            # Renormalize all weights after every fast step to prevent explosion/drift
            for w in [self.model.fc1.weight, self.model.fc2.weight, self.model.fc3.weight, self.model.fc_rec.weight]:
                self.renorm_row_(w)
            
        x_corr=0
        y_corr=0
        return combined_loss.item(), x_corr, y_corr

    def _rms(self, x, eps=1e-5):
        return x / (torch.sqrt(torch.mean(x**2)) + eps)

    def renorm_row_(self, w):
       l2 = torch.sum((w.to(torch.int32) * w), dim=1, keepdim=True)          # Q12
       caps = (self.fc3_row_cap_q12 if w.shape[0] == 2
               else torch.full_like(l2, int(6.0 * 4096)))                    # scalar cap
       mask = l2 > caps
       if mask.any():
           shift = (torch.log2(l2[mask].float()) - torch.log2(caps[mask].float())
                   ).ceil().to(torch.int32)
           w_int = w.data.to(torch.int32)
           w_int[mask.squeeze(1)] >>= shift.squeeze(1)                       # power-of-two divide
           w.data.copy_(w_int.to(torch.float32))                                  # back to fp32 for PyTorch

    def normalize_weights(self, weights, max_norm=1.0):
        # This function is no longer used by the fast update path.
        # Kept for potential other uses, but can be removed.
        norm = torch.norm(weights, p=2, dim=1, keepdim=True)
        scale = torch.clamp(norm, max=max_norm)
        return weights * (scale / (norm + 1e-8))

    def slow_update(self):
        with torch.no_grad():
            if self.grad_scale > 0:
                self.grad_fc1_avg /= self.grad_scale
                self.grad_fc2_avg /= self.grad_scale
                self.grad_fc3_avg /= self.grad_scale
                self.grad_fc_rec_avg /= self.grad_scale

                self.model.fc1.weight.data += self.slow_lr * self._rms(self.grad_fc1_avg)
                self.model.fc2.weight.data += self.slow_lr * self._rms(self.grad_fc2_avg)
                self.model.fc3.weight.data += self.slow_lr * self._rms(self.grad_fc3_avg)
                self.model.fc_rec.weight.data += self.slow_lr * self._rms(self.grad_fc_rec_avg)

        # Apply a light weight decay to all layers during the slow update
        with torch.no_grad():
            for w in [self.model.fc1.weight, self.model.fc2.weight,
                      self.model.fc3.weight, self.model.fc_rec.weight]:
                w.mul_(0.999995) # ≈ 1e-5 weight-decay per slow step
        
        current_avg_loss = self.cumulative_error / self.step_counter if self.step_counter > 0 else float('inf')
        current_avg_loss = min(current_avg_loss, self.max_loss_value)
        
        # Meta-learning adaptation - only if enabled
        if self.use_meta_learning and self.prev_cumulative_error is not None:
            prev_loss = min(self.prev_cumulative_error, self.max_loss_value)
            if current_avg_loss < prev_loss:
                self.meta_params['plasticity'] *= (1 + self.meta_lr)
                self.meta_params['sensitivity'] *= (1 + self.meta_lr)
            else:
                self.meta_params['plasticity'] *= (1 - self.meta_lr)
                self.meta_params['sensitivity'] *= (1 - self.meta_lr)
            
            self.meta_params['plasticity'] = np.clip(self.meta_params['plasticity'], 0.2, 1.5)
            self.meta_params['sensitivity'] = np.clip(self.meta_params['sensitivity'], 0.2, 1.5)
            
            # Update learning rates based on meta-parameters
            self.fast_lr = self.base_fast_lr * self.meta_params['plasticity']
            self.slow_lr = self.base_slow_lr * self.meta_params['sensitivity']
        # If meta-learning is disabled, learning rates remain fixed at base values
        
        self.prev_cumulative_error = current_avg_loss
        
        self.grad_fc1_avg.zero_() 
        self.grad_fc2_avg.zero_()
        self.grad_fc3_avg.zero_()
        self.grad_fc_rec_avg.zero_() # RESTORED
        self.grad_scale = 0
        self.cumulative_error = 0
        self.step_counter = 0

        # Reset eligibility traces to prevent saturation over long runs (dual-timescale)
        traces_to_reset = [
            self.e_fast_fc1, self.e_fast_fc2, self.e_fast_fc3, self.e_fast_rec,
            self.e_slow_fc1, self.e_slow_fc2, self.e_slow_fc3, self.e_slow_rec
        ]
        
        for e in traces_to_reset:
            e.zero_()

    def update(self, input_sequence, desired_sequence, initial_states):
        input_sequence, desired_sequence = input_sequence.to(self.device), desired_sequence.to(self.device)
        
        combined_loss, _, _ = self.fast_update(input_sequence, desired_sequence, initial_states)
        
        bounded_loss = min(combined_loss, self.max_loss_value)
        
        if bounded_loss > self.max_loss_value * 1.5:
            print(f"CATASTROPHIC ERROR DETECTED: {bounded_loss:.4f}. Performing Xavier weight reset.")
            self.reset_weights_with_xavier()
            return bounded_loss
            
        self.cumulative_error += bounded_loss
        self.step_counter += 1
        
        if self.step_counter >= self.window_size:
            self.slow_update()
        
        return bounded_loss


def train_snn_windowed(model, rl_updater, train_loader, val_loader, num_epochs, patience=10):
    device = next(model.parameters()).device
    print(f"Starting SNN Causal Windowed Training for {num_epochs} epochs on {device}...")
    
    best_val_metric = -np.inf # Track best Pearson correlation (can be negative)
    best_model_state = None
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        
        for i, (input_seq, target_seq) in enumerate(train_loader):
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)

            batch_size = input_seq.size(0)
            spk1_rec = torch.zeros(batch_size, model.fc1.out_features, device=device) # RESTORED
            mem1 = torch.zeros(batch_size, model.fc1.out_features, device=device)
            mem2 = torch.zeros(batch_size, model.fc2.out_features, device=device)
            mem3 = torch.zeros(batch_size, model.fc3.out_features, device=device)
            initial_states = (spk1_rec, mem1, mem2, mem3)
            
            loss = rl_updater.update(input_seq, target_seq, initial_states)
            running_train_loss += loss

        avg_train_loss = running_train_loss / len(train_loader)
        
        model.eval()
        val_preds, val_targets = [], []
        with torch.no_grad():
            for input_seq, target_seq in val_loader:
                input_seq, target_seq = input_seq.to(device), target_seq.to(device)
                batch_size = input_seq.size(0)
                val_spk1_rec = torch.zeros(batch_size, model.fc1.out_features, device=device) # RESTORED
                val_mem1 = torch.zeros(batch_size, model.fc1.out_features, device=device)
                val_mem2 = torch.zeros(batch_size, model.fc2.out_features, device=device)
                val_mem3 = torch.zeros(batch_size, model.fc3.out_features, device=device)
                
                pred_seq, _ = model(input_seq, val_spk1_rec, val_mem1, val_mem2, val_mem3)
                # Use final time-step to align with read-out
                val_preds.append(pred_seq[:, -1, :].cpu())
                val_targets.append(target_seq[:, -1, :].cpu())
            
            if not val_preds:
                print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Avg Corr: N/A (empty val set)")
                continue

            val_preds = torch.cat(val_preds)
            val_targets = torch.cat(val_targets)
            val_corr_x = compute_correlation(val_targets[:, 0], val_preds[:, 0])
            val_corr_y = compute_correlation(val_targets[:, 1], val_preds[:, 1])
            avg_val_corr = (val_corr_x + val_corr_y) / 2.0
            
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Avg Corr: {avg_val_corr:.4f}")

        if avg_val_corr > best_val_metric:
            best_val_metric = avg_val_corr
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            print(f"  >>> New best validation model saved with Avg Corr: {best_val_metric:.4f} <<<")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break

    if best_model_state:
        model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
        print(f"Loaded best model with validation Corr: {best_val_metric:.4f}")

    return model

def train_snn_online(model, rl_updater, train_stream, val_stream, num_epochs, patience=10, 
                     eval_every_n=500, reset_states_every_n=None):
    """
    Train SNN with true online learning - one timestep at a time.
    Maintains persistent hidden states and applies proper online adaptations.
    
    Args:
        model: SNN model to train
        rl_updater: Online-enabled updater (must have online_mode=True)
        train_stream: OnlineDataStream for training data
        val_stream: OnlineDataStream for validation data  
        num_epochs: Number of epochs to train
        patience: Early stopping patience (number of evaluation intervals without improvement)
        eval_every_n: Evaluate on validation every N steps
        reset_states_every_n: Reset hidden states every N steps (None = never reset)
    """
    if not rl_updater.online_mode:
        raise ValueError("rl_updater must be in online_mode for online training")
        
    device = next(model.parameters()).device
    print(f"Starting SNN Online Training for {num_epochs} epochs on {device}...")
    print(f"Evaluation every {eval_every_n} steps, reset states every {reset_states_every_n or 'never'} steps")
    
    best_val_metric = -np.inf
    best_model_state = None
    epochs_no_improve = 0
    
    # Initialize persistent hidden states - these persist across timesteps!
    spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
    mem1 = torch.zeros(1, model.fc1.out_features, device=device)
    mem2 = torch.zeros(1, model.fc2.out_features, device=device)
    mem3 = torch.zeros(1, model.fc3.out_features, device=device)
    states = (spk1_rec, mem1, mem2, mem3)
    
    # Warm-start RMS EMAs if we have training data
    if hasattr(train_stream, 'spike_data') and len(train_stream.spike_data) > 100:
        if rl_updater.rms_mode in ['with_rms', 'partial_rms']: # Only warm-start if RMS is enabled
            rl_updater.warm_start_rms_emas(
                train_stream.spike_data[:100], 
                train_stream.targets[:100], 
                num_warmup_steps=50
            )
    
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        step_count = 0
        
        for x_t, y_t, should_reset in train_stream:
            # Reset states if requested (e.g., trial boundaries)
            if should_reset or (reset_states_every_n and step_count % reset_states_every_n == 0):
                spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                states = (spk1_rec, mem1, mem2, mem3)
                
            # Single timestep update with persistent states
            loss, states = rl_updater.update_single_timestep(x_t, y_t, states)
            running_train_loss += loss
            step_count += 1
            
            # Periodic validation
            if step_count % eval_every_n == 0:
                model.eval()
                val_preds, val_targets = [], []
                
                # Initialize fresh states for validation
                val_spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                val_mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                val_mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                val_mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                val_states = (val_spk1_rec, val_mem1, val_mem2, val_mem3)
                
                with torch.no_grad():
                    val_step_count = 0
                    for x_val, y_val, should_reset_val in val_stream:
                        if should_reset_val or (reset_states_every_n and val_step_count % reset_states_every_n == 0):
                            val_spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                            val_mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                            val_mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                            val_mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                            val_states = (val_spk1_rec, val_mem1, val_mem2, val_mem3)
                            
                        x_val_seq = x_val.unsqueeze(0).unsqueeze(0).to(device)  # [1, 1, features]
                        pred_seq, val_states = model(x_val_seq, *val_states)
                        val_preds.append(pred_seq[0, 0, :].cpu())  # [2]
                        val_targets.append(y_val.cpu())
                        val_step_count += 1
                        
                        # Limit validation length to avoid overly long evaluation
                        if val_step_count >= 1000:
                            break
                
                if val_preds:
                    val_preds = torch.stack(val_preds)      # [N, 2]
                    val_targets = torch.stack(val_targets)  # [N, 2]
                    
                    # DEBUG: Check for problematic data patterns
                    pred_std_x, pred_std_y = val_preds[:, 0].std().item(), val_preds[:, 1].std().item()
                    target_std_x, target_std_y = val_targets[:, 0].std().item(), val_targets[:, 1].std().item()
                    pred_mean_x, pred_mean_y = val_preds[:, 0].mean().item(), val_preds[:, 1].mean().item()
                    target_mean_x, target_mean_y = val_targets[:, 0].mean().item(), val_targets[:, 1].mean().item()
                    
                    if pred_std_x < 1e-6 or pred_std_y < 1e-6 or target_std_x < 1e-6 or target_std_y < 1e-6:
                        print(f"  ⚠️  ZERO VARIANCE DETECTED:")
                        print(f"     Pred std: X={pred_std_x:.6f}, Y={pred_std_y:.6f}")
                        print(f"     Target std: X={target_std_x:.6f}, Y={target_std_y:.6f}")
                        print(f"     Pred mean: X={pred_mean_x:.6f}, Y={pred_mean_y:.6f}")
                        print(f"     Target mean: X={target_mean_x:.6f}, Y={target_mean_y:.6f}")
                        
                        # Auto-recovery from dead neurons
                        if pred_std_x < 1e-6 and pred_std_y < 1e-6 and step_count > 1000:
                            print(f"  🔧 ATTEMPTING DEAD NEURON RECOVERY:")
                            print(f"     Increasing learning rate by 5x and reinitializing model...")
                            
                            # Reinitialize model weights
                            for module in [model.fc1, model.fc2, model.fc3, model.fc_rec]:
                                nn.init.xavier_uniform_(module.weight)
                                if hasattr(module, 'bias') and module.bias is not None:
                                    nn.init.zeros_(module.bias)
                            
                            # Boost learning rates dramatically
                            rl_updater.fast_lr *= 5.0
                            rl_updater.base_fast_lr *= 5.0
                            print(f"     New fast LR: {rl_updater.fast_lr:.1e}")
                            
                            # Reset states to break any bad patterns
                            spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                            mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                            mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                            mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                            states = (spk1_rec, mem1, mem2, mem3)
                    
                    val_corr_x = compute_correlation(val_targets[:, 0], val_preds[:, 0])
                    val_corr_y = compute_correlation(val_targets[:, 1], val_preds[:, 1])
                    avg_val_corr = (val_corr_x + val_corr_y) / 2.0
                    
                    avg_train_loss = running_train_loss / eval_every_n
                    print(f"Epoch [{epoch+1}/{num_epochs}], Step {step_count}, "
                          f"Train Loss: {avg_train_loss:.4f}, Val Avg Corr: {avg_val_corr:.4f}")
                    
                    # Early stopping check
                    if avg_val_corr > best_val_metric:
                        best_val_metric = avg_val_corr
                        best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                        print(f"  >>> New best validation model saved with Avg Corr: {best_val_metric:.4f} <<<")
                        epochs_no_improve = 0
                    else:
                        epochs_no_improve += 1
                        if epochs_no_improve >= patience:
                            print(f"Early stopping triggered after {step_count} steps in epoch {epoch+1}.")
                            print(f"  (Patience: {patience} eval intervals = {patience * eval_every_n} steps)")
                            break
                    
                    running_train_loss = 0.0  # Reset for next interval
                    
                model.train()  # Return to training mode
                
        # Check if early stopping was triggered
        if epochs_no_improve >= patience:
            break
    
    # Load best model
    if best_model_state:
        model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
        print(f"Loaded best model with validation Corr: {best_val_metric:.4f}")
    
    return model

# %%
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from filterpy.kalman import KalmanFilter
def train_test_kalman_filter(train_rates, train_vel, test_rates, test_vel):
    """
    Trains and applies a Kalman Filter for neural decoding.
    
    Returns:
        predicted_velocities (np.array): KF predicted velocities for test data.
        kf (KalmanFilter): The trained Kalman Filter object.
    """
    print("\n--- Setting up Kalman Filter ---")
    # Ensure inputs are numpy arrays
    if torch.is_tensor(train_rates):
        train_rates = train_rates.detach().cpu().numpy()
    if torch.is_tensor(train_vel):
        train_vel   = train_vel.detach().cpu().numpy()
    if torch.is_tensor(test_rates):
        test_rates  = test_rates.detach().cpu().numpy()
    # test_vel is not used for KF training/prediction, only evaluation later

    n_timesteps_train, n_neurons = train_rates.shape
    n_timesteps_test = test_rates.shape[0]
    n_outputs = train_vel.shape[1]  # Should be 2 for [vx, vy]

    if n_outputs != 2:
        raise ValueError("This KF implementation assumes 2D velocity (vx, vy)")
    
    # --- 1. Train linear decoder velocity = W * spikes + b ---
    print("Fitting linear velocity decoder (Ridge)")
    lin_decoder = Ridge(alpha=1e-3, fit_intercept=True)
    lin_decoder.fit(train_rates, train_vel)
    W = lin_decoder.coef_.T            # shape (2, n_neurons)
    b = lin_decoder.intercept_         # shape (2,)

    # Use decoder to obtain noisy velocity measurements
    vel_meas_train = lin_decoder.predict(train_rates)
    vel_meas_test  = lin_decoder.predict(test_rates)

    # --- 1b. Remove systematic bias between decoder output and true velocity ---
    bias = vel_meas_train.mean(axis=0) - train_vel.mean(axis=0)
    vel_meas_train -= bias
    vel_meas_test  -= bias

    # --- 2. Dimensions and Dynamics ---
    dt =  0.050  # fall back to 50-ms if not defined
    dim_x = 4           # [px, py, vx, vy]
    dim_z = 2           # measurement provides velocity only

    # Constant-velocity model
    F = np.array([[1,0,dt,0],
                  [0,1,0 ,dt],
                  [0,0,1 ,0 ],
                  [0,0,0 ,1 ]], dtype=np.float32)

    H = np.array([[0,0,1,0],
                  [0,0,0,1]], dtype=np.float32)  # we observe velocity

    # --- 3. Measurement noise R from training residuals ---
    print("Estimating R from decoder residuals")
    residuals_meas = train_vel - vel_meas_train
    R = np.cov(residuals_meas.T) + np.eye(dim_z)*1e-6
    
    # --- 4. Process noise already set via decoder residuals; skip old 2×2 estimation ---

    # --- 6. Initialize Kalman Filter ---
    print("Initializing Kalman Filter...")
    kf = KalmanFilter(dim_x=dim_x, dim_z=dim_z)
    # Initialise state vector [px, py, vx, vy].
    # We set the starting position to 0 and initialise the velocity with the
    # first sample from the training data.  This yields a 4-element state
    # vector, as required by the 4×4 transition matrix F.
    initial_state = np.zeros(dim_x, dtype=np.float32)
    initial_state[2:] = train_vel[0]
    kf.x = initial_state
    kf.P = np.eye(dim_x) * 500.     # High initial uncertainty
    kf.F = F
    kf.H = H
    kf.R = R
    # Define process noise Q based on residual velocity variance
    q_vel = np.var(residuals_meas, axis=0).mean()
    q_pos = (dt**2) * q_vel
    Q = np.diag([q_pos, q_pos, q_vel, q_vel]) + 1e-6*np.eye(dim_x)
    kf.Q = Q
    
    # --- 7. Run the filter on test data ---
    print(f"Running Kalman Filter for {n_timesteps_test} test steps...")
    predicted_velocities = np.zeros((n_timesteps_test, 2))
    b = b.reshape(-1, 1) # Ensure intercept is a column vector
    start_time_kf = time.time()
    for t in range(n_timesteps_test):
        z = vel_meas_test[t]  # 1-D vector length 2
        kf.predict()
        kf.update(z)
        predicted_velocities[t] = kf.x[2:].copy()  # store vx, vy only
    end_time_kf = time.time()
    print(f"Kalman Filter processing time: {end_time_kf - start_time_kf:.2f} seconds")
    print("--- Kalman Filter Setup and Run Complete ---")
    return predicted_velocities, kf


# %%
# +++ Re-inserting LSTM Definitions +++
class LSTMRegression(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.2):
        super(LSTMRegression, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # Ensure dropout is only applied if num_layers > 1
        lstm_dropout = dropout if num_layers > 1 else 0
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=lstm_dropout)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x shape: (batch, seq_len, input_size)
        # Initialize hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        # LSTM forward pass
        out, _ = self.lstm(x, (h0, c0))
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

def train_lstm_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=50, patience=5):
    """Basic training loop for the LSTM model with early stopping (handles val_loader=None)."""
    print(f"Starting LSTM training for {num_epochs} epochs on {device}...")
    best_val_loss = float('inf')
    epochs_no_improve = 0
    train_losses = []
    val_losses = []
    best_model_state = model.state_dict() # Start with initial state as best if no validation
    best_epoch = 0

    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            if inputs.dim() == 2:
                inputs = inputs.unsqueeze(1) # Add sequence dimension if needed
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item()
        avg_train_loss = running_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        avg_val_loss = None
        # Validation phase only if val_loader is provided and not empty
        if val_loader and len(val_loader.dataset) > 0:
            model.eval()
            running_val_loss = 0.0
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    if inputs.dim() == 2:
                        inputs = inputs.unsqueeze(1)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    running_val_loss += loss.item()
                avg_val_loss = running_val_loss / len(val_loader)
            val_losses.append(avg_val_loss)
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")

            # Check for improvement and early stopping based on validation loss
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                epochs_no_improve = 0
                best_model_state = model.state_dict()
                best_epoch = epoch + 1
                print(f"  New best validation loss: {best_val_loss:.6f}")
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"Early stopping triggered after {epoch+1} epochs.")
                    break
        else:
            # If no validation, just print training loss
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}")
            # Optionally save the model from the last epoch if no validation is done
            # best_model_state = model.state_dict()
            # best_epoch = epoch + 1

    print("Finished LSTM Training.")
    # Load the best model state before returning
    if best_model_state:
        model.load_state_dict(best_model_state)
        if val_loader and len(val_loader.dataset) > 0:
            print(f"Loaded best model state from epoch {best_epoch}.")
        # else:
            # print("Loaded model state from last epoch (no validation performed).")

    return model, train_losses, val_losses
# +++ End Re-inserted LSTM Definitions +++


def create_sequences(data_x, data_y, seq_len=10):
    """Creates overlapping sequences for LSTM."""
    xs, ys = [], []
    for i in range(len(data_x) - seq_len + 1):
        x = data_x[i:(i + seq_len)]
        y = data_y[i + seq_len - 1] # Target is the last velocity in the window
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)

def run_rms_ablation(num_sessions=10):
    """
    Clean ablation study: Different RMS normalization configurations with IDENTICAL training setup.
    Only the rms_mode parameter differs between runs.
    """
    print("\n" + "="*80)
    print("RMS NORMALIZATION ABLATION STUDY")
    print("="*80)
    
    results = {}
    
    # Test all RMS modes
    rms_modes = [
        ("WITH_RMS", 'with_rms'),        # Full RMS normalization (current default)
        ("PARTIAL_RMS", 'partial_rms'),  # Only output error normalization  
        ("WITHOUT_RMS", 'without_rms')   # No RMS normalization at all
    ]
    
    for mode_name, rms_mode in rms_modes:
        print(f"\n{'='*50}")
        print(f"Testing {mode_name} RMS Mode")
        if rms_mode == 'with_rms':
            print("  Using full RMS normalization (errors + spikes)")
        elif rms_mode == 'partial_rms':
            print("  Using partial RMS normalization (output error + spikes only)")
        elif rms_mode == 'without_rms':
            print("  Using no RMS normalization")
        print(f"{'='*50}")
        
        # Set seeds for reproducibility
        np.random.seed(42)
        random.seed(42)
        torch.manual_seed(42)
        basepath = os.path.expanduser('~/scratch/zenodo_dataset')
        all_mat_files = [f for f in os.listdir(basepath) if f.endswith('.mat')]
        
        selected_files = random.sample(all_mat_files, num_sessions)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        session_results = []
        
        for session_idx, filename in enumerate(selected_files):
            print(f"\n--- Session {session_idx + 1} ({filename}) with {mode_name} ---")
            
            try:
                data = load_data(
                    mat_file_path=os.path.join(basepath, filename),
                    bin_width_s=0.050,
                    stride_s=0.050,
                    spike_processing='count',
                    test_split_ratio=0,
                    verbose=False
                )
                
                if 'error' in data:
                    print(f"  Skipping {filename} due to load error: {data['error']}")
                    continue
                
                X_full = data.get('X_processed_full').numpy()
                y_full = data.get('y_processed_full').numpy()
                
                if X_full.shape[1] > 96:
                    X_full = X_full[:, :96]
                
                input_size = X_full.shape[1]
                
                # IDENTICAL preprocessing
                n_total = X_full.shape[0]
                n_train = int(n_total * 0.70)
                n_val = int(n_total * 0.15)
                
                X_train_norm = X_full[:n_train]
                y_train_raw = y_full[:n_train]
                X_val_norm = X_full[n_train:n_train + n_val]
                y_val_raw = y_full[n_train:n_train + n_val]
                X_test_norm = X_full[n_train + n_val:]
                y_test_raw = y_full[n_train + n_val:]
                
                # IDENTICAL normalization
                velocity_scaler = StandardScaler().fit(y_train_raw)
                y_train_norm = velocity_scaler.transform(y_train_raw)
                y_val_norm = velocity_scaler.transform(y_val_raw)
                y_test_norm = velocity_scaler.transform(y_test_raw)
                
                # IDENTICAL model setup
                snn_model = SNNRegression(input_size=input_size, hidden_size=256, output_size=2).to(device)
                
                # LUT setup (conditional on RMS mode)
                if rms_mode in ['with_rms', 'partial_rms']:
                    vals = torch.linspace(0.25, 8.0, 256, device=device)
                    inv_sqrt_LUT = 1.0 / torch.sqrt(vals)
                else:
                    inv_sqrt_LUT = None
                
                # IDENTICAL training setup
                snn_val_split_ratio = 0.15
                n_snn_train_full = X_train_norm.shape[0]
                n_snn_val = int(n_snn_train_full * snn_val_split_ratio)
                n_snn_train_split = n_snn_train_full - n_snn_val
                
                X_train_snn_final = X_train_norm[:n_snn_train_split]
                y_train_snn_norm_final = y_train_norm[:n_snn_train_split]
                X_val_snn = X_train_norm[n_snn_train_split:]
                y_val_snn_norm = y_train_norm[n_snn_train_split:]
                
                # ONLY DIFFERENCE: rms_mode parameter
                # Adaptive learning rate based on dataset characteristics
                spike_activity = np.mean(X_train_snn_final)
                velocity_variance = np.var(y_train_snn_norm_final)
                adaptive_lr = min(max(3e-3 * (1 + spike_activity * 0.1) * (1 + velocity_variance * 0.5), 1e-4), 1e-2)
                
                snn_updater = TwoScaleMetaRLWeightUpdaterFull(
                    snn_model, 
                    base_fast_lr=adaptive_lr, 
                    base_slow_lr=1e-6, 
                    window_size=50, 
                    meta_lr=0.01,
                    online_mode=True,               # CHANGED: online mode for all
                    use_hebbian=True,               # IDENTICAL: use Hebbian learning for all
                    use_meta_learning=True,         # IDENTICAL: use meta-learning for all
                    rms_mode=rms_mode,              # ONLY DIFFERENCE
                    inv_sqrt_LUT=inv_sqrt_LUT
                )
                
                # IDENTICAL data loading (for evaluation only)
                seq_len = 10
                batch_size = 32
                overlap_stride = max(1, seq_len // 2)
                
                # Online data streams
                train_stream = OnlineDataStream(X_train_snn_final, y_train_snn_norm_final)
                val_stream = OnlineDataStream(X_val_snn, y_val_snn_norm)
                
                # IDENTICAL training - switched to online mode
                trained_snn_model = train_snn_online(
                    snn_model, snn_updater, train_stream, val_stream, 
                    num_epochs=15, patience=20, eval_every_n=500
                )
                
                # IDENTICAL evaluation
                test_loader = CausalBatcher(
                    X_test_norm, y_test_norm,
                    batch_size=batch_size, sequence_length=seq_len,
                    shuffle=False, stride=overlap_stride)
                
                trained_snn_model.eval()
                snn_preds_norm_list, snn_targets_norm_list = [], []
                
                with torch.no_grad():
                    for input_seq, target_seq in test_loader:
                        input_seq, target_seq = input_seq.to(device), target_seq.to(device)
                        batch_size_test = input_seq.size(0)
                        test_spk1_rec = torch.zeros(batch_size_test, trained_snn_model.fc1.out_features, device=device)
                        test_mem1 = torch.zeros(batch_size_test, trained_snn_model.fc1.out_features, device=device)
                        test_mem2 = torch.zeros(batch_size_test, trained_snn_model.fc2.out_features, device=device)
                        test_mem3 = torch.zeros(batch_size_test, trained_snn_model.fc3.out_features, device=device)
                        
                        pred_seq, _ = trained_snn_model(input_seq, test_spk1_rec, test_mem1, test_mem2, test_mem3)
                        snn_preds_norm_list.append(pred_seq[:, -1, :].cpu())
                        snn_targets_norm_list.append(target_seq[:, -1, :].cpu())

                snn_preds_normalized = torch.cat(snn_preds_norm_list).numpy()
                snn_targets_normalized = torch.cat(snn_targets_norm_list).numpy()
                
                # IDENTICAL evaluation
                snn_corr_x = compute_correlation(snn_targets_normalized[:, 0], snn_preds_normalized[:, 0])
                snn_corr_y = compute_correlation(snn_targets_normalized[:, 1], snn_preds_normalized[:, 1])
                
                session_results.append({
                    'session': filename,
                    'rms_mode': mode_name,
                    'corr_x': snn_corr_x,
                    'corr_y': snn_corr_y,
                    'avg_corr': (snn_corr_x + snn_corr_y) / 2
                })
                
                print(f"  {mode_name} Results: X={snn_corr_x:.4f}, Y={snn_corr_y:.4f}, Avg={((snn_corr_x + snn_corr_y)/2):.4f}")
                
            except Exception as e:
                print(f"  ERROR in {mode_name} run: {e}")
        
        results[mode_name] = session_results
    
    # Compare results
    print("\n" + "="*80)
    print("RMS NORMALIZATION ABLATION RESULTS SUMMARY")
    print("="*80)
    
    for mode, mode_results in results.items():
        if mode_results:
            avg_corr_x = np.mean([r['corr_x'] for r in mode_results])
            avg_corr_y = np.mean([r['corr_y'] for r in mode_results])  
            avg_total = np.mean([r['avg_corr'] for r in mode_results])
            print(f"{mode:12s}: X={avg_corr_x:.4f}, Y={avg_corr_y:.4f}, Avg={avg_total:.4f}")
    
    # Detailed analysis
    print("\n" + "="*50)
    print("ANALYSIS")
    print("="*50)
    
    if 'WITH_RMS' in results and all(mode in results for mode in ['PARTIAL_RMS', 'WITHOUT_RMS']):
        with_rms_avg = np.mean([r['avg_corr'] for r in results['WITH_RMS']])
        
        # Compare each reduced RMS mode to full RMS
        for mode in ['PARTIAL_RMS', 'WITHOUT_RMS']:
            if results[mode]:
                reduced_avg = np.mean([r['avg_corr'] for r in results[mode]])
                diff = with_rms_avg - reduced_avg
                print(f"WITH_RMS vs {mode}: {diff:+.4f} ({diff/reduced_avg*100:+.1f}% relative)")
        
        # Find best reduced RMS mode
        reduced_modes = ['PARTIAL_RMS', 'WITHOUT_RMS']
        best_reduced_mode = max(reduced_modes, 
                              key=lambda m: np.mean([r['avg_corr'] for r in results[m]]) if results[m] else -1)
        best_reduced_avg = np.mean([r['avg_corr'] for r in results[best_reduced_mode]])
        
        print(f"\nBest reduced RMS mode: {best_reduced_mode} ({best_reduced_avg:.4f})")
        rms_improvement = with_rms_avg - best_reduced_avg
        print(f"Full RMS improvement: {rms_improvement:+.4f} ({rms_improvement/best_reduced_avg*100:+.1f}% relative)")
        
        if abs(rms_improvement) < 0.01:
            print("→ Full RMS normalization provides minimal benefit (hardware simplification recommended)")
        elif rms_improvement > 0.02:
            print("→ Full RMS normalization provides significant benefit (justifies complexity)")
        else:
            print("→ Full RMS normalization provides modest benefit (design trade-off)")
    
    return results

# Clean main execution
if __name__ == "__main__":
    print("Running RMS Normalization Ablation...")
    ablation_results = run_rms_ablation(num_sessions=10)
