# %%
import h5py
import numpy as np
from sklearn.preprocessing import StandardScaler
import warnings
import torch
import snntorch as snn
from snntorch import surrogate
import numpy as np
import torch
import h5py
from sklearn.preprocessing import StandardScaler
import warnings
import os
import random
import pandas as pd
import matplotlib.pyplot as plt
import json
import time
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from filterpy.kalman import KalmanFilter
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from scipy.signal import butter, sosfiltfilt
from sklearn.linear_model import Ridge
import math
from scipy.signal import correlate, correlation_lags
import torch.nn.functional as F
from nlb_tools.nwb_interface import NWBDataset

dataset = NWBDataset("~/scratch/000128/sub-Jenkins/", "*train", split_heldout=False)
dataset.resample(10)


dataset.smooth_spk(50, name='smth_50')


# %%
DATASET = dataset

# %%
from nlb_tools.nwb_interface import NWBDataset
import pandas as pd

from torch.utils.data import Dataset
import numpy as np
import os


# --- MODIFIED Signature: Add window/stride, remove resample_ms ---
def load_data(dataset_name="mc_maze", use_cached=True, 
              bin_width_ms=100, stride_ms=10, 
              spike_processing='binary'):
# --- END MODIFIED ---
    """Load data from MC_MAZE, MC_RTT, or Area2_Bump dataset"""
    spike_rates = None
    hand_vel = None
    hand_vel_mean = None
    hand_vel_std = None

    # Define cache file based on dataset name
    cache_file = f"{dataset_name}_cached_data.pt"

    dataset = DATASET
    if dataset_name == "mc_maze":
        print("Processing MC_MAZE data from scratch...")

        # Full trial alignment
        trial_data = dataset.make_trial_data(
            align_field="move_onset_time",
            align_range=(-130, 370),
            allow_nans=False
        )
        lagged_trial_data = dataset.make_trial_data(
            align_field="move_onset_time",
            align_range=(-130, 370),
            allow_nans=False
        )

        # Intersect trial IDs to avoid mismatches
        common_ids = sorted(set(trial_data.trial_id.unique()) & set(lagged_trial_data.trial_id.unique()))

        spikes_list = []
        vel_list = []
        windowed_trial_id_list = [] # ADDED: List to store trial IDs for windowed data

        for tid in common_ids:
            td = trial_data[trial_data.trial_id == tid]
            ld = lagged_trial_data[lagged_trial_data.trial_id == tid]

            spikes = td.spikes.to_numpy()
            vel = ld.hand_vel.to_numpy()

            # --- ADDED: Sliding Window Logic (within trial) ---
            # Ensure minimum length for at least one window
            min_len_s = min(spikes.shape[0], vel.shape[0])
            if min_len_s * dataset.bin_width < bin_width_ms:
                # print(f"Skipping trial {tid}: too short for window.") # Optional: Verbose skip message
                continue # Skip trial if too short for even one window
                
            # Calculate window width and stride in steps based on original bin width
            bin_width_steps = int(round(bin_width_ms / dataset.bin_width))
            stride_steps = int(round(stride_ms / dataset.bin_width))
            
            # Determine number of bins for this trial
            T_trial = min_len_s
            num_bins_trial = (T_trial - bin_width_steps) // stride_steps + 1
            
            if num_bins_trial <= 0:
                # print(f"Skipping trial {tid}: not enough steps for bins.") # Optional: Verbose skip message
                continue # Skip if not enough steps for any bins
                
            # Apply sliding window and aggregate
            spike_windows_trial = np.stack([spikes[i*stride_steps : i*stride_steps+bin_width_steps]
                                            for i in range(num_bins_trial)])
            vel_windows_trial = np.stack([vel[i*stride_steps : i*stride_steps+bin_width_steps]
                                          for i in range(num_bins_trial)])
            
            spikes_b_trial = spike_windows_trial.sum(axis=1) # Sum spikes in window
            vel_b_trial = vel_windows_trial.mean(axis=1)    # Average velocity in window
            
            # Append windowed data (instead of raw trial data)
            spikes_list.append(spikes_b_trial)
            vel_list.append(vel_b_trial)
            # ADDED: Create and append trial ID array matching window count
            ids_for_trial_windows = np.full(num_bins_trial, tid, dtype=np.int64)
            windowed_trial_id_list.append(ids_for_trial_windows)
            # --- END ADDED ---
            # --- END Sliding Window Logic ---

        spike_rates = np.vstack(spikes_list).astype(np.float32)
        hand_vel = np.vstack(vel_list).astype(np.float32)
        # --- ADDED: Stack windowed trial IDs --- 
        windowed_trial_ids = np.concatenate(windowed_trial_id_list) # Use concatenate for 1D array
        # --- END ADDED ---

        # --- ADDED: Apply spike processing based on argument ---
        print(f"Applying spike processing: '{spike_processing}'")
        if spike_processing == 'binary':
            spike_rates = (spike_rates > 0).astype(np.float32)
        elif spike_processing == 'count':
            # Keep them as raw counts - no conversion!
            pass  # spike_rates already contains summed counts
        elif spike_processing == 'rate':
            # Convert counts to rate based on resample_ms
            bin_width_s = bin_width_ms / 1000.0  # --- MODIFIED: Use bin_width_ms for rate calculation --- 
            spike_rates = (spike_rates / bin_width_s).astype(np.float32)
        # Add other modes like 'poisson' if needed, similar to Zenodo script
        else:
            print(f"Warning: Unknown spike_processing '{spike_processing}'. Using raw counts.")
        # --- END ADDED ---

    print(f"Spike rates shape: {spike_rates.shape}")
    print(f"Hand velocity shape: {hand_vel.shape}")


    # --- MODIFIED: Return windowed_trial_ids --- 
    return spike_rates, hand_vel, windowed_trial_ids, hand_vel_mean, hand_vel_std, trial_data
    # --- END MODIFIED ---




# %%
import random
import time
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from scipy.interpolate import interp1d
import numpy as np
import pandas as pd

