import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import numpy as np
import pandas as pd
import time
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from collections import deque, defaultdict
import random
from filterpy.kalman import KalmanFilter as FilterPyKalmanFilter
from scipy.linalg import inv

# --- Attempt to import from local modules (excluding train_single_gpu) ---
try:
    from bci_simulation import (
        Synthetic_Neuron,
        LSTMRegression,
        train_test_kalman_filter,
        train_lstm_model,
        compute_correlation,
        
    )
    from bci_simulation import TwoScaleMetaRLWeightUpdaterFull
    print("Successfully imported components from bci_simulation.py")
except ImportError as e:
    print(f"Warning: Could not import from bci_simulation.py: {e}")
    # Define dummy classes/functions if import fails
    class Synthetic_Neuron: pass
    class LSTMRegression(nn.Module): pass
    def train_test_kalman_filter(*args, **kwargs): return None, None
    def train_lstm_model(*args, **kwargs): return None, [], []
    def compute_correlation(*args, **kwargs): return 0.0
    class TwoScaleMetaRLWeightUpdaterFull: pass

# --- Import ZENODO classes with renamed imports to avoid conflicts ---
try:
    from ZENODO_SCRIPT2 import (
        SNNRegression as SNNRegressionZenodo,
        TwoScaleMetaRLWeightUpdaterFull as TwoScaleMetaRLWeightUpdaterFullZenodo
    )
    print("Successfully imported ZENODO components from ZENODO_SCRIPT2.py")
except ImportError as e:
    print(f"Warning: Could not import from ZENODO_SCRIPT2.py: {e}")
    # Define dummy classes if import fails
    class SNNRegressionZenodo: pass
    class TwoScaleMetaRLWeightUpdaterFullZenodo: pass

# --- SNNRegression class (copied from train_dist_old_zenodo.py or bci_simulation.py) ---
import snntorch as snn
from snntorch import surrogate

