# %%
import h5py
import numpy as np
from sklearn.preprocessing import StandardScaler
import warnings
import torch
import snntorch as snn
from snntorch import surrogate
import numpy as np
import torch
import h5py
from sklearn.preprocessing import StandardScaler
import warnings
import os
import random
import pandas as pd
import matplotlib.pyplot as plt
import json
import time
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from filterpy.kalman import KalmanFilter
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from scipy.signal import butter, sosfiltfilt
from sklearn.linear_model import Ridge
import math
from scipy.signal import correlate, correlation_lags
import torch.nn.functional as F

# =============================================================
# Evaluation configuration
# -------------------------------------------------------------
# If EVAL_MOVEMENT_ONLY is True, metrics are computed only on bins
# where cursor speed exceeds MOVEMENT_THRESHOLD_NORM (value in the
# *normalised* velocity space produced by StandardScaler).  If False,
# all time bins contribute to the reported correlation and R².
# -------------------------------------------------------------
EVAL_MOVEMENT_ONLY = False          # Set to False for full-trajectory evaluation
MOVEMENT_THRESHOLD_NORM = 0.2      # ≈ 2 cm s⁻¹ after inverse transform
MOVEMENT_THRESHOLD_CM_S = 2.0      # threshold used directly on de-normalised velocities for plots
# =============================================================

def load_data(
    mat_file_path='zenodo_dataset/indy_20160419_01.mat',
    bin_width_s=0.064,   # 64 ms window to match Makin et al.
    stride_s=0.064,      # 64 ms stride for non-overlapping bins
    test_split_ratio=0.2,
    spike_processing='rate',
    normalize_spikes=False,
    verbose=True
):
    if verbose:
        print("=== load_data parameters ===")
        print(f" mat_file_path    = {mat_file_path}")
        print(f" bin_width_s      = {bin_width_s}")
        print(f" stride_s         = {stride_s}")
        print("  (Processing is now zero-lag with non-overlapping bins)")
        print(f" test_split_ratio = {test_split_ratio}")
        print(f" spike_processing = {spike_processing}")
        print(f" normalize_spikes = {normalize_spikes}")
        print("============================\n")

    with h5py.File(mat_file_path, 'r') as f:
        t_vec = f['t'][()]
        if t_vec.ndim == 1:
            t_vec = t_vec[:, None]
        elif t_vec.shape[0] < t_vec.shape[1]:
            t_vec = t_vec.T
        fs = 1.0 / np.mean(np.diff(t_vec.squeeze()))

        ds = f['spikes']
        if h5py.check_dtype(ref=ds.dtype) is not None:
            refs = ds[()]
            if refs.ndim == 2 and refs.shape[1] != 1:
                n_units, get_ref = refs.shape[1], lambda i: refs[0,i]
            elif refs.ndim == 2 and refs.shape[0] != 1:
                n_units, get_ref = refs.shape[0], lambda i: refs[i,0]
            else:
                refs = refs.ravel()
                n_units, get_ref = refs.shape[0], lambda i: refs[i]
            T_orig = t_vec.shape[0]
            spikes = np.zeros((T_orig, n_units), dtype=np.float32)
            for i in range(n_units):
                r = get_ref(i)
                if isinstance(r, h5py.Reference):
                    times = f[r][()].squeeze()
                    idx = np.searchsorted(t_vec.squeeze(), times, 'left')
                    idx = idx[idx < T_orig]
                    spikes[np.unique(idx), i] = 1.0
        else:
            spikes = ds[()]
            if spikes.shape[0] < spikes.shape[1]:
                spikes = spikes.T
            spikes = spikes.astype(np.float32)

        cp = f['cursor_pos'][()]
        if cp.shape[0] < cp.shape[1]:
            cp = cp.T
        
        # --- Filter cursor position before calculating velocity ---
        # Per Makin et al. 2019, apply a 4th-order, 10Hz, non-causal (zero-phase)
        # Butterworth filter to the cursor position data before differentiation.
        sos = butter(4, 10, 'low', fs=fs, output='sos')
        cp_filtered = sosfiltfilt(sos, cp, axis=0)
        vel = np.gradient(cp_filtered, t_vec.squeeze(), axis=0)

    # --- Binning --- 
    bin_width = int(round(bin_width_s * fs))
    stride = int(round(stride_s * fs))
    T = spikes.shape[0]
    num_bins = (T - bin_width) // stride + 1

    # +++ Handle Case of Zero Bins +++
    if num_bins <= 0:
        warnings.warn(f"Not enough data points ({T}) for the given bin ({bin_width}) and stride ({stride}) in file {mat_file_path}. Returning empty data structure.", RuntimeWarning)
        # Return a dictionary indicating failure and providing empty structures where possible
        n_neurons = spikes.shape[1] # Get neuron count even if no bins
        return {
            'error': 'Insufficient data for binning',
            'X_binned_raw': np.empty((0, n_neurons), dtype=np.float32),
            'y_binned_raw': np.empty((0, vel.shape[1]), dtype=np.float32), # Use original vel shape for columns
            't_binned_raw': np.empty((0,), dtype=np.float64),
            # Include other keys expected by caller, filled with None or empty equivalents
            'X_train': torch.empty((0, n_neurons), dtype=torch.float32),
            'y_train': torch.empty((0, vel.shape[1]), dtype=torch.float32),
            'X_test':  torch.empty((0, n_neurons), dtype=torch.float32),
            'y_test':  torch.empty((0, vel.shape[1]), dtype=torch.float32),
            't_train': np.empty((0,), dtype=np.float64),
            't_test':  np.empty((0,), dtype=np.float64),
            'X_train_norm': torch.empty((0, n_neurons), dtype=torch.float32),
            'y_train_norm': torch.empty((0, vel.shape[1]), dtype=torch.float32),
            'X_test_norm':  torch.empty((0, n_neurons), dtype=torch.float32),
            'y_test_norm':  torch.empty((0, vel.shape[1]), dtype=torch.float32),
            'spike_scaler': None,
            'velocity_scaler': None,
            'bin_width_s': bin_width_s,
            'stride_s': stride_s
        }
    # +++ End Handle Case of Zero Bins +++

    spike_windows = np.stack([spikes[i*stride : i*stride+bin_width] for i in range(num_bins)])
    spikes_b_raw = spike_windows.sum(axis=1).astype(np.float32) # Raw spike counts

    vel_windows = np.stack([vel[i*stride : i*stride+bin_width] for i in range(num_bins)])
    vel_b_raw = vel_windows.mean(axis=1).astype(np.float32)    # Raw mean velocity in bins

    t_binned = np.array([t_vec[i*stride][0] for i in range(num_bins)])

    # --- Spike Processing (Applied After Binning) ---
    if spike_processing == 'binary':
        X_processed = (spikes_b_raw > 0).astype(np.float32)
    elif spike_processing == 'count':
        X_processed = spikes_b_raw # Already counts
    elif spike_processing == 'rate':
        X_processed = (spikes_b_raw / bin_width_s).astype(np.float32)
    elif spike_processing == 'poisson':
        lam = spikes_b_raw / bin_width_s
        X_processed = np.random.poisson(lam).astype(np.float32)
    else:
        raise ValueError(f"Unknown spike_processing: {spike_processing}")

    # --- Spike Normalization (Optional, applied AFTER processing and lag) ---
    if normalize_spikes:
        # Note: Normalizing here before splitting might leak test info
        # It's generally better to normalize after splitting, fitting scaler only on train
        mn = X_processed.min(axis=0, keepdims=True)
        mx = X_processed.max(axis=0, keepdims=True)
        rng = np.where(mx - mn > 0, mx - mn, 1.0)
        X_processed = (X_processed - mn) / rng
        if verbose:
            print("Applied spike normalization (min-max scaling). Consider normalizing after splitting.")

    # --- Internal Train/Test Split (if ratio > 0) --- 
    y_processed = vel_b_raw # Use the zero-lag velocity
    B_total = X_processed.shape[0]
    n_test = 0
    if test_split_ratio > 0:
        n_test = max(int(B_total * test_split_ratio), 1)
    n_train = B_total - n_test

    X_tr, X_te = X_processed[:n_train], X_processed[n_train:]
    y_tr, y_te = y_processed[:n_train], y_processed[n_train:]
    t_tr, t_te = t_binned[:n_train], t_binned[n_train:]

    # --- Fit Scalers (Only if train data exists) --- 
    sx = None
    sy = None
    X_tr_norm = X_tr # Default to unnormalized if no scaling
    y_tr_norm = y_tr
    X_te_norm = X_te
    y_te_norm = y_te

    if n_train > 0:
        try:
            sx = StandardScaler().fit(X_tr)
            X_tr_norm = sx.transform(X_tr)
            if n_test > 0:
                X_te_norm = sx.transform(X_te) # Only transform test if it exists
        except ValueError as e:
            warnings.warn(f"Could not fit/transform spike scaler: {e}", RuntimeWarning)
            sx = None # Ensure scaler is None if fitting failed

        try:
            sy = StandardScaler().fit(y_tr)
            y_tr_norm = sy.transform(y_tr)
            if n_test > 0:
                y_te_norm = sy.transform(y_te) # Only transform test if it exists
        except ValueError as e:
            warnings.warn(f"Could not fit/transform velocity scaler: {e}", RuntimeWarning)
            sy = None # Ensure scaler is None if fitting failed
    # --- End Fit Scalers --- 

    # --- Return Dictionary --- 
    return_dict = {
        # Data after lag, processing, and internal split
        'X_train': torch.tensor(X_tr, dtype=torch.float32),
        'y_train': torch.tensor(y_tr, dtype=torch.float32),
        'X_test':  torch.tensor(X_te, dtype=torch.float32),
        'y_test':  torch.tensor(y_te, dtype=torch.float32),
        't_train': t_tr,
        't_test':  t_te,
        # Data after lag, processing, split, AND normalization fit on train
        'X_train_norm': torch.tensor(X_tr_norm, dtype=torch.float32),
        'y_train_norm': torch.tensor(y_tr_norm, dtype=torch.float32),
        'X_test_norm':  torch.tensor(X_te_norm, dtype=torch.float32),
        'y_test_norm':  torch.tensor(y_te_norm, dtype=torch.float32),
        # Raw binned data (before lag or processing) - Useful for aggregation
        'X_binned_raw': spikes_b_raw, # Raw counts
        'y_binned_raw': vel_b_raw,    # Raw velocity
        't_binned_raw': t_binned,     # Timestamps for raw binned data
        # Scalers fit on the internal training split
        'spike_scaler': sx,
        'velocity_scaler': sy,
        # Original parameters for reference
        'bin_width_s': bin_width_s,
        'stride_s': stride_s
    }

    # Add the full processed (but un-split) data if no split was done
    if test_split_ratio == 0:
        return_dict['X_processed_full'] = torch.tensor(X_processed, dtype=torch.float32)
        return_dict['y_processed_full'] = torch.tensor(y_processed, dtype=torch.float32)
        return_dict['t_processed_full'] = t_binned
        # Provide raw counts after lag as well
        return_dict['X_raw_counts_lagged'] = spikes_b_raw

    return return_dict