class CausalBatcher:
    """
    Creates batches of contiguous sequences from time-series data.
    TRIAL-AWARE: Sequences never cross trial boundaries.
    """
    def __init__(self, spike_data, targets, trial_ids, batch_size=64, sequence_length=20, shuffle=False, stride=None):
        self.spike_data = torch.tensor(spike_data, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.trial_ids = trial_ids
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.shuffle = shuffle
        self.stride = stride if stride is not None else sequence_length # Non-overlapping by default

        # Find valid sequence start indices (trial-aware)
        self.valid_starts = []
        unique_trials = np.unique(trial_ids)
        
        for trial_id in unique_trials:
            # Find all timesteps for this trial
            trial_mask = (trial_ids == trial_id)
            trial_indices = np.where(trial_mask)[0]
            
            if len(trial_indices) < self.sequence_length:
                continue  # Skip trials too short for a sequence
                
            # Create sequences within this trial
            trial_start = trial_indices[0]
            trial_end = trial_indices[-1]
            trial_length = trial_end - trial_start + 1
            
            # Generate valid sequence starts within this trial
            for start_offset in range(0, trial_length - self.sequence_length + 1, self.stride):
                seq_start = trial_start + start_offset
                seq_end = seq_start + self.sequence_length - 1
                
                # Verify the sequence doesn't cross trial boundaries
                if seq_end <= trial_end and np.all(trial_ids[seq_start:seq_end+1] == trial_id):
                    self.valid_starts.append(seq_start)

        self.num_sequences = len(self.valid_starts)
        self.num_batches = self.num_sequences // self.batch_size

        if self.num_batches == 0:
            raise ValueError(
                f"Not enough valid trial-aware sequences to create a single batch. "
                f"Have {self.num_sequences} sequences for seq_len={sequence_length}, but batch_size is {self.batch_size}."
            )

        print(f"Created {self.num_sequences} trial-aware sequences from {len(unique_trials)} trials")
        
        self.indices = np.arange(self.num_sequences)
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __len__(self):
        return self.num_batches

    def __iter__(self):
        self.current_batch = 0
        if self.shuffle: # Reshuffle for each epoch
            np.random.shuffle(self.indices)
        return self

    def __next__(self):
        if self.current_batch >= self.num_batches:
            raise StopIteration

        # Get the indices for the sequences in this batch
        start_idx = self.current_batch * self.batch_size
        end_idx = start_idx + self.batch_size
        batch_seq_indices = self.indices[start_idx:end_idx]

        # Get the actual start timestep for each sequence in the batch using valid_starts
        batch_timestep_starts = [self.valid_starts[i] for i in batch_seq_indices]

        # Build the batch of sequences (all sequences are guaranteed to be within trials)
        batch_spikes = torch.stack([
            self.spike_data[ts:ts + self.sequence_length] for ts in batch_timestep_starts
        ])
        batch_targets = torch.stack([
            self.targets[ts:ts + self.sequence_length] for ts in batch_timestep_starts
        ])

        self.current_batch += 1
        return batch_spikes, batch_targets

class OnlineDataStream:
    """
    Creates a stream of single timestep samples for true online learning.
    TRIAL-AWARE: Automatically resets states when trial changes.
    """
    def __init__(self, spike_data, targets, trial_ids, reset_every_n=None):
        self.spike_data = torch.tensor(spike_data, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.trial_ids = trial_ids
        self.reset_every_n = reset_every_n  # Additional reset every N timesteps (optional)
        self.current_idx = 0
        self.steps_since_reset = 0
        self.prev_trial_id = None
        
    def __len__(self):
        return self.spike_data.shape[0]
    
    def __iter__(self):
        self.current_idx = 0
        self.steps_since_reset = 0
        self.prev_trial_id = None
        return self
    
    def __next__(self):
        if self.current_idx >= len(self.spike_data):
            raise StopIteration
            
        x_t = self.spike_data[self.current_idx]
        y_t = self.targets[self.current_idx]
        current_trial_id = self.trial_ids[self.current_idx]
        
        should_reset = False
        
        # Reset if trial changes (most important for trialized data)
        if self.prev_trial_id is not None and current_trial_id != self.prev_trial_id:
            should_reset = True
            self.steps_since_reset = 0
            # FIXED: Rate-limit the print messages
            if self.current_idx < 10 or self.current_idx % 500 == 0:
                print(f"Trial change detected: {self.prev_trial_id} -> {current_trial_id}, resetting states")
        
        # Additional reset every N timesteps (optional)
        elif self.reset_every_n and self.steps_since_reset >= self.reset_every_n:
            should_reset = True
            self.steps_since_reset = 0
        
        self.prev_trial_id = current_trial_id
        self.current_idx += 1
        self.steps_since_reset += 1
        
        return x_t, y_t, should_reset

class SNNRegression(nn.Module):
    def __init__(self, input_size, hidden_size=128, output_size=2):
        super(SNNRegression, self).__init__()
        spike_grad = surrogate.fast_sigmoid()

        # Feedforward & Recurrent Layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc_rec = nn.Linear(hidden_size, hidden_size, bias=False) # RESTORED
        self.lif1 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)

        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.lif2 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)

        self.fc3 = nn.Linear(hidden_size // 2, output_size)
        self.lif3 = snn.Leaky(
            beta=0.5,
            spike_grad=spike_grad,
            init_hidden=False,
            threshold=1.0,
            reset_mechanism="none"
        )
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                # give fc1 / fc2 a small negative bias; keep others at 0
                if module in (self.fc1, self.fc2):
                    nn.init.constant_(module.bias, -0.1)
                else:
                    nn.init.zeros_(module.bias)

    def forward(self, x, spk1_rec, mem1, mem2, mem3, need_traces: bool = False):
        """
        Args
        ----
        x           : [batch, T, features]
        spk1_rec …  : recurrent + membrane states (same as before)
        need_traces : if True, also returns the "pre / post" tensors
                      and next membrane potentials required by the
                      Hebbian updater.
        """
        batch, T, _ = x.shape
        outputs = []
        
        # Buffers only filled when learning
        if need_traces:
            pre_ff_buf, pre_rec_buf, spk1_buf, spk2_buf = [], [], [], []
            mem1_next_buf, mem2_next_buf, mem3_next_buf = [], [], []

        for t in range(x.size(1)):
            inp = x[:, t, :]
            
            # Detach states for BPTT-like behavior within the sequence
            mem1 = mem1.detach()
            mem2 = mem2.detach()
            spk1_rec = spk1_rec.detach()
            mem3 = mem3.detach()
            
            pre_rec_t_for_trace = spk1_rec

            # Combine feedforward and recurrent input for the first layer
            cur1 = self.fc1(inp) + self.fc_rec(spk1_rec)
            spk1, mem1 = self.lif1(cur1, mem1)

            # Update the recurrent state for the next time step
            spk1_rec = spk1
            
            # Second hidden layer
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            # Output layer
            cur3 = self.fc3(spk2)
            out, mem3 = self.lif3(cur3, mem3)
            outputs.append(mem3)

            # --- Collect traces if asked ---
            if need_traces:
                pre_ff_buf.append(inp.detach())
                pre_rec_buf.append(pre_rec_t_for_trace.detach())
                spk1_buf.append(spk1.detach())
                spk2_buf.append(spk2.detach())
                mem1_next_buf.append(mem1.detach())
                mem2_next_buf.append(mem2.detach())
                mem3_next_buf.append(mem3.detach())

        out_seq = torch.stack(outputs, dim=1)
        final_states = (spk1_rec, mem1, mem2, mem3)

        if need_traces:
            traces = (
                torch.stack(pre_ff_buf, dim=1),
                torch.stack(pre_rec_buf, dim=1),
                torch.stack(spk1_buf, dim=1),
                torch.stack(spk2_buf, dim=1),
                torch.stack(mem1_next_buf, dim=1),
                torch.stack(mem2_next_buf, dim=1),
                torch.stack(mem3_next_buf, dim=1)
            )
            return out_seq, final_states, traces
        else:
            return out_seq, final_states


def compute_correlation(pred, target):
    """
    Compute Pearson correlation between predicted and target values.
    Handles both tensor and numpy inputs.
    """
    # Convert to numpy if tensors
    if torch.is_tensor(pred):
        pred = pred.detach().cpu().numpy()
    if torch.is_tensor(target):
        target = target.detach().cpu().numpy()
    
    # Handle multi-dimensional arrays by flattening
    pred = pred.flatten()
    target = target.flatten()
    
    # Check for zero variance
    if np.std(pred) < 1e-10 or np.std(target) < 1e-10:
        return 0.0
    
    # Compute correlation safely
    try:
        corr_matrix = np.corrcoef(pred, target)
        # Handle case when corrcoef returns a scalar or 2x2 matrix
        if corr_matrix.size > 1:
            return corr_matrix[0, 1]
        else:
            # If identical arrays, correlation is 1.0
            return 1.0
    except (IndexError, ValueError):
        # Fallback for any errors
        return 0.0

def compute_fvaf(y_true, y_pred):
    """
    Computes the Fraction of Variance Accounted For (FVAF), equivalent to R^2 score.
    Handles both tensor and numpy inputs. Expects 1D arrays.
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()

    if y_true.ndim > 1 or y_pred.ndim > 1:
        warnings.warn(f"FVAF expects 1D arrays, but got shapes {y_true.shape} and {y_pred.shape}. Flattening.", RuntimeWarning)
        y_true = y_true.flatten()
        y_pred = y_pred.flatten()

    ss_res = np.sum((y_true - y_pred)**2)
    ss_tot = np.sum((y_true - np.mean(y_true))**2)

    if ss_tot < 1e-10:
        return 1.0 if ss_res < 1e-10 else 0.0

    return 1 - (ss_res / ss_tot)

def butter_lowpass(data, fs=1/0.05, cutoff=8, order=4):
    sos = butter(order, cutoff, 'low', fs=fs, output='sos')
    return sosfiltfilt(sos, data, axis=0)

# %%
##########################################
# TWO-TIMESCALE META RL UPDATER WITH META-LEARNING
##########################################
class TwoScaleMetaRLWeightUpdaterFull:
    """
    This updater adapts all layers using configurable timescales and incorporates meta-learning.
    It uses eligibility traces to integrate credit assignment over time, 
    providing memory of recent activity without acausally looking into the future. 
    The learning rate is modulated by a biologically-inspired reward signal based on performance.
    
    This approach integrates Hebbian principles with the needs of the BCI task.
    """
    def __init__(self, model, base_fast_lr=1e-3, base_slow_lr=1e-2, window_size=180, 
                 meta_lr=1e-3, tau_e_fast=0.12, tau_e_slow=0.7, online_mode=False,
                 use_hebbian=True,  # ABLATION: Enable/disable Hebbian learning
                 use_meta_learning=True,  # ABLATION: Enable/disable meta-learning
                 rms_mode='with_rms',  # NEW ABLATION: 'with_rms', 'without_rms', 'partial_rms'
                 **kwargs):
        self.model = model
        self.device = next(model.parameters()).device
        self.base_fast_lr = base_fast_lr
        self.base_slow_lr = base_slow_lr
        self.meta_lr = meta_lr
        self.grad_scale = 0
        self.online_mode = online_mode
        self.use_hebbian = use_hebbian  # FIX: Use the parameter, not hardcoded False!
        self.use_meta_learning = use_meta_learning  # Meta-learning ablation parameter
        self.rms_mode = rms_mode  # NEW: RMS normalization ablation parameter
        
        # Print ablation status for clarity
        learning_type = "Hebbian (error × d_lif × activity)" if self.use_hebbian else "Delta Rule (error × activity)"
        meta_status = "WITH meta-learning" if self.use_meta_learning else "WITHOUT meta-learning (fixed LR)"
        rms_status = f"RMS mode: {rms_mode.upper()}"
        print(f"Initializing updater with {learning_type} {meta_status} {rms_status}")

        # RMS accumulators for error and spike normalization (Tweaks A & B)
        # Use per-unit, integer-based EMAs for hardware friendliness
        # Initialize based on RMS mode
        if self.rms_mode in ['with_rms', 'partial_rms']:
            self.err2_sq_ema = torch.ones(model.fc2.out_features, dtype=torch.int32, device=self.device)
            self.err1_sq_ema = torch.ones(model.fc1.out_features, dtype=torch.int32, device=self.device)
            self.spk1_rms    = torch.ones(model.fc1.out_features, dtype=torch.int32, device=self.device)
        else:  # without_rms mode
            self.err2_sq_ema = None
            self.err1_sq_ema = None
            self.spk1_rms = None
        
        # Meta-learning parameters - only used if use_meta_learning=True
        self.meta_params = {
            'plasticity': 1.0, 
            'sensitivity': 1.0
        } 
        
        # Set learning rates based on meta-learning mode
        if self.use_meta_learning:
            self.fast_lr = self.base_fast_lr * self.meta_params['plasticity']
            self.slow_lr = self.base_slow_lr * self.meta_params['sensitivity']
        else:
            # Fixed learning rates when meta-learning is disabled
            self.fast_lr = self.base_fast_lr
            self.slow_lr = self.base_slow_lr

        # --- Hardware-friendly LUTs ---
        # Conditional initialization based on RMS mode
        if self.rms_mode in ['with_rms', 'partial_rms']:
            # Passed in from the main training script
            self.inv_sqrt_LUT = kwargs.get('inv_sqrt_LUT', None)
            if self.inv_sqrt_LUT is None:
                raise ValueError("inv_sqrt_LUT must be provided to the updater when using RMS normalization.")
            self.out_rms = torch.ones(2, dtype=torch.int32, device=self.device)  # vx, vy
        else:  # without_rms mode
            self.inv_sqrt_LUT = None
            self.out_rms = None
            
        # LUT for error-bucket based reward scaling (always used)
        self.reward_LUT = torch.tensor(
       [15,14,13,12,11,10,9,8,8,8,7,7,6,5,4,3], device=self.device)
        self.fc3_row_cap_q12 = torch.full((2, 1),
                                     int(6.0*4096),
                                     dtype=torch.int32,
                                     device=self.device)

        # --- Standard Dual-Timescale Eligibility Trace Parameters ---
        # Use standard dual timescale system (not being ablated in this study)
        if online_mode:
            self.decay_fast = math.exp(-0.064 / (tau_e_fast * 0.5))  # Faster decay
            self.decay_slow = math.exp(-0.064 / (tau_e_slow * 0.8))  # Slightly faster decay
            self.trace_mix_a = 0.8  # More weight on fast traces for responsiveness
        else:
            self.decay_fast = math.exp(-0.064 / tau_e_fast)
            self.decay_slow = math.exp(-0.064 / tau_e_slow)
            self.trace_mix_a = 0.5  # Fixed mixing parameter for fast and slow traces

        # Initialize eligibility traces (dual-timescale)
        self.e_fast_fc1  = torch.zeros_like(model.fc1.weight, device=self.device)
        self.e_fast_fc2  = torch.zeros_like(model.fc2.weight, device=self.device)
        self.e_fast_fc3  = torch.zeros_like(model.fc3.weight, device=self.device)
        self.e_fast_rec  = torch.zeros_like(model.fc_rec.weight, device=self.device)
        
        self.e_slow_fc1  = torch.zeros_like(model.fc1.weight, device=self.device)
        self.e_slow_fc2  = torch.zeros_like(model.fc2.weight, device=self.device)
        self.e_slow_fc3  = torch.zeros_like(model.fc3.weight, device=self.device)
        self.e_slow_rec  = torch.zeros_like(model.fc_rec.weight, device=self.device)
        # --- End Eligibility Trace ---

        self.window_size = window_size
        self.step_counter = 0
        self.cumulative_error = 0
        self.prev_cumulative_error = None
        
        self.max_loss_value = 5.0
        
        # Initialize gradient averages with zeros matching model layer shapes and device
        self.grad_fc1_avg = torch.zeros_like(self.model.fc1.weight.data, device=self.device)
        self.grad_fc2_avg = torch.zeros_like(self.model.fc2.weight.data, device=self.device)
        self.grad_fc3_avg = torch.zeros_like(self.model.fc3.weight.data, device=self.device)
        self.grad_fc_rec_avg = torch.zeros_like(self.model.fc_rec.weight.data, device=self.device) # RESTORED
        
        # Initialize momentum for gradient accumulation
        self.momentum = 0.9

    def reset_weights_with_xavier(self):
        """
        Reset weights using Xavier/Glorot initialization for better convergence
        after catastrophic errors.
        """
        with torch.no_grad():
            for module in [self.model.fc1, self.model.fc2, self.model.fc3, self.model.fc_rec]: # RESTORED self.model.fc_rec
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
        print("Weights reset using Xavier/Glorot initialization for better convergence")

    def warm_start_rms_emas(self, warmup_data_x, warmup_data_y, num_warmup_steps=100):
        """
        Warm-start the RMS EMAs with some initial data to avoid cold-start issues in online mode.
        """
        if not self.online_mode:
            return
            
        print(f"Warm-starting RMS EMAs with {num_warmup_steps} steps...")
        with torch.no_grad():
            for i in range(min(num_warmup_steps, len(warmup_data_x))):
                x_t = warmup_data_x[i:i+1].to(self.device)  # [1, features]
                y_t = warmup_data_y[i:i+1].to(self.device)  # [1, 2]
                
                # Quick forward pass to get some initial error statistics
                x_t_seq = x_t.unsqueeze(1)  # [1, 1, features]
                y_t_seq = y_t.unsqueeze(1)  # [1, 1, 2]
                
                # Initialize temporary states
                batch_size = 1
                spk1_rec = torch.zeros(batch_size, self.model.fc1.out_features, device=self.device)
                mem1 = torch.zeros(batch_size, self.model.fc1.out_features, device=self.device)
                mem2 = torch.zeros(batch_size, self.model.fc2.out_features, device=self.device)
                mem3 = torch.zeros(batch_size, self.model.fc3.out_features, device=self.device)
                
                pred_seq, _, _ = self.model(x_t_seq, spk1_rec, mem1, mem2, mem3, need_traces=True)
                output_error = y_t_seq - pred_seq
                
                # Update output RMS - need to sum over batch AND sequence dims to get [2] shape
                k_out = 2  # Faster adaptation for warmup
                sq_out = ((output_error**2) * 4096).sum(dim=(0,1)).to(torch.int32)  # Sum over batch and seq dims
                self.out_rms -= self.out_rms >> k_out
                self.out_rms += sq_out >> k_out
                
        print("RMS EMA warm-start complete.")

    def fast_update_single_timestep(self, x_t, y_t, states):
        """
        Online single-timestep update optimized for real-time learning.
        Maintains persistent hidden states and applies scaled learning rates.
        """
        if not self.online_mode:
            raise ValueError("Single timestep update only available in online mode")
            
        # Reshape for model forward pass: [batch=1, seq=1, features]
        x_t_seq = x_t.unsqueeze(0).unsqueeze(0)  # [1, 1, features]
        y_t_seq = y_t.unsqueeze(0).unsqueeze(0)  # [1, 1, 2]
        
        device = self.device
        spk1_rec, mem1, mem2, mem3 = states

        # Single forward pass to get prediction and traces
        pred_seq, final_states, traces = self.model(
            x_t_seq, spk1_rec, mem1, mem2, mem3, need_traces=True
        )
        (pre_ff_all, pre_rec_all, spk1_all, spk2_all, 
         mem1_next_all, mem2_next_all, mem3_next_all) = traces
        
        # Calculate loss for single timestep
        mse_loss = F.mse_loss(pred_seq, y_t_seq)
        combined_loss = torch.clamp(mse_loss, 0.0, 10.0)
        
        # Hebbian update calculation for single timestep
        with torch.no_grad():
            # Get single timestep data
            pre_ff_t = pre_ff_all[:, 0, :]    # [1, features]
            pre_rec_t = pre_rec_all[:, 0, :]  # [1, hidden]
            spk1_t = spk1_all[:, 0, :]        # [1, hidden]
            spk2_t = spk2_all[:, 0, :]        # [1, hidden//2]
            mem1_next = mem1_next_all[:, 0, :] # [1, hidden]
            mem2_next = mem2_next_all[:, 0, :] # [1, hidden//2]
            mem3_next = mem3_next_all[:, 0, :] # [1, 2]
            
            # Single timestep error
            output_error_t = y_t_seq[:, 0, :] - pred_seq[:, 0, :]  # [1, 2]
            
            # Calculate bucket-based reward with online-adapted EMAs
            abs_err_xy = torch.abs(output_error_t).mean(dim=0)  # [2]
            bucket_xy = (abs_err_xy * 8).clamp(0, 15).to(torch.int64)
            bucket = torch.max(bucket_xy)
            lr_scale = self.reward_LUT[bucket]
            
            # Scale learning rate for single timestep (was calibrated for seq_len=10)
            effective_lr = self.fast_lr * lr_scale / 16 * 0.1  # 0.1 scale factor for single timestep
            
            # RMS normalization (conditional on rms_mode)
            if self.rms_mode == 'with_rms':
                # Online-adapted EMA constants (faster adaptation)
                k_out = 2  # Faster than batch mode (was 4)
                k = 3      # Faster than batch mode (was 5)
                
                # Normalize output error per-channel
                sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)
                self.out_rms -= self.out_rms >> k_out
                self.out_rms += sq_out >> k_out
                idx_out = (self.out_rms >> 8).clamp_(0, 255)
                norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]
                
                # Backpropagate error through layers with online-adapted normalization
                err2_raw = torch.matmul(norm_err, self.model.fc3.weight)
                sq_err2 = ((err2_raw**2) * 4096).sum(dim=0).to(torch.int32)
                self.err2_sq_ema -= self.err2_sq_ema >> k
                self.err2_sq_ema += sq_err2 >> k
                idx = (self.err2_sq_ema >> 8).clamp_(0, 255)
                hidden2_error_t = err2_raw * self.inv_sqrt_LUT[idx]

                err1_raw = torch.matmul(hidden2_error_t, self.model.fc2.weight)
                sq_err1 = ((err1_raw**2) * 4096).sum(dim=0).to(torch.int32)
                self.err1_sq_ema -= self.err1_sq_ema >> k
                self.err1_sq_ema += sq_err1 >> k
                idx = (self.err1_sq_ema >> 8).clamp_(0, 255)
                hidden1_error_t = err1_raw * self.inv_sqrt_LUT[idx]
            elif self.rms_mode == 'partial_rms':
                # Only normalize output error, not hidden errors
                k_out = 2
                sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)
                self.out_rms -= self.out_rms >> k_out
                self.out_rms += sq_out >> k_out
                idx_out = (self.out_rms >> 8).clamp_(0, 255)
                norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]
                
                # Backpropagate without normalization
                hidden2_error_t = torch.matmul(norm_err, self.model.fc3.weight)
                hidden1_error_t = torch.matmul(hidden2_error_t, self.model.fc2.weight)
            else:  # without_rms
                # No normalization at all
                norm_err = output_error_t
                hidden2_error_t = torch.matmul(norm_err, self.model.fc3.weight)
                hidden1_error_t = torch.matmul(hidden2_error_t, self.model.fc2.weight)
            
            # Local sensitivities
            d_lif3_t = self.model.lif3.spike_grad(mem3_next)
            d_lif2_t = self.model.lif2.spike_grad(mem2_next)
            d_lif1_t = self.model.lif1.spike_grad(mem1_next)

            # Normalize spike activities (conditional on RMS mode)
            if self.rms_mode in ['with_rms', 'partial_rms']:
                sq_spk1 = ((spk1_t**2) * 4096).sum(dim=0).to(torch.int32)
                self.spk1_rms += -(self.spk1_rms >> 2) + (sq_spk1 >> 2)  # Faster adaptation (k=2)
                idx = (self.spk1_rms >> 8).clamp_(0, 255)
                spk1_t = spk1_t * self.inv_sqrt_LUT[idx]
            # else: no spike normalization in without_rms mode

            # Compute weight updates using either Hebbian or Delta rule
            if self.use_hebbian:
                # Original Hebbian updates with surrogate gradients
                hebb_fc3 = torch.matmul((norm_err * d_lif3_t).t(), spk2_t)
                hebb_fc2 = torch.matmul((hidden2_error_t * d_lif2_t).t(), spk1_t)
                hebb_fc1 = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_ff_t)
                hebb_rec = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_rec_t)
            else:
                # Delta rule updates (no surrogate gradients)
                hebb_fc3 = torch.matmul(norm_err.t(), spk2_t)
                hebb_fc2 = torch.matmul(hidden2_error_t.t(), spk1_t)
                hebb_fc1 = torch.matmul(hidden1_error_t.t(), pre_ff_t)
                hebb_rec = torch.matmul(hidden1_error_t.t(), pre_rec_t)

            # Update eligibility traces with computed terms (dual-timescale)
            self.e_fast_fc3.mul_(self.decay_fast).add_(hebb_fc3)
            self.e_fast_fc2.mul_(self.decay_fast).add_(hebb_fc2)
            self.e_fast_fc1.mul_(self.decay_fast).add_(hebb_fc1)
            self.e_fast_rec.mul_(self.decay_fast).add_(hebb_rec)
            
            self.e_slow_fc3.mul_(self.decay_slow).add_(hebb_fc3)
            self.e_slow_fc2.mul_(self.decay_slow).add_(hebb_fc2)
            self.e_slow_fc1.mul_(self.decay_slow).add_(hebb_fc1)
            self.e_slow_rec.mul_(self.decay_slow).add_(hebb_rec)

            # Apply updates using combined dual-timescale traces
            e_comb_fc1 = self.trace_mix_a * self.e_fast_fc1 + (1 - self.trace_mix_a) * self.e_slow_fc1
            e_comb_fc2 = self.trace_mix_a * self.e_fast_fc2 + (1 - self.trace_mix_a) * self.e_slow_fc2
            e_comb_fc3 = self.trace_mix_a * self.e_fast_fc3 + (1 - self.trace_mix_a) * self.e_slow_fc3
            e_comb_rec = self.trace_mix_a * self.e_fast_rec + (1 - self.trace_mix_a) * self.e_slow_rec

            self.model.fc1.weight.data += effective_lr * e_comb_fc1
            self.model.fc2.weight.data += effective_lr * e_comb_fc2
            self.model.fc3.weight.data += effective_lr * e_comb_fc3
            self.model.fc_rec.weight.data += effective_lr * e_comb_rec
            
            # Accumulate for slow updates
            self.grad_fc1_avg = self.momentum * self.grad_fc1_avg + (1 - self.momentum) * e_comb_fc1
            self.grad_fc2_avg = self.momentum * self.grad_fc2_avg + (1 - self.momentum) * e_comb_fc2
            self.grad_fc3_avg = self.momentum * self.grad_fc3_avg + (1 - self.momentum) * e_comb_fc3
            self.grad_fc_rec_avg = self.momentum * self.grad_fc_rec_avg + (1 - self.momentum) * e_comb_rec
            self.grad_scale += 1

            # Renormalize weights
            for w in [self.model.fc1.weight, self.model.fc2.weight, self.model.fc3.weight, self.model.fc_rec.weight]:
                self.renorm_row_(w)
        
        return combined_loss.item(), final_states
    
    def update_single_timestep(self, x_t, y_t, states):
        """
        Main interface for online single-timestep updates.
        """
        x_t, y_t = x_t.to(self.device), y_t.to(self.device)
        
        loss, new_states = self.fast_update_single_timestep(x_t, y_t, states)
        
        bounded_loss = min(loss, self.max_loss_value)
        
        if bounded_loss > self.max_loss_value * 1.5:
            print(f"CATASTROPHIC ERROR DETECTED: {bounded_loss:.4f}. Performing Xavier weight reset.")
            self.reset_weights_with_xavier()
            return bounded_loss, new_states
            
        self.cumulative_error += bounded_loss
        self.step_counter += 1
        
        if self.step_counter >= self.window_size:
            self.slow_update()
        
        return bounded_loss, new_states
              
    def fast_update(self, input_sequence, desired_sequence, initial_states):
        """
        Novel biologically plausible learning rule for a SEQUENCE using an eligibility trace.
        It integrates Hebbian updates over the entire window before applying.
        """
        batch_size, seq_len, _ = input_sequence.shape
        device = self.device

        # Unpack initial states for the sequence
        spk1_rec, mem1, mem2, mem3 = initial_states

        # NOTE: Eligibility traces are NOT reset here. They persist and decay across batches
        # to maintain a continuous memory of activity, as per best practices.

        # SINGLE forward pass to get predictions AND traces for Hebbian update
        pred_sequence, final_states, traces = self.model(
            input_sequence, *initial_states, need_traces=True
        )
        (pre_ff_all, pre_rec_all, spk1_all, spk2_all, 
         mem1_next_all, mem2_next_all, mem3_next_all) = traces
        
        # Calculate loss over the whole sequence
        mse_loss = F.mse_loss(pred_sequence, desired_sequence)
        combined_loss = torch.clamp(mse_loss, 0.0, 10.0)
        
        # Hebbian update calculation
        with torch.no_grad():
            # Calculate a 4-bit error bucket reward signal to modulate plasticity.
            output_error_t_for_reward = desired_sequence - pred_sequence
            abs_err_xy = torch.mean(torch.abs(output_error_t_for_reward), dim=(0,1))  # [2]
            bucket_xy  = (abs_err_xy * 8).clamp(0,15).to(torch.int64)      # 4-bit each
            bucket     = torch.max(bucket_xy)                              # int64 scalar
            lr_scale   = self.reward_LUT[bucket]
            effective_lr = self.fast_lr * lr_scale / 16 # Use integer division

            # NO second pass needed. Use the returned traces.
            for t in range(seq_len):
                # Get activations from the pre-computed traces
                pre_ff_t = pre_ff_all[:, t, :]
                pre_rec_t = pre_rec_all[:, t, :]
                spk1_t = spk1_all[:, t, :]
                spk2_t = spk2_all[:, t, :]
                mem1_next = mem1_next_all[:, t, :]
                mem2_next = mem2_next_all[:, t, :]
                mem3_next = mem3_next_all[:, t, :]
                
                # Propagate error backward from this time step's prediction
                output_error_t = desired_sequence[:, t, :] - pred_sequence[:, t, :]
                
                # RMS normalization (conditional on rms_mode)
                if self.rms_mode == 'with_rms':
                    # Patch A: Normalize output error per-channel before backprop
                    k_out = 4           # EMA 1/16
                    sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)  # [2]
                    self.out_rms -= self.out_rms >> k_out
                    self.out_rms += sq_out       >> k_out
                    idx_out = (self.out_rms >> 8).clamp_(0,255)
                    norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]   # element-wise
                    
                    # Tweak A: RMS-Rescale the local errors (Hardware-Friendly Version)
                    k = 5 # EMA constant: 2^-5 = 0.03125
                    
                    err2_raw = torch.matmul(norm_err, self.model.fc3.weight)     # [B,H2]
                    sq_err2 = ((err2_raw**2) * 4096).sum(dim=0).to(torch.int32) # Sum over batch
                    self.err2_sq_ema -= self.err2_sq_ema >> k
                    self.err2_sq_ema += sq_err2 >> k
                    idx = (self.err2_sq_ema >> 8).clamp_(0,255)
                    hidden2_error_t  = err2_raw * self.inv_sqrt_LUT[idx]

                    err1_raw = torch.matmul(hidden2_error_t, self.model.fc2.weight)    # [B,H1]
                    sq_err1 = ((err1_raw**2) * 4096).sum(dim=0).to(torch.int32)
                    self.err1_sq_ema -= self.err1_sq_ema >> k
                    self.err1_sq_ema += sq_err1 >> k
                    idx = (self.err1_sq_ema >> 8).clamp_(0,255)
                    hidden1_error_t  = err1_raw * self.inv_sqrt_LUT[idx]
                elif self.rms_mode == 'partial_rms':
                    # Only normalize output error, not hidden errors
                    k_out = 4
                    sq_out = ((output_error_t**2) * 4096).sum(dim=0).to(torch.int32)
                    self.out_rms -= self.out_rms >> k_out
                    self.out_rms += sq_out >> k_out
                    idx_out = (self.out_rms >> 8).clamp_(0, 255)
                    norm_err = output_error_t * self.inv_sqrt_LUT[idx_out]
                    
                    # Backpropagate without normalization
                    hidden2_error_t = torch.matmul(norm_err, self.model.fc3.weight)
                    hidden1_error_t = torch.matmul(hidden2_error_t, self.model.fc2.weight)
                else:  # without_rms
                    # No normalization at all
                    norm_err = output_error_t
                    hidden2_error_t = torch.matmul(norm_err, self.model.fc3.weight)
                    hidden1_error_t = torch.matmul(hidden2_error_t, self.model.fc2.weight)
                
                # Get local sensitivities (surrogate gradients)
                d_lif3_t = self.model.lif3.spike_grad(mem3_next)
                d_lif2_t = self.model.lif2.spike_grad(mem2_next)
                d_lif1_t = self.model.lif1.spike_grad(mem1_next)

                # Tweak B: RMS-Rescale spike activities (conditional on RMS mode)
                if self.rms_mode in ['with_rms', 'partial_rms']:
                    # Note: spk1_t is already detached from previous ops
                    sq_spk1 = ((spk1_t**2) * 4096).sum(dim=0).to(torch.int32)
                    self.spk1_rms += - (self.spk1_rms >> 3) + (sq_spk1 >> 3) # k=3 for spikes
                    idx = (self.spk1_rms >> 8).clamp_(0,255)
                    spk1_t = spk1_t * self.inv_sqrt_LUT[idx]
                # else: no spike normalization in without_rms mode

                # Update both fast and slow eligibility traces with the new instantaneous Hebbian term
                if self.use_hebbian:
                    # Hebbian updates with surrogate gradients
                    hebb_fc3 = torch.matmul((norm_err * d_lif3_t).t(), spk2_t)
                    hebb_fc2 = torch.matmul((hidden2_error_t * d_lif2_t).t(), spk1_t)
                    hebb_fc1 = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_ff_t)
                    hebb_rec = torch.matmul((hidden1_error_t * d_lif1_t).t(), pre_rec_t)
                else:
                    # Delta rule updates (no surrogate gradients)
                    hebb_fc3 = torch.matmul(norm_err.t(), spk2_t)
                    hebb_fc2 = torch.matmul(hidden2_error_t.t(), spk1_t)
                    hebb_fc1 = torch.matmul(hidden1_error_t.t(), pre_ff_t)
                    hebb_rec = torch.matmul(hidden1_error_t.t(), pre_rec_t)
                
                # Update eligibility traces (dual-timescale)
                self.e_fast_fc3.mul_(self.decay_fast).add_(hebb_fc3)
                self.e_fast_fc2.mul_(self.decay_fast).add_(hebb_fc2)
                self.e_fast_fc1.mul_(self.decay_fast).add_(hebb_fc1)
                self.e_fast_rec.mul_(self.decay_fast).add_(hebb_rec)
                
                self.e_slow_fc3.mul_(self.decay_slow).add_(hebb_fc3)
                self.e_slow_fc2.mul_(self.decay_slow).add_(hebb_fc2)
                self.e_slow_fc1.mul_(self.decay_slow).add_(hebb_fc1)
                self.e_slow_rec.mul_(self.decay_slow).add_(hebb_rec)

                # NO state update at end of loop

            # Apply updates using the final state of the eligibility traces
            # The effective_lr is now calculated once for the whole sequence at the top
            
            # Combine traces for the update (dual-timescale)
            e_comb_fc1 = self.trace_mix_a * self.e_fast_fc1 + (1 - self.trace_mix_a) * self.e_slow_fc1
            e_comb_fc2 = self.trace_mix_a * self.e_fast_fc2 + (1 - self.trace_mix_a) * self.e_slow_fc2
            e_comb_fc3 = self.trace_mix_a * self.e_fast_fc3 + (1 - self.trace_mix_a) * self.e_slow_fc3
            e_comb_rec = self.trace_mix_a * self.e_fast_rec + (1 - self.trace_mix_a) * self.e_slow_rec

            self.model.fc1.weight.data += effective_lr * e_comb_fc1
            self.model.fc2.weight.data += effective_lr * e_comb_fc2
            self.model.fc3.weight.data += effective_lr * e_comb_fc3
            self.model.fc_rec.weight.data += effective_lr * e_comb_rec
            
            # Accumulate combined traces for the slow update path
            self.grad_fc1_avg = self.momentum * self.grad_fc1_avg + (1 - self.momentum) * e_comb_fc1
            self.grad_fc2_avg = self.momentum * self.grad_fc2_avg + (1 - self.momentum) * e_comb_fc2
            self.grad_fc3_avg = self.momentum * self.grad_fc3_avg + (1 - self.momentum) * e_comb_fc3
            self.grad_fc_rec_avg = self.momentum * self.grad_fc_rec_avg + (1 - self.momentum) * e_comb_rec
            self.grad_scale += 1

            # Renormalize all weights after every fast step to prevent explosion/drift
            for w in [self.model.fc1.weight, self.model.fc2.weight, self.model.fc3.weight, self.model.fc_rec.weight]:
                self.renorm_row_(w)
            
        x_corr=0
        y_corr=0
        return combined_loss.item(), x_corr, y_corr

    def _rms(self, x, eps=1e-5):
        return x / (torch.sqrt(torch.mean(x**2)) + eps)

    def renorm_row_(self, w):
       l2 = torch.sum((w.to(torch.int32) * w), dim=1, keepdim=True)          # Q12
       caps = (self.fc3_row_cap_q12 if w.shape[0] == 2
               else torch.full_like(l2, int(6.0 * 4096)))                    # scalar cap
       mask = l2 > caps
       if mask.any():
           shift = (torch.log2(l2[mask].float()) - torch.log2(caps[mask].float())
                   ).ceil().to(torch.int32)
           w_int = w.data.to(torch.int32)
           w_int[mask.squeeze(1)] >>= shift.squeeze(1)                       # power-of-two divide
           w.data.copy_(w_int.to(torch.float32))                                  # back to fp32 for PyTorch

    def normalize_weights(self, weights, max_norm=1.0):
        # This function is no longer used by the fast update path.
        # Kept for potential other uses, but can be removed.
        norm = torch.norm(weights, p=2, dim=1, keepdim=True)
        scale = torch.clamp(norm, max=max_norm)
        return weights * (scale / (norm + 1e-8))

    def slow_update(self):
        with torch.no_grad():
            if self.grad_scale > 0:
                self.grad_fc1_avg /= self.grad_scale
                self.grad_fc2_avg /= self.grad_scale
                self.grad_fc3_avg /= self.grad_scale
                self.grad_fc_rec_avg /= self.grad_scale

                self.model.fc1.weight.data += self.slow_lr * self._rms(self.grad_fc1_avg)
                self.model.fc2.weight.data += self.slow_lr * self._rms(self.grad_fc2_avg)
                self.model.fc3.weight.data += self.slow_lr * self._rms(self.grad_fc3_avg)
                self.model.fc_rec.weight.data += self.slow_lr * self._rms(self.grad_fc_rec_avg)

        # Apply a light weight decay to all layers during the slow update
        with torch.no_grad():
            for w in [self.model.fc1.weight, self.model.fc2.weight,
                      self.model.fc3.weight, self.model.fc_rec.weight]:
                w.mul_(0.999995) # ≈ 1e-5 weight-decay per slow step
        
        current_avg_loss = self.cumulative_error / self.step_counter if self.step_counter > 0 else float('inf')
        current_avg_loss = min(current_avg_loss, self.max_loss_value)
        
        # Meta-learning adaptation - only if enabled
        if self.use_meta_learning and self.prev_cumulative_error is not None:
            prev_loss = min(self.prev_cumulative_error, self.max_loss_value)
            if current_avg_loss < prev_loss:
                self.meta_params['plasticity'] *= (1 + self.meta_lr)
                self.meta_params['sensitivity'] *= (1 + self.meta_lr)
            else:
                self.meta_params['plasticity'] *= (1 - self.meta_lr)
                self.meta_params['sensitivity'] *= (1 - self.meta_lr)
            
            self.meta_params['plasticity'] = np.clip(self.meta_params['plasticity'], 0.2, 1.5)
            self.meta_params['sensitivity'] = np.clip(self.meta_params['sensitivity'], 0.2, 1.5)
            
            # Update learning rates based on meta-parameters
            self.fast_lr = self.base_fast_lr * self.meta_params['plasticity']
            self.slow_lr = self.base_slow_lr * self.meta_params['sensitivity']
        # If meta-learning is disabled, learning rates remain fixed at base values
        
        self.prev_cumulative_error = current_avg_loss
        
        self.grad_fc1_avg.zero_() 
        self.grad_fc2_avg.zero_()
        self.grad_fc3_avg.zero_()
        self.grad_fc_rec_avg.zero_() # RESTORED
        self.grad_scale = 0
        self.cumulative_error = 0
        self.step_counter = 0

        # Reset eligibility traces to prevent saturation over long runs (dual-timescale)
        traces_to_reset = [
            self.e_fast_fc1, self.e_fast_fc2, self.e_fast_fc3, self.e_fast_rec,
            self.e_slow_fc1, self.e_slow_fc2, self.e_slow_fc3, self.e_slow_rec
        ]
        
        for e in traces_to_reset:
            e.zero_()

    def update(self, input_sequence, desired_sequence, initial_states):
        input_sequence, desired_sequence = input_sequence.to(self.device), desired_sequence.to(self.device)
        
        combined_loss, _, _ = self.fast_update(input_sequence, desired_sequence, initial_states)
        
        bounded_loss = min(combined_loss, self.max_loss_value)
        
        if bounded_loss > self.max_loss_value * 1.5:
            print(f"CATASTROPHIC ERROR DETECTED: {bounded_loss:.4f}. Performing Xavier weight reset.")
            self.reset_weights_with_xavier()
            return bounded_loss
            
        self.cumulative_error += bounded_loss
        self.step_counter += 1
        
        if self.step_counter >= self.window_size:
            self.slow_update()
        
        return bounded_loss


def train_snn_windowed(model, rl_updater, train_loader, val_loader, num_epochs, patience=10):
    device = next(model.parameters()).device
    print(f"Starting SNN Causal Windowed Training for {num_epochs} epochs on {device}...")
    
    best_val_metric = -np.inf # Track best Pearson correlation (can be negative)
    best_model_state = None
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        
        for i, (input_seq, target_seq) in enumerate(train_loader):
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)

            batch_size = input_seq.size(0)
            spk1_rec = torch.zeros(batch_size, model.fc1.out_features, device=device) # RESTORED
            mem1 = torch.zeros(batch_size, model.fc1.out_features, device=device)
            mem2 = torch.zeros(batch_size, model.fc2.out_features, device=device)
            mem3 = torch.zeros(batch_size, model.fc3.out_features, device=device)
            initial_states = (spk1_rec, mem1, mem2, mem3)
            
            loss = rl_updater.update(input_seq, target_seq, initial_states)
            running_train_loss += loss

        avg_train_loss = running_train_loss / len(train_loader)
        
        model.eval()
        val_preds, val_targets = [], []
        with torch.no_grad():
            for input_seq, target_seq in val_loader:
                input_seq, target_seq = input_seq.to(device), target_seq.to(device)
                batch_size = input_seq.size(0)
                val_spk1_rec = torch.zeros(batch_size, model.fc1.out_features, device=device) # RESTORED
                val_mem1 = torch.zeros(batch_size, model.fc1.out_features, device=device)
                val_mem2 = torch.zeros(batch_size, model.fc2.out_features, device=device)
                val_mem3 = torch.zeros(batch_size, model.fc3.out_features, device=device)
                
                pred_seq, _ = model(input_seq, val_spk1_rec, val_mem1, val_mem2, val_mem3)
                # Use final time-step to align with read-out
                val_preds.append(pred_seq[:, -1, :].cpu())
                val_targets.append(target_seq[:, -1, :].cpu())
            
            if not val_preds:
                print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Avg Corr: N/A (empty val set)")
                continue

            val_preds = torch.cat(val_preds)
            val_targets = torch.cat(val_targets)
            val_corr_x = compute_correlation(val_targets[:, 0], val_preds[:, 0])
            val_corr_y = compute_correlation(val_targets[:, 1], val_preds[:, 1])
            avg_val_corr = (val_corr_x + val_corr_y) / 2.0
            
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Avg Corr: {avg_val_corr:.4f}")

        if avg_val_corr > best_val_metric:
            best_val_metric = avg_val_corr
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            print(f"  >>> New best validation model saved with Avg Corr: {best_val_metric:.4f} <<<")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break

    if best_model_state:
        model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
        print(f"Loaded best model with validation Corr: {best_val_metric:.4f}")

    return model

def train_snn_online(model, rl_updater, train_stream, val_stream, num_epochs, patience=10, 
                     eval_every_n=500, reset_states_every_n=None):
    """
    Train SNN with true online learning - one timestep at a time.
    Maintains persistent hidden states and applies proper online adaptations.
    
    Args:
        model: SNN model to train
        rl_updater: Online-enabled updater (must have online_mode=True)
        train_stream: OnlineDataStream for training data
        val_stream: OnlineDataStream for validation data  
        num_epochs: Number of epochs to train
        patience: Early stopping patience (number of evaluation intervals without improvement)
        eval_every_n: Evaluate on validation every N steps
        reset_states_every_n: Reset hidden states every N steps (None = never reset)
    """
    if not rl_updater.online_mode:
        raise ValueError("rl_updater must be in online_mode for online training")
        
    device = next(model.parameters()).device
    print(f"Starting SNN Online Training for {num_epochs} epochs on {device}...")
    print(f"Evaluation every {eval_every_n} steps, reset states every {reset_states_every_n or 'never'} steps")
    
    best_val_metric = -np.inf
    best_model_state = None
    epochs_no_improve = 0
    
    # Initialize persistent hidden states - these persist across timesteps!
    spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
    mem1 = torch.zeros(1, model.fc1.out_features, device=device)
    mem2 = torch.zeros(1, model.fc2.out_features, device=device)
    mem3 = torch.zeros(1, model.fc3.out_features, device=device)
    states = (spk1_rec, mem1, mem2, mem3)
    
    # Warm-start RMS EMAs if we have training data
    if hasattr(train_stream, 'spike_data') and len(train_stream.spike_data) > 100:
        rl_updater.warm_start_rms_emas(
            train_stream.spike_data[:100], 
            train_stream.targets[:100], 
            num_warmup_steps=50
        )
    
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        step_count = 0
        
        for x_t, y_t, should_reset in train_stream:
            # Reset states if requested (e.g., trial boundaries)
            if should_reset or (reset_states_every_n and step_count % reset_states_every_n == 0):
                spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                states = (spk1_rec, mem1, mem2, mem3)
                
            # Single timestep update with persistent states
            loss, states = rl_updater.update_single_timestep(x_t, y_t, states)
            running_train_loss += loss
            step_count += 1
            
            # Periodic validation
            if step_count % eval_every_n == 0:
                model.eval()
                val_preds, val_targets = [], []
                
                # Initialize fresh states for validation
                val_spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                val_mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                val_mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                val_mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                val_states = (val_spk1_rec, val_mem1, val_mem2, val_mem3)
                
                with torch.no_grad():
                    val_step_count = 0
                    for x_val, y_val, should_reset_val in val_stream:
                        if should_reset_val or (reset_states_every_n and val_step_count % reset_states_every_n == 0):
                            val_spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                            val_mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                            val_mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                            val_mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                            val_states = (val_spk1_rec, val_mem1, val_mem2, val_mem3)
                            
                        x_val_seq = x_val.unsqueeze(0).unsqueeze(0).to(device)  # [1, 1, features]
                        pred_seq, val_states = model(x_val_seq, *val_states)
                        val_preds.append(pred_seq[0, 0, :].cpu())  # [2]
                        val_targets.append(y_val.cpu())
                        val_step_count += 1
                        
                        # Limit validation length to avoid overly long evaluation
                        if val_step_count >= 1000:
                            break
                
                if val_preds:
                    val_preds = torch.stack(val_preds)      # [N, 2]
                    val_targets = torch.stack(val_targets)  # [N, 2]
                    
                    # DEBUG: Check for problematic data patterns
                    pred_std_x, pred_std_y = val_preds[:, 0].std().item(), val_preds[:, 1].std().item()
                    target_std_x, target_std_y = val_targets[:, 0].std().item(), val_targets[:, 1].std().item()
                    pred_mean_x, pred_mean_y = val_preds[:, 0].mean().item(), val_preds[:, 1].mean().item()
                    target_mean_x, target_mean_y = val_targets[:, 0].mean().item(), val_targets[:, 1].mean().item()
                    
                    if pred_std_x < 1e-6 or pred_std_y < 1e-6 or target_std_x < 1e-6 or target_std_y < 1e-6:
                        print(f"  ⚠️  ZERO VARIANCE DETECTED:")
                        print(f"     Pred std: X={pred_std_x:.6f}, Y={pred_std_y:.6f}")
                        print(f"     Target std: X={target_std_x:.6f}, Y={target_std_y:.6f}")
                        print(f"     Pred mean: X={pred_mean_x:.6f}, Y={pred_mean_y:.6f}")
                        print(f"     Target mean: X={target_mean_x:.6f}, Y={target_mean_y:.6f}")
                        
                        # Auto-recovery from dead neurons
                        if pred_std_x < 1e-6 and pred_std_y < 1e-6 and step_count > 1000:
                            print(f"  🔧 ATTEMPTING DEAD NEURON RECOVERY:")
                            print(f"     Increasing learning rate by 5x and reinitializing model...")
                            
                            # Reinitialize model weights
                            for module in [model.fc1, model.fc2, model.fc3, model.fc_rec]:
                                nn.init.xavier_uniform_(module.weight)
                                if hasattr(module, 'bias') and module.bias is not None:
                                    nn.init.zeros_(module.bias)
                            
                            # Boost learning rates dramatically
                            rl_updater.fast_lr *= 5.0
                            rl_updater.base_fast_lr *= 5.0
                            print(f"     New fast LR: {rl_updater.fast_lr:.1e}")
                            
                            # Reset states to break any bad patterns
                            spk1_rec = torch.zeros(1, model.fc1.out_features, device=device)
                            mem1 = torch.zeros(1, model.fc1.out_features, device=device)
                            mem2 = torch.zeros(1, model.fc2.out_features, device=device)
                            mem3 = torch.zeros(1, model.fc3.out_features, device=device)
                            states = (spk1_rec, mem1, mem2, mem3)
                    
                    val_corr_x = compute_correlation(val_targets[:, 0], val_preds[:, 0])
                    val_corr_y = compute_correlation(val_targets[:, 1], val_preds[:, 1])
                    avg_val_corr = (val_corr_x + val_corr_y) / 2.0
                    
                    avg_train_loss = running_train_loss / eval_every_n
                    print(f"Epoch [{epoch+1}/{num_epochs}], Step {step_count}, "
                          f"Train Loss: {avg_train_loss:.4f}, Val Avg Corr: {avg_val_corr:.4f}")
                    
                    # Early stopping check
                    if avg_val_corr > best_val_metric:
                        best_val_metric = avg_val_corr
                        best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                        print(f"  >>> New best validation model saved with Avg Corr: {best_val_metric:.4f} <<<")
                        epochs_no_improve = 0
                    else:
                        epochs_no_improve += 1
                        if epochs_no_improve >= patience:
                            print(f"Early stopping triggered after {step_count} steps in epoch {epoch+1}.")
                            print(f"  (Patience: {patience} eval intervals = {patience * eval_every_n} steps)")
                            break
                    
                    running_train_loss = 0.0  # Reset for next interval
                    
                model.train()  # Return to training mode
                
        # Check if early stopping was triggered
        if epochs_no_improve >= patience:
            break
    
    # Load best model
    if best_model_state:
        model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
        print(f"Loaded best model with validation Corr: {best_val_metric:.4f}")
    
    return model

# %%
import torch
import numpy as np
import time
import os
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split



def run_rms_ablation_mcmaze(num_sessions=10):
    """
    Clean ablation study: Different RMS normalization configurations with IDENTICAL training setup.
    Only the rms_mode parameter differs between runs.
    Uses MCMaze data with trial-based splitting.
    """
    print("\n" + "="*80)
    print("MCMAZE RMS NORMALIZATION ABLATION STUDY")
    print("="*80)
    
    results = {}
    
    # Load MCMaze data once
    print("Loading MCMaze data...")
    firing_rates, hand_vel, windowed_trial_ids, _, _, _ = load_data(
        "mc_maze", use_cached=False, spike_processing='count'
    )
    
    # Get unique trials for splitting
    unique_trials = np.unique(windowed_trial_ids)
    n_total_trials = len(unique_trials)
    print(f"Found {n_total_trials} unique trials")
    
    # Test all RMS modes
    rms_modes = [
        ("WITH_RMS", 'with_rms'),        # Full RMS normalization (current default)
        ("PARTIAL_RMS", 'partial_rms'),  # Only output error normalization  
        ("WITHOUT_RMS", 'without_rms')   # No RMS normalization at all
    ]
    
    for mode_name, rms_mode in rms_modes:
        print(f"\n{'='*50}")
        print(f"Testing {mode_name} RMS Mode")
        if rms_mode == 'with_rms':
            print("  Using full RMS normalization (errors + spikes)")
        elif rms_mode == 'partial_rms':
            print("  Using partial RMS normalization (output error + spikes only)")
        elif rms_mode == 'without_rms':
            print("  Using no RMS normalization")
        print(f"{'='*50}")
        
        session_results = []
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        for session_idx in range(num_sessions):
            print(f"\n--- Session {session_idx + 1} with {mode_name} ---")
            
            try:
                # Set seeds for reproducibility per session
                np.random.seed(42 + session_idx)
                random.seed(42 + session_idx)
                torch.manual_seed(42 + session_idx)
                
                # Create trial-based split (70/15/15)
                n_train_trials = int(0.70 * n_total_trials)
                n_val_trials = int(0.15 * n_total_trials)
                
                permuted_trials = np.random.permutation(unique_trials)
                train_trials = permuted_trials[:n_train_trials]
                val_trials = permuted_trials[n_train_trials:n_train_trials + n_val_trials]
                test_trials = permuted_trials[n_train_trials + n_val_trials:]
                
                # Create masks based on trial IDs
                train_mask = np.isin(windowed_trial_ids, train_trials)
                val_mask = np.isin(windowed_trial_ids, val_trials)
                test_mask = np.isin(windowed_trial_ids, test_trials)
                
                # Split data using masks
                train_firing_rates = firing_rates[train_mask]
                train_targets_raw = hand_vel[train_mask]
                val_firing_rates = firing_rates[val_mask]
                val_targets_raw = hand_vel[val_mask]
                test_firing_rates = firing_rates[test_mask]
                test_targets_raw = hand_vel[test_mask]
                
                train_trial_ids = windowed_trial_ids[train_mask]
                val_trial_ids = windowed_trial_ids[val_mask]
                test_trial_ids = windowed_trial_ids[test_mask]
                
                # IDENTICAL normalization - use only training data statistics
                hand_vel_mean = train_targets_raw.mean(axis=0)
                hand_vel_std = train_targets_raw.std(axis=0) + 1e-6
                
                train_targets = (train_targets_raw - hand_vel_mean) / hand_vel_std
                val_targets = (val_targets_raw - hand_vel_mean) / hand_vel_std
                test_targets = (test_targets_raw - hand_vel_mean) / hand_vel_std
                
                input_size = firing_rates.shape[1]
                print(f"  Input size: {input_size}, Train: {len(train_firing_rates)}, Val: {len(val_firing_rates)}, Test: {len(test_firing_rates)}")
                
                # IDENTICAL model setup
                snn_model = SNNRegression(input_size=input_size, hidden_size=256, output_size=2).to(device)
                
                # LUT setup (conditional on RMS mode)
                if rms_mode in ['with_rms', 'partial_rms']:
                    vals = torch.linspace(0.25, 8.0, 256, device=device)
                    inv_sqrt_LUT = 1.0 / torch.sqrt(vals)
                else:
                    inv_sqrt_LUT = None
                
                # IDENTICAL data loading - trial-aware batching
                seq_len = 10
                batch_size = 32
                overlap_stride = max(1, seq_len // 2)
                
                train_loader = CausalBatcher(
                    train_firing_rates, train_targets, train_trial_ids,
                    batch_size=batch_size, sequence_length=seq_len,
                    shuffle=True, stride=overlap_stride)
                val_loader = CausalBatcher(
                    val_firing_rates, val_targets, val_trial_ids,
                    batch_size=batch_size, sequence_length=seq_len,
                    shuffle=False, stride=overlap_stride)
                
                # ONLY DIFFERENCE: rms_mode parameter
                snn_updater = TwoScaleMetaRLWeightUpdaterFull(
                    snn_model, 
                    base_fast_lr=1e-5,  # Adjusted for MCMaze count data
                    base_slow_lr=1e-6, 
                    window_size=50, 
                    meta_lr=0.01,
                    online_mode=False,              # IDENTICAL: windowed mode for all
                    use_hebbian=True,               # IDENTICAL: use Hebbian learning for all
                    use_meta_learning=True,         # IDENTICAL: use meta-learning for all
                    rms_mode=rms_mode,              # ONLY DIFFERENCE
                    inv_sqrt_LUT=inv_sqrt_LUT
                )
                
                # IDENTICAL training
                trained_snn_model = train_snn_windowed(
                    snn_model, snn_updater, train_loader, val_loader, 
                    num_epochs=15, patience=5
                )
                
                # IDENTICAL evaluation
                test_loader = CausalBatcher(
                    test_firing_rates, test_targets, test_trial_ids,
                    batch_size=batch_size, sequence_length=seq_len,
                    shuffle=False, stride=None)
                
                trained_snn_model.eval()
                snn_preds_norm_list, snn_targets_norm_list = [], []
                
                with torch.no_grad():
                    for input_seq, target_seq in test_loader:
                        input_seq, target_seq = input_seq.to(device), target_seq.to(device)
                        batch_size_test = input_seq.size(0)
                        test_spk1_rec = torch.zeros(batch_size_test, trained_snn_model.fc1.out_features, device=device)
                        test_mem1 = torch.zeros(batch_size_test, trained_snn_model.fc1.out_features, device=device)
                        test_mem2 = torch.zeros(batch_size_test, trained_snn_model.fc2.out_features, device=device)
                        test_mem3 = torch.zeros(batch_size_test, trained_snn_model.fc3.out_features, device=device)
                        
                        pred_seq, _ = trained_snn_model(input_seq, test_spk1_rec, test_mem1, test_mem2, test_mem3)
                        snn_preds_norm_list.append(pred_seq[:, -1, :].cpu())
                        snn_targets_norm_list.append(target_seq[:, -1, :].cpu())

                if snn_preds_norm_list:
                    snn_preds_normalized = torch.cat(snn_preds_norm_list).numpy()
                    snn_targets_normalized = torch.cat(snn_targets_norm_list).numpy()
                    
                    # IDENTICAL evaluation
                    snn_corr_x = compute_correlation(snn_targets_normalized[:, 0], snn_preds_normalized[:, 0])
                    snn_corr_y = compute_correlation(snn_targets_normalized[:, 1], snn_preds_normalized[:, 1])
                    
                    session_results.append({
                        'session': session_idx + 1,
                        'rms_mode': mode_name,
                        'corr_x': snn_corr_x,
                        'corr_y': snn_corr_y,
                        'avg_corr': (snn_corr_x + snn_corr_y) / 2
                    })
                    
                    print(f"  {mode_name} Results: X={snn_corr_x:.4f}, Y={snn_corr_y:.4f}, Avg={((snn_corr_x + snn_corr_y)/2):.4f}")
                else:
                    print(f"  WARNING: No test predictions generated for session {session_idx + 1}")
                
            except Exception as e:
                print(f"  ERROR in {mode_name} run: {e}")
                import traceback
                traceback.print_exc()
        
        results[mode_name] = session_results
    
    # Compare results
    print("\n" + "="*80)
    print("MCMAZE RMS NORMALIZATION ABLATION RESULTS SUMMARY")
    print("="*80)
    
    for mode, mode_results in results.items():
        if mode_results:
            avg_corr_x = np.mean([r['corr_x'] for r in mode_results])
            avg_corr_y = np.mean([r['corr_y'] for r in mode_results])  
            avg_total = np.mean([r['avg_corr'] for r in mode_results])
            print(f"{mode:12s}: X={avg_corr_x:.4f}, Y={avg_corr_y:.4f}, Avg={avg_total:.4f}")
    
    # Detailed analysis
    print("\n" + "="*50)
    print("ANALYSIS")
    print("="*50)
    
    if 'WITH_RMS' in results and all(mode in results for mode in ['PARTIAL_RMS', 'WITHOUT_RMS']):
        with_rms_avg = np.mean([r['avg_corr'] for r in results['WITH_RMS']])
        
        # Compare each reduced RMS mode to full RMS
        for mode in ['PARTIAL_RMS', 'WITHOUT_RMS']:
            if results[mode]:
                reduced_avg = np.mean([r['avg_corr'] for r in results[mode]])
                diff = with_rms_avg - reduced_avg
                print(f"WITH_RMS vs {mode}: {diff:+.4f} ({diff/reduced_avg*100:+.1f}% relative)")
        
        # Find best reduced RMS mode
        reduced_modes = ['PARTIAL_RMS', 'WITHOUT_RMS']
        best_reduced_mode = max(reduced_modes, 
                              key=lambda m: np.mean([r['avg_corr'] for r in results[m]]) if results[m] else -1)
        best_reduced_avg = np.mean([r['avg_corr'] for r in results[best_reduced_mode]])
        
        print(f"\nBest reduced RMS mode: {best_reduced_mode} ({best_reduced_avg:.4f})")
        rms_improvement = with_rms_avg - best_reduced_avg
        print(f"Full RMS improvement: {rms_improvement:+.4f} ({rms_improvement/best_reduced_avg*100:+.1f}% relative)")
        
        if abs(rms_improvement) < 0.01:
            print("→ Full RMS normalization provides minimal benefit (hardware simplification recommended)")
        elif rms_improvement > 0.02:
            print("→ Full RMS normalization provides significant benefit (justifies complexity)")
        else:
            print("→ Full RMS normalization provides modest benefit (design trade-off)")
    
    return results

# Clean main execution
if __name__ == "__main__":
    print("Running RMS Normalization Ablation (MCMaze)...")
    ablation_results = run_rms_ablation_mcmaze(num_sessions=10)