class SNNRegression(nn.Module):
    def __init__(self, input_size=182, hidden_size=128, output_size=2, input_beta=0.9, use_mem=False):
        self.use_mem = use_mem
        super(SNNRegression, self).__init__()
        spike_grad = surrogate.fast_sigmoid()

        self.lif_input = snn.Leaky(
            beta=input_beta,
            spike_grad=spike_grad,
            init_hidden=False
        )
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.lif1 = snn.Leaky(beta=0.9, spike_grad=spike_grad, init_hidden=False)
        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.lif2 = snn.Leaky(beta=0.9, spike_grad=spike_grad, init_hidden=False)
        self.fc3 = nn.Linear(hidden_size // 2, output_size)
        self.lif3 = snn.Leaky(
            beta=0.9,
            spike_grad=spike_grad,
            threshold=float('inf'),
            reset_mechanism="none",
            init_hidden=False
        )
        self.apply(self._init_weights)
        self.mem_in = None
        self.mem1 = None
        self.mem2 = None
        self.mem3 = None

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def reset_states(self):
        self.mem_in = self.lif_input.reset_mem()
        self.mem1   = self.lif1.reset_mem()
        self.mem2   = self.lif2.reset_mem()
        self.mem3   = self.lif3.reset_mem()

    def forward(self, x):
        batch_size, T, _ = x.size()
        device = x.device
        if self.mem_in is None or self.mem_in.device != device:
            self.reset_states()
            self.mem_in = self.mem_in.to(device)
            self.mem1   = self.mem1.to(device)
            self.mem2   = self.mem2.to(device)
            self.mem3   = self.mem3.to(device)

        outputs = []
        for t in range(T):
            inp = x[:, t, :].float().to(device)
            spk_in, self.mem_in = self.lif_input(inp, self.mem_in)
            if self.use_mem:
                cur1 = self.fc1(self.mem_in)
            else:
                cur1 = self.fc1(spk_in)
            spk1, self.mem1 = self.lif1(cur1, self.mem1)
            cur2 = self.fc2(spk1)
            spk2, self.mem2 = self.lif2(cur2, self.mem2)
            cur3 = self.fc3(spk2)
            _, self.mem3 = self.lif3(cur3, self.mem3)
            outputs.append(self.mem3)
        return torch.stack(outputs, dim=1)

# --- evaluate_model function (copied from train_dist_old_zenodo.py) ---
def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_outputs = []
    all_targets = []
    with torch.no_grad():
        for batch_spike, batch_target in data_loader:
            batch_spike, batch_target = batch_spike.to(device), batch_target.to(device)
            
            # --- Add Padding to Input --- 
            # batch_spike shape: [batch, seq_len, num_neurons(96)]
            batch_size_dyn, seq_len_dyn, num_features_dyn = batch_spike.shape
            # Determine padding needed based on model's expected input size
            # We need to know the model's input_size. Let's get it from fc1
            expected_input_size = model.fc1.in_features
            padding_size = expected_input_size - num_features_dyn
            if padding_size > 0:
                padding_tensor = torch.zeros(batch_size_dyn, seq_len_dyn, padding_size, device=device)
                batch_spike_padded = torch.cat([batch_spike, padding_tensor], dim=2)
            elif padding_size < 0:
                # This case indicates the model expects fewer features than provided
                print(f"Warning in evaluate_model: Input features ({num_features_dyn}) > model expected ({expected_input_size}). Truncating.")
                batch_spike_padded = batch_spike[:, :, :expected_input_size]
            else:
                batch_spike_padded = batch_spike # No padding needed
            # batch_spike_padded shape: [batch, seq_len, expected_input_size]
            # --- End Padding ---

            model.reset_states()
            # Use padded input
            outputs = model(batch_spike_padded)
            loss = criterion(outputs[:, -1, :], batch_target)
            total_loss += loss.item()
            all_outputs.append(outputs.cpu().numpy())
            all_targets.append(batch_target.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    all_outputs = np.concatenate(all_outputs, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    # Calculate correlation for X and Y components
    # Assuming outputs and targets are [N, seq_len, 2] or [N, 2]
    # If they have a sequence dimension, take the last time step or average
    if all_outputs.ndim == 3: # N, seq_len, features
        all_outputs = all_outputs[:, -1, :] 
    if all_targets.ndim == 3:
        all_targets = all_targets[:, -1, :]

    corr_x = compute_correlation(all_outputs[:, 0], all_targets[:, 0])
    corr_y = compute_correlation(all_outputs[:, 1], all_targets[:, 1])
    return avg_loss, corr_x, corr_y

# --- Locally defined and adapted BPTT SNN training function ---
from torch.cuda.amp import GradScaler, autocast

# --- Experiment Phase Definitions ---
PHASE_INITIAL_LEARNING = "INITIAL_LEARNING"
PHASE_ADAPT_TO_DISRUPTION = "ADAPT_TO_DISRUPTION"

# --- Disruption Parameters (to be used in run_single_comparison) ---
# These will be defined properly within run_single_comparison or passed to it.
# For now, placeholders for concept.
# DISRUPTION_TYPE_CONFIG = "remapping" # Options: "dropout", "remapping", "drift"
# DISRUPTION_INTENSITY_CONFIG = 0.5
# NUM_TARGETS_INITIAL_CONFIG = 400
# NUM_TARGETS_ADAPT_CONFIG = 120

# Function to apply disruption (adapted from bci_simulation.py)
def apply_disruption_to_spike_generator(spike_gen, disruption_type, intensity, original_spike_gen_properties, device, num_neurons_in_gen):
    """Apply AGGRESSIVE neural disruption to the spike generator.
    Returns a dropout_mask if disruption_type is 'dropout', otherwise None."""
    print(f"\nAPPLYING AGGRESSIVE DISRUPTION: {disruption_type.upper()} with intensity {intensity}")
    dropout_mask = None # Initialize to None
    
    if disruption_type == "dropout":
        # AGGRESSIVE: Drop up to 90% of neurons for maximum disruption
        # Scale intensity to be more aggressive: 0.5 intensity = 90% dropout
        aggressive_intensity = min(0.9, intensity * 1.8)  # 0.5 -> 0.9, 0.3 -> 0.54, etc.
        print("HELPOOOOOOOOOOOOOOOOO")
        num_to_drop = int(num_neurons_in_gen * aggressive_intensity)
        print((num_to_drop//num_neurons_in_gen) * 100)
        # Create a base mask of all ones
        active_mask = torch.ones(num_neurons_in_gen, device=device)
        
        if num_to_drop > 0 and num_to_drop < num_neurons_in_gen:
            # Randomly select neurons to drop
            drop_indices = torch.randperm(num_neurons_in_gen, device=device)[:num_to_drop]
            active_mask[drop_indices] = 0
            dropout_mask = active_mask # This is the mask to be applied externally
            print(f"  AGGRESSIVE: Generated dropout mask: {num_to_drop} neurons will be silenced ({aggressive_intensity*100:.1f}%).")
        elif num_to_drop >= num_neurons_in_gen:
            dropout_mask = torch.zeros(num_neurons_in_gen, device=device) # Drop all
            print(f"  AGGRESSIVE: Generated dropout mask: ALL neurons will be silenced!")
        else: # num_to_drop == 0
            dropout_mask = active_mask # No neurons dropped
            print(f"  Generated dropout mask: NO neurons will be silenced.")
            
        # AGGRESSIVE: Dramatically increase noise level
        if hasattr(spike_gen, 'noise_level') and 'noise_level' in original_spike_gen_properties:
            spike_gen.noise_level = min(0.3, original_spike_gen_properties['noise_level'] * (1 + intensity*4)) # Much more aggressive noise
            print(f"  AGGRESSIVE: Increased spike_gen noise level to {spike_gen.noise_level:.3f}")
        
    elif disruption_type == "remapping":
        if not (hasattr(spike_gen, 'preferred_directions') and 'preferred_directions' in original_spike_gen_properties):
            print("  ERROR: Spike generator does not have 'preferred_directions' or not saved in original_properties. Cannot remap.")
            return
            
        # AGGRESSIVE: Remap up to 95% of neurons for maximum disruption
        aggressive_intensity = min(0.95, intensity * 1.9)  # 0.5 -> 0.95, 0.3 -> 0.57, etc.
        num_to_remap = int(num_neurons_in_gen * aggressive_intensity)
        if num_to_remap == 0 and intensity > 0: # Ensure at least one neuron is remapped if intensity > 0
            num_to_remap = 1
        
        remap_indices = torch.randperm(num_neurons_in_gen, device=device)[:num_to_remap] # Ensure device for randperm
        print(f"  AGGRESSIVE: Remapping {len(remap_indices)} neurons ({aggressive_intensity*100:.1f}% of population)")
        
        # AGGRESSIVE: Use more extreme angle changes (not just random, but opposite directions)
        new_angles = torch.rand(len(remap_indices), device=device) * 2 * np.pi
        # Add 180-degree shifts to make remapping more dramatic
        angle_shifts = torch.randint(0, 2, (len(remap_indices),), device=device) * np.pi
        new_angles = new_angles + angle_shifts
        new_dirs = torch.stack([torch.cos(new_angles), torch.sin(new_angles)], dim=1).to(device)
        
        # Ensure preferred_directions is on the correct device before modification
        if spike_gen.preferred_directions.device != device:
            spike_gen.preferred_directions = spike_gen.preferred_directions.to(device)
            
        spike_gen.preferred_directions[remap_indices] = new_dirs
        print(f"  AGGRESSIVE: Neural remapping applied successfully to spike_gen.")
        
    elif disruption_type == "drift":
        if not (hasattr(spike_gen, 'max_firing_rate') and hasattr(spike_gen, 'min_firing_rate') and \
                'max_firing_rate' in original_spike_gen_properties and 'min_firing_rate' in original_spike_gen_properties):
            print("  ERROR: Spike generator missing firing rate attributes or not saved. Cannot apply drift.")
            return

        # Safe, monotonic drift with physical constraints
        original_max_rate = original_spike_gen_properties['max_firing_rate']
        original_min_rate = original_spike_gen_properties['min_firing_rate']

        # Reduce max rate by up to 80% at intensity=1, but never below a small positive bound
        max_reduction = 0.8
        new_max_rate = original_max_rate * (1.0 - max_reduction * float(intensity))
        new_max_rate = max(new_max_rate, 1.0)

        # Increase min rate by up to 300% at intensity=1, but keep strictly below max
        min_increase = 3.0
        proposed_min = original_min_rate * (1.0 + min_increase * float(intensity))
        # Enforce ordering and positivity with a margin to preserve dynamic range
        new_min_rate = max(0.0, min(proposed_min, 0.8 * new_max_rate))

        spike_gen.max_firing_rate = new_max_rate
        spike_gen.min_firing_rate = new_min_rate
        print(f"  DRIFT: Applied firing rate drift: max {original_max_rate:.1f} -> {new_max_rate:.1f}, min {original_min_rate:.1f} -> {new_min_rate:.1f}")

        # Moderate noise increase with cap
        if hasattr(spike_gen, 'noise_level') and 'noise_level' in original_spike_gen_properties:
            base_noise = original_spike_gen_properties['noise_level']
            spike_gen.noise_level = min(0.4, base_noise * (1.0 + 4.0 * float(intensity)))
            print(f"  DRIFT: Updated noise level to {spike_gen.noise_level:.3f}")
            
    elif disruption_type == "catastrophic":
        # NEW: Catastrophic disruption - combines multiple effects
        print(f"  CATASTROPHIC: Applying multiple aggressive disruptions simultaneously!")
        
        # 1. Dropout 70% of neurons
        num_to_drop = int(num_neurons_in_gen * 0.7)
        active_mask = torch.ones(num_neurons_in_gen, device=device)
        if num_to_drop > 0:
            drop_indices = torch.randperm(num_neurons_in_gen, device=device)[:num_to_drop]
            active_mask[drop_indices] = 0
            dropout_mask = active_mask
            print(f"  CATASTROPHIC: Dropping {num_to_drop} neurons (70%)")
        
        # 2. Remap remaining neurons
        remaining_neurons = torch.where(active_mask == 1)[0]
        if len(remaining_neurons) > 0:
            new_angles = torch.rand(len(remaining_neurons), device=device) * 2 * np.pi
            new_dirs = torch.stack([torch.cos(new_angles), torch.sin(new_angles)], dim=1).to(device)
            if spike_gen.preferred_directions.device != device:
                spike_gen.preferred_directions = spike_gen.preferred_directions.to(device)
            spike_gen.preferred_directions[remaining_neurons] = new_dirs
            print(f"  CATASTROPHIC: Remapped {len(remaining_neurons)} remaining neurons")
        
        # 3. Drastic firing rate changes
        if hasattr(spike_gen, 'max_firing_rate') and 'max_firing_rate' in original_spike_gen_properties:
            original_max_rate = original_spike_gen_properties['max_firing_rate']
            new_max_rate = max(original_max_rate * 0.2, 1.0)
            spike_gen.max_firing_rate = new_max_rate
            print(f"  CATASTROPHIC: Reduced max firing rate to {new_max_rate:.1f}")
            
        if hasattr(spike_gen, 'min_firing_rate') and 'min_firing_rate' in original_spike_gen_properties:
            original_min_rate = original_spike_gen_properties['min_firing_rate']
            proposed_min = original_min_rate * 5.0
            spike_gen.min_firing_rate = min(proposed_min, 0.8 * spike_gen.max_firing_rate)
            print(f"  CATASTROPHIC: Set min firing rate to {spike_gen.min_firing_rate:.1f}")
        
        # 4. Massive noise increase
        if hasattr(spike_gen, 'noise_level') and 'noise_level' in original_spike_gen_properties:
            spike_gen.noise_level = min(0.5, original_spike_gen_properties['noise_level'] * 10)  # 10x noise
            print(f"  CATASTROPHIC: Increased noise level by 10x to {spike_gen.noise_level:.3f}")
            
    else:
        print(f"  Warning: Unknown disruption type '{disruption_type}' for spike_generator.")
    return dropout_mask # Return the mask (or None if not dropout)

# Function to revert disruption (adapted from bci_simulation.py)
def revert_disruption_of_spike_generator(spike_gen, original_spike_gen_properties):
    """Restore original neural properties to the spike generator.
    Note: Dropout mask clearing is handled by not re-applying it."""
    print(f"\nREVERTING SPIKE GENERATOR DISRUPTION...")
    restored_any = False
    if 'noise_level' in original_spike_gen_properties and hasattr(spike_gen, 'noise_level'):
        spike_gen.noise_level = original_spike_gen_properties['noise_level']
        print(f"  Restored spike_gen noise level to {spike_gen.noise_level:.3f}")
        restored_any = True
        
    if 'preferred_directions' in original_spike_gen_properties and hasattr(spike_gen, 'preferred_directions'):
        # Ensure original directions are cloned and on the correct device
        original_dirs_tensor = original_spike_gen_properties['preferred_directions']
        if not torch.is_tensor(original_dirs_tensor): # Convert if it's numpy, etc.
             original_dirs_tensor = torch.tensor(original_dirs_tensor, dtype=torch.float32)
        spike_gen.preferred_directions = original_dirs_tensor.clone().to(spike_gen.preferred_directions.device)
        print(f"  Restored original spike_gen tuning directions.")
        restored_any = True
        
    if 'max_firing_rate' in original_spike_gen_properties and hasattr(spike_gen, 'max_firing_rate'):
        spike_gen.max_firing_rate = original_spike_gen_properties['max_firing_rate']
        if 'min_firing_rate' in original_spike_gen_properties and hasattr(spike_gen, 'min_firing_rate'):
            spike_gen.min_firing_rate = original_spike_gen_properties['min_firing_rate']
            print(f"  Restored original spike_gen firing rates: max={spike_gen.max_firing_rate:.1f}, min={spike_gen.min_firing_rate:.1f}")
        else:
            print(f"  Restored original spike_gen max_firing_rate={spike_gen.max_firing_rate:.1f} (min_rate not found in original_properties)")
        restored_any = True

    if not restored_any:
        print("  No specific disruption properties found to revert in spike_generator based on original_properties.")
    return

def train_bptt_snn_local(train_dataset, val_dataset, 
                        input_size, hidden_size, output_size,
                        epochs=200, batch_size=32768, device=torch.device("cpu"),
                        lr=1e-3,
                        checkpoint_path="snn_bptt_local_temp.pth",
                        patience=5):
    print(f"Starting BPTT SNN training (local) on {device}...")
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=min(4, os.cpu_count() if os.cpu_count() else 1), pin_memory=True, 
        persistent_workers=True if min(4, os.cpu_count() if os.cpu_count() else 1) > 0 else False, 
        prefetch_factor=2 if min(4, os.cpu_count() if os.cpu_count() else 1) > 0 else None
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=min(4, os.cpu_count() if os.cpu_count() else 1), pin_memory=True,
        persistent_workers=True if min(4, os.cpu_count() if os.cpu_count() else 1) > 0 else False,
        prefetch_factor=2 if min(4, os.cpu_count() if os.cpu_count() else 1) > 0 else None
    )

    model = SNNRegression(input_size=input_size, hidden_size=hidden_size, output_size=output_size).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scaler = GradScaler()
    criterion = nn.MSELoss()

    history = []
    best_val_loss = float('inf')
    best_model_state_dict = None
    epochs_no_improve = 0
    best_epoch = 0

    print(f"  Model: input={input_size}, hidden={hidden_size}, output={output_size}")

    try:
        for epoch in range(epochs):
            model.train()
            epoch_train_loss = 0.0
            for batch_spike, batch_target in train_loader:
                batch_spike, batch_target = batch_spike.to(device), batch_target.to(device)
                
                # --- Add Padding to Input --- 
                # batch_spike shape: [batch, seq_len, num_neurons(96)]
                batch_size_dyn, seq_len_dyn, _ = batch_spike.shape
                padding_size = input_size - batch_spike.shape[-1] # input_size is 98 here
                if padding_size > 0:
                    padding_tensor = torch.zeros(batch_size_dyn, seq_len_dyn, padding_size, device=device)
                    batch_spike_padded = torch.cat([batch_spike, padding_tensor], dim=2)
                else:
                    batch_spike_padded = batch_spike
                # batch_spike_padded shape: [batch, seq_len, input_size(98)]
                # --- End Padding ---
                
                model.reset_states()
                optimizer.zero_grad()
                with autocast():
                    # Use padded input
                    outputs = model(batch_spike_padded)
                    loss = criterion(outputs[:, -1, :], batch_target)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                epoch_train_loss += loss.item()
            
            avg_train_loss = epoch_train_loss / len(train_loader)
            
            current_val_loss, val_x_corr, val_y_corr = evaluate_model(model, val_loader, criterion, device)
            
            print(f"  Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {current_val_loss:.4f}, Val Corr X: {val_x_corr:.4f}, Val Corr Y: {val_y_corr:.4f}")
            history.append({
                'epoch': epoch + 1,
                'train_loss': avg_train_loss,
                'val_loss': current_val_loss,
                'val_corr_x': val_x_corr,
                'val_corr_y': val_y_corr
            })

            if current_val_loss < best_val_loss:
                best_val_loss = current_val_loss
                best_model_state_dict = model.state_dict()
                epochs_no_improve = 0
                best_epoch = epoch + 1
                print(f"    New best val_loss: {best_val_loss:.4f}. Saving model state.")
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"Early stopping triggered after {epoch + 1} epochs.")
                    break
    
    except KeyboardInterrupt:
        print("\\nTraining interrupted by user.")
    finally:
        if best_model_state_dict is None and model is not None:
             best_model_state_dict = model.state_dict()
        print("BPTT SNN training (local) finished.")
        if best_epoch > 0:
            print(f"Loaded best model state from epoch {best_epoch}.")

    history_df = pd.DataFrame(history)
    return best_model_state_dict, history_df

# --- Helper Classes and Functions (SpikeVelDataset, generate_data) ---
class SpikeVelDataset(Dataset):
    """ Simple Dataset for spike sequences and velocities """
    def __init__(self, spike_data, vel_data, sequence_length):
        self.spike_data = torch.tensor(spike_data, dtype=torch.float32)
        self.vel_data = torch.tensor(vel_data, dtype=torch.float32)
        self.sequence_length = sequence_length
        self.num_samples = max(0, len(spike_data) - sequence_length + 1)

    def __len__(self):
        return self.num_samples 

    def __getitem__(self, idx):
        if idx >= self.num_samples:
            raise IndexError("Index out of bounds")
        spikes_seq = self.spike_data[idx : idx + self.sequence_length]
        target_vel = self.vel_data[idx + self.sequence_length - 1] 
        return spikes_seq, target_vel

def generate_data(spike_generator, num_steps, sequence_length=1, batch_size=1):
    print(f"Generating {num_steps} steps of synthetic data...")
    all_spikes = []
    all_vels = []
    current_pos = np.array([400.0, 300.0]) 
    target_pos = np.random.rand(2) * np.array([800.0, 600.0])
    
    for step in range(num_steps):
        if step % 100 == 0:
             target_pos = np.random.rand(2) * np.array([800.0, 600.0])
        desired_vec = target_pos - current_pos
        distance = np.linalg.norm(desired_vec)
        if distance > 1e-6:
            desired_vel_np = desired_vec / distance * min(1.0, distance / 200.0)
        else:
            desired_vel_np = np.zeros(2)
        current_pos += desired_vel_np * 25 
        current_pos[0] = np.clip(current_pos[0], 0, 800)
        current_pos[1] = np.clip(current_pos[1], 0, 600)
        desired_vel_tensor = torch.tensor(desired_vel_np, dtype=torch.float32, device=spike_generator.device)
        spikes_step = spike_generator.generate_spikes(desired_vel_tensor, sequence_length=1)
        all_spikes.append(spikes_step.squeeze().cpu().numpy())
        all_vels.append(desired_vel_np)
    print("Data generation complete.")
    return np.array(all_spikes), np.array(all_vels)

# --- Simulation Parameters for Reach Task (GLOBAL CONSTANTS) ---
SCREEN_WIDTH, SCREEN_HEIGHT = 800, 600
TARGET_RADIUS = 50
# MAX_REACH_STEPS = 1500 # This was for the old simulate_decoder_reach_task
MOVEMENT_SCALE = 5
CENTER_POS = np.array([SCREEN_WIDTH/2, SCREEN_HEIGHT/2], dtype=np.float32)
# NUM_EVAL_REACHES = 200 # This was for the old simulate_decoder_reach_task

# --- New Global Constant for Phased Simulation ---
MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL = 300 # Global definition

# --- Function to Simulate a Single Reach Attempt (ensure defined before run_single_comparison) ---
# This is the core simulation logic for one reach, called by run_single_comparison.
def simulate_single_reach_attempt(
    decoder_name, model_object_or_tuple, 
    spike_generator, velocity_scaler, device,
    target_pos_abs, initial_cursor_pos_abs,
    center_pos_abs, target_radius_abs, movement_scale_abs,
    max_steps_this_reach,
    input_size_snn_val, num_neurons_val,
    is_snn_online_learning_active,
    current_phase_name,
    dropout_mask_to_apply=None
):
    reach_times_steps = []
    total_simulation_steps = 0

    if decoder_name != 'KF' and isinstance(model_object_or_tuple, nn.Module):
        model_object_or_tuple.eval() # Ensure model is in eval mode (unless it's KF or Online SNN during its run)
    
    model_tuple_or_obj = model_object_or_tuple

    # Reset SNN states if it's an SNN model (for stateless operation per reach)
    # The calling function will now be responsible for resetting SNN_Online states as needed before calling this function.
    # This allows the pre-training loop to manage its own periodic resets, while main evaluation loops can reset per trial.
    # if decoder_name == "SNN_Online": 
    #     if hasattr(model_tuple_or_obj[0], 'reset_states'):
    #         model_tuple_or_obj[0].reset_states()
    if decoder_name == "SNN_BPTT": # SNN_BPTT is always reset per trial by this function
        # Check if model_object_or_tuple is the model itself or a tuple (model, updater)
        snn_bptt_model_to_reset = model_tuple_or_obj[0] if isinstance(model_tuple_or_obj, tuple) else model_tuple_or_obj
        if hasattr(snn_bptt_model_to_reset, 'reset_states'):
            snn_bptt_model_to_reset.reset_states()
    elif decoder_name == "LSTM":
        # Initialize hidden state for LSTM at the start of each reach attempt
        lstm_model_eval = model_object_or_tuple # This is the LSTMRegression instance
        # Ensure batch_size is 1 for single reach simulation
        h_lstm_prev = torch.zeros(lstm_model_eval.num_layers, 1, lstm_model_eval.hidden_size, device=device)
        c_lstm_prev = torch.zeros(lstm_model_eval.num_layers, 1, lstm_model_eval.hidden_size, device=device)

    cursor_pos = initial_cursor_pos_abs.copy()
    # Generate a random target ensuring it's not too close to the center
    while True:
        target_x = np.random.uniform(target_radius_abs * 1.5, SCREEN_WIDTH - target_radius_abs * 1.5)
        target_y = np.random.uniform(target_radius_abs * 1.5, SCREEN_HEIGHT - target_radius_abs * 1.5)
        target_pos = np.array([target_x, target_y], dtype=np.float32)
        if np.linalg.norm(target_pos - center_pos_abs) > target_radius_abs * 3: # Ensure target is sufficiently far
            break

    steps_in_reach = 0
    reach_succeeded = False

    for step_num in range(max_steps_this_reach):
        # --- DYNAMIC SPIKE GENERATION --- 
        # 1. Calculate desired velocity based on current cursor and target
        desired_vec_np = target_pos - cursor_pos
        distance = np.linalg.norm(desired_vec_np)
        if distance > 1e-6:
            # Normalize and scale velocity based on distance (simple proportional controller)
            current_desired_vel_np = desired_vec_np / distance * min(1.0, distance / 200.0) 
        else:
            current_desired_vel_np = np.zeros(2)
        current_desired_vel_tensor = torch.tensor(current_desired_vel_np, dtype=torch.float32).unsqueeze(0).to(device)

        # 2. Generate spikes dynamically using spike_generator
        # spike_generator.generate_spikes expects a 1D velocity tensor if sequence_length=1
        # Ensure spike_generator is on the correct device if it has internal tensors
        if hasattr(spike_generator, 'device') and spike_generator.device != device:
             # This is a simplistic check; ideally spike_generator handles its own device movement or is device-agnostic
             pass # Or move spike_generator to device if it has .to(device) method

        # The output of spike_generator.generate_spikes is [1, sequence_length, num_neurons]
        # For step-by-step, sequence_length is 1.
        dynamically_generated_spikes = spike_generator.generate_spikes(current_desired_vel_tensor.squeeze(0), sequence_length=1)
        # dynamically_generated_spikes shape: [1, 1, num_neurons_from_generator]

        # --- APPLY DROPOUT MASK if active ---
        if dropout_mask_to_apply is not None:
            # Ensure mask is on the same device and compatible shape
            # Mask is [num_neurons], spikes are [1, 1, num_neurons]
            mask = dropout_mask_to_apply.to(dynamically_generated_spikes.device)
            dynamically_generated_spikes = dynamically_generated_spikes * mask.view(1, 1, -1) # Reshape mask for broadcasting
            # print("Applied dropout mask") # Optional: for debugging
        # --- END APPLY DROPOUT MASK ---

        # --- END DYNAMIC SPIKE GENERATION ---

        # Get current spikes (now dynamically generated)
        current_spikes_for_kf_np = dynamically_generated_spikes.squeeze().cpu().numpy()

        pred_vel_norm = np.zeros(2)

        # --- Get Prediction based on Decoder Type (using DYNAMIC spikes) ---
        if decoder_name == 'KF':
            try:
                z = current_spikes_for_kf_np.reshape(-1, 1) # KF expects [num_neurons, 1]
                model_object_or_tuple.predict()
                model_object_or_tuple.update(z)
                kf_pred_vel_raw = model_object_or_tuple.x.copy()
                # KF outputs raw velocity, needs scaling by velocity_scaler if trained on normalized
                pred_vel_norm = velocity_scaler.transform(kf_pred_vel_raw.reshape(1, -1)).squeeze()
            except Exception as e:
                # print(f"KF Error during simulation: {e}") # Optional debug
                pred_vel_norm = np.zeros(2)

        elif decoder_name == 'LSTM':
            # LSTM expects [batch, seq_len, features]
            # We feed a sequence of 1 dynamic spike at each step.
            # dynamically_generated_spikes is [1, 1, num_neurons_from_generator]
            # Ensure features match LSTM input_size (INPUT_SIZE_LSTM which is num_neurons)
            lstm_input_tensor = dynamically_generated_spikes.to(device) # Already [1,1,N]
            if lstm_input_tensor.shape[2] != model_object_or_tuple.lstm.input_size:
                # This case should ideally not happen if spike_generator.num_neurons matches LSTM input
                # For now, let's assume they match. Add error or padding if necessary.
                print(f"LSTM input size mismatch: model expects {model_object_or_tuple.lstm.input_size}, got {lstm_input_tensor.shape[2]}")
                # Fallback or error handling needed here
            with torch.no_grad():
                # --- MODIFIED LSTM STATEFUL EVALUATION ---
                # Directly use the lstm layer and fc layer for stateful evaluation
                lstm_out_step, (h_lstm_new, c_lstm_new) = model_object_or_tuple.lstm(lstm_input_tensor, (h_lstm_prev, c_lstm_prev))
                # lstm_out_step shape: [batch_size=1, seq_len=1, hidden_size]
                # Pass to fc layer: input should be [batch_size, hidden_size]
                pred_vel_norm_tensor = model_object_or_tuple.fc(lstm_out_step.squeeze(1)) # squeeze(1) removes seq_len dim
                pred_vel_norm = pred_vel_norm_tensor.squeeze().detach().cpu().numpy()
                h_lstm_prev, c_lstm_prev = h_lstm_new, c_lstm_new # Update hidden state for next step
                # --- END MODIFIED LSTM STATEFUL EVALUATION ---

        elif decoder_name == 'SNN_BPTT':
            # SNN_BPTT expects [batch, seq_len, features_padded]
            # dynamically_generated_spikes is [1, 1, num_neurons_from_generator]
            snn_input_step_bptt = dynamically_generated_spikes.squeeze(0).to(device) # [1, num_neurons]
            padding_bptt = torch.zeros((snn_input_step_bptt.size(0), input_size_snn_val - num_neurons_val), device=device)
            snn_input_padded_bptt = torch.cat([snn_input_step_bptt, padding_bptt], dim=1) # [1, input_size_snn]
            snn_input_tensor_bptt = snn_input_padded_bptt.unsqueeze(0) # [1, 1, input_size_snn]
            with torch.no_grad():
                snn_bptt_output_seq = model_object_or_tuple(snn_input_tensor_bptt)
                pred_vel_norm = snn_bptt_output_seq[:, -1, :].squeeze().detach().cpu().numpy()

        elif decoder_name == 'SNN_Online':
            online_model, online_updater = model_object_or_tuple # Unpack tuple
            
            # Input for SNN_Online model
            # dynamically_generated_spikes is [1, 1, num_neurons_from_generator]
            snn_input_step_online = dynamically_generated_spikes.squeeze(0).to(device) # [1, num_neurons]
            padding_online = torch.zeros((snn_input_step_online.size(0), input_size_snn_val - num_neurons_val), device=device)
            snn_input_padded_online = torch.cat([snn_input_step_online, padding_online], dim=1) # [1, input_size_snn]
            snn_input_tensor_online = snn_input_padded_online.unsqueeze(0) # [1, 1, input_size_snn]

            # Target for SNN_Online updater is the current_desired_vel_tensor
            # MODIFICATION: Ensure target is normalized using velocity_scaler, consistent with pre-training
            # current_desired_vel_np is the raw-ish command signal for this step
            current_desired_vel_normalized_np = velocity_scaler.transform(current_desired_vel_np.reshape(1, -1)).squeeze()
            target_tensor_online = torch.tensor(current_desired_vel_normalized_np, dtype=torch.float32).unsqueeze(0).to(device) # Shape [1, 2]
            # END MODIFICATION

            with torch.enable_grad(): # Keep grad enabled for online update
                online_model.train() # Ensure online model is in train mode for updates
                
                # Handle ZENODO SNN forward signature with explicit state management
                if hasattr(online_model, '_persistent_states'):
                    # Use persistent states for ZENODO SNN
                    spk1_rec, mem1, mem2, mem3 = online_model._persistent_states
                    pred_output_sequence_online, final_states = online_model(
                        snn_input_tensor_online, spk1_rec, mem1, mem2, mem3, need_traces=False
                    )
                    online_model._persistent_states = final_states
                else:
                    # Fallback: Initialize states if they don't exist (for ZENODO SNN)
                    if hasattr(online_model, 'fc1') and hasattr(online_model, 'fc2') and hasattr(online_model, 'fc3'):
                        # This is likely a ZENODO SNN without initialized states
                        batch_size = snn_input_tensor_online.size(0)
                        spk1_rec = torch.zeros(batch_size, online_model.fc1.out_features, device=device)
                        mem1 = torch.zeros(batch_size, online_model.fc1.out_features, device=device)
                        mem2 = torch.zeros(batch_size, online_model.fc2.out_features, device=device)
                        mem3 = torch.zeros(batch_size, online_model.fc3.out_features, device=device)
                        pred_output_sequence_online, final_states = online_model(
                            snn_input_tensor_online, spk1_rec, mem1, mem2, mem3, need_traces=False
                        )
                        online_model._persistent_states = final_states
                    else:
                        # Simple SNN forward (shouldn't happen with ZENODO integration)
                        pred_output_sequence_online = online_model(snn_input_tensor_online)
                
                pred_velocity_online_tensor = pred_output_sequence_online[:, -1, :] # Shape [1,2]
                pred_vel_norm = pred_velocity_online_tensor.squeeze().detach().cpu().numpy()
                # Perform the online update step
                if hasattr(online_updater, 'update_single_timestep'):
                    # ZENODO version - single timestep update with persistent states
                    x_t = snn_input_tensor_online.squeeze(0).squeeze(0)  # Convert to [features]
                    y_t = target_tensor_online.squeeze(0)  # Convert to [output_size]
                    loss, new_states = online_updater.update_single_timestep(x_t, y_t, online_model._persistent_states)
                    online_model._persistent_states = new_states
                else:
                    # Simple version - sequence update
                    online_updater.update(snn_input_tensor_online, pred_velocity_online_tensor, target_tensor_online)
            online_model.eval() # Switch back to eval mode if layer behavior differs

        else:
            print(f"Error: Unknown decoder name '{decoder_name}'")
            return max_steps_this_reach, False # Error case

        # Un-normalize the predicted velocity (optional, but useful if scaling is large)
        # For simplicity, we move the cursor based on the normalized prediction scaled
        # pred_vel_raw = velocity_scaler.inverse_transform(pred_vel_norm.reshape(1, -1)).squeeze()

        # Update cursor position
        cursor_pos += pred_vel_norm * movement_scale_abs # Move based on normalized prediction
        cursor_pos[0] = np.clip(cursor_pos[0], 0, SCREEN_WIDTH)
        cursor_pos[1] = np.clip(cursor_pos[1], 0, SCREEN_HEIGHT)

        # Check for target reach
        distance = np.linalg.norm(target_pos - cursor_pos)
        if distance < target_radius_abs:
            reach_succeeded = True
            break 
        
        # steps_in_reach += 1 # This variable is no longer needed if we use step_num

    # End of step loop for one reach
    # The print statement that was here has been removed.

    if reach_succeeded:
        # 'step_num' is the 0-indexed iteration of the loop where success was found.
        # So, (step_num + 1) is the correct number of steps taken.
        return step_num + 1, True 
    else: # Timeout
        # If timeout, the loop completed all 'max_steps_this_reach' iterations,
        # or step_num reached max_steps_this_reach - 1. 
        # Returning max_steps_this_reach is conventional for a timeout.
        return max_steps_this_reach, False

# --- Plotting function for reach times (MOVED TO GLOBAL SCOPE) ---
def plot_reach_time_comparison(reach_results, filename="decoder_reach_time_comparison.png", smoothing_window=15):
    plt.figure(figsize=(12, 7))
    colors = {'KF': 'blue', 'LSTM': 'green', 'SNN_BPTT': 'red', 'SNN_Online': 'purple'}
    
    # If reach_results is a list of dictionaries (from multiple runs)
    if isinstance(reach_results, list) and reach_results and isinstance(reach_results[0], dict):
        # For simplicity in this step, let's just plot the first run if multiple are passed
        # We will enhance this in the runner script
        # OR, expect the runner to aggregate first and pass a single aggregated dict
        # For now, to keep this script runnable standalone:
        if len(reach_results) > 1:
            print("Warning: plot_reach_time_comparison received multiple run results, plotting first run only for standalone execution.")
        reach_results_to_plot = reach_results[0] 
    elif isinstance(reach_results, dict):
        reach_results_to_plot = reach_results
    else:
        print("Error: Invalid format for reach_results in plot_reach_time_comparison.")
        return

    for name, times in reach_results_to_plot.items(): # Use reach_results_to_plot
        if times:
            times_seconds = (np.array(times) * 10) / 1000.0
            series = pd.Series(times_seconds)
            smoothed_times = series.rolling(window=smoothing_window, min_periods=1).mean()
            trials = np.arange(1, len(smoothed_times) + 1)
            plt.plot(trials, smoothed_times, label=f'{name} (Smoothed N={smoothing_window})', color=colors.get(name, 'black'), linewidth=2, alpha=0.8)
        else:
            print(f"No reach times recorded for {name}, skipping plot.")

    plt.xlabel("Successful Reach Trial Number")
    plt.ylabel("Time to Reach Target (seconds)")
    plt.title("Decoder Time-to-Target Comparison (Smoothed)")
    plt.legend()
    plt.grid(True, alpha=0.4)
    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    print(f"\nSaved reach time comparison plot to {filename}")
    plt.close()
# --- End of moved plotting function ---

# --- Plotting function for Phased Reach Time Comparison (ensure defined before run_single_comparison if called directly from it) ---
def plot_phased_reach_time_comparison(
    phased_results_single_run, 
    base_filename="phased_decoder_reach_time_comparison.png", 
    smoothing_window=10,
    # Add parameters for disruption info to be passed for the title
    disruption_type_str=None, 
    disruption_intensity_val=None 
):
    """
    Plots the smoothed time to reach target for each decoder across different experiment phases.
    Smoothing is reset per phase.

    Args:
        phased_results_single_run (dict): Data from a single run.
            Format: {decoder_name: {phase_name: [steps_reach1, steps_reach2, ...], ...}, ...}
        base_filename (str): Base filename for the saved plot.
        smoothing_window (int): Window size for rolling mean smoothing.
        disruption_type_str (str): Disruption type to include in the plot title.
        disruption_intensity_val (float): Disruption intensity to include in the plot title.
    """
    plt.ioff() # Turn off interactive mode
    num_decoders = len(phased_results_single_run)
    if num_decoders == 0:
        print("No data to plot for phased reach time comparison.")
        return

    # Apply a professional style sheet and update font sizes
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.size': 12, 
        'axes.titlesize': 18, # Title font size
        'axes.labelsize': 16, # Axis labels font size
        'xtick.labelsize': 14,
        'ytick.labelsize': 14,
        'legend.fontsize': 12,
        'legend.title_fontsize': 14
    })

    first_decoder_name = list(phased_results_single_run.keys())[0]
    # Ensure a consistent order of phases for processing
    phase_names_ordered = []
    if PHASE_INITIAL_LEARNING in phased_results_single_run[first_decoder_name]:
        phase_names_ordered.append(PHASE_INITIAL_LEARNING)
    if PHASE_ADAPT_TO_DISRUPTION in phased_results_single_run[first_decoder_name]:
        phase_names_ordered.append(PHASE_ADAPT_TO_DISRUPTION)
    # Add any other phases that might exist, ensuring known ones are first
    for phase_key in phased_results_single_run[first_decoder_name].keys():
        if phase_key not in phase_names_ordered:
            phase_names_ordered.append(phase_key)


    fig, ax1 = plt.subplots(figsize=(17, 9)) # Increased figure size slightly
    ax2 = ax1.twinx() # Create a second y-axis

    # Define distinct colors for each decoder (consistent across phases)
    decoder_colors = {
        'KF': 'blue',
        'LSTM': 'green', 
        'SNN_BPTT': 'red',
        'SNN_Online': 'purple'
    }
    
    # Define line styles for phases
    line_styles_phase = {PHASE_INITIAL_LEARNING: '-', PHASE_ADAPT_TO_DISRUPTION: '--'}
    # Define distinct markers for decoders
    decoder_markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*'] # Add more if >9 decoders
    
    all_lines_for_legend = []
    all_labels_for_legend = []

    max_total_reaches_on_xaxis = 0

    # Store all times for y-axis scaling
    all_times_phase1 = []
    all_times_phase2 = []

    for i, decoder_name in enumerate(phased_results_single_run.keys()):
        current_total_reach_idx_on_xaxis = 0
        decoder_marker_style = decoder_markers[i % len(decoder_markers)] # Assign a marker to this decoder

        for phase_idx, phase_name in enumerate(phase_names_ordered):
            if phase_name not in phased_results_single_run[decoder_name]:
                continue
            
            reach_steps_this_phase = phased_results_single_run[decoder_name][phase_name]
            if not reach_steps_this_phase:
                continue

            reach_times_this_phase_s = np.array(reach_steps_this_phase) * 0.01 # seconds
            
            if len(reach_times_this_phase_s) >= smoothing_window:
                smoothed_times = pd.Series(reach_times_this_phase_s).rolling(window=smoothing_window, min_periods=1).mean().to_numpy()
            else:
                smoothed_times = pd.Series(reach_times_this_phase_s).rolling(window=max(1, len(reach_times_this_phase_s)), min_periods=1).mean().to_numpy()
            
            num_reaches_in_phase = len(reach_times_this_phase_s)
            trial_indices_this_segment = np.arange(current_total_reach_idx_on_xaxis, current_total_reach_idx_on_xaxis + num_reaches_in_phase)
            
            # Use decoder color and phase line style
            decoder_color = decoder_colors.get(decoder_name, 'gray')
            phase_linestyle = line_styles_phase.get(phase_name, '-')
            
            # Create label with decoder name and phase
            label_text = f'{decoder_name} ({phase_name.replace("_", " ").lower()})'
            
            current_axis_for_plot = None

            if phase_name == PHASE_INITIAL_LEARNING:
                current_axis_for_plot = ax1
                all_times_phase1.extend(reach_times_this_phase_s)
            elif phase_name == PHASE_ADAPT_TO_DISRUPTION:
                current_axis_for_plot = ax2
                all_times_phase2.extend(reach_times_this_phase_s)
            else: # For any other potential phases, default to ax1
                current_axis_for_plot = ax1 
                all_times_phase1.extend(reach_times_this_phase_s)

            if current_axis_for_plot:
                line, = current_axis_for_plot.plot(trial_indices_this_segment, smoothed_times, 
                                        color=decoder_color,
                                        linestyle=phase_linestyle,
                                        linewidth=2.5, 
                                        marker=decoder_marker_style, # Add marker for decoder
                                        markersize=5, # Adjust marker size as needed
                                        markevery=max(1, num_reaches_in_phase // 10), # Show marker every N points
                                        label=label_text)
                all_lines_for_legend.append(line)
                all_labels_for_legend.append(label_text)
            
            current_total_reach_idx_on_xaxis += num_reaches_in_phase

        if current_total_reach_idx_on_xaxis > max_total_reaches_on_xaxis:
            max_total_reaches_on_xaxis = current_total_reach_idx_on_xaxis

    # Y-axis scaling for ax1 (PHASE_INITIAL_LEARNING)
    if all_times_phase1:
        min_val = np.min(all_times_phase1)
        max_val = np.max(all_times_phase1)
        successful_times = [t for t in all_times_phase1 if t < (MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL * 0.01 * 0.99)]
        if successful_times:
            reasonable_max = np.percentile(successful_times, 99) * 1.5
            # Ensure top is at least slightly above min_val if all values are similar
            top_limit = min(max_val, reasonable_max, MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL * 0.01 * 1.1)
            ax1.set_ylim(bottom=max(0, min_val * 0.8), top=max(top_limit, min_val + 0.1)) # Ensure some space if min_val is close to top_limit
        else: # All timeouts or no successful reaches
            ax1.set_ylim(bottom=0, top=max(MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL * 0.01 * 1.1, 1.0)) # Min top of 1s
    else:
        ax1.set_ylim(bottom=0, top=max(MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL * 0.01 * 0.5, 1.0)) # Default if no data for phase 1

    # Y-axis scaling for ax2 (PHASE_ADAPT_TO_DISRUPTION)
    if all_times_phase2:
        min_val = np.min(all_times_phase2)
        max_val = np.max(all_times_phase2)
        successful_times = [t for t in all_times_phase2 if t < (MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL * 0.01 * 0.99)]
        if successful_times:
            reasonable_max = np.percentile(successful_times, 99) * 1.5
            top_limit = min(max_val, reasonable_max, MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL * 0.01 * 1.1)
            ax2.set_ylim(bottom=max(0, min_val * 0.8), top=max(top_limit, min_val + 0.1))
        else: # All timeouts or no successful reaches
            ax2.set_ylim(bottom=0, top=max(MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL * 0.01 * 1.1, 1.0))
    else:
         ax2.set_ylim(bottom=0, top=max(MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL*0.01*1.1, 1.0)) # Default if no data for phase 2

    # Add vertical line for phase transition
    vline_label_text = ""
    if PHASE_INITIAL_LEARNING in phase_names_ordered and PHASE_ADAPT_TO_DISRUPTION in phase_names_ordered:
        # Get number of reaches from the first decoder that has PHASE_INITIAL_LEARNING
        num_reaches_phase1 = 0
        for dec_name_temp in phased_results_single_run: # Iterate to find first valid
            if PHASE_INITIAL_LEARNING in phased_results_single_run[dec_name_temp]:
                 num_reaches_phase1 = len(phased_results_single_run[dec_name_temp].get(PHASE_INITIAL_LEARNING, []))
                 if num_reaches_phase1 > 0:
                    break # Found it
        
        if num_reaches_phase1 > 0:
            vline_label_text = f'{PHASE_INITIAL_LEARNING.replace("_", " ")} End / {PHASE_ADAPT_TO_DISRUPTION.replace("_", " ")} Start'
            vline = ax1.axvline(x=num_reaches_phase1 - 0.5, color='black', linestyle=':', alpha=0.7, linewidth=1.5, label=vline_label_text)
            all_lines_for_legend.append(vline)
            all_labels_for_legend.append(vline_label_text)

    # Set labels and title with increased font size
    ax1.set_xlabel("Reach Completion Trial Index") # Fontsize controlled by rcParams['axes.labelsize']
    title_text = "Decoder Performance: Smoothed Time to Reach Target (Phased, Dual Y-Axis)"
    if disruption_type_str: # Use the passed argument
        title_text += f'\nDisruption: {disruption_type_str.capitalize()} ({disruption_intensity_val*100}%)' # Use the passed argument
    ax1.set_title(title_text) # Fontsize controlled by rcParams['axes.titlesize']

    ax1.set_ylabel(f"Smoothed Time (s) - {PHASE_INITIAL_LEARNING.replace('_', ' ')}", color='black') # Fontsize by rcParams
    ax1.tick_params(axis='y', labelcolor='black') # Tick label size by rcParams
    
    if PHASE_ADAPT_TO_DISRUPTION in phase_names_ordered : # Only add label if phase exists
      ax2.set_ylabel(f"Smoothed Time (s) - {PHASE_ADAPT_TO_DISRUPTION.replace('_', ' ')}", color='black') # Fontsize by rcParams
      ax2.tick_params(axis='y', labelcolor='black') # Tick label size by rcParams
    else: # Hide ax2 if no disruption phase data was plotted
        ax2.set_yticks([])
        ax2.set_yticklabels([])


    ax1.tick_params(axis='x') # Tick label size by rcParams
    # Use MaxNLocator for x-axis to control tick density
    ax1.xaxis.set_major_locator(plt.MaxNLocator(nbins=10, prune='both')) # Aim for ~10 ticks, prune ends if needed

    # Grid lines - style sheet might handle this, but explicit control can be good
    ax1.grid(True, which='major', axis='x', linestyle='-', alpha=0.6)
    ax1.grid(True, which='major', axis='y', linestyle='--', alpha=0.4, color='gray')
    if PHASE_ADAPT_TO_DISRUPTION in phase_names_ordered:
      ax2.grid(True, which='major', axis='y', linestyle=':', alpha=0.4, color='gray')

    # Create a cleaner legend with decoder names only
    unique_decoders = list(phased_results_single_run.keys())
    legend_lines = []
    legend_labels = []
    
    # Add one line per decoder for the legend (using solid line style)
    for decoder_name in unique_decoders:
        decoder_color = decoder_colors.get(decoder_name, 'gray')
        # Create a dummy line for legend
        dummy_line, = ax1.plot([], [], color=decoder_color, linewidth=2.5, linestyle='-', label=decoder_name)
        legend_lines.append(dummy_line)
        legend_labels.append(decoder_name)
    
    # Add phase transition line to legend if it exists
    if vline_label_text:
        legend_lines.append(vline)
        legend_labels.append(vline_label_text)
    
    # Add legend
    fig.legend(
        legend_lines, 
        legend_labels, 
        loc='upper center', 
        bbox_to_anchor=(0.5, -0.05), # Positioned below the plot
        fancybox=True, 
        shadow=False, 
        ncol=max(1, len(legend_lines) // 2 + 1), 
        title="Decoders", # Add legend title
        title_fontsize=plt.rcParams['legend.title_fontsize']
    )

    plt.tight_layout(rect=[0, 0.08, 1, 0.95]) # Adjust rect for legend AND title

    plot_filename = base_filename 
    try:
        fig.savefig(plot_filename, dpi=200) # Increased DPI for better quality
        print(f"Saved phased reach time comparison plot to {plot_filename}")
    except Exception as e:
        print(f"Error saving phased reach time plot: {e}")
    plt.close(fig)


# --- New Function for Closed-Loop Offline Pretraining Data Generation ---
def generate_main_offline_pretraining_data_closed_loop(
    driving_decoder_model, # Untrained LSTM
    spike_generator,
    num_total_steps,
    device,
    sequence_length_for_spikes=1, # Usually 1 for this kind of step-by-step generation
    # Simulation parameters from global scope
    screen_width=SCREEN_WIDTH, screen_height=SCREEN_HEIGHT,
    target_radius=TARGET_RADIUS, movement_scale=MOVEMENT_SCALE,
    center_pos=CENTER_POS,
    max_steps_per_sim_reach=300 # Max steps per reach during data generation
):
    print(f"Generating {num_total_steps} steps of closed-loop offline pretraining data using an untrained LSTM...")
    all_spikes_list = []
    all_true_velocities_list = [] # These are the DESIRED velocities, used as targets for supervised learning

    collected_steps = 0
    
    # Ensure driving_decoder_model has lstm attribute and input_size
    if not hasattr(driving_decoder_model, 'lstm') or not hasattr(driving_decoder_model.lstm, 'input_size'):
        raise ValueError("driving_decoder_model must be an LSTMRegression model with an lstm attribute and input_size.")
    lstm_input_size = driving_decoder_model.lstm.input_size 

    with torch.no_grad(): # Ensure LSTM is not accidentally trained here
        driving_decoder_model.eval()

        while collected_steps < num_total_steps:
            current_sim_cursor_pos = center_pos.copy()
            # Generate a random target
            while True:
                target_x = np.random.uniform(target_radius * 1.5, screen_width - target_radius * 1.5)
                target_y = np.random.uniform(target_radius * 1.5, screen_height - target_radius * 1.5)
                current_sim_target_pos = np.array([target_x, target_y], dtype=np.float32)
                if np.linalg.norm(current_sim_target_pos - current_sim_cursor_pos) > target_radius * 2:
                    break
            
            for _ in range(max_steps_per_sim_reach):
                if collected_steps >= num_total_steps:
                    break

                # 1. Calculate desired velocity (oracle for target label)
                desired_vec_np = current_sim_target_pos - current_sim_cursor_pos
                distance_to_target = np.linalg.norm(desired_vec_np)
                if distance_to_target > 1e-6:
                    current_desired_vel_np = desired_vec_np / distance_to_target * min(1.0, distance_to_target / 200.0)
                else:
                    current_desired_vel_np = np.zeros(2)
                
                current_desired_vel_tensor = torch.tensor(current_desired_vel_np, dtype=torch.float32).to(device)

                # 2. Generate spikes based on this DESIRED velocity
                spikes_this_step_tensor = spike_generator.generate_spikes(
                    current_desired_vel_tensor, sequence_length=sequence_length_for_spikes
                ) # Output: [1, sequence_length, num_neurons]
                
                # Store the generated spikes and the TRUE (desired) velocity
                # If sequence_length_for_spikes is 1, spikes_this_step_tensor is [1,1,N], squeeze to [N]
                all_spikes_list.append(spikes_this_step_tensor.squeeze().cpu().numpy()) 
                all_true_velocities_list.append(current_desired_vel_np.copy())

                # 3. Feed spikes to the (untrained) LSTM to get a predicted velocity
                lstm_input_tensor = spikes_this_step_tensor.to(device) # Shape [1, seq_len, num_neurons]
                
                if lstm_input_tensor.shape[2] != lstm_input_size:
                    print(f"CRITICAL WARNING in data gen: LSTM input size mismatch. Model expects {lstm_input_size}, Got {lstm_input_tensor.shape[2]}. This may lead to errors or poor performance.")
                    # Fallback: if sizes don't match, we might need to skip this step or pad/truncate,
                    # but it indicates a configuration issue. For now, we proceed but log warning.

                lstm_output_seq = driving_decoder_model(lstm_input_tensor) # LSTM outputs [batch, seq_len, output_size]
                # If seq_len was 1, lstm_output_seq might be [batch, output_size] due to LSTM model's internal handling
                if lstm_output_seq.ndim == 2: # It's [batch, output_size]
                    predicted_vel_from_lstm_tensor = lstm_output_seq
                elif lstm_output_seq.ndim == 3: # It's [batch, seq_len=1, output_size]
                    predicted_vel_from_lstm_tensor = lstm_output_seq[:, -1, :] # Get last step: [batch, output_size]
                else:
                    raise ValueError(f"Unexpected LSTM output dimension: {lstm_output_seq.ndim}")

                predicted_vel_from_lstm_np = predicted_vel_from_lstm_tensor.squeeze().detach().cpu().numpy() # Squeeze to [output_size]

                # 4. Update cursor position using the LSTM's PREDICTED velocity
                current_sim_cursor_pos += predicted_vel_from_lstm_np * movement_scale
                current_sim_cursor_pos[0] = np.clip(current_sim_cursor_pos[0], 0, screen_width)
                current_sim_cursor_pos[1] = np.clip(current_sim_cursor_pos[1], 0, screen_height)

                collected_steps += 1

                if np.linalg.norm(current_sim_target_pos - current_sim_cursor_pos) < target_radius:
                    break 
            
            if collected_steps > 0 and (collected_steps % max(1, (num_total_steps // 20)) == 0 or collected_steps == num_total_steps) :
                print(f"  Generated {collected_steps}/{num_total_steps} closed-loop pretraining steps...")

    print("Closed-loop offline pretraining data generation complete.")
    return np.array(all_spikes_list), np.array(all_true_velocities_list)
# --- End of New Function ---


def run_single_comparison(seed_value, disruption_type="dropout", disruption_intensity=0.5): # Added disruption parameters
    # --- Main Comparison Logic ---
    torch.manual_seed(seed_value)
    np.random.seed(seed_value)
    random.seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
    
    print(f"--- Running Phased Comparison with SEED: {seed_value} ---")
    print(f"--- Disruption: {disruption_type.upper()} with intensity {disruption_intensity} ---")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Phase & Disruption Parameters ---
    NUM_TARGETS_INITIAL = 150 # Reduced for quicker test runs, bci_sim uses 400
    NUM_TARGETS_ADAPT = 50   # Reduced, bci_sim uses 120
    DISRUPTION_TYPE = disruption_type # Now passed as parameter
    DISRUPTION_INTENSITY = disruption_intensity # Now passed as parameter

    # --- Model & Training Parameters ---
    NUM_TRAIN_STEPS_OFFLINE = 10000 # For initial offline training of KF, LSTM, BPTT
    NUM_STEPS_FOR_SCALER_FIT = 5000  # For fitting the velocity scaler (can be open-loop)
    NUM_NEURONS = 96
    INPUT_SIZE_SNN = NUM_NEURONS + 2 # Assuming 2 for padding, adjust if different
    INPUT_SIZE_LSTM = NUM_NEURONS
    OUTPUT_SIZE = 2
    BPTT_SEQUENCE_LENGTH = 50 # For offline training of SNN_BPTT/LSTM
    NN_HIDDEN_SIZE = 256       # For SNN_BPTT
    LSTM_HIDDEN_SIZE = NN_HIDDEN_SIZE
    LSTM_NUM_LAYERS = 2
    LSTM_DROPOUT = 0.2
    ONLINE_SNN_HIDDEN_SIZE = 256 # Let's try a smaller one for comparison script

    OFFLINE_EPOCHS = 30 # Reduced for speed
    OFFLINE_BATCH_SIZE = 64
    OFFLINE_LR = 1e-3
    
    ONLINE_SNN_FAST_LR = 1e-4
    ONLINE_SNN_SLOW_LR = 1e-3
    ONLINE_SNN_META_LR = 0.1
    ONLINE_SNN_WINDOW = 10
    
    PRETRAIN_ONLINE_SNN = True
    NUM_PRETRAIN_TARGETS_ONLINE = 100 # Reduced for speed, bci_sim uses 400
    MAX_REACH_TIME_SECONDS_PRETRAIN = 7.0 # bci_sim uses 10.0

    # --- Spike Generator & Scaler ---
    spike_gen = Synthetic_Neuron(num_neurons=NUM_NEURONS, noise_level=0.02)
    if hasattr(spike_gen, 'to') and callable(getattr(spike_gen, 'to')):
        spike_gen.to(device)
    
    original_spike_gen_properties = {
        'noise_level': spike_gen.noise_level if hasattr(spike_gen, 'noise_level') else 0.02,
        'preferred_directions': spike_gen.preferred_directions.clone().cpu() if hasattr(spike_gen, 'preferred_directions') else None,
        'max_firing_rate': spike_gen.max_firing_rate if hasattr(spike_gen, 'max_firing_rate') else None,
        'min_firing_rate': spike_gen.min_firing_rate if hasattr(spike_gen, 'min_firing_rate') else None
    }

    spike_gen.training = False 
    # Generate data for velocity scaler fitting (using original open-loop oracle method)
    print(f"Generating {NUM_STEPS_FOR_SCALER_FIT} steps of oracle data for velocity scaler fitting...")
    _, train_vel_raw_for_scaler = generate_data(spike_gen, NUM_STEPS_FOR_SCALER_FIT, sequence_length=1)
    velocity_scaler = StandardScaler()
    velocity_scaler.fit(train_vel_raw_for_scaler)
    print("Velocity scaler fitted.")
    
    # Generate main offline pretraining data using the new closed-loop method
    # Initialize an LSTM model to drive the closed-loop data generation
    # This LSTM model will then itself be trained on this generated data.
    print("\nInitializing LSTM model for driving closed-loop data generation...")
    lstm_driver_model = LSTMRegression(
        input_size=INPUT_SIZE_LSTM, 
        hidden_size=LSTM_HIDDEN_SIZE, 
        output_size=OUTPUT_SIZE, 
        num_layers=LSTM_NUM_LAYERS, 
        dropout=LSTM_DROPOUT # Dropout will be off due to model.eval() in gen function
    ).to(device)

    train_spikes_raw_offline, train_vel_raw_offline = generate_main_offline_pretraining_data_closed_loop(
        driving_decoder_model=lstm_driver_model,
        spike_generator=spike_gen,
        num_total_steps=NUM_TRAIN_STEPS_OFFLINE,
        device=device,
        sequence_length_for_spikes=1, # Generate spikes step-by-step
        # screen_width, screen_height, etc., will use global defaults
    )
    
    # --- Data Preparation for a Fair Comparison ---
    # Normalize the raw velocities obtained from closed-loop generation
    train_vel_norm_offline = velocity_scaler.transform(train_vel_raw_offline)

    # --- 1. Data for SNNs (using raw binary spikes) ---
    print("\n--- Preparing SNN Data (using raw binary spikes) ---")
    offline_dataset_snn = SpikeVelDataset(train_spikes_raw_offline, train_vel_norm_offline, BPTT_SEQUENCE_LENGTH)
    train_size_snn = int(0.8 * len(offline_dataset_snn))
    val_size_snn = len(offline_dataset_snn) - train_size_snn
    offline_train_subset_snn, offline_val_subset_snn = random_split(offline_dataset_snn, [train_size_snn, val_size_snn])

    # --- 2. Data for LSTM & KF (using rate-coded spikes) ---
    print("--- Preparing LSTM & KF Data (converting spikes to firing rates) ---")
    rate_smoothing_window = 5  # 50 ms window if simulation step is 10ms
    spikes_df = pd.DataFrame(train_spikes_raw_offline)
    # Use rolling average to create firing rates, then drop initial NaNs
    rates_df = spikes_df.rolling(window=rate_smoothing_window).mean().dropna()
    train_spikes_rates_offline = rates_df.to_numpy()
    
    # Align velocity data with the new, slightly shorter rate-coded data
    num_valid_rates = len(train_spikes_rates_offline)
    # Take the corresponding tail of the velocity arrays
    train_vel_raw_for_rates = train_vel_raw_offline[-num_valid_rates:]
    train_vel_norm_for_rates = train_vel_norm_offline[-num_valid_rates:]
    print(f"  Converted binary spikes to rates using a {rate_smoothing_window}-step window. New dataset size: {num_valid_rates} steps.")

    # Create dataset and dataloaders for LSTM
    offline_dataset_lstm = SpikeVelDataset(train_spikes_rates_offline, train_vel_norm_for_rates, BPTT_SEQUENCE_LENGTH)
    if len(offline_dataset_lstm) < 2:
        print("Error: Not enough rate-coded data for train/val split. Increase NUM_TRAIN_STEPS_OFFLINE or decrease smoothing window.")
        return {}, None
    train_size_lstm = int(0.8 * len(offline_dataset_lstm))
    val_size_lstm = len(offline_dataset_lstm) - train_size_lstm
    offline_train_subset_lstm, offline_val_subset_lstm = random_split(offline_dataset_lstm, [train_size_lstm, val_size_lstm])

    offline_train_loader_lstm = DataLoader(offline_train_subset_lstm, batch_size=OFFLINE_BATCH_SIZE, shuffle=True, num_workers=0)
    offline_val_loader_lstm = DataLoader(offline_val_subset_lstm, batch_size=OFFLINE_BATCH_SIZE, shuffle=False, num_workers=0)
    
    trained_decoders = {}

    # a) Kalman Filter (trained on RATE-CODED data)
    print("\n--- Training Kalman Filter (on rate-coded data) ---")
    # We provide dummy test data as we only need the trained kf object for the simulation phase
    num_dummy_samples_rates = min(10, len(train_spikes_rates_offline))
    dummy_test_rates_for_kf = train_spikes_rates_offline[:num_dummy_samples_rates]
    dummy_test_vel_for_kf = train_vel_norm_for_rates[:num_dummy_samples_rates]

    _, kf_model = train_test_kalman_filter(
        train_spikes_rates_offline, 
        train_vel_norm_for_rates, # KF trained on normalized velocities
        dummy_test_rates_for_kf, 
        dummy_test_vel_for_kf
    )
    if kf_model:
        trained_decoders['KF'] = kf_model
        print("Kalman Filter trained successfully.")
    else:
        print("Kalman Filter training failed.")
        trained_decoders['KF'] = None

    # b) LSTM (trained on RATE-CODED data)
    print("\n--- Training LSTM (on rate-coded data) ---")
    lstm_model = lstm_driver_model # Use the same instance
    criterion_lstm = nn.MSELoss()
    optimizer_lstm = optim.Adam(lstm_model.parameters(), lr=OFFLINE_LR)
    lstm_model, _, _ = train_lstm_model(
        lstm_model, offline_train_loader_lstm, offline_val_loader_lstm, criterion_lstm, optimizer_lstm,
        device, num_epochs=OFFLINE_EPOCHS, patience=10
    )
    trained_decoders['LSTM'] = lstm_model
    print("LSTM trained.")

    # c) BPTT SNN (trained on BINARY data)
    print("\n--- Training BPTT SNN (on binary spike data) ---")
    snn_bptt_model_instance = SNNRegression(
        input_size=INPUT_SIZE_SNN, hidden_size=NN_HIDDEN_SIZE, output_size=OUTPUT_SIZE
    ).to(device)
    # Note: Using the SNN-specific data subsets now
    snn_bptt_state_dict, _ = train_bptt_snn_local(
         train_dataset=offline_train_subset_snn, val_dataset=offline_val_subset_snn, 
         input_size=INPUT_SIZE_SNN, hidden_size=NN_HIDDEN_SIZE, output_size=OUTPUT_SIZE,
         epochs=OFFLINE_EPOCHS, batch_size=OFFLINE_BATCH_SIZE, device=device, lr=OFFLINE_LR,
         checkpoint_path=f"snn_bptt_comp_local_temp_seed{seed_value}.pth", patience=10
     )
    if snn_bptt_state_dict:
        snn_bptt_model_instance.load_state_dict(snn_bptt_state_dict)
        trained_decoders['SNN_BPTT'] = snn_bptt_model_instance
        print("BPTT SNN trained and state loaded.")
    else:
        print("BPTT SNN training did not return a model state. Skipping.")
        trained_decoders['SNN_BPTT'] = None 
    
    # d) Online SNN Setup and Pre-training (uses BINARY spikes)
    print("\n--- Setting up and Pre-training Online SNN (on binary spike data) ---")
    
    # --- Create hardware-friendly lookup table for ZENODO algorithm ---
    vals = torch.linspace(0.25, 8.0, 256, device=device)
    inv_sqrt_LUT = 1.0 / torch.sqrt(vals)
    
    snn_online_model = SNNRegressionZenodo(
        input_size=INPUT_SIZE_SNN, hidden_size=ONLINE_SNN_HIDDEN_SIZE, output_size=OUTPUT_SIZE
    ).to(device)
    
    # Initialize persistent states for ZENODO SNN
    batch_size = 1
    snn_online_model._persistent_states = (
        torch.zeros(batch_size, snn_online_model.fc1.out_features, device=device),  # spk1_rec
        torch.zeros(batch_size, snn_online_model.fc1.out_features, device=device),  # mem1
        torch.zeros(batch_size, snn_online_model.fc2.out_features, device=device),  # mem2
        torch.zeros(batch_size, snn_online_model.fc3.out_features, device=device)   # mem3
    )
    
    online_updater = TwoScaleMetaRLWeightUpdaterFullZenodo(
        snn_online_model, base_fast_lr=ONLINE_SNN_FAST_LR, base_slow_lr=ONLINE_SNN_SLOW_LR, 
        window_size=ONLINE_SNN_WINDOW, meta_lr=ONLINE_SNN_META_LR,
        online_mode=True, inv_sqrt_LUT=inv_sqrt_LUT
    )
    
    if PRETRAIN_ONLINE_SNN:
        print(f"--- Pre-training Online SNN with {NUM_PRETRAIN_TARGETS_ONLINE} closed-loop reaches ---")
        snn_online_model.train()
        online_pretrain_start_time = time.time()
        original_spike_gen_training_state_pt = spike_gen.training
        spike_gen.training = True # Enable more stochasticity for pre-training
        
        # Initialize persistent states for ZENODO SNN or reset simple SNN states
        if hasattr(snn_online_model, '_persistent_states'):
            # For ZENODO SNN, initialize persistent states
            batch_size = 1
            snn_online_model._persistent_states = (
                torch.zeros(batch_size, snn_online_model.fc1.out_features, device=device),  # spk1_rec
                torch.zeros(batch_size, snn_online_model.fc1.out_features, device=device),  # mem1
                torch.zeros(batch_size, snn_online_model.fc2.out_features, device=device),  # mem2
                torch.zeros(batch_size, snn_online_model.fc3.out_features, device=device)   # mem3
            )
        elif hasattr(snn_online_model, 'reset_states'):
            snn_online_model.reset_states()

        pretrain_cursor_pos_pt = CENTER_POS.copy()
        # --- Inner helper for pretrain target (kept local to pretrain block) ---
        def _get_pretrain_target_pt(current_center_pos_pt_local): # Renamed arg to avoid conflict
            margin = TARGET_RADIUS * 2 
            x = np.random.uniform(margin, SCREEN_WIDTH - margin)
            y = np.random.uniform(margin, SCREEN_HEIGHT - margin)
            return np.array([x, y], dtype=np.float32)
        # --- End Inner helper ---
        pretrain_target_pos_pt = _get_pretrain_target_pt(pretrain_cursor_pos_pt)
        
        for reach_idx_pt in range(NUM_PRETRAIN_TARGETS_ONLINE):
            if (reach_idx_pt > 0 and reach_idx_pt % 50 == 0): # Periodic reset during pretrain
                if hasattr(snn_online_model, '_persistent_states'):
                    # For ZENODO SNN, reset persistent states
                    batch_size = 1
                    snn_online_model._persistent_states = (
                        torch.zeros(batch_size, snn_online_model.fc1.out_features, device=device),  # spk1_rec
                        torch.zeros(batch_size, snn_online_model.fc1.out_features, device=device),  # mem1
                        torch.zeros(batch_size, snn_online_model.fc2.out_features, device=device),  # mem2
                        torch.zeros(batch_size, snn_online_model.fc3.out_features, device=device)   # mem3
                    )
                elif hasattr(snn_online_model, 'reset_states'):
                    snn_online_model.reset_states()

            steps_this_pretrain_reach, success_pt = simulate_single_reach_attempt(
                decoder_name="SNN_Online",
                model_object_or_tuple=(snn_online_model, online_updater),
                spike_generator=spike_gen, velocity_scaler=velocity_scaler, device=device,
                target_pos_abs=pretrain_target_pos_pt,
                initial_cursor_pos_abs=pretrain_cursor_pos_pt,
                center_pos_abs=CENTER_POS, target_radius_abs=TARGET_RADIUS,
                movement_scale_abs=MOVEMENT_SCALE, 
                max_steps_this_reach=int(MAX_REACH_TIME_SECONDS_PRETRAIN * 100), # Convert s to steps (assuming 10ms/step)
                input_size_snn_val=INPUT_SIZE_SNN, num_neurons_val=NUM_NEURONS,
                is_snn_online_learning_active=True,
                current_phase_name="PRETRAIN_ONLINE_SNN"
            )
            if (reach_idx_pt + 1) % 20 == 0 or reach_idx_pt == NUM_PRETRAIN_TARGETS_ONLINE -1 :
                 print(f"    SNN_Online Pre-train Reach {reach_idx_pt + 1}/{NUM_PRETRAIN_TARGETS_ONLINE} - {'Success' if success_pt else 'Timeout'}. Steps: {steps_this_pretrain_reach}.")
            pretrain_cursor_pos_pt = CENTER_POS.copy()
            pretrain_target_pos_pt = _get_pretrain_target_pt(pretrain_cursor_pos_pt)

        spike_gen.training = original_spike_gen_training_state_pt
        print(f"--- Online SNN Pre-training finished in {time.time() - online_pretrain_start_time:.2f} seconds ---")
        snn_online_model.eval()
        # Reset states after pretraining
        if hasattr(snn_online_model, '_persistent_states'):
            # For ZENODO SNN, reset persistent states
            batch_size = 1
            snn_online_model._persistent_states = (
                torch.zeros(batch_size, snn_online_model.fc1.out_features, device=device),  # spk1_rec
                torch.zeros(batch_size, snn_online_model.fc1.out_features, device=device),  # mem1
                torch.zeros(batch_size, snn_online_model.fc2.out_features, device=device),  # mem2
                torch.zeros(batch_size, snn_online_model.fc3.out_features, device=device)   # mem3
            )
        elif hasattr(snn_online_model, 'reset_states'):
            snn_online_model.reset_states()
            
    trained_decoders['SNN_Online'] = (snn_online_model, online_updater)
    print("Online SNN setup and pre-training complete.")

    # --- Main Phased Experiment Loop ---
    phased_reach_results = defaultdict(lambda: defaultdict(list))
    experiment_phases = [PHASE_INITIAL_LEARNING, PHASE_ADAPT_TO_DISRUPTION]
    spike_gen.training = False # Ensure controlled evaluation
    active_dropout_mask = None # Initialize active dropout mask for the run

    for phase_name in experiment_phases:
        print(f"\n===== Starting Phase: {phase_name} =====")
        num_targets_for_phase = NUM_TARGETS_INITIAL if phase_name == PHASE_INITIAL_LEARNING else NUM_TARGETS_ADAPT
        
        if phase_name == PHASE_ADAPT_TO_DISRUPTION:
            # Apply disruption and get a potential dropout mask
            active_dropout_mask = apply_disruption_to_spike_generator(
                spike_gen, DISRUPTION_TYPE, DISRUPTION_INTENSITY, 
                original_spike_gen_properties, device, NUM_NEURONS
            )
            # REMOVED: State reset before disruption phase
            # In real BCI applications, you never know when disruption occurs
            # Resetting states before disruption would be cheating
            # The algorithms must adapt to disruption with their current state
            pass
        
        for decoder_name, model_or_tuple_val in trained_decoders.items():
            if model_or_tuple_val is None and decoder_name == 'SNN_BPTT':
                print(f"  Skipping {decoder_name} for phase {phase_name} as it was not trained successfully.")
                phased_reach_results[decoder_name][phase_name] = [MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL] * num_targets_for_phase # USE GLOBAL
                continue

            print(f"  -- Evaluating {decoder_name} in {phase_name} for {num_targets_for_phase} targets --")
            
            current_model_obj_eval = model_or_tuple_val[0] if isinstance(model_or_tuple_val, tuple) else model_or_tuple_val
            if decoder_name in ['LSTM', 'SNN_BPTT'] and hasattr(current_model_obj_eval, 'reset_states'):
                current_model_obj_eval.reset_states() # Reset at start of this decoder's block in this phase

            for target_idx in range(num_targets_for_phase):
                current_cursor_pos_eval = CENTER_POS.copy()
                while True:
                    target_x = np.random.uniform(TARGET_RADIUS * 1.5, SCREEN_WIDTH - TARGET_RADIUS * 1.5)
                    target_y = np.random.uniform(TARGET_RADIUS * 1.5, SCREEN_HEIGHT - TARGET_RADIUS * 1.5)
                    current_target_pos_eval = np.array([target_x, target_y], dtype=np.float32)
                    if np.linalg.norm(current_target_pos_eval - current_cursor_pos_eval) > TARGET_RADIUS * 3:
                        break
                
                # Explicitly reset SNN_Online states before each reach attempt in main experiment phases
                if decoder_name == 'SNN_Online' and isinstance(model_or_tuple_val, tuple):
                    online_model_for_reset = model_or_tuple_val[0]
                    # For ZENODO SNN, reset persistent states before each reach attempt
                    if hasattr(online_model_for_reset, '_persistent_states'):
                        batch_size = 1
                        online_model_for_reset._persistent_states = (
                            torch.zeros(batch_size, online_model_for_reset.fc1.out_features, device=device),  # spk1_rec
                            torch.zeros(batch_size, online_model_for_reset.fc1.out_features, device=device),  # mem1
                            torch.zeros(batch_size, online_model_for_reset.fc2.out_features, device=device),  # mem2
                            torch.zeros(batch_size, online_model_for_reset.fc3.out_features, device=device)   # mem3
                        )
                    elif hasattr(online_model_for_reset, 'reset_states'):
                        online_model_for_reset.reset_states()
                    # print(f"  Resetting {decoder_name} states for target {target_idx+1} in {phase_name}") # Optional debug print

                is_learning_active = (decoder_name == 'SNN_Online')

                steps_taken, success = simulate_single_reach_attempt(
                    decoder_name=decoder_name, model_object_or_tuple=model_or_tuple_val,
                    spike_generator=spike_gen, velocity_scaler=velocity_scaler, device=device,
                    target_pos_abs=current_target_pos_eval, initial_cursor_pos_abs=current_cursor_pos_eval,
                    center_pos_abs=CENTER_POS, target_radius_abs=TARGET_RADIUS, movement_scale_abs=MOVEMENT_SCALE,
                    max_steps_this_reach=MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL, # USE GLOBAL
                    input_size_snn_val=INPUT_SIZE_SNN, num_neurons_val=NUM_NEURONS,
                    is_snn_online_learning_active=is_learning_active,
                    current_phase_name=phase_name,
                    dropout_mask_to_apply=active_dropout_mask # Pass the active mask
                )
                phased_reach_results[decoder_name][phase_name].append(steps_taken if success else MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL) # USE GLOBAL
                if (target_idx + 1) % 10 == 0 or target_idx == num_targets_for_phase -1 :
                     print(f"    {decoder_name}, {phase_name}, Target {target_idx+1}/{num_targets_for_phase}: {'Success' if success else 'Timeout'} in {steps_taken} steps.")
            
            phase_steps_list = phased_reach_results[decoder_name][phase_name]
            if phase_steps_list:
                successful_steps = [s for s in phase_steps_list if s < MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL] # USE GLOBAL
                avg_steps = np.mean(successful_steps) if successful_steps else float('nan')
                success_rate_phase = len(successful_steps) / len(phase_steps_list) if len(phase_steps_list) > 0 else 0
                print(f"    Summary for {decoder_name} in {phase_name}: Success Rate: {success_rate_phase*100:.1f}%, Avg Steps/Success: {avg_steps:.1f}")

    if PHASE_ADAPT_TO_DISRUPTION in experiment_phases:
        revert_disruption_of_spike_generator(spike_gen, original_spike_gen_properties)
        active_dropout_mask = None # Clear the mask after disruption phase is over

    print("\n--- Phased Experiment Simulation Complete --- ")
    summary_for_run_dict = {}
    print("\nOverall Phased Evaluation Results Summary:")
    for decoder_name, phase_data in phased_reach_results.items():
        summary_for_run_dict[decoder_name] = {}
        for phase_name_key, steps_list_val in phase_data.items():
            if steps_list_val:
                successful_steps_val = [s for s in steps_list_val if s < MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL] # USE GLOBAL
                mean_s_steps_val = np.mean(successful_steps_val) if successful_steps_val else float('nan')
                s_rate_val = len(successful_steps_val) / len(steps_list_val) * 100 if len(steps_list_val) > 0 else 0.0
                summary_for_run_dict[decoder_name][phase_name_key] = f"{mean_s_steps_val:.1f} steps / {s_rate_val:.1f}% success"
                print(f"  {decoder_name} - {phase_name_key}: Avg Steps/Success={mean_s_steps_val:.1f}, Success Rate={s_rate_val:.1f}%")
            else:
                summary_for_run_dict[decoder_name][phase_name_key] = "N/A / 0.0%"
    
    print(f"--- Phased Comparison with SEED: {seed_value} Finished ---")
    return phased_reach_results, summary_for_run_dict


if __name__ == "__main__":
    # For standalone execution, run with a default seed and plot
    # Note: This will only plot results for a single run.
    # The runner script will handle aggregation and plotting of multiple runs.
    default_seed = 42 
    print(f"Executing decoder_comparison.py as main script with seed: {default_seed}")
    
    # Ensure device is determined early for any top-level device-dependent ops if any were outside run_single_comparison
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- MODIFICATION: Define spike_generator for standalone execution context ---
    # Parameters should match those in run_single_comparison for consistency
    NUM_NEURONS_standalone = 96 # Match the NUM_NEURONS in run_single_comparison
    # Define actual disruption params for standalone plotting
    # These should match what's used in run_single_comparison
    standalone_disruption_type = "dropout"  # Match DISRUPTION_TYPE in run_single_comparison
    standalone_disruption_intensity = 0.5   # Match DISRUPTION_INTENSITY in run_single_comparison

    spike_generator_standalone = Synthetic_Neuron(num_neurons=NUM_NEURONS_standalone, noise_level=0.02)
    # Ensure the spike generator's internal device is set if it matters for its operations, 
    # though generate_spikes in Synthetic_Neuron seems to handle device for its tensors.
    if hasattr(spike_generator_standalone, 'device') and spike_generator_standalone.device != device:
        if hasattr(spike_generator_standalone, 'to'):
            spike_generator_standalone.to(device)
        else:
            # Manually set if .to is not available but .device attribute exists
            spike_generator_standalone.device = device 
    # --- END MODIFICATION ---

    # Call run_single_comparison and unpack all three return values
    # MODIFIED: Unpack only two values
    single_run_phased_reach_times, single_run_summary_text = run_single_comparison(
        seed_value=default_seed,
        disruption_type=standalone_disruption_type,
        disruption_intensity=standalone_disruption_intensity
    )
    
    if single_run_phased_reach_times: # Check if results were returned
        # For standalone, plot_reach_time_comparison expects a dict, not a list of dicts
        # We need to fetch disruption_type and intensity from within run_single_comparison context,
        # or pass them explicitly. For standalone, we might not have them unless we re-define or extract.
        # The plot function now takes them as args. For a true standalone run of just the plot,
        # these might be manually set or known. Here, if run_single_comparison was just called,
        # those values are not directly returned. This highlights a slight challenge in standalone plotting
        # if it depends on runtime variables from a full simulation run.
        # For now, let's assume we want to show something generic or pass placeholder values if we run this block standalone.
        # The proper call will be from within run_experiments.py which calls run_single_comparison and then plots.
        
        # If run_single_comparison is called, its local DISRUPTION_TYPE/INTENSITY were used for the simulation.
        # For plotting here, we'd ideally pass them. Since run_single_comparison doesn't return them, 
        # we pass the standalone placeholders for title consistency if this __main__ is run directly.
        plot_phased_reach_time_comparison(
            single_run_phased_reach_times, 
            base_filename=f"phased_reach_time_seed{default_seed}.png",
            disruption_type_str=standalone_disruption_type, # Placeholder for direct run
            disruption_intensity_val=standalone_disruption_intensity # Placeholder for direct run
        )
    
    print("\nComparison script finished (standalone run).") 