# %%
import random
import time
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from scipy.interpolate import interp1d
import numpy as np
import pandas as pd

class CausalBatcher:
    """
    Creates batches of contiguous sequences from time-series data.
    Supports overlapping windows for training data augmentation.
    """
    def __init__(self, spike_data, targets, batch_size=64, sequence_length=20, shuffle=False, stride=None):
        self.spike_data = torch.tensor(spike_data, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.shuffle = shuffle
        self.stride = stride if stride is not None else sequence_length # Non-overlapping by default

        # Calculate how many full sequences we can make
        num_total_timesteps = self.spike_data.shape[0]
        if num_total_timesteps < self.sequence_length:
            self.num_sequences = 0
        else:
            self.num_sequences = (num_total_timesteps - self.sequence_length) // self.stride + 1

        # Calculate how many full batches we can make
        self.num_batches = self.num_sequences // self.batch_size

        if self.num_batches == 0:
            raise ValueError(
                f"Not enough data to create a single batch. "
                f"Have {self.num_sequences} sequences for seq_len={sequence_length} and stride={self.stride}, but batch_size is {self.batch_size}."
            )

        self.indices = np.arange(self.num_sequences)
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __len__(self):
        return self.num_batches

    def __iter__(self):
        self.current_batch = 0
        if self.shuffle: # Reshuffle for each epoch
            np.random.shuffle(self.indices)
        return self

    def __next__(self):
        if self.current_batch >= self.num_batches:
            raise StopIteration

        # Get the indices for the sequences in this batch
        start_idx = self.current_batch * self.batch_size
        end_idx = start_idx + self.batch_size
        batch_seq_indices = self.indices[start_idx:end_idx]

        # Get the actual start timestep for each sequence in the batch
        batch_timestep_starts = batch_seq_indices * self.stride

        # Build the batch of sequences
        batch_spikes = torch.stack([
            self.spike_data[ts:ts + self.sequence_length] for ts in batch_timestep_starts
        ])
        batch_targets = torch.stack([
            self.targets[ts:ts + self.sequence_length] for ts in batch_timestep_starts
        ])

        self.current_batch += 1
        return batch_spikes, batch_targets

class OnlineDataStream:
    """
    Creates a stream of single timestep samples for true online learning.
    Maintains proper chronological order and supports trial boundaries.
    """
    def __init__(self, spike_data, targets, reset_every_n=None):
        self.spike_data = torch.tensor(spike_data, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.reset_every_n = reset_every_n  # Reset hidden states every N timesteps (e.g., trial boundaries)
        self.current_idx = 0
        self.steps_since_reset = 0
        
    def __len__(self):
        return self.spike_data.shape[0]
    
    def __iter__(self):
        self.current_idx = 0
        self.steps_since_reset = 0
        return self
    
    def __next__(self):
        if self.current_idx >= len(self.spike_data):
            raise StopIteration
            
        x_t = self.spike_data[self.current_idx]
        y_t = self.targets[self.current_idx]
        
        should_reset = False
        if self.reset_every_n and self.steps_since_reset >= self.reset_every_n:
            should_reset = True
            self.steps_since_reset = 0
        
        self.current_idx += 1
        self.steps_since_reset += 1
        
        return x_t, y_t, should_reset

class SNNRegression(nn.Module):
    def __init__(self, input_size, hidden_size=128, output_size=2):
        super(SNNRegression, self).__init__()
        spike_grad = surrogate.fast_sigmoid()

        # Feedforward & Recurrent Layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc_rec = nn.Linear(hidden_size, hidden_size, bias=False) # RESTORED
        self.lif1 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)

        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.lif2 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)

        self.fc3 = nn.Linear(hidden_size // 2, output_size)
        self.lif3 = snn.Leaky(
            beta=0.5,
            spike_grad=spike_grad,
            init_hidden=False,
            threshold=1.0,
            reset_mechanism="none"
        )
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                # give fc1 / fc2 a small negative bias; keep others at 0
                if module in (self.fc1, self.fc2):
                    nn.init.constant_(module.bias, -0.1)
                else:
                    nn.init.zeros_(module.bias)

    def forward(self, x, spk1_rec, mem1, mem2, mem3, need_traces: bool = False):
        """
        Args
        ----
        x           : [batch, T, features]
        spk1_rec …  : recurrent + membrane states (same as before)
        need_traces : if True, also returns the "pre / post" tensors
                      and next membrane potentials required by the
                      Hebbian updater.
        """
        batch, T, _ = x.shape
        outputs = []
        
        # Buffers only filled when learning
        if need_traces:
            pre_ff_buf, pre_rec_buf, spk1_buf, spk2_buf = [], [], [], []
            mem1_next_buf, mem2_next_buf, mem3_next_buf = [], [], []

        for t in range(x.size(1)):
            inp = x[:, t, :]
            
            # Detach states for BPTT-like behavior within the sequence
            mem1 = mem1.detach()
            mem2 = mem2.detach()
            spk1_rec = spk1_rec.detach()
            mem3 = mem3.detach()
            
            pre_rec_t_for_trace = spk1_rec

            # Combine feedforward and recurrent input for the first layer
            cur1 = self.fc1(inp) + self.fc_rec(spk1_rec)
            spk1, mem1 = self.lif1(cur1, mem1)

            # Update the recurrent state for the next time step
            spk1_rec = spk1
            
            # Second hidden layer
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            # Output layer
            cur3 = self.fc3(spk2)
            out, mem3 = self.lif3(cur3, mem3)
            outputs.append(mem3)

            # --- Collect traces if asked ---
            if need_traces:
                pre_ff_buf.append(inp.detach())
                pre_rec_buf.append(pre_rec_t_for_trace.detach())
                spk1_buf.append(spk1.detach())
                spk2_buf.append(spk2.detach())
                mem1_next_buf.append(mem1.detach())
                mem2_next_buf.append(mem2.detach())
                mem3_next_buf.append(mem3.detach())

        out_seq = torch.stack(outputs, dim=1)
        final_states = (spk1_rec, mem1, mem2, mem3)

        if need_traces:
            traces = (
                torch.stack(pre_ff_buf, dim=1),
                torch.stack(pre_rec_buf, dim=1),
                torch.stack(spk1_buf, dim=1),
                torch.stack(spk2_buf, dim=1),
                torch.stack(mem1_next_buf, dim=1),
                torch.stack(mem2_next_buf, dim=1),
                torch.stack(mem3_next_buf, dim=1)
            )
            return out_seq, final_states, traces
        else:
            return out_seq, final_states


def compute_correlation(pred, target):
    """
    Compute Pearson correlation between predicted and target values.
    Handles both tensor and numpy inputs.
    """
    # Convert to numpy if tensors
    if torch.is_tensor(pred):
        pred = pred.detach().cpu().numpy()
    if torch.is_tensor(target):
        target = target.detach().cpu().numpy()
    
    # Handle multi-dimensional arrays by flattening
    pred = pred.flatten()
    target = target.flatten()
    
    # Check for zero variance
    if np.std(pred) < 1e-10 or np.std(target) < 1e-10:
        return 0.0
    
    # Compute correlation safely
    try:
        corr_matrix = np.corrcoef(pred, target)
        # Handle case when corrcoef returns a scalar or 2x2 matrix
        if corr_matrix.size > 1:
            return corr_matrix[0, 1]
        else:
            # If identical arrays, correlation is 1.0
            return 1.0
    except (IndexError, ValueError):
        # Fallback for any errors
        return 0.0

def compute_fvaf(y_true, y_pred):
    """
    Computes the Fraction of Variance Accounted For (FVAF), equivalent to R^2 score.
    Handles both tensor and numpy inputs. Expects 1D arrays.
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()

    if y_true.ndim > 1 or y_pred.ndim > 1:
        warnings.warn(f"FVAF expects 1D arrays, but got shapes {y_true.shape} and {y_pred.shape}. Flattening.", RuntimeWarning)
        y_true = y_true.flatten()
        y_pred = y_pred.flatten()

    ss_res = np.sum((y_true - y_pred)**2)
    ss_tot = np.sum((y_true - np.mean(y_true))**2)

    if ss_tot < 1e-10:
        return 1.0 if ss_res < 1e-10 else 0.0

    return 1 - (ss_res / ss_tot)

def butter_lowpass(data, fs=1/0.05, cutoff=8, order=4):
    sos = butter(order, cutoff, 'low', fs=fs, output='sos')
    return sosfiltfilt(sos, data, axis=0)

# %%
##########################################
# TWO-TIMESCALE META RL UPDATER WITH META-LEARNING
##########################################
class TwoScaleMetaRLWeightUpdaterFull:
    """
    This updater adapts all layers using configurable timescales and incorporates meta-learning.
    It uses eligibility traces to integrate credit assignment over time, 
    providing memory of recent activity without acausally looking into the future. 
    The learning rate is modulated by a biologically-inspired reward signal based on performance.
    
    This approach integrates Hebbian principles with the needs of the BCI task.
    """
    def __init__(self, model, base_fast_lr=1e-3, base_slow_lr=1e-2, window_size=180, 
                 meta_lr=1e-3, tau_e_fast=0.12, tau_e_slow=0.7, online_mode=False,
                 use_hebbian=True,  # ABLATION: Enable/disable Hebbian learning
                 use_meta_learning=True,  # ABLATION: Enable/disable meta-learning
                 timescale_mode='dual',  # ABLATION: 'fast_only', 'slow_only', 'dual', 'medium_only'
                 use_fast_update=True,  # NEW ABLATION: Enable/disable fast updates
                 use_slow_update=True,  # NEW ABLATION: Enable/disable slow updates
                 **kwargs):
        self.model = model
        self.device = next(model.parameters()).device
        self.base_fast_lr = base_fast_lr
        self.base_slow_lr = base_slow_lr
        self.meta_lr = meta_lr
        self.grad_scale = 0
        self.online_mode = online_mode
        self.use_hebbian = use_hebbian  # FIX: Use the parameter, not hardcoded False!
        self.use_meta_learning = use_meta_learning  # Meta-learning ablation parameter
        self.timescale_mode = timescale_mode  # Timescale ablation parameter
        self.use_fast_update = use_fast_update  # NEW: Fast update ablation parameter
        self.use_slow_update = use_slow_update  # NEW: Slow update ablation parameter
        
        # Print ablation status for clarity
        learning_type = "Hebbian (error × d_lif × activity)" if self.use_hebbian else "Delta Rule (error × activity)"
        meta_status = "WITH meta-learning" if self.use_meta_learning else "WITHOUT meta-learning (fixed LR)"
        timescale_status = f"Timescale mode: {timescale_mode.upper()}"
        update_status = f"Updates: Fast={use_fast_update}, Slow={use_slow_update}"
        print(f"Initializing updater with {learning_type} {meta_status} {timescale_status} {update_status}")

        # RMS accumulators for error and spike normalization (Tweaks A & B)
        # Use per-unit, integer-based EMAs for hardware friendliness
        self.err2_sq_ema = torch.ones(model.fc2.out_features, dtype=torch.int32, device=self.device)
        self.err1_sq_ema = torch.ones(model.fc1.out_features, dtype=torch.int32, device=self.device)
        self.spk1_rms    = torch.ones(model.fc1.out_features, dtype=torch.int32, device=self.device)
        
        # Meta-learning parameters - only used if use_meta_learning=True
        self.meta_params = {
            'plasticity': 1.0, 
            'sensitivity': 1.0
        } 
        
        # Set learning rates based on meta-learning mode
        if self.use_meta_learning:
            self.fast_lr = self.base_fast_lr * self.meta_params['plasticity']
            self.slow_lr = self.base_slow_lr * self.meta_params['sensitivity']
        else:
            # Fixed learning rates when meta-learning is disabled
            self.fast_lr = self.base_fast_lr
            self.slow_lr = self.base_slow_lr

        # --- Hardware-friendly LUTs ---
        # Passed in from the main training script
        self.inv_sqrt_LUT = kwargs.get('inv_sqrt_LUT', None)
        if self.inv_sqrt_LUT is None:
            raise ValueError("inv_sqrt_LUT must be provided to the updater.")
        # LUT for error-bucket based reward scaling
        self.reward_LUT = torch.tensor(
       [15,14,13,12,11,10,9,8,8,8,7,7,6,5,4,3], device=self.device)
        self.out_rms = torch.ones(2, dtype=torch.int32, device=self.device)  # vx, vy
        self.fc3_row_cap_q12 = torch.full((2, 1),
                                     int(6.0*4096),
                                     dtype=torch.int32,
                                     device=self.device)

        # --- Configurable Timescale Eligibility Trace Parameters ---
        # Set up timescales based on mode
        if timescale_mode == 'fast_only':
            # Only fast traces (τ=80ms)
            if online_mode:
                self.decay_fast = math.exp(-0.064 / (tau_e_fast * 0.5))
            else:
                self.decay_fast = math.exp(-0.064 / tau_e_fast)
            self.decay_slow = None
            self.trace_mix_a = 1.0  # 100% fast traces
        elif timescale_mode == 'slow_only':
            # Only slow traces (τ=800ms)
            if online_mode:
                self.decay_slow = math.exp(-0.064 / (tau_e_slow * 0.8))
            else:
                self.decay_slow = math.exp(-0.064 / tau_e_slow)
            self.decay_fast = None
            self.trace_mix_a = 0.0  # 100% slow traces
        elif timescale_mode == 'medium_only':
            # Medium timescale only (τ=400ms, middle ground)
            tau_e_medium = (tau_e_fast + tau_e_slow) / 2  # 400ms
            if online_mode:
                self.decay_fast = math.exp(-0.064 / (tau_e_medium * 0.65))
            else:
                self.decay_fast = math.exp(-0.064 / tau_e_medium)
            self.decay_slow = None
            self.trace_mix_a = 1.0  # Use "fast" traces for medium timescale
        elif timescale_mode == 'dual':
            # Original dual timescale system
            if online_mode:
                self.decay_fast = math.exp(-0.064 / (tau_e_fast * 0.5))  # Faster decay
                self.decay_slow = math.exp(-0.064 / (tau_e_slow * 0.8))  # Slightly faster decay
                self.trace_mix_a = 0.8  # More weight on fast traces for responsiveness
            else:
                self.decay_fast = math.exp(-0.064 / tau_e_fast)
                self.decay_slow = math.exp(-0.064 / tau_e_slow)
                self.trace_mix_a = 0.5  # Fixed mixing parameter for fast and slow traces
        else:
            raise ValueError(f"Unknown timescale_mode: {timescale_mode}")

        # Initialize eligibility traces based on mode
        self.e_fast_fc1  = torch.zeros_like(model.fc1.weight, device=self.device)
        self.e_fast_fc2  = torch.zeros_like(model.fc2.weight, device=self.device)
        self.e_fast_fc3  = torch.zeros_like(model.fc3.weight, device=self.device)
        self.e_fast_rec  = torch.zeros_like(model.fc_rec.weight, device=self.device)
        
        # Only create slow traces if needed
        if self.decay_slow is not None:
            self.e_slow_fc1  = torch.zeros_like(model.fc1.weight, device=self.device)
            self.e_slow_fc2  = torch.zeros_like(model.fc2.weight, device=self.device)
            self.e_slow_fc3  = torch.zeros_like(model.fc3.weight, device=self.device)
            self.e_slow_rec  = torch.zeros_like(model.fc_rec.weight, device=self.device)
        else:
            self.e_slow_fc1 = self.e_slow_fc2 = self.e_slow_fc3 = self.e_slow_rec = None
        # --- End Eligibility Trace ---

        self.window_size = window_size
        self.step_counter = 0
        self.cumulative_error = 0
        self.prev_cumulative_error = None
        
        self.max_loss_value = 5.0
        
        # Initialize gradient averages with zeros matching model layer shapes and device
        self.grad_fc1_avg = torch.zeros_like(self.model.fc1.weight.data, device=self.device)
        self.grad_fc2_avg = torch.zeros_like(self.model.fc2.weight.data, device=self.device)
        self.grad_fc3_avg = torch.zeros_like(self.model.fc3.weight.data, device=self.device)
        self.grad_fc_rec_avg = torch.zeros_like(self.model.fc_rec.weight.data, device=self.device) # RESTORED
        
        # Initialize momentum for gradient accumulation
        self.momentum = 0.9

    def reset_weights_with_xavier(self):
        """
        Reset weights using Xavier/Glorot initialization for better convergence
        after catastrophic errors.
        """
        with torch.no_grad():
            for module in [self.model.fc1, self.model.fc2, self.model.fc3, self.model.fc_rec]: # RESTORED self.model.fc_rec
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
        print("Weights reset using Xavier/Glorot initialization for better convergence")

    def warm_start_rms_emas(self, warmup_data_x, warmup_data_y, num_warmup_steps=100):
        """
        Warm-start the RMS EMAs with some initial data to avoid cold-start issues in online mode.
        """
        if not self.online_mode:
            return
            
        print(f"Warm-starting RMS EMAs with {num_warmup_steps} steps...")
        with torch.no_grad():
            for i in range(min(num_warmup_steps, len(warmup_data_x))):
                x_t = warmup_data_x[i:i+1].to(self.device)  # [1, features]
                y_t = warmup_data_y[i:i+1].to(self.device)  # [1, 2]
                
                # Quick forward pass to get some initial error statistics
                x_t_seq = x_t.unsqueeze(1)  # [1, 1, features]
                y_t_seq = y_t.unsqueeze(1)  # [1, 1, 2]
                
                # Initialize temporary states
                batch_size = 1
                spk1_rec = torch.zeros(batch_size, self.model.fc1.out_features, device=self.device)
                mem1 = torch.zeros(batch_size, self.model.fc1.out_features, device=self.device)
                mem2 = torch.zeros(batch_size, self.model.fc2.out_features, device=self.device)
                mem3 = torch.zeros(batch_size, self.model.fc3.out_features, device=self.device)
                
                pred_seq, _, _ = self.model(x_t_seq, spk1_rec, mem1, mem2, mem3, need_traces=True)
                output_error = y_t_seq - pred_seq
                
                # Update output RMS - need to sum over batch AND sequence dims to get [2] shape
                k_out = 2  # Faster adaptation for warmup
                sq_out = ((output_error**2) * 4096).sum(dim=(0,1)).to(torch.int32)  # Sum over batch and seq dims
                self.out_rms -= self.out_rms >> k_out
                self.out_rms += sq_out >> k_out
                
        print("RMS EMA warm-start complete.")

    def fast_update_single_timestep(self, x_t, y_t, states):
        """
        Online single-timestep update optimized for real-time learning.
        Maintains persistent hidden states and applies scaled learning rates.
        """
        if not self.online_mode:
            raise ValueError("Single timestep update only available in online mode")
            
        # Reshape for model forward pass: [batch=1, seq=1, features]
        x_t_seq = x_t.unsqueeze(0).unsqueeze(0)  # [1, 1, features]
        y_t_seq = y_t.unsqueeze(0).unsqueeze(0)  # [1, 1, 2]
        
        device = self.device
        spk1_rec, mem1, mem2, mem3 = states

        # Single forward pass to get prediction and traces
        pred_seq, final_states, traces = self.model(
            x_t_seq, spk1_rec, mem1, mem2, mem3, need_traces=True
        )
        (pre_ff_all, pre_rec_all, spk1_all, spk2_all, 
         mem1_next_all, mem2_next_all, mem3_next_all) = traces
        
        # Calculate loss for single timestep
        mse_loss = F.mse_loss(pred_seq, y_t_seq)
        combined_loss = torch.clamp(mse_loss, 0.0, 10.0)
        
        # Hebbian update calculation for single timestep
        with torch.no_grad():
            # Get single timestep data
            pre_ff_t = pre_ff_all[:, 0, :]    # [1, features]
            pre_rec_t = pre_rec_all[:, 0, :]  # [1, hidden]
            spk1_t = spk1_all[:, 0, :]        # [1, hidden]
            spk2_t = spk2_all[:, 0, :]        # [1, hidden//2]
            mem1_next = mem1_next_all[:, 0, :] # [1, hidden]
            mem2_next = mem2_next_all[:, 0, :] # [1, hidden//2]
            mem3_next = mem3_next_all[:, 0, :] # [1, 2]
            
            # Single timestep error
            output_error_t = y_t_seq[:, 0, :] - pred_seq[:, 0, :]  # [1, 2]
            
            # Calculate bucket-based reward with online-adapted EMAs
            abs_err_xy = torch.abs(output_error_t).mean(dim=0)  # [2]
            bucket_xy = (abs_err_xy * 8).clamp(0, 15).to(torch.int64)
            bucket = torch.max(bucket_xy)
            lr_scale = self.reward_LUT[bucket]
            
            # Scale learning rate for single timestep (was calibrated for seq_len=10)
            effective_lr = self.fast_lr * lr_scale / 16 * 0.1  # 0.1 scale factor for single timestep
            
            # Online-adapted EMA constants (faster adaptation)
            k_out = 2  # Faster than batch mode (was 4)
            k = 3      # Faster than batch mode (was 5)
            
            # Normalize output error per-channel
            sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)
            self.out_rms -= self.out_rms >> k_out
            self.out_rms += sq_out >> k_out
            idx_out = (self.out_rms >> 8).clamp_(0, 255)
            norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]
            
            # Backpropagate error through layers with online-adapted normalization
            err2_raw = torch.matmul(norm_err, self.model.fc3.weight)
            sq_err2 = ((err2_raw**2) * 4096).sum(dim=0).to(torch.int32)
            self.err2_sq_ema -= self.err2_sq_ema >> k
            self.err2_sq_ema += sq_err2 >> k
            idx = (self.err2_sq_ema >> 8).clamp_(0, 255)
            hidden2_error_t = err2_raw * self.inv_sqrt_LUT[idx]

            err1_raw = torch.matmul(hidden2_error_t, self.model.fc2.weight)
            sq_err1 = ((err1_raw**2) * 4096).sum(dim=0).to(torch.int32)
            self.err1_sq_ema -= self.err1_sq_ema >> k
            self.err1_sq_ema += sq_err1 >> k
            idx = (self.err1_sq_ema >> 8).clamp_(0, 255)
            hidden1_error_t = err1_raw * self.inv_sqrt_LUT[idx]
            
            # Local sensitivities
            d_lif3_t = self.model.lif3.spike_grad(mem3_next)
            d_lif2_t = self.model.lif2.spike_grad(mem2_next)
            d_lif1_t = self.model.lif1.spike_grad(mem1_next)

            # Normalize spike activities
            sq_spk1 = ((spk1_t**2) * 4096).sum(dim=0).to(torch.int32)
            self.spk1_rms += -(self.spk1_rms >> 2) + (sq_spk1 >> 2)  # Faster adaptation (k=2)
            idx = (self.spk1_rms >> 8).clamp_(0, 255)
            spk1_t = spk1_t * self.inv_sqrt_LUT[idx]

            # Compute weight updates using either Hebbian or Delta rule
            if self.use_hebbian:
                # Original Hebbian updates with surrogate gradients
                hebb_fc3 = torch.matmul((norm_err * d_lif3_t).t(), spk2_t)
                hebb_fc2 = torch.matmul((hidden2_error_t * d_lif2_t).t(), spk1_t)
                hebb_fc1 = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_ff_t)
                hebb_rec = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_rec_t)
            else:
                # Delta rule updates (no surrogate gradients)
                hebb_fc3 = torch.matmul(norm_err.t(), spk2_t)
                hebb_fc2 = torch.matmul(hidden2_error_t.t(), spk1_t)
                hebb_fc1 = torch.matmul(hidden1_error_t.t(), pre_ff_t)
                hebb_rec = torch.matmul(hidden1_error_t.t(), pre_rec_t)

            # Update eligibility traces with computed terms (conditional on timescale mode)
            if self.decay_fast is not None:
                self.e_fast_fc3.mul_(self.decay_fast).add_(hebb_fc3)
                self.e_fast_fc2.mul_(self.decay_fast).add_(hebb_fc2)
                self.e_fast_fc1.mul_(self.decay_fast).add_(hebb_fc1)
                self.e_fast_rec.mul_(self.decay_fast).add_(hebb_rec)
            
            if self.decay_slow is not None:
                self.e_slow_fc3.mul_(self.decay_slow).add_(hebb_fc3)
                self.e_slow_fc2.mul_(self.decay_slow).add_(hebb_fc2)
                self.e_slow_fc1.mul_(self.decay_slow).add_(hebb_fc1)
                self.e_slow_rec.mul_(self.decay_slow).add_(hebb_rec)

            # Apply updates using combined traces (mode-dependent)
            if self.timescale_mode == 'fast_only' or self.timescale_mode == 'medium_only':
                e_comb_fc1 = self.e_fast_fc1
                e_comb_fc2 = self.e_fast_fc2
                e_comb_fc3 = self.e_fast_fc3
                e_comb_rec = self.e_fast_rec
            elif self.timescale_mode == 'slow_only':
                e_comb_fc1 = self.e_slow_fc1
                e_comb_fc2 = self.e_slow_fc2
                e_comb_fc3 = self.e_slow_fc3
                e_comb_rec = self.e_slow_rec
            elif self.timescale_mode == 'dual':
                e_comb_fc1 = self.trace_mix_a * self.e_fast_fc1 + (1 - self.trace_mix_a) * self.e_slow_fc1
                e_comb_fc2 = self.trace_mix_a * self.e_fast_fc2 + (1 - self.trace_mix_a) * self.e_slow_fc2
                e_comb_fc3 = self.trace_mix_a * self.e_fast_fc3 + (1 - self.trace_mix_a) * self.e_slow_fc3
                e_comb_rec = self.trace_mix_a * self.e_fast_rec + (1 - self.trace_mix_a) * self.e_slow_rec

            self.model.fc1.weight.data += effective_lr * e_comb_fc1
            self.model.fc2.weight.data += effective_lr * e_comb_fc2
            self.model.fc3.weight.data += effective_lr * e_comb_fc3
            self.model.fc_rec.weight.data += effective_lr * e_comb_rec
            
            # Accumulate for slow updates
            self.grad_fc1_avg = self.momentum * self.grad_fc1_avg + (1 - self.momentum) * e_comb_fc1
            self.grad_fc2_avg = self.momentum * self.grad_fc2_avg + (1 - self.momentum) * e_comb_fc2
            self.grad_fc3_avg = self.momentum * self.grad_fc3_avg + (1 - self.momentum) * e_comb_fc3
            self.grad_fc_rec_avg = self.momentum * self.grad_fc_rec_avg + (1 - self.momentum) * e_comb_rec
            self.grad_scale += 1

            # Renormalize weights
            for w in [self.model.fc1.weight, self.model.fc2.weight, self.model.fc3.weight, self.model.fc_rec.weight]:
                self.renorm_row_(w)
        
        return combined_loss.item(), final_states
    
    def update_single_timestep(self, x_t, y_t, states):
        """
        Main interface for online single-timestep updates.
        """
        x_t, y_t = x_t.to(self.device), y_t.to(self.device)
        
        # Calculate loss regardless of whether we do fast update
        if self.use_fast_update:
            loss, new_states = self.fast_update_single_timestep(x_t, y_t, states)
        else:
            # Just do forward pass to get loss and states without updating weights
            x_t_seq = x_t.unsqueeze(0).unsqueeze(0)
            y_t_seq = y_t.unsqueeze(0).unsqueeze(0)
            spk1_rec, mem1, mem2, mem3 = states
            pred_seq, new_states = self.model(x_t_seq, spk1_rec, mem1, mem2, mem3)
            mse_loss = F.mse_loss(pred_seq, y_t_seq)
            loss = torch.clamp(mse_loss, 0.0, 10.0).item()
        
        bounded_loss = min(loss, self.max_loss_value)
        
        if bounded_loss > self.max_loss_value * 1.5:
            print(f"CATASTROPHIC ERROR DETECTED: {bounded_loss:.4f}. Performing Xavier weight reset.")
            self.reset_weights_with_xavier()
            return bounded_loss, new_states
            
        self.cumulative_error += bounded_loss
        self.step_counter += 1
        
        # Apply slow update only if enabled
        if self.step_counter >= self.window_size and self.use_slow_update:
            self.slow_update()
        elif self.step_counter >= self.window_size:
            # Reset counters even if slow update is disabled
            self.step_counter = 0
            self.cumulative_error = 0
        
        return bounded_loss, new_states
              
    def fast_update(self, input_sequence, desired_sequence, initial_states):
        """
        Novel biologically plausible learning rule for a SEQUENCE using an eligibility trace.
        It integrates Hebbian updates over the entire window before applying.
        """
        batch_size, seq_len, _ = input_sequence.shape
        device = self.device

        # Unpack initial states for the sequence
        spk1_rec, mem1, mem2, mem3 = initial_states

        # NOTE: Eligibility traces are NOT reset here. They persist and decay across batches
        # to maintain a continuous memory of activity, as per best practices.

        # SINGLE forward pass to get predictions AND traces for Hebbian update
        pred_sequence, final_states, traces = self.model(
            input_sequence, *initial_states, need_traces=True
        )
        (pre_ff_all, pre_rec_all, spk1_all, spk2_all, 
         mem1_next_all, mem2_next_all, mem3_next_all) = traces
        
        # Calculate loss over the whole sequence
        mse_loss = F.mse_loss(pred_sequence, desired_sequence)
        combined_loss = torch.clamp(mse_loss, 0.0, 10.0)
        
        # Hebbian update calculation
        with torch.no_grad():
            # Calculate a 4-bit error bucket reward signal to modulate plasticity.
            output_error_t_for_reward = desired_sequence - pred_sequence
            abs_err_xy = torch.mean(torch.abs(output_error_t_for_reward), dim=(0,1))  # [2]
            bucket_xy  = (abs_err_xy * 8).clamp(0,15).to(torch.int64)      # 4-bit each
            bucket     = torch.max(bucket_xy)                              # int64 scalar
            lr_scale   = self.reward_LUT[bucket]
            effective_lr = self.fast_lr * lr_scale / 16 # Use integer division

            # NO second pass needed. Use the returned traces.
            for t in range(seq_len):
                # Get activations from the pre-computed traces
                pre_ff_t = pre_ff_all[:, t, :]
                pre_rec_t = pre_rec_all[:, t, :]
                spk1_t = spk1_all[:, t, :]
                spk2_t = spk2_all[:, t, :]
                mem1_next = mem1_next_all[:, t, :]
                mem2_next = mem2_next_all[:, t, :]
                mem3_next = mem3_next_all[:, t, :]
                
                # Propagate error backward from this time step's prediction
                output_error_t = desired_sequence[:, t, :] - pred_sequence[:, t, :]
                
                # Patch A: Normalize output error per-channel before backprop
                k_out = 4           # EMA 1/16
                sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)  # [2]
                self.out_rms -= self.out_rms >> k_out
                self.out_rms += sq_out       >> k_out
                idx_out = (self.out_rms >> 8).clamp_(0,255)
                norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]   # element-wise
                
                # Tweak A: RMS-Rescale the local errors (Hardware-Friendly Version)
                k = 5 # EMA constant: 2^-5 = 0.03125
                
                err2_raw = torch.matmul(norm_err, self.model.fc3.weight)     # [B,H2]
                sq_err2 = ((err2_raw**2) * 4096).sum(dim=0).to(torch.int32) # Sum over batch
                self.err2_sq_ema -= self.err2_sq_ema >> k
                self.err2_sq_ema += sq_err2 >> k
                idx = (self.err2_sq_ema >> 8).clamp_(0,255)
                hidden2_error_t  = err2_raw * self.inv_sqrt_LUT[idx]

                err1_raw = torch.matmul(hidden2_error_t, self.model.fc2.weight)    # [B,H1]
                sq_err1 = ((err1_raw**2) * 4096).sum(dim=0).to(torch.int32)
                self.err1_sq_ema -= self.err1_sq_ema >> k
                self.err1_sq_ema += sq_err1 >> k
                idx = (self.err1_sq_ema >> 8).clamp_(0,255)
                hidden1_error_t  = err1_raw * self.inv_sqrt_LUT[idx]
                
                # Get local sensitivities (surrogate gradients)
                d_lif3_t = self.model.lif3.spike_grad(mem3_next)
                d_lif2_t = self.model.lif2.spike_grad(mem2_next)
                d_lif1_t = self.model.lif1.spike_grad(mem1_next)

                # Tweak B: RMS-Rescale spike activities (Hardware-Friendly Version)
                # Note: spk1_t is already detached from previous ops
                sq_spk1 = ((spk1_t**2) * 4096).sum(dim=0).to(torch.int32)
                self.spk1_rms += - (self.spk1_rms >> 3) + (sq_spk1 >> 3) # k=3 for spikes
                idx = (self.spk1_rms >> 8).clamp_(0,255)
                spk1_t = spk1_t * self.inv_sqrt_LUT[idx]

                # Update both fast and slow eligibility traces with the new instantaneous Hebbian term
                if self.use_hebbian:
                    # Hebbian updates with surrogate gradients
                    hebb_fc3 = torch.matmul((norm_err * d_lif3_t).t(), spk2_t)
                    hebb_fc2 = torch.matmul((hidden2_error_t * d_lif2_t).t(), spk1_t)
                    hebb_fc1 = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_ff_t)
                    hebb_rec = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_rec_t)
                else:
                    # Delta rule updates (no surrogate gradients)
                    hebb_fc3 = torch.matmul(norm_err.t(), spk2_t)
                    hebb_fc2 = torch.matmul(hidden2_error_t.t(), spk1_t)
                    hebb_fc1 = torch.matmul(hidden1_error_t.t(), pre_ff_t)
                    hebb_rec = torch.matmul(hidden1_error_t.t(), pre_rec_t)
                
                # Update eligibility traces (conditional on timescale mode)
                if self.decay_fast is not None:
                    self.e_fast_fc3.mul_(self.decay_fast).add_(hebb_fc3)
                    self.e_fast_fc2.mul_(self.decay_fast).add_(hebb_fc2)
                    self.e_fast_fc1.mul_(self.decay_fast).add_(hebb_fc1)
                    self.e_fast_rec.mul_(self.decay_fast).add_(hebb_rec)
                
                if self.decay_slow is not None:
                    self.e_slow_fc3.mul_(self.decay_slow).add_(hebb_fc3)
                    self.e_slow_fc2.mul_(self.decay_slow).add_(hebb_fc2)
                    self.e_slow_fc1.mul_(self.decay_slow).add_(hebb_fc1)
                    self.e_slow_rec.mul_(self.decay_slow).add_(hebb_rec)

                # NO state update at end of loop

            # Apply updates using the final state of the eligibility traces
            # The effective_lr is now calculated once for the whole sequence at the top
            
            # Combine traces for the update (mode-dependent)
            if self.timescale_mode == 'fast_only' or self.timescale_mode == 'medium_only':
                e_comb_fc1 = self.e_fast_fc1
                e_comb_fc2 = self.e_fast_fc2
                e_comb_fc3 = self.e_fast_fc3
                e_comb_rec = self.e_fast_rec
            elif self.timescale_mode == 'slow_only':
                e_comb_fc1 = self.e_slow_fc1
                e_comb_fc2 = self.e_slow_fc2
                e_comb_fc3 = self.e_slow_fc3
                e_comb_rec = self.e_slow_rec
            elif self.timescale_mode == 'dual':
                e_comb_fc1 = self.trace_mix_a * self.e_fast_fc1 + (1 - self.trace_mix_a) * self.e_slow_fc1
                e_comb_fc2 = self.trace_mix_a * self.e_fast_fc2 + (1 - self.trace_mix_a) * self.e_slow_fc2
                e_comb_fc3 = self.trace_mix_a * self.e_fast_fc3 + (1 - self.trace_mix_a) * self.e_slow_fc3
                e_comb_rec = self.trace_mix_a * self.e_fast_rec + (1 - self.trace_mix_a) * self.e_slow_rec

            self.model.fc1.weight.data += effective_lr * e_comb_fc1
            self.model.fc2.weight.data += effective_lr * e_comb_fc2
            self.model.fc3.weight.data += effective_lr * e_comb_fc3
            self.model.fc_rec.weight.data += effective_lr * e_comb_rec
            
            # Accumulate combined traces for the slow update path
            self.grad_fc1_avg = self.momentum * self.grad_fc1_avg + (1 - self.momentum) * e_comb_fc1
            self.grad_fc2_avg = self.momentum * self.grad_fc2_avg + (1 - self.momentum) * e_comb_fc2
            self.grad_fc3_avg = self.momentum * self.grad_fc3_avg + (1 - self.momentum) * e_comb_fc3
            self.grad_fc_rec_avg = self.momentum * self.grad_fc_rec_avg + (1 - self.momentum) * e_comb_rec
            self.grad_scale += 1

            # Renormalize all weights after every fast step to prevent explosion/drift
            for w in [self.model.fc1.weight, self.model.fc2.weight, self.model.fc3.weight, self.model.fc_rec.weight]:
                self.renorm_row_(w)
            
        x_corr=0
        y_corr=0
        return combined_loss.item(), x_corr, y_corr

    def _rms(self, x, eps=1e-5):
        return x / (torch.sqrt(torch.mean(x**2)) + eps)

    def renorm_row_(self, w):
       l2 = torch.sum((w.to(torch.int32) * w), dim=1, keepdim=True)          # Q12
       caps = (self.fc3_row_cap_q12 if w.shape[0] == 2
               else torch.full_like(l2, int(6.0 * 4096)))                    # scalar cap
       mask = l2 > caps
       if mask.any():
           shift = (torch.log2(l2[mask].float()) - torch.log2(caps[mask].float())
                   ).ceil().to(torch.int32)
           w_int = w.data.to(torch.int32)
           w_int[mask.squeeze(1)] >>= shift.squeeze(1)                       # power-of-two divide
           w.data.copy_(w_int.to(torch.float32))                                  # back to fp32 for PyTorch

    def normalize_weights(self, weights, max_norm=1.0):
        # This function is no longer used by the fast update path.
        # Kept for potential other uses, but can be removed.
        norm = torch.norm(weights, p=2, dim=1, keepdim=True)
        scale = torch.clamp(norm, max=max_norm)
        return weights * (scale / (norm + 1e-8))

    def slow_update(self):
        with torch.no_grad():
            if self.grad_scale > 0:
                self.grad_fc1_avg /= self.grad_scale
                self.grad_fc2_avg /= self.grad_scale
                self.grad_fc3_avg /= self.grad_scale
                self.grad_fc_rec_avg /= self.grad_scale

                self.model.fc1.weight.data += self.slow_lr * self._rms(self.grad_fc1_avg)
                self.model.fc2.weight.data += self.slow_lr * self._rms(self.grad_fc2_avg)
                self.model.fc3.weight.data += self.slow_lr * self._rms(self.grad_fc3_avg)
                self.model.fc_rec.weight.data += self.slow_lr * self._rms(self.grad_fc_rec_avg)

        # Apply a light weight decay to all layers during the slow update
        with torch.no_grad():
            for w in [self.model.fc1.weight, self.model.fc2.weight,
                      self.model.fc3.weight, self.model.fc_rec.weight]:
                w.mul_(0.999995) # ≈ 1e-5 weight-decay per slow step
        
        current_avg_loss = self.cumulative_error / self.step_counter if self.step_counter > 0 else float('inf')
        current_avg_loss = min(current_avg_loss, self.max_loss_value)
        
        # Meta-learning adaptation - only if enabled
        if self.use_meta_learning and self.prev_cumulative_error is not None:
            prev_loss = min(self.prev_cumulative_error, self.max_loss_value)
            if current_avg_loss < prev_loss:
                self.meta_params['plasticity'] *= (1 + self.meta_lr)
                self.meta_params['sensitivity'] *= (1 + self.meta_lr)
            else:
                self.meta_params['plasticity'] *= (1 - self.meta_lr)
                self.meta_params['sensitivity'] *= (1 - self.meta_lr)
            
            self.meta_params['plasticity'] = np.clip(self.meta_params['plasticity'], 0.2, 1.5)
            self.meta_params['sensitivity'] = np.clip(self.meta_params['sensitivity'], 0.2, 1.5)
            
            # Update learning rates based on meta-parameters
            self.fast_lr = self.base_fast_lr * self.meta_params['plasticity']
            self.slow_lr = self.base_slow_lr * self.meta_params['sensitivity']
        # If meta-learning is disabled, learning rates remain fixed at base values
        
        self.prev_cumulative_error = current_avg_loss
        
        self.grad_fc1_avg.zero_() 
        self.grad_fc2_avg.zero_()
        self.grad_fc3_avg.zero_()
        self.grad_fc_rec_avg.zero_() # RESTORED
        self.grad_scale = 0
        self.cumulative_error = 0
        self.step_counter = 0

        # Reset eligibility traces to prevent saturation over long runs (mode-dependent)
        traces_to_reset = [self.e_fast_fc1, self.e_fast_fc2, self.e_fast_fc3, self.e_fast_rec]
        if self.decay_slow is not None:
            traces_to_reset.extend([self.e_slow_fc1, self.e_slow_fc2, self.e_slow_fc3, self.e_slow_rec])
        
        for e in traces_to_reset:
            if e is not None:
                e.zero_()

    def update(self, input_sequence, desired_sequence, initial_states):
        input_sequence, desired_sequence = input_sequence.to(self.device), desired_sequence.to(self.device)
        
        # Calculate loss regardless of whether we do fast update
        with torch.no_grad():
            pred_sequence, _ = self.model(input_sequence, *initial_states, need_traces=False)
            mse_loss = F.mse_loss(pred_sequence, desired_sequence)
            combined_loss = torch.clamp(mse_loss, 0.0, 10.0).item()
        
        # Apply fast update only if enabled
        if self.use_fast_update:
            combined_loss, _, _ = self.fast_update(input_sequence, desired_sequence, initial_states)
        
        bounded_loss = min(combined_loss, self.max_loss_value)
        
        if bounded_loss > self.max_loss_value * 1.5:
            print(f"CATASTROPHIC ERROR DETECTED: {bounded_loss:.4f}. Performing Xavier weight reset.")
            self.reset_weights_with_xavier()
            return bounded_loss
            
        self.cumulative_error += bounded_loss
        self.step_counter += 1
        
        # Apply slow update only if enabled
        if self.step_counter >= self.window_size and self.use_slow_update:
            self.slow_update()
        elif self.step_counter >= self.window_size:
            # Reset counters even if slow update is disabled
            self.step_counter = 0
            self.cumulative_error = 0
        
        return bounded_loss


def train_snn_windowed(model, rl_updater, train_loader, val_loader, num_epochs, patience=10):
    device = next(model.parameters()).device
    print(f"Starting SNN Causal Windowed Training for {num_epochs} epochs on {device}...")
    
    best_val_metric = -np.inf # Track best Pearson correlation (can be negative)
    best_model_state = None
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        
        for i, (input_seq, target_seq) in enumerate(train_loader):
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)

            batch_size = input_seq.size(0)
            spk1_rec = torch.zeros(batch_size, model.fc1.out_features, device=device) # RESTORED
            mem1 = torch.zeros(batch_size, model.fc1.out_features, device=device)
            mem2 = torch.zeros(batch_size, model.fc2.out_features, device=device)
            mem3 = torch.zeros(batch_size, model.fc3.out_features, device=device)
            initial_states = (spk1_rec, mem1, mem2, mem3)
            
            loss = rl_updater.update(input_seq, target_seq, initial_states)
            running_train_loss += loss

        avg_train_loss = running_train_loss / len(train_loader)
        
        model.eval()
        val_preds, val_targets = [], []
        with torch.no_grad():
            for input_seq, target_seq in val_loader:
                input_seq, target_seq = input_seq.to(device), target_seq.to(device)
                batch_size = input_seq.size(0)
                val_spk1_rec = torch.zeros(batch_size, model.fc1.out_features, device=device) # RESTORED
                val_mem1 = torch.zeros(batch_size, model.fc1.out_features, device=device)
                val_mem2 = torch.zeros(batch_size, model.fc2.out_features, device=device)
                val_mem3 = torch.zeros(batch_size, model.fc3.out_features, device=device)
                
                pred_seq, _ = model(input_seq, val_spk1_rec, val_mem1, val_mem2, val_mem3)
                # Use final time-step to align with read-out
                val_preds.append(pred_seq[:, -1, :].cpu())
                val_targets.append(target_seq[:, -1, :].cpu())
            
            if not val_preds:
                print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Avg Corr: N/A (empty val set)")
                continue

            val_preds = torch.cat(val_preds)
            val_targets = torch.cat(val_targets)
            val_corr_x = compute_correlation(val_targets[:, 0], val_preds[:, 0])
            val_corr_y = compute_correlation(val_targets[:, 1], val_preds[:, 1])
            avg_val_corr = (val_corr_x + val_corr_y) / 2.0
            
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Avg Corr: {avg_val_corr:.4f}")

        if avg_val_corr > best_val_metric:
            best_val_metric = avg_val_corr
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            print(f"  >>> New best validation model saved with Avg Corr: {best_val_metric:.4f} <<<")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break

    if best_model_state:
        model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
        print(f"Loaded best model with validation Corr: {best_val_metric:.4f}")

    return model

def train_snn_online(model, rl_updater, train_stream, val_stream, num_epochs, patience=10, 
                     eval_every_n=500, reset_states_every_n=None):
    """
    Train SNN with true online learning - one timestep at a time.
    Maintains persistent hidden states and applies proper online adaptations.
    
    Args:
        model: SNN model to train
        rl_updater: Online-enabled updater (must have online_mode=True)
        train_stream: OnlineDataStream for training data
        val_stream: OnlineDataStream for validation data  
        num_epochs: Number of epochs to train
        patience: Early stopping patience (number of evaluation intervals without improvement)
        eval_every_n: Evaluate on validation every N steps
        reset_states_every_n: Reset hidden states every N steps (None = never reset)
    """
    if not rl_updater.online_mode:
        raise ValueError("rl_updater must be in online_mode for online training")
        
    device = next(model.parameters()).device
    print(f"Starting SNN Online Training for {num_epochs} epochs on {device}...")
    print(f"Evaluation every {eval_every_n} steps, reset states every {reset_states_every_n or 'never'} steps")
    
    best_val_metric = -np.inf
    best_model_state = None
    epochs_no_improve = 0
    
    # Initialize persistent hidden states - these persist across timesteps!
    spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
    mem1 = torch.zeros(1, model.fc1.out_features, device=device)
    mem2 = torch.zeros(1, model.fc2.out_features, device=device)
    mem3 = torch.zeros(1, model.fc3.out_features, device=device)
    states = (spk1_rec, mem1, mem2, mem3)
    
    # Warm-start RMS EMAs if we have training data
    if hasattr(train_stream, 'spike_data') and len(train_stream.spike_data) > 100:
        rl_updater.warm_start_rms_emas(
            train_stream.spike_data[:100], 
            train_stream.targets[:100], 
            num_warmup_steps=50
        )
    
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        step_count = 0
        
        for x_t, y_t, should_reset in train_stream:
            # Reset states if requested (e.g., trial boundaries)
            if should_reset or (reset_states_every_n and step_count % reset_states_every_n == 0):
                spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                states = (spk1_rec, mem1, mem2, mem3)
                
            # Single timestep update with persistent states
            loss, states = rl_updater.update_single_timestep(x_t, y_t, states)
            running_train_loss += loss
            step_count += 1
            
            # Periodic validation
            if step_count % eval_every_n == 0:
                model.eval()
                val_preds, val_targets = [], []
                
                # Initialize fresh states for validation
                val_spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                val_mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                val_mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                val_mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                val_states = (val_spk1_rec, val_mem1, val_mem2, val_mem3)
                
                with torch.no_grad():
                    val_step_count = 0
                    for x_val, y_val, should_reset_val in val_stream:
                        if should_reset_val or (reset_states_every_n and val_step_count % reset_states_every_n == 0):
                            val_spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                            val_mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                            val_mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                            val_mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                            val_states = (val_spk1_rec, val_mem1, val_mem2, val_mem3)
                            
                        x_val_seq = x_val.unsqueeze(0).unsqueeze(0).to(device)  # [1, 1, features]
                        pred_seq, val_states = model(x_val_seq, *val_states)
                        val_preds.append(pred_seq[0, 0, :].cpu())  # [2]
                        val_targets.append(y_val.cpu())
                        val_step_count += 1
                        
                        # Limit validation length to avoid overly long evaluation
                        if val_step_count >= 1000:
                            break
                
                if val_preds:
                    val_preds = torch.stack(val_preds)      # [N, 2]
                    val_targets = torch.stack(val_targets)  # [N, 2]
                    
                    # DEBUG: Check for problematic data patterns
                    pred_std_x, pred_std_y = val_preds[:, 0].std().item(), val_preds[:, 1].std().item()
                    target_std_x, target_std_y = val_targets[:, 0].std().item(), val_targets[:, 1].std().item()
                    pred_mean_x, pred_mean_y = val_preds[:, 0].mean().item(), val_preds[:, 1].mean().item()
                    target_mean_x, target_mean_y = val_targets[:, 0].mean().item(), val_targets[:, 1].mean().item()
                    
                    if pred_std_x < 1e-6 or pred_std_y < 1e-6 or target_std_x < 1e-6 or target_std_y < 1e-6:
                        print(f"  ⚠️  ZERO VARIANCE DETECTED:")
                        print(f"     Pred std: X={pred_std_x:.6f}, Y={pred_std_y:.6f}")
                        print(f"     Target std: X={target_std_x:.6f}, Y={target_std_y:.6f}")
                        print(f"     Pred mean: X={pred_mean_x:.6f}, Y={pred_mean_y:.6f}")
                        print(f"     Target mean: X={target_mean_x:.6f}, Y={target_mean_y:.6f}")
                        
                        # Auto-recovery from dead neurons
                        if pred_std_x < 1e-6 and pred_std_y < 1e-6 and step_count > 1000:
                            print(f"  🔧 ATTEMPTING DEAD NEURON RECOVERY:")
                            print(f"     Increasing learning rate by 5x and reinitializing model...")
                            
                            # Reinitialize model weights
                            for module in [model.fc1, model.fc2, model.fc3, model.fc_rec]:
                                nn.init.xavier_uniform_(module.weight)
                                if hasattr(module, 'bias') and module.bias is not None:
                                    nn.init.zeros_(module.bias)
                            
                            # Boost learning rates dramatically
                            rl_updater.fast_lr *= 5.0
                            rl_updater.base_fast_lr *= 5.0
                            print(f"     New fast LR: {rl_updater.fast_lr:.1e}")
                            
                            # Reset states to break any bad patterns
                            spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                            mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                            mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                            mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                            states = (spk1_rec, mem1, mem2, mem3)
                    
                    val_corr_x = compute_correlation(val_targets[:, 0], val_preds[:, 0])
                    val_corr_y = compute_correlation(val_targets[:, 1], val_preds[:, 1])
                    avg_val_corr = (val_corr_x + val_corr_y) / 2.0
                    
                    avg_train_loss = running_train_loss / eval_every_n
                    print(f"Epoch [{epoch+1}/{num_epochs}], Step {step_count}, "
                          f"Train Loss: {avg_train_loss:.4f}, Val Avg Corr: {avg_val_corr:.4f}")
                    
                    # Early stopping check
                    if avg_val_corr > best_val_metric:
                        best_val_metric = avg_val_corr
                        best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                        print(f"  >>> New best validation model saved with Avg Corr: {best_val_metric:.4f} <<<")
                        epochs_no_improve = 0
                    else:
                        epochs_no_improve += 1
                        if epochs_no_improve >= patience:
                            print(f"Early stopping triggered after {step_count} steps in epoch {epoch+1}.")
                            print(f"  (Patience: {patience} eval intervals = {patience * eval_every_n} steps)")
                            break
                    
                    running_train_loss = 0.0  # Reset for next interval
                    
                model.train()  # Return to training mode
                
        # Check if early stopping was triggered
        if epochs_no_improve >= patience:
            break
    
    # Load best model
    if best_model_state:
        model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
        print(f"Loaded best model with validation Corr: {best_val_metric:.4f}")
    
    return model

# %%
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from filterpy.kalman import KalmanFilter
def train_test_kalman_filter(train_rates, train_vel, test_rates, test_vel):
    """
    Trains and applies a Kalman Filter for neural decoding.
    
    Returns:
        predicted_velocities (np.array): KF predicted velocities for test data.
        kf (KalmanFilter): The trained Kalman Filter object.
    """
    print("\n--- Setting up Kalman Filter ---")
    # Ensure inputs are numpy arrays
    if torch.is_tensor(train_rates):
        train_rates = train_rates.detach().cpu().numpy()
    if torch.is_tensor(train_vel):
        train_vel   = train_vel.detach().cpu().numpy()
    if torch.is_tensor(test_rates):
        test_rates  = test_rates.detach().cpu().numpy()
    # test_vel is not used for KF training/prediction, only evaluation later

    n_timesteps_train, n_neurons = train_rates.shape
    n_timesteps_test = test_rates.shape[0]
    n_outputs = train_vel.shape[1]  # Should be 2 for [vx, vy]

    if n_outputs != 2:
        raise ValueError("This KF implementation assumes 2D velocity (vx, vy)")
    
    # --- 1. Train linear decoder velocity = W * spikes + b ---
    print("Fitting linear velocity decoder (Ridge)")
    lin_decoder = Ridge(alpha=1e-3, fit_intercept=True)
    lin_decoder.fit(train_rates, train_vel)
    W = lin_decoder.coef_.T            # shape (2, n_neurons)
    b = lin_decoder.intercept_         # shape (2,)

    # Use decoder to obtain noisy velocity measurements
    vel_meas_train = lin_decoder.predict(train_rates)
    vel_meas_test  = lin_decoder.predict(test_rates)

    # --- 1b. Remove systematic bias between decoder output and true velocity ---
    bias = vel_meas_train.mean(axis=0) - train_vel.mean(axis=0)
    vel_meas_train -= bias
    vel_meas_test  -= bias

    # --- 2. Dimensions and Dynamics ---
    dt =  0.050  # fall back to 50-ms if not defined
    dim_x = 4           # [px, py, vx, vy]
    dim_z = 2           # measurement provides velocity only

    # Constant-velocity model
    F = np.array([[1,0,dt,0],
                  [0,1,0 ,dt],
                  [0,0,1 ,0 ],
                  [0,0,0 ,1 ]], dtype=np.float32)

    H = np.array([[0,0,1,0],
                  [0,0,0,1]], dtype=np.float32)  # we observe velocity

    # --- 3. Measurement noise R from training residuals ---
    print("Estimating R from decoder residuals")
    residuals_meas = train_vel - vel_meas_train
    R = np.cov(residuals_meas.T) + np.eye(dim_z)*1e-6
    
    # --- 4. Process noise already set via decoder residuals; skip old 2×2 estimation ---

    # --- 6. Initialize Kalman Filter ---
    print("Initializing Kalman Filter...")
    kf = KalmanFilter(dim_x=dim_x, dim_z=dim_z)
    # Initialise state vector [px, py, vx, vy].
    # We set the starting position to 0 and initialise the velocity with the
    # first sample from the training data.  This yields a 4-element state
    # vector, as required by the 4×4 transition matrix F.
    initial_state = np.zeros(dim_x, dtype=np.float32)
    initial_state[2:] = train_vel[0]
    kf.x = initial_state
    kf.P = np.eye(dim_x) * 500.     # High initial uncertainty
    kf.F = F
    kf.H = H
    kf.R = R
    # Define process noise Q based on residual velocity variance
    q_vel = np.var(residuals_meas, axis=0).mean()
    q_pos = (dt**2) * q_vel
    Q = np.diag([q_pos, q_pos, q_vel, q_vel]) + 1e-6*np.eye(dim_x)
    kf.Q = Q
    
    # --- 7. Run the filter on test data ---
    print(f"Running Kalman Filter for {n_timesteps_test} test steps...")
    predicted_velocities = np.zeros((n_timesteps_test, 2))
    b = b.reshape(-1, 1) # Ensure intercept is a column vector
    start_time_kf = time.time()
    for t in range(n_timesteps_test):
        z = vel_meas_test[t]  # 1-D vector length 2
        kf.predict()
        kf.update(z)
        predicted_velocities[t] = kf.x[2:].copy()  # store vx, vy only
    end_time_kf = time.time()
    print(f"Kalman Filter processing time: {end_time_kf - start_time_kf:.2f} seconds")
    print("--- Kalman Filter Setup and Run Complete ---")
    return predicted_velocities, kf


# %%
# +++ Re-inserting LSTM Definitions +++
class LSTMRegression(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.2):
        super(LSTMRegression, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # Ensure dropout is only applied if num_layers > 1
        lstm_dropout = dropout if num_layers > 1 else 0
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=lstm_dropout)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x shape: (batch, seq_len, input_size)
        # Initialize hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        # LSTM forward pass
        out, _ = self.lstm(x, (h0, c0))
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

def train_lstm_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=50, patience=5):
    """Basic training loop for the LSTM model with early stopping (handles val_loader=None)."""
    print(f"Starting LSTM training for {num_epochs} epochs on {device}...")
    best_val_loss = float('inf')
    epochs_no_improve = 0
    train_losses = []
    val_losses = []
    best_model_state = model.state_dict() # Start with initial state as best if no validation
    best_epoch = 0

    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            if inputs.dim() == 2:
                inputs = inputs.unsqueeze(1) # Add sequence dimension if needed
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item()
        avg_train_loss = running_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        avg_val_loss = None
        # Validation phase only if val_loader is provided and not empty
        if val_loader and len(val_loader.dataset) > 0:
            model.eval()
            running_val_loss = 0.0
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    if inputs.dim() == 2:
                        inputs = inputs.unsqueeze(1)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    running_val_loss += loss.item()
                avg_val_loss = running_val_loss / len(val_loader)
            val_losses.append(avg_val_loss)
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")

            # Check for improvement and early stopping based on validation loss
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                epochs_no_improve = 0
                best_model_state = model.state_dict()
                best_epoch = epoch + 1
                print(f"  New best validation loss: {best_val_loss:.6f}")
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"Early stopping triggered after {epoch+1} epochs.")
                    break
        else:
            # If no validation, just print training loss
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}")
            # Optionally save the model from the last epoch if no validation is done
            # best_model_state = model.state_dict()
            # best_epoch = epoch + 1

    print("Finished LSTM Training.")
    # Load the best model state before returning
    if best_model_state:
        model.load_state_dict(best_model_state)
        if val_loader and len(val_loader.dataset) > 0:
            print(f"Loaded best model state from epoch {best_epoch}.")
        # else:
            # print("Loaded model state from last epoch (no validation performed).")

    return model, train_losses, val_losses
# +++ End Re-inserted LSTM Definitions +++


def create_sequences(data_x, data_y, seq_len=10):
    """Creates overlapping sequences for LSTM."""
    xs, ys = [], []
    for i in range(len(data_x) - seq_len + 1):
        x = data_x[i:(i + seq_len)]
        y = data_y[i + seq_len - 1] # Target is the last velocity in the window
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)

def run_timescale_ablation(num_sessions=10):
    """
    Comprehensive ablation study: Different eligibility trace timescale configurations AND fast/slow update combinations.
    Tests both eligibility trace timescales and update mechanism combinations with IDENTICAL training setup.
    """
    print("\n" + "="*80)
    print("COMPREHENSIVE LEARNING MECHANISMS ABLATION STUDY")
    print("Testing: Eligibility Trace Timescales + Fast/Slow Update Combinations")
    print("="*80)
    
    results = {}
    
    # Test combinations of timescale modes and update configurations
    test_configurations = [
        # # Original timescale ablation with both updates enabled
        # ("FAST_TRACES_BOTH_UPDATES", 'fast_only', True, True),
        # ("SLOW_TRACES_BOTH_UPDATES", 'slow_only', True, True),  
        # ("MEDIUM_TRACES_BOTH_UPDATES", 'medium_only', True, True),
        # ("DUAL_TRACES_BOTH_UPDATES", 'dual', True, True),
        
        # Fast update only (no slow updates) with best timescale
        ("DUAL_TRACES_FAST_ONLY", 'dual', True, False),
        
        # Slow update only (no fast updates) with best timescale  
        ("DUAL_TRACES_SLOW_ONLY", 'dual', False, True),
        
        # No updates (baseline - should show no learning)
        ("DUAL_TRACES_NO_UPDATES", 'dual', False, False),
    ]
    
    for mode_name, timescale_mode, use_fast_update, use_slow_update in test_configurations:
        print(f"\n{'='*50}")
        print(f"Testing {mode_name}")
        
        # Describe the configuration
        if timescale_mode == 'fast_only':
            trace_desc = "fast traces (τ=80ms)"
        elif timescale_mode == 'slow_only':
            trace_desc = "slow traces (τ=800ms)"
        elif timescale_mode == 'medium_only':
            trace_desc = "medium traces (τ=400ms)"
        elif timescale_mode == 'dual':
            trace_desc = "dual timescales (τ_fast=80ms, τ_slow=800ms)"
        
        update_desc = []
        if use_fast_update:
            update_desc.append("fast updates")
        if use_slow_update:
            update_desc.append("slow updates")
        if not update_desc:
            update_desc = ["no updates (baseline)"]
        
        print(f"  Traces: {trace_desc}")
        print(f"  Updates: {' + '.join(update_desc)}")
        print(f"{'='*50}")
        
        # Set seeds for reproducibility
        np.random.seed(42)
        random.seed(42)
        torch.manual_seed(42)
        basepath = os.path.expanduser('~/scratch/zenodo_dataset')
        # basepath = 'zenodo_dataset'
        all_mat_files = [f for f in os.listdir(basepath) if f.endswith('.mat')]
        
        selected_files = random.sample(all_mat_files, num_sessions)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        session_results = []
        
        for session_idx, filename in enumerate(selected_files):
            print(f"\n--- Session {session_idx + 1} ({filename}) with {mode_name} ---")
            
            try:
                data = load_data(
                    mat_file_path=os.path.join(basepath, filename),
                    bin_width_s=0.050,
                    stride_s=0.050,
                    spike_processing='count',
                    test_split_ratio=0,
                    verbose=False
                )
                
                if 'error' in data:
                    print(f"  Skipping {filename} due to load error: {data['error']}")
                    continue
                
                X_full = data.get('X_processed_full').numpy()
                y_full = data.get('y_processed_full').numpy()
                
                if X_full.shape[1] > 96:
                    X_full = X_full[:, :96]
                
                input_size = X_full.shape[1]
                
                # IDENTICAL preprocessing
                n_total = X_full.shape[0]
                n_train = int(n_total * 0.70)
                n_val = int(n_total * 0.15)
                
                X_train_norm = X_full[:n_train]
                y_train_raw = y_full[:n_train]
                X_val_norm = X_full[n_train:n_train + n_val]
                y_val_raw = y_full[n_train:n_train + n_val]
                X_test_norm = X_full[n_train + n_val:]
                y_test_raw = y_full[n_train + n_val:]
                
                # IDENTICAL normalization
                velocity_scaler = StandardScaler().fit(y_train_raw)
                y_train_norm = velocity_scaler.transform(y_train_raw)
                y_val_norm = velocity_scaler.transform(y_val_raw)
                y_test_norm = velocity_scaler.transform(y_test_raw)
                
                # IDENTICAL model setup
                snn_model = SNNRegression(input_size=input_size, hidden_size=256, output_size=2).to(device)
                
                # IDENTICAL LUT
                vals = torch.linspace(0.25, 8.0, 256, device=device)
                inv_sqrt_LUT = 1.0 / torch.sqrt(vals)
                
                # IDENTICAL training setup
                snn_val_split_ratio = 0.15
                n_snn_train_full = X_train_norm.shape[0]
                n_snn_val = int(n_snn_train_full * snn_val_split_ratio)
                n_snn_train_split = n_snn_train_full - n_snn_val
                
                X_train_snn_final = X_train_norm[:n_snn_train_split]
                y_train_snn_norm_final = y_train_norm[:n_snn_train_split]
                X_val_snn = X_train_norm[n_snn_train_split:]
                y_val_snn_norm = y_train_norm[n_snn_train_split:]
                
                # KEY DIFFERENCES: timescale_mode, use_fast_update, use_slow_update parameters
                # Adaptive learning rate based on dataset characteristics
                spike_activity = np.mean(X_train_snn_final)
                velocity_variance = np.var(y_train_snn_norm_final)
                adaptive_lr = min(max(3e-3 * (1 + spike_activity * 0.1) * (1 + velocity_variance * 0.5), 1e-4), 1e-2)
                
                snn_updater = TwoScaleMetaRLWeightUpdaterFull(
                    snn_model, 
                    base_fast_lr=adaptive_lr, 
                    base_slow_lr=1e-6, 
                    window_size=50, 
                    meta_lr=0.01,
                    online_mode=True,                   # CHANGED: online mode for all
                    use_hebbian=True,                   # IDENTICAL: use Hebbian learning for all
                    use_meta_learning=True,             # IDENTICAL: use meta-learning for all
                    timescale_mode=timescale_mode,      # ABLATION PARAMETER
                    use_fast_update=use_fast_update,    # NEW ABLATION PARAMETER
                    use_slow_update=use_slow_update,    # NEW ABLATION PARAMETER
                    inv_sqrt_LUT=inv_sqrt_LUT
                )
                
                # IDENTICAL data loading (for evaluation only)
                seq_len = 10
                batch_size = 32
                overlap_stride = max(1, seq_len // 2)
                
                # Online data streams
                train_stream = OnlineDataStream(X_train_snn_final, y_train_snn_norm_final)
                val_stream = OnlineDataStream(X_val_snn, y_val_snn_norm)
                
                # IDENTICAL training - switched to online mode
                trained_snn_model = train_snn_online(
                    snn_model, snn_updater, train_stream, val_stream, 
                    num_epochs=15, patience=20, eval_every_n=500
                )
                
                # IDENTICAL evaluation
                test_loader = CausalBatcher(
                    X_test_norm, y_test_norm,
                    batch_size=batch_size, sequence_length=seq_len,
                    shuffle=False, stride=overlap_stride)
                
                trained_snn_model.eval()
                snn_preds_norm_list, snn_targets_norm_list = [], []
                
                with torch.no_grad():
                    for input_seq, target_seq in test_loader:
                        input_seq, target_seq = input_seq.to(device), target_seq.to(device)
                        batch_size_test = input_seq.size(0)
                        test_spk1_rec = torch.zeros(batch_size_test, trained_snn_model.fc1.out_features, device=device)
                        test_mem1 = torch.zeros(batch_size_test, trained_snn_model.fc1.out_features, device=device)
                        test_mem2 = torch.zeros(batch_size_test, trained_snn_model.fc2.out_features, device=device)
                        test_mem3 = torch.zeros(batch_size_test, trained_snn_model.fc3.out_features, device=device)
                        
                        pred_seq, _ = trained_snn_model(input_seq, test_spk1_rec, test_mem1, test_mem2, test_mem3)
                        snn_preds_norm_list.append(pred_seq[:, -1, :].cpu())
                        snn_targets_norm_list.append(target_seq[:, -1, :].cpu())

                snn_preds_normalized = torch.cat(snn_preds_norm_list).numpy()
                snn_targets_normalized = torch.cat(snn_targets_norm_list).numpy()
                
                # IDENTICAL evaluation
                snn_corr_x = compute_correlation(snn_targets_normalized[:, 0], snn_preds_normalized[:, 0])
                snn_corr_y = compute_correlation(snn_targets_normalized[:, 1], snn_preds_normalized[:, 1])
                
                session_results.append({
                    'session': filename,
                    'config_name': mode_name,
                    'timescale_mode': timescale_mode,
                    'use_fast_update': use_fast_update,
                    'use_slow_update': use_slow_update,
                    'corr_x': snn_corr_x,
                    'corr_y': snn_corr_y,
                    'avg_corr': (snn_corr_x + snn_corr_y) / 2
                })
                
                print(f"  {mode_name} Results: X={snn_corr_x:.4f}, Y={snn_corr_y:.4f}, Avg={((snn_corr_x + snn_corr_y)/2):.4f}")
                
            except Exception as e:
                import traceback
                traceback.print_exc()
                print(f"  ERROR in {mode_name} run: {e}")
        
        results[mode_name] = session_results
    
    # Compare results
    print("\n" + "="*80)
    print("COMPREHENSIVE LEARNING MECHANISMS ABLATION RESULTS SUMMARY")
    print("="*80)
    
    for mode, mode_results in results.items():
        if mode_results:
            avg_corr_x = np.mean([r['corr_x'] for r in mode_results])
            avg_corr_y = np.mean([r['corr_y'] for r in mode_results])  
            avg_total = np.mean([r['avg_corr'] for r in mode_results])
            print(f"{mode:25s}: X={avg_corr_x:.4f}, Y={avg_corr_y:.4f}, Avg={avg_total:.4f}")
    
    # Detailed analysis
    print("\n" + "="*60)
    print("ANALYSIS")
    print("="*60)
    
    # 1. Trace timescale analysis (with both updates enabled)
    print("\n1. ELIGIBILITY TRACE TIMESCALE ANALYSIS:")
    trace_modes = ['FAST_TRACES_BOTH_UPDATES', 'SLOW_TRACES_BOTH_UPDATES', 
                   'MEDIUM_TRACES_BOTH_UPDATES', 'DUAL_TRACES_BOTH_UPDATES']
    
    trace_results = {mode: results.get(mode, []) for mode in trace_modes}
    best_trace_mode = max(trace_modes, 
                         key=lambda m: np.mean([r['avg_corr'] for r in trace_results[m]]) if trace_results[m] else -1)
    
    for mode in trace_modes:
        if trace_results[mode]:
            avg = np.mean([r['avg_corr'] for r in trace_results[mode]])
            marker = " ← BEST" if mode == best_trace_mode else ""
            print(f"   {mode}: {avg:.4f}{marker}")
    
    # 2. Update mechanism analysis (with dual traces)
    print("\n2. UPDATE MECHANISM ANALYSIS:")
    update_modes = ['DUAL_TRACES_BOTH_UPDATES', 'DUAL_TRACES_FAST_ONLY', 
                    'DUAL_TRACES_SLOW_ONLY', 'DUAL_TRACES_NO_UPDATES']
    
    update_results = {mode: results.get(mode, []) for mode in update_modes}
    
    for mode in update_modes:
        if update_results[mode]:
            avg = np.mean([r['avg_corr'] for r in update_results[mode]])
            print(f"   {mode}: {avg:.4f}")
    
    # 3. Key comparisons
    print("\n3. KEY COMPARISONS:")
    
    if (update_results['DUAL_TRACES_BOTH_UPDATES'] and 
        update_results['DUAL_TRACES_FAST_ONLY'] and 
        update_results['DUAL_TRACES_SLOW_ONLY']):
        
        both_avg = np.mean([r['avg_corr'] for r in update_results['DUAL_TRACES_BOTH_UPDATES']])
        fast_only_avg = np.mean([r['avg_corr'] for r in update_results['DUAL_TRACES_FAST_ONLY']])
        slow_only_avg = np.mean([r['avg_corr'] for r in update_results['DUAL_TRACES_SLOW_ONLY']])
        
        print(f"   Both updates vs Fast-only: {both_avg - fast_only_avg:+.4f}")
        print(f"   Both updates vs Slow-only: {both_avg - slow_only_avg:+.4f}")
        print(f"   Fast-only vs Slow-only: {fast_only_avg - slow_only_avg:+.4f}")
        
        # Determine if fast or slow updates are more important
        if fast_only_avg > slow_only_avg:
            print("   → Fast updates appear more critical than slow updates")
        else:
            print("   → Slow updates appear more critical than fast updates")
        
        # Check if both together provide benefit
        better_single = max(fast_only_avg, slow_only_avg)
        improvement = both_avg - better_single
        if improvement > 0.01:
            print(f"   → Combining both updates provides significant benefit: +{improvement:.4f}")
        elif improvement > 0.005:
            print(f"   → Combining both updates provides modest benefit: +{improvement:.4f}")
        else:
            print(f"   → Combining both updates provides minimal benefit: +{improvement:.4f}")
    
    # 4. Baseline check
    if update_results['DUAL_TRACES_NO_UPDATES']:
        no_update_avg = np.mean([r['avg_corr'] for r in update_results['DUAL_TRACES_NO_UPDATES']])
        print(f"\n4. BASELINE CHECK:")
        print(f"   No updates (baseline): {no_update_avg:.4f}")
        print("   → This should show minimal learning, confirming update mechanisms are necessary")
    
    return results

# Clean main execution
if __name__ == "__main__":
    print("Running Comprehensive Learning Mechanisms Ablation Study...")
    print("This study tests:")
    print("1. Different eligibility trace timescale configurations")
    print("2. Fast vs slow update mechanism combinations")
    print("3. The necessity of each update component")
    ablation_results = run_timescale_ablation(num_sessions=10)
