import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
import torch.nn.functional as F  # Add this import
import datetime
import torch.nn as nn # Added for SNNRegression
import snntorch as snn # Added for SNNRegression
from snntorch import surrogate # Added for SNNRegression
from scipy.interpolate import interp1d # Added for plotting
import pandas as pd # Added for saving results
import matplotlib.quiver # <-- IMPORT ADDED
from collections import deque # Use deque for efficient sliding window buffer
from sklearn.linear_model import LinearRegression
from filterpy.kalman import KalmanFilter
from sklearn.preprocessing import StandardScaler # If you plan to normalize KF inputs/outputs later, though not strictly needed for the basic version.

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from filterpy.kalman import KalmanFilter
def train_test_kalman_filter(train_rates, train_vel, test_rates, test_vel):
    """
    Trains and applies a Kalman Filter for neural decoding.
    
    Returns:
        predicted_velocities (np.array): KF predicted velocities for test data.
        kf (KalmanFilter): The trained Kalman Filter object.
    """
    print("\n--- Setting up Kalman Filter ---")
    # Ensure inputs are numpy arrays
    if torch.is_tensor(train_rates):
        train_rates = train_rates.detach().cpu().numpy()
    if torch.is_tensor(train_vel):
        train_vel   = train_vel.detach().cpu().numpy()
    if torch.is_tensor(test_rates):
        test_rates  = test_rates.detach().cpu().numpy()
    # test_vel is not used for KF training/prediction, only evaluation later

    n_timesteps_train, n_neurons = train_rates.shape
    n_timesteps_test = test_rates.shape[0]
    n_outputs = train_vel.shape[1]  # Should be 2 for [vx, vy]
    if n_outputs != 2:
        raise ValueError("This KF implementation assumes 2D velocity (vx, vy)")
    
    # --- 1. Define state-space model dimensions ---
    dim_x = n_outputs   # State [vx, vy]
    dim_z = n_neurons   # Measurement (neural spike rates)
    
    # --- 2. Estimate Observation Model (H) via Linear Regression ---
    print("Estimating Observation Matrix H using Linear Regression...")
    regression_model = LinearRegression(fit_intercept=False)
    regression_model.fit(train_vel, train_rates)
    H = regression_model.coef_  # Expected shape: (n_neurons, 2)
    print(f"Estimated H shape: {H.shape}")
    
    # --- 3. Estimate Measurement Noise Covariance (R) ---
    print("Estimating Measurement Noise Covariance R from residuals...")
    residuals = train_rates - regression_model.predict(train_vel)
    R = np.cov(residuals.T)
    R += np.eye(dim_z) * 1e-6
    print(f"Estimated R shape: {R.shape} (Full covariance)")
    
    # --- 4. Estimate State Transition Model (F) ---
    print("Estimating State Transition Matrix F using Linear Regression...")
    X_F = train_vel[:-1]
    y_F = train_vel[1:]
    regression_model_F = LinearRegression(fit_intercept=False)
    regression_model_F.fit(X_F, y_F)
    F = regression_model_F.coef_  # Expected shape: (2, 2)
    print(f"Estimated State Transition Matrix F:\n{F}")
    
    # --- 5. Estimate Process Noise Covariance (Q) ---
    print("Estimating Process Noise Covariance Q from residuals...")
    residuals_F = y_F - regression_model_F.predict(X_F)
    Q = np.cov(residuals_F.T)
    Q += np.eye(dim_x) * 1e-6
    print(f"Estimated Process Noise Covariance Q:\n{Q}")
    
    # --- 6. Initialize Kalman Filter ---
    print("Initializing Kalman Filter...")
    kf = KalmanFilter(dim_x=dim_x, dim_z=dim_z)
    kf.x = train_vel[0].copy()      # Use first training velocity as initial state
    kf.P = np.eye(dim_x) * 500.     # High initial uncertainty
    kf.F = F
    kf.H = H
    kf.R = R
    kf.Q = Q
    
    # --- 7. Run the filter on test data ---
    print(f"Running Kalman Filter for {n_timesteps_test} test steps...")
    predicted_velocities = np.zeros((n_timesteps_test, dim_x)) # Fails if n_timesteps_test is based on None
    start_time_kf = time.time()
    for t in range(n_timesteps_test):
        z = test_rates[t].reshape(-1, 1)  # Measurement vector (column)
        kf.predict()
        kf.update(z)
        predicted_velocities[t] = kf.x.copy()
    end_time_kf = time.time()
    print(f"Kalman Filter processing time: {end_time_kf - start_time_kf:.2f} seconds")
    print("--- Kalman Filter Setup and Run Complete ---")
    return predicted_velocities, kf


##########################################
# SIMPLIFIED SPIKE GENERATION MODEL
##########################################
class Synthetic_Neuron:
    """
    Simple synthetic neuron that simulates spike trains based on directional tuning.
    Neurons are evenly distributed among four quadrants.
    """
    def __init__(self, num_neurons=180, noise_level=0.02):
        self.num_neurons = num_neurons
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Divide neurons equally among four quadrants.
        assert num_neurons % 4 == 0, "num_neurons must be divisible by 4"
        neurons_per_quadrant = num_neurons // 4
        right_angles = np.linspace(-np.pi/4, np.pi/4, neurons_per_quadrant)
        right_dirs = np.column_stack([np.cos(right_angles), np.sin(right_angles)])
        up_angles = np.linspace(np.pi/4, 3*np.pi/4, neurons_per_quadrant)
        up_dirs = np.column_stack([np.cos(up_angles), np.sin(up_angles)])
        left_angles = np.linspace(3*np.pi/4, 5*np.pi/4, neurons_per_quadrant)
        left_dirs = np.column_stack([np.cos(left_angles), np.sin(left_angles)])
        down_angles = np.linspace(5*np.pi/4, 7*np.pi/4, neurons_per_quadrant)
        down_dirs = np.column_stack([np.cos(down_angles), np.sin(down_angles)])
        all_dirs = np.vstack([right_dirs, up_dirs, left_dirs, down_dirs])
        self.preferred_directions = torch.tensor(all_dirs, dtype=torch.float32, device=self.device)
        # Parameters
        self.time_step = 0.01  # seconds
        self.max_firing_rate = 100  # Hz
        self.min_firing_rate = 5    # Hz
        self.max_acceleration = 2.0  # scaling factor for velocity
        self.noise_level = noise_level
        self.training = True

    def generate_spikes(self, velocity, sequence_length=50):
        """
        Given a target velocity (tensor of shape [2]), generate a spike train.
        Returns a tensor of shape [1, sequence_length, num_neurons].
        """
        vel_magnitude = torch.norm(velocity)
        if vel_magnitude > 1e-5:
            normalized_vel = velocity / vel_magnitude
        else:
            normalized_vel = velocity
        scaled_vel = normalized_vel / self.max_acceleration
        dir_match = torch.matmul(self.preferred_directions, scaled_vel)
        normalized_match = torch.clamp(dir_match, -0.5, 1.0)
        rate_range = self.max_firing_rate - self.min_firing_rate
        firing_rates = self.min_firing_rate + (normalized_match + 0.5) * rate_range * 0.67
        spike_probabilities = torch.clamp(firing_rates * self.time_step, 0.0, 1.0)
        if self.training:
            noise = torch.randn_like(spike_probabilities) * self.noise_level
            spike_probabilities = torch.clamp(spike_probabilities + noise, 0.0, 1.0)
        spike_probabilities = spike_probabilities.unsqueeze(0).repeat(sequence_length, 1)
        spikes = torch.bernoulli(spike_probabilities)
        return spikes.unsqueeze(0)  # [1, sequence_length, num_neurons]

class LSTMRegression(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.2):
        super(LSTMRegression, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # Ensure dropout is only applied if num_layers > 1
        lstm_dropout = dropout if num_layers > 1 else 0
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=lstm_dropout)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x shape: (batch, seq_len, input_size)
        # Initialize hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        # LSTM forward pass
        out, _ = self.lstm(x, (h0, c0))
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

def train_lstm_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=50, patience=5):
    """Basic training loop for the LSTM model with early stopping (handles val_loader=None)."""
    print(f"Starting LSTM training for {num_epochs} epochs on {device}...")
    best_val_loss = float('inf')
    epochs_no_improve = 0
    train_losses = []
    val_losses = []
    best_model_state = model.state_dict() # Start with initial state as best if no validation
    best_epoch = 0

    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            if inputs.dim() == 2:
                inputs = inputs.unsqueeze(1) # Add sequence dimension if needed
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item()
        avg_train_loss = running_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        avg_val_loss = None
        # Validation phase only if val_loader is provided and not empty
        if val_loader and len(val_loader.dataset) > 0:
            model.eval()
            running_val_loss = 0.0
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    if inputs.dim() == 2:
                        inputs = inputs.unsqueeze(1)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    running_val_loss += loss.item()
                avg_val_loss = running_val_loss / len(val_loader)
            val_losses.append(avg_val_loss)
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")

            # Check for improvement and early stopping based on validation loss
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                epochs_no_improve = 0
                best_model_state = model.state_dict()
                best_epoch = epoch + 1
                print(f"  New best validation loss: {best_val_loss:.6f}")
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"Early stopping triggered after {epoch+1} epochs.")
                    break
        else:
            # If no validation, just print training loss
            print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}")
            # Optionally save the model from the last epoch if no validation is done
            # best_model_state = model.state_dict()
            # best_epoch = epoch + 1

    print("Finished LSTM Training.")
    # Load the best model state before returning
    if best_model_state:
        model.load_state_dict(best_model_state)
        if val_loader and len(val_loader.dataset) > 0:
            print(f"Loaded best model state from epoch {best_epoch}.")
        # else:
            # print("Loaded model state from last epoch (no validation performed).")

    return model, train_losses, val_losses
# +++ End Re-inserted LSTM Definitions +++


##########################################
# SNNREGRESSION (FROM simplified_bci_area2bump.py)
##########################################
# %%
import random
import time
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from scipy.interpolate import interp1d
import numpy as np
import pandas as pd

class SNNRegression(nn.Module):
    def __init__(self, input_size=182, hidden_size=128, output_size=2, input_beta=0.9, use_mem=False):
        self.use_mem = use_mem
        super(SNNRegression, self).__init__()
        spike_grad = surrogate.fast_sigmoid()

        # --- Input Smoothing Layer ---
        self.lif_input = snn.Leaky(
            beta=input_beta,
            spike_grad=spike_grad,
            init_hidden=False
        )

        # Hidden & readout layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.lif1 = snn.Leaky(beta=0.9, spike_grad=spike_grad, init_hidden=False)

        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.lif2 = snn.Leaky(beta=0.9, spike_grad=spike_grad, init_hidden=False)

        self.fc3 = nn.Linear(hidden_size // 2, output_size)
        self.lif3 = snn.Leaky(
            beta=0.9,
            spike_grad=spike_grad,
            threshold=float('inf'), # wont matter even if 1
            reset_mechanism="none",
            init_hidden=False
        )

        # Xavier initialization
        self.apply(self._init_weights)

        # placeholders for membrane potentials
        self.mem_in = None
        self.mem1 = None
        self.mem2 = None
        self.mem3 = None

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def reset_states(self):
        """Call once before running a new sequence to zero out all hidden states."""
        # self.lif1.reset_hidden()
        # self.lif2.reset_hidden()
        # self.lif3.reset_hidden()
        self.mem_in = self.lif_input.reset_mem()
        self.mem1   = self.lif1.reset_mem()
        self.mem2   = self.lif2.reset_mem()
        self.mem3   = self.lif3.reset_mem()
        

    def forward(self, x):
        """
        x: [batch, T, features]
        This version *does not* re-init hidden states every call.
        Call reset_states() once before feeding in a new sequence.
        """
        batch_size, T, _ = x.size()
        device = x.device

        # ensure hidden states exist & are on correct device
        if self.mem_in is None or self.mem_in.device != device:
            self.reset_states()
            self.mem_in = self.mem_in.to(device)
            self.mem1   = self.mem1.to(device)
            self.mem2   = self.mem2.to(device)
            self.mem3   = self.mem3.to(device)

        outputs = []
        for t in range(T):
            inp = x[:, t, :].float().to(device)

            # 1) input smoothing LIF
            spk_in, self.mem_in = self.lif_input(inp, self.mem_in)
            #print(f"t={t:2d}  spike_in[0,0]={inp[0,0].item():.0f}  →  mem_in[0,0]={self.mem_in[0,0].item():.4f} spk_in[0,0]={spk_in[0,0].item():.4f})")

            # 2) hidden layer 1
            if self.use_mem:
                cur1 = self.fc1(self.mem_in)
            else:
                cur1 = self.fc1(spk_in)

            spk1, self.mem1 = self.lif1(cur1, self.mem1)

            # 3) hidden layer 2
            cur2 = self.fc2(spk1)
            spk2, self.mem2 = self.lif2(cur2, self.mem2)

            # 4) output layer (membrane potential readout)
            cur3 = self.fc3(spk2)
            _, self.mem3 = self.lif3(cur3, self.mem3)

            outputs.append(self.mem3)

        return torch.stack(outputs, dim=1)


##########################################
# HELPER: GET ACTIVATIONS (REFACTORED TO BE STATELESS)
##########################################
def get_activations(model: SNNRegression, x_t, mem_in_prev, mem1_prev, mem2_prev, mem3_prev):
    """
    Get the instantaneous activations and pre-currents for one time-step
    based on the provided previous membrane states.
    This function DOES NOT modify the model's internal state.
    """
    x_t = x_t.float() # Ensure input is float

    # Use detach() to prevent gradients from flowing back through these states
    # This might not be strictly necessary if called within no_grad(), but it's safer.
    mem_in_prev = mem_in_prev.detach()
    mem1_prev = mem1_prev.detach()
    mem2_prev = mem2_prev.detach()
    mem3_prev = mem3_prev.detach()

    # Process the single time-step using the provided previous states:
    spk_in, mem_in_next = model.lif_input(x_t, mem_in_prev)

    if model.use_mem:
        cur1 = model.fc1(mem_in_prev) # Use previous mem for current calc if use_mem=True
        pre_for_hebbian = mem_in_prev # Input to fc1 was mem_in
    else:
        cur1 = model.fc1(spk_in)
        pre_for_hebbian = spk_in      # Input to fc1 was spk_in
    spk1, mem1_next = model.lif1(cur1, mem1_prev)

    cur2 = model.fc2(spk1)
    spk2, mem2_next = model.lif2(cur2, mem2_prev)

    cur3 = model.fc3(spk2)
    _, mem3_next = model.lif3(cur3, mem3_prev) # Output layer only needs mem

    # Return the *newly computed* activations/potentials needed by the updater:
    # Return pre_for_hebbian instead of separate mem_in/spk_in logic
    # Return the *next* membrane potentials for the surrogate gradients
    return pre_for_hebbian, spk1, spk2, mem1_next, mem2_next, mem3_next
# --- END REFACTOR ---

def compute_correlation(pred, target):
    """
    Compute Pearson correlation between predicted and target values.
    Handles both tensor and numpy inputs.
    """
    if torch.is_tensor(pred):
        pred = pred.detach().cpu().numpy()
    if torch.is_tensor(target):
        target = target.detach().cpu().numpy()
    
    pred = pred.flatten()
    target = target.flatten()
    
    if np.std(pred) == 0 or np.std(target) == 0:
        return 0.0 # Or handle as NaN, depending on preference
    
    try:
        corr_matrix = np.corrcoef(pred, target)
        if corr_matrix.size > 1:
            return corr_matrix[0, 1]
        else:
            return 1.0 if np.array_equal(pred, target) else 0.0
    except (IndexError, ValueError):
        return 0.0

import torch.nn.functional as F
##########################################
# TWO-TIMESCALE META RL UPDATER WITH META-LEARNING
##########################################
class TwoScaleMetaRLWeightUpdaterFull:
    """
    This updater adapts all layers using two timescales and incorporates meta-learning.
    Implements biologically plausible Hebbian learning with the following properties:
    
    1. True Hebbian principle: weight changes proportional to pre-synaptic and 
       post-synaptic activity correlations (neurons that fire together, wire together)
    
    2. Biologically-inspired modulation: 
       - Similar to dopaminergic signaling in the brain, learning rate is modulated 
         by the correlation between predicted and desired outputs
       - When correlation is poor, learning is enhanced (like dopamine bursts)
       - When correlation is good, learning is reduced (like tonic dopamine)
    
    3. Hierarchical learning: 
       - Output layer learns directly from desired output signals
       - Hidden layers learn via activity propagation analogous to neural communication
       - Similar to how biological neuromodulators gate plasticity at different layers
    
    This approach integrates Hebbian principles with the needs of the BCI task.
    """
    def __init__(self, model, base_fast_lr=1e-3, base_slow_lr=1e-2, window_size=180, 
                 meta_lr=1e-3):
                 # Removed adaptive max_norm parameters: base_max_norm, min_max_norm, max_max_norm
                 # Removed norm_meta_lr, target_correlation
        self.model = model
        self.device = next(model.parameters()).device  # Get the device from the model
        self.base_fast_lr = base_fast_lr
        self.base_slow_lr = base_slow_lr
        self.meta_lr = meta_lr
        # Removed norm_meta_lr and target_correlation assignment
        
        self.meta_params = {
            'plasticity': 1.0, 
            'sensitivity': 1.0
            # Removed 'norm_factor'
        } 
        self.fast_lr = self.base_fast_lr * self.meta_params['plasticity']
        self.slow_lr = self.base_slow_lr * self.meta_params['sensitivity']
        
        # Removed Adaptive Max Norm parameters and initialization of self.current_max_norm

        self.window_size = window_size
        self.step_counter = 0
        self.cumulative_error = 0
        self.prev_cumulative_error = None
        
        # Removed Correlation accumulators
        
        
        self.max_loss_value = 5.0  # Maximum loss value for stability
        self.max_error_norm = 2.0  # Used in gradient calculations
        
        # Initialize gradient averages with zeros matching model layer shapes and device
        self.grad_fc1_avg = torch.zeros_like(self.model.fc1.weight.data, device=self.device)
        self.grad_fc2_avg = torch.zeros_like(self.model.fc2.weight.data, device=self.device)
        self.grad_fc3_avg = torch.zeros_like(self.model.fc3.weight.data, device=self.device)
        
        # Initialize momentum for gradient accumulation
        self.momentum = 0.9

    def reset_weights_with_xavier(self):
        """
        Reset weights using Xavier/Glorot initialization for better convergence
        after catastrophic errors.
        """
        with torch.no_grad():
            # Reset fc1 weights with Xavier/Glorot initialization
            nn.init.xavier_uniform_(self.model.fc1.weight)
            if self.model.fc1.bias is not None:
                nn.init.zeros_(self.model.fc1.bias)
                
            # Reset fc2 weights with Xavier/Glorot initialization
            nn.init.xavier_uniform_(self.model.fc2.weight)
            if self.model.fc2.bias is not None:
                nn.init.zeros_(self.model.fc2.bias)
                
            # Reset fc3 weights with Xavier/Glorot initialization
            nn.init.xavier_uniform_(self.model.fc3.weight)
            if self.model.fc3.bias is not None:
                nn.init.zeros_(self.model.fc3.bias)
                
        print("✅ Weights reset using Xavier/Glorot initialization for better convergence")
        
    def correlation_loss(self, pred, target):
        """Calculate differentiable correlation loss."""
        # Separate x and y components
        pred_x, pred_y = pred[:, 0], pred[:, 1]
        target_x, target_y = target[:, 0], target[:, 1]
        
        # Calculate means
        pred_x_mean = pred_x.mean()
        pred_y_mean = pred_y.mean()
        target_x_mean = target_x.mean()
        target_y_mean = target_y.mean()
        
        # Center the data
        pred_x_centered = pred_x - pred_x_mean
        pred_y_centered = pred_y - pred_y_mean
        target_x_centered = target_x - target_x_mean
        target_y_centered = target_y - target_y_mean
        
        # Calculate correlation for each dimension
        eps = 1e-8  # Avoid division by zero
        x_corr = (pred_x_centered * target_x_centered).sum() / (torch.sqrt((pred_x_centered ** 2).sum() * (target_x_centered ** 2).sum()) + eps)
        y_corr = (pred_y_centered * target_y_centered).sum() / (torch.sqrt((pred_y_centered ** 2).sum() * (target_y_centered ** 2).sum()) + eps)
        
        # Transform to loss (1 - correlation)
        x_loss = 1.0 - torch.abs(x_corr)
        y_loss = 1.0 - torch.abs(y_corr)
        
        return (x_loss + y_loss) / 2.0, x_corr, y_corr
        
    def fast_update(self, input_tensor, desired_velocity, pred_velocity):
        """
        Novel biologically plausible learning rule with balanced MSE/correlation optimization.
        Uses linear error approximation for backpropagation, but modulates the Hebbian
        update magnitude by the post-synaptic neuron's sensitivity (surrogate gradient).
        """
        # Ensure inputs are on the correct device
        input_tensor = input_tensor.to(self.device)
        # print(f"This is input tensor: {input_tensor}")
        desired_velocity = desired_velocity.to(self.device)
        pred_velocity = pred_velocity.to(self.device) # Note: pred_velocity here is the model output (mem3)
        # print(f"This is pred velocity: {pred_velocity}")
        
        # print(f"This is desired velocity: {desired_velocity}")
        # Get network activations and pre-activation currents
        # --- MODIFIED: Call the stateless get_activations ---
        # Get the single time step input
        x_t = input_tensor[:, 0, :].float() # shape [batch, num_neurons]

        # Call stateless function, passing current model membrane states
        pre_fc1, spk1, spk2, mem1_next, mem2_next, mem3_next = get_activations(
            self.model, x_t,
            self.model.mem_in, self.model.mem1, self.model.mem2, self.model.mem3
        )
        # Note: pred_velocity comes from the model's forward pass, which uses the *updated* mem3.
        # The surrogate gradients should ideally be based on the potential *before* the final output spike
        # (or the potential that led to the output), which mem3_next represents here.
        mem3_for_grad = mem3_next
        # --- END MODIFIED ---

        # Get desired and predicted outputs
        des = desired_velocity
        pred = pred_velocity # Use the prediction from the main forward pass
        
        # Calculate MSE loss
        mse_loss = F.mse_loss(pred, des)
        
        # Calculate correlation loss
        # corr_loss, x_corr, y_corr = self.correlation_loss(pred, des)
        
        # Dynamic weighting based on correlation quality
        # avg_corr = (torch.abs(x_corr) + torch.abs(y_corr)) / 2.0
        # corr_weight = torch.clamp(1.0 - avg_corr, 0.3, 0.7) # <-- COMMENTED OUT
        
        # --- USE FIXED WEIGHT INSTEAD --- (Emphasize MSE)
        corr_weight = 0.0 
        # ------------------------------
        
        # Combined loss with dynamic weighting
        combined_loss = (1 - corr_weight) * mse_loss + corr_weight * 0
        combined_loss = torch.clamp(combined_loss, 0.0, 10.0)
        
        # Calculate updates using modulated Hebbian rule
        with torch.no_grad():
            # Compute output error
            output_error = (des - pred).to(self.device) # Ensure error is on device
            
            # Backpropagate error linearly through layers (same as before)
            hidden2_error = torch.matmul(output_error, self.model.fc3.weight).to(self.device) # Ensure error is on device
            hidden1_error = torch.matmul(hidden2_error, self.model.fc2.weight).to(self.device) # Ensure error is on device

            # --- Calculate post-synaptic sensitivities using surrogate gradients ---
            # --- MODIFIED: Use the *next* membrane potentials calculated by get_activations ---
            d_lif3 = self.model.lif3.spike_grad(mem3_for_grad).to(self.device) # Use mem3_next
            d_lif2 = self.model.lif2.spike_grad(mem2_next).to(self.device)     # Use mem2_next
            d_lif1 = self.model.lif1.spike_grad(mem1_next).to(self.device)     # Use mem1_next
            # --- END MODIFIED ---
            # --------------------------------------------------------------------
            
            # --- Compute Hebbian updates: (error * post_sensitivity) * pre_activity ---
            # Modulate the error signal locally by the sensitivity before the outer product
            hebbian_fc3 = torch.matmul((output_error * d_lif3).t(), spk2)
            hebbian_fc2 = torch.matmul((hidden2_error * d_lif2).t(), spk1)
            # --- MODIFIED: Use pre_fc1 (which is mem_in or spk_in) for hebbian_fc1 ---
            hebbian_fc1 = torch.matmul((hidden1_error * d_lif1).t(), pre_fc1) # Use pre_fc1
            # --- END MODIFIED ---
            # -----------------------------------------------------------------------
            
            # Apply updates with dynamic learning rate (modulated by correlation loss)
            effective_lr = self.fast_lr 
            
            # Update weights
            self.model.fc3.weight.data += effective_lr * hebbian_fc3
            self.model.fc2.weight.data += effective_lr * hebbian_fc2
            self.model.fc1.weight.data += effective_lr * hebbian_fc1
            
            # Accumulate Hebbian updates for slow update using class momentum
            # Ensure gradient averages exist (reset by slow_update)
            self.grad_fc3_avg = self.momentum * self.grad_fc3_avg + (1 - self.momentum) * hebbian_fc3
            self.grad_fc2_avg = self.momentum * self.grad_fc2_avg + (1 - self.momentum) * hebbian_fc2
            self.grad_fc1_avg = self.momentum * self.grad_fc1_avg + (1 - self.momentum) * hebbian_fc1
            
            # Apply weight normalization for stability using the fixed max_norm
            strict_max_norm = 1.0 # Hardcoded value
            self.model.fc3.weight.data = self.normalize_weights(self.model.fc3.weight.data, max_norm=strict_max_norm)
            self.model.fc2.weight.data = self.normalize_weights(self.model.fc2.weight.data, max_norm=strict_max_norm)
            self.model.fc1.weight.data = self.normalize_weights(self.model.fc1.weight.data, max_norm=strict_max_norm)
        
        # Removed correlation accumulation
        
        # Return combined loss and correlations
        return combined_loss.item(), 0, 0

    def normalize_weights(self, weights, max_norm=10.0):
        """Normalize weights to prevent explosion while maintaining direction."""
        norm = torch.norm(weights, dim=1, keepdim=True)
        scale = torch.clamp(norm, 0, max_norm)
        return weights * (scale / (norm + 1e-8))

    def slow_update(self):
        """
        Apply consolidated Hebbian updates and adapt meta-parameters.
        Similar to long-term potentiation in biological neural systems.
        """
        # Apply slow update using moving averages (analogous to protein synthesis in LTP)
        with torch.no_grad():
            self.model.fc1.weight.data += self.slow_lr * self.grad_fc1_avg
            self.model.fc2.weight.data += self.slow_lr * self.grad_fc2_avg
            self.model.fc3.weight.data += self.slow_lr * self.grad_fc3_avg
            
            # Weight clipping (homeostatic regulation in biological systems)
            self.model.fc1.weight.data.clamp_(-10.0, 10.0)
            self.model.fc2.weight.data.clamp_(-10.0, 10.0)
            self.model.fc3.weight.data.clamp_(-10.0, 10.0)
        
        # Update meta-parameters based on learning progress (metacognitive regulation)
        current_avg_loss = self.cumulative_error / self.step_counter if self.step_counter > 0 else float('inf')
        current_avg_loss = min(current_avg_loss, self.max_loss_value)
        
        if self.prev_cumulative_error is not None:
            # Ensure previous loss is also bounded
            prev_loss = min(self.prev_cumulative_error, self.max_loss_value)
            
            # Adapt meta-parameters based on improvement (analogous to metaplasticity)
            if current_avg_loss < prev_loss:
                # Increase plasticity and sensitivity when learning is successful
                self.meta_params['plasticity'] *= (1 + self.meta_lr)
                self.meta_params['sensitivity'] *= (1 + self.meta_lr)
            else:
                # Decrease plasticity and sensitivity when learning plateaus
                self.meta_params['plasticity'] *= (1 - self.meta_lr)
                self.meta_params['sensitivity'] *= (1 - self.meta_lr)
            
            # Keep meta-parameters in biologically reasonable range
            self.meta_params['plasticity'] = np.clip(self.meta_params['plasticity'], 0.2, 1.5)
            self.meta_params['sensitivity'] = np.clip(self.meta_params['sensitivity'], 0.2, 1.5)
        
        # Update learning rates based on meta-parameters
        self.fast_lr = self.base_fast_lr * self.meta_params['plasticity']
        self.slow_lr = self.base_slow_lr * self.meta_params['sensitivity']
        
        # --- Removed adaptation of norm_factor and current_max_norm --- 
        
        # Store current loss for next comparison
        self.prev_cumulative_error = current_avg_loss
        
        # Reset accumulators for next learning window - use zero tensors
        self.grad_fc1_avg.zero_() 
        self.grad_fc2_avg.zero_()
        self.grad_fc3_avg.zero_()
        self.cumulative_error = 0
        # Removed reset of correlation accumulators
        self.step_counter = 0

    def update(self, input_tensor, pred_velocity, desired_velocity):
        """
        Main update function that handles both fast and slow Hebbian updates.
        Similar to how biological learning consolidates over multiple timescales.
        """
        # Make sure tensors are on the correct device right at the start
        if hasattr(self, 'device'):
            pred_velocity = pred_velocity.to(self.device)
            desired_velocity = desired_velocity.to(self.device)
            input_tensor = input_tensor.to(self.device)
        else: # Fallback if device attribute not set (should not happen)
            print("Warning: Updater device attribute not found.")
        
        # Compute Hebbian updates and apply fast update
        # Now we know all inputs to fast_update are on the correct device
        combined_loss, x_corr, y_corr = self.fast_update(input_tensor, desired_velocity, pred_velocity)
        
        # SAFETY: Ensure loss is bounded for accumulated metrics
        bounded_loss = min(combined_loss, self.max_loss_value)
        
        # Handle catastrophic errors with Xavier/Glorot initialization for better recovery
        if bounded_loss > self.max_loss_value * 1.5:  # Truly catastrophic error
            print(f"🚨 CATASTROPHIC ERROR DETECTED: {bounded_loss:.4f}. Performing Xavier weight reset.")
            self.reset_weights_with_xavier()
            return bounded_loss, x_corr, y_corr
            
        # Track cumulative loss for meta-learning
        self.cumulative_error += bounded_loss
        self.step_counter += 1
        
        # Apply slow update if window is complete
        if self.step_counter >= self.window_size:
            self.slow_update()
            self.step_counter = 0
            self.cumulative_error = 0
        
        return bounded_loss, x_corr, y_corr

##########################################
# CURSOR TASK ENVIRONMENT
##########################################
screen_width, screen_height = 800, 600
target_radius = 50
CENTER_POS = np.array([screen_width/2, screen_height/2], dtype=np.float32)

def new_target():
    margin = target_radius * 2
    
    # Divide screen into quadrants and ensure target is in outer region
    quadrant = np.random.randint(0, 4)  # Randomly select a quadrant (0-3)
    
    if quadrant == 0:  # Top-left
        x = np.random.uniform(margin, screen_width * 0.3)
        y = np.random.uniform(margin, screen_height * 0.3)
    elif quadrant == 1:  # Top-right
        x = np.random.uniform(screen_width * 0.7, screen_width - margin)
        y = np.random.uniform(margin, screen_height * 0.3)
    elif quadrant == 2:  # Bottom-left
        x = np.random.uniform(margin, screen_width * 0.3)
        y = np.random.uniform(screen_height * 0.7, screen_height - margin)
    else:  # Bottom-right
        x = np.random.uniform(screen_width * 0.7, screen_width - margin)
        y = np.random.uniform(screen_height * 0.7, screen_height - margin)
    
    return np.array([x, y], dtype=np.float32)

# ---------------------------------- #
# EXPERIMENT FRAMEWORK SETUP         #
# ---------------------------------- #
# Define experiment phases with clear boundaries
PHASE_INITIAL_LEARNING = 0    # Initial learning phase
PHASE_ADAPT_TO_DISRUPTION = 1  # Neural disruption applied and SNN adapts to it
PHASE_RECOVERY = 2    # System adapts to disruption
PHASE_EVALUATION = 3  # Evaluate final performance

# Phase durations (NOW IN SUCCESSFUL TARGET REACHES)
INITIAL_LEARNING_TARGETS = 400 # Number of successful reaches for initial learning
ADAPT_TO_DISRUPTION_TARGETS = 120 # Number of successful reaches while adapting to disruption

# Define disruption parameters
DISRUPTION_TYPE = "remapping"  # Options: "dropout", "remapping", "drift"
DISRUPTION_INTENSITY = 0.5  # Intensity of disruption (0-1) # CHANGED FROM 0.99 to 0.5

# Results directory for saving graphs and data
RESULTS_DIR = "experiment_results_" + DISRUPTION_TYPE # Changed directory name based on type

# Function to get target based on phase
def get_target_for_phase(phase, step_or_eval_idx, current_center_pos): # Renamed arg for clarity
    """Generate appropriate peripheral targets from the center based on experiment phase"""
    global screen_width, screen_height, target_radius, CENTER_POS # Access globals

    # current_center_pos is expected to be the screen center (or where the last reach ended, which should be center)
    # For consistency, always calculate from CENTER_POS for peripheral targets
    
    if phase == PHASE_EVALUATION:
        # During evaluation, use 8 evenly spaced directions from CENTER_POS.
        # step_or_eval_idx is the evaluation_target_count
        angle = (step_or_eval_idx % 8) * (np.pi / 4)
        # Ensure distance is reasonable and targets are within bounds
        distance = min(screen_width, screen_height) * 0.35  # e.g., 35% of smaller screen dimension from center
        
        x = CENTER_POS[0] + np.cos(angle) * distance
        y = CENTER_POS[1] + np.sin(angle) * distance
        
        # Clip to ensure target is on screen and respects a margin
        margin = target_radius # Use the standard target_radius as margin from edge
        x = np.clip(x, margin, screen_width - margin)
        y = np.clip(y, margin, screen_height - margin)
        return np.array([x, y], dtype=np.float32)
    else:
        # For Baseline, Disruption, Recovery phases:
        # new_target() generates a random peripheral target.
        # The reach will start from current_center_pos (which should be CENTER_POS).
        return new_target()

# Function to apply disruption
def apply_disruption(spike_gen, disruption_type, intensity, original_properties, device, num_neurons):
    """Apply controlled neural disruption"""
    print(f"\nAPPLYING DISRUPTION: {disruption_type.upper()} with intensity {intensity}")
    
    if disruption_type == "dropout":
        # Dropout handled in main loop by multiplying spikes by mask
        print(f"Simulating dropout: {intensity*100}% of neurons will be silenced externally.")
        spike_gen.noise_level = min(0.1, original_properties['noise_level'] * (1 + intensity*2))
        print(f"Increased noise level to {spike_gen.noise_level:.3f}")
        
    elif disruption_type == "remapping":
        # Remap preferred directions for some neurons
        remap_indices = torch.randperm(num_neurons)[:int(num_neurons * intensity)]
        print(f"Remapping {len(remap_indices)} neurons ({intensity*100}% of population)")
        
        # Generate new random directions (shifted from original)
        new_angles = torch.rand(len(remap_indices), device=device) * 2 * np.pi # Ensure on correct device
        new_dirs = torch.stack([torch.cos(new_angles), torch.sin(new_angles)], dim=1).to(device)
        
        # Apply remapping
        spike_gen.preferred_directions[remap_indices] = new_dirs
        print(f"Neural remapping applied successfully")
        
    elif disruption_type == "drift":
        # Gradually change firing rates
        original_max_rate = spike_gen.max_firing_rate
        new_max_rate = original_max_rate * (1.0 - intensity*0.5)  # Reduce max firing rate
        
        # Also change min firing rate
        original_min_rate = spike_gen.min_firing_rate
        new_min_rate = original_min_rate * (1.0 + intensity)  # Increase min rate
        
        spike_gen.max_firing_rate = new_max_rate
        spike_gen.min_firing_rate = new_min_rate
        
        print(f"Applied firing rate drift: max rate {original_max_rate:.1f} → {new_max_rate:.1f}, min rate {original_min_rate:.1f} → {new_min_rate:.1f}")
    
    return True

# Function to end disruption
def end_disruption(spike_gen, disruption_type, original_properties):
    """Restore original neural properties after disruption phase"""
    print(f"\nENDING DISRUPTION: {disruption_type.upper()}")
    
    if disruption_type == "dropout":
        # Restore original noise level
        spike_gen.noise_level = original_properties['noise_level']
        print(f"Restored noise level to {spike_gen.noise_level:.3f}")
        
    elif disruption_type == "remapping":
        # Restore original preferred directions
        spike_gen.preferred_directions = original_properties['preferred_directions'].clone()
        print(f"Restored original neural tuning directions")
        
    elif disruption_type == "drift":
        # Restore original firing rates
        spike_gen.max_firing_rate = original_properties['max_firing_rate']
        spike_gen.min_firing_rate = original_properties['min_firing_rate'] # Use saved original min rate
        print(f"Restored original firing rates: max={spike_gen.max_firing_rate:.1f}, min={spike_gen.min_firing_rate:.1f}")
    
    return True

# ---> DEFINE HISTORY AND SLIDING WINDOW VARS GLOBALLY (BEFORE FUNCTIONS THAT USE THEM) < ---
# --- Correlation Calculation Parameters ---
corr_calc_window = 50 # Calculate correlation over this many steps
lag = 25              # Lag (in steps) to apply to desired velocity for correlation
des_buffer_size = corr_calc_window + lag # Buffer for desired velocity needs to hold lag
pred_buffer_size = corr_calc_window       # Buffer for predicted velocity
# --- END Parameters ---

pred_vx_buffer = deque(maxlen=pred_buffer_size)
pred_vy_buffer = deque(maxlen=pred_buffer_size)
des_vx_buffer = deque(maxlen=des_buffer_size)
des_vy_buffer = deque(maxlen=des_buffer_size)

sliding_x_corr_history = []
sliding_y_corr_history = []
# --- ADD variables to hold last calculated correlation ---
last_calc_x_corr = 0.0
last_calc_y_corr = 0.0
# --- END ADD ---

# Define other history lists globally as well
time_points = []
cursor_x_history = []
cursor_y_history = []
desired_x_history = []
desired_y_history = []
pred_vx_history = []
pred_vy_history = []
desired_vx_history = []
desired_vy_history = []
weight_norms = {'fc1': [], 'fc2': [], 'fc3': []}
# ---> END GLOBAL DEFINITIONS <---

# Function to plot final results by phase
def plot_experiment_results(phase_metrics_data, current_results_dir, disruption_type="drift", phase_names_map_for_plot=None): # ADDED ARGUMENTS
    """Plot clean, simple graphs showing metrics for each experiment phase"""
    plt.ioff()  # Turn off interactive mode for generating these plots
    
    # Create results directory if it doesn't exist
    os.makedirs(current_results_dir, exist_ok=True) # USE current_results_dir
    
    # Updated phase data structure for two phases (this defines the structure expected in phase_metrics_data)
    phases_data_for_plot = {
        "INITIAL_LEARNING": {"steps": [], "mse": [], "distance": []},
        "ADAPT_TO_DISRUPTION": {"steps": [], "mse": [], "distance": []}
    }
    # If phase_names_map_for_plot is not provided, use a default one matching the expected structure
    if phase_names_map_for_plot is None:
        phase_names_map_for_plot = {
            PHASE_INITIAL_LEARNING: "INITIAL_LEARNING",
            PHASE_ADAPT_TO_DISRUPTION: "ADAPT_TO_DISRUPTION"
        }
    
    # Populate data for plotting from the passed phase_metrics_data
    total_steps_processed = 0 
    for phase_id, phase_name_key in phase_names_map_for_plot.items():
        if phase_id in phase_metrics_data: # Check if phase exists in metrics
            # Ensure the key exists in phases_data_for_plot before trying to access it
            if phase_name_key not in phases_data_for_plot:
                print(f"Warning: phase_name_key '{phase_name_key}' not in phases_data_for_plot. Skipping.")
                continue

            phase_data_from_metrics = phase_metrics_data[phase_id]
            num_steps_in_phase = len(phase_data_from_metrics['mse'])
            for i in range(num_steps_in_phase):
                total_steps_processed += 1
                phases_data_for_plot[phase_name_key]["steps"].append(total_steps_processed)
                phases_data_for_plot[phase_name_key]["mse"].append(phase_data_from_metrics['mse'][i])
                phases_data_for_plot[phase_name_key]["distance"].append(phase_data_from_metrics['distance'][i])
    
    # Set up colors
    colors = {
        "INITIAL_LEARNING": "blue",
        "ADAPT_TO_DISRUPTION": "red"
    }
    
    # --- MSE PLOT ---
    plt.figure(figsize=(12, 6))
    for phase_name_key in phases_data_for_plot.keys(): # Iterate through the keys of our new dict
        if len(phases_data_for_plot[phase_name_key]["steps"]) > 0:
            plt.plot(phases_data_for_plot[phase_name_key]["steps"], phases_data_for_plot[phase_name_key]["mse"], 
                    color=colors[phase_name_key], label=phase_name_key, linewidth=2)
            
            avg_mse = np.mean(phases_data_for_plot[phase_name_key]["mse"]) if len(phases_data_for_plot[phase_name_key]["mse"]) > 0 else 0
            plt.axhline(y=avg_mse, color=colors[phase_name_key], linestyle='--', alpha=0.5)
            
            mid_point = np.mean(phases_data_for_plot[phase_name_key]["steps"]) if len(phases_data_for_plot[phase_name_key]["steps"]) > 0 else 0
            if mid_point > 0:
                 plt.text(mid_point, avg_mse + np.std(phases_data_for_plot[phase_name_key]["mse"])*0.1 if len(phases_data_for_plot[phase_name_key]["mse"]) > 1 else avg_mse+0.01, 
                         f"Avg: {avg_mse:.4f}", ha='center', color=colors[phase_name_key], fontweight='bold')
    
    # Add phase separators (now only one separator if both phases have data)
    phase_boundaries = [0]
    last_step_initial = 0
    if len(phases_data_for_plot["INITIAL_LEARNING"]["steps"]) > 0:
        last_step_initial = phases_data_for_plot["INITIAL_LEARNING"]["steps"][-1]
        phase_boundaries.append(last_step_initial)
    if len(phases_data_for_plot["ADAPT_TO_DISRUPTION"]["steps"]) > 0:
        phase_boundaries.append(total_steps_processed) # End of all processing
    
    # Only draw the separator between the two phases
    if len(phase_boundaries) > 2 and phase_boundaries[1] > 0 and phase_boundaries[1] < total_steps_processed:
        plt.axvline(x=phase_boundaries[1], color='black', linestyle=':', alpha=0.6)
    
    plt.title(f"MSE Error During Experiment Phases\nDisruption Type: {disruption_type.capitalize()}")
    plt.xlabel("Experiment Step")
    plt.ylabel("MSE Error")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(current_results_dir, f"mse_by_phase_{disruption_type}.png"), dpi=300) # USE current_results_dir
    plt.close()
    
    # --- DISTANCE PLOT (Similar adjustments) ---
    plt.figure(figsize=(12, 6))
    for phase_name_key in phases_data_for_plot.keys():
        if len(phases_data_for_plot[phase_name_key]["steps"]) > 0:
            plt.plot(phases_data_for_plot[phase_name_key]["steps"], phases_data_for_plot[phase_name_key]["distance"], 
                    color=colors[phase_name_key], label=phase_name_key, linewidth=2)
            
            avg_dist = np.mean(phases_data_for_plot[phase_name_key]["distance"]) if len(phases_data_for_plot[phase_name_key]["distance"]) > 0 else 0
            plt.axhline(y=avg_dist, color=colors[phase_name_key], linestyle='--', alpha=0.5)
            
            mid_point = np.mean(phases_data_for_plot[phase_name_key]["steps"]) if len(phases_data_for_plot[phase_name_key]["steps"]) > 0 else 0
            if mid_point > 0:
                plt.text(mid_point, avg_dist + 20, f"Avg: {avg_dist:.1f}", 
                        ha='center', color=colors[phase_name_key], fontweight='bold')
    
    if len(phase_boundaries) > 2 and phase_boundaries[1] > 0 and phase_boundaries[1] < total_steps_processed:
        plt.axvline(x=phase_boundaries[1], color='black', linestyle=':', alpha=0.6)
        
    plt.title(f"Distance to Target During Experiment Phases\nDisruption Type: {disruption_type.capitalize()}")
    plt.xlabel("Experiment Step")
    plt.ylabel("Distance (pixels)")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(current_results_dir, f"distance_by_phase_{disruption_type}.png"), dpi=300) # USE current_results_dir
    plt.close()
    
    # --- SUMMARY BAR CHART (Adjusted for two phases) ---
    plt.figure(figsize=(8, 6)) # Slightly smaller for two bars
    
    phase_names_list_for_bar = list(phases_data_for_plot.keys())
    x = np.arange(len(phase_names_list_for_bar))
    width = 0.35 
    
    avg_mses = [np.mean(phases_data_for_plot[p]["mse"]) if len(phases_data_for_plot[p]["mse"]) > 0 else 0 for p in phase_names_list_for_bar]
    
    bars_mse = plt.bar(x, avg_mses, width, label="Mean MSE Error", color=colors[phase_names_list_for_bar[0]], alpha=0.7) # Use first color, or map
    # A more robust way for bar colors if they differ:
    bar_plot_colors = [colors[phase_name] for phase_name in phase_names_list_for_bar]
    bars_mse = plt.bar(x, avg_mses, width, label="Mean MSE Error", color=bar_plot_colors, alpha=0.7)

    plt.bar_label(bars_mse, fmt='%.4f')
    
    plt.xlabel("Experiment Phase")
    plt.ylabel("Average MSE Error")
    plt.title(f"Performance Summary (MSE) by Experiment Phase\nDisruption Type: {disruption_type.capitalize()}")
    plt.xticks(x, phase_names_list_for_bar)
    plt.legend()
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(current_results_dir, f"performance_summary_mse_{disruption_type}.png"), dpi=300) # USE current_results_dir
    plt.close()
    
    print(f"\nExperiment result graphs saved to {current_results_dir}/ directory (MSE and Distance only for two phases)")


# ---> DEFINE save_trajectory_plots HERE (after global vars are defined) < ---
# Function to save final trajectory plots
# Accesses global history lists and correlation_window_size directly
def save_trajectory_plots( # ADDED ARGUMENTS
    current_results_dir, # For saving the plot
    time_points_data, cursor_x_history_data, cursor_y_history_data,
    desired_x_history_data, desired_y_history_data,
    pred_vx_history_data, pred_vy_history_data,
    desired_vx_history_data, desired_vy_history_data,
    sliding_x_corr_history_data, sliding_y_corr_history_data,
    target_reach_steps_data, target_reach_durations_data,
    initial_learning_completion_step_data, # Renamed for clarity
    phase_metrics_data, # To get phase boundary information
    corr_calc_window_val, lag_val, # Pass these as values
    screen_width_val, screen_height_val, # Pass screen dimensions
    phase_names_val # Pass phase names dictionary
    ):
    # Declare globals used  <- REMOVE THIS, use passed arguments
    # global time_points, cursor_x_history, cursor_y_history, desired_x_history, desired_y_history
    # global pred_vx_history, pred_vy_history, desired_vx_history, desired_vy_history
    # global sliding_x_corr_history, sliding_y_corr_history, correlation_window_size
    # global screen_width, screen_height
    # global PHASE_INITIAL_LEARNING, PHASE_ADAPT_TO_DISRUPTION, PHASE_RECOVERY, PHASE_EVALUATION
    # # ---> ADD new global lists < ---
    # global target_reach_steps, target_reach_durations
    # # ---> END ADD < ---

    plt.ioff() # Ensure plotting is off for saving

    # Get timestamp for file naming
    timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    
    # Create a new figure for the saved plots with high resolution
    # ---> REVERT GridSpec to 3x3 for 6 plots < ---
    save_fig = plt.figure(figsize=(18, 12), dpi=150) # Reverted height
    save_gs = plt.GridSpec(3, 3, figure=save_fig) # Reverted to 3 rows
    # ---> END REVERT < ---

    # Check history length before proceeding
    if not time_points_data: # USE time_points_data
        print("No data points recorded, skipping trajectory plot saving.")
        plt.close(save_fig)
        return
        
    # ---> INSERT MISSING SUBPLOT DEFINITIONS <---    
    save_ax_traj = save_fig.add_subplot(save_gs[0:2, 0:2])
    save_ax_posx = save_fig.add_subplot(save_gs[0, 2])
    save_ax_posy = save_fig.add_subplot(save_gs[1, 2])
    save_ax_velx = save_fig.add_subplot(save_gs[2, 0])
    save_ax_vely = save_fig.add_subplot(save_gs[2, 1])
    save_ax_corr = save_fig.add_subplot(save_gs[2, 2])
    # ---> REMOVE 7th subplot definition < ---
    # save_ax_reach_time = save_fig.add_subplot(save_gs[3, 0:]) # Span bottom row
    # ---> END REMOVE < ---
    # ---> END INSERT <---    

    # 1. TRAJECTORY PLOT (Now uses save_ax_traj)
    save_ax_traj.set_title("Cursor Trajectory", fontsize=14)
    save_ax_traj.set_xlabel("X Position", fontsize=12)
    save_ax_traj.set_ylabel("Y Position", fontsize=12)
    save_ax_traj.set_xlim(0, screen_width_val) # USE screen_width_val
    save_ax_traj.set_ylim(0, screen_height_val) # USE screen_height_val
    save_ax_traj.set_aspect('equal')
    save_ax_traj.grid(True)
    save_ax_traj.plot(cursor_x_history_data, cursor_y_history_data, 'b-', alpha=0.8, linewidth=2, label="Cursor Path") # USE ..._data

    # 2. X POSITION OVER TIME (Now uses save_ax_posx)
    save_ax_posx.set_title("X Position vs Time", fontsize=14)
    save_ax_posx.set_xlabel("Time Step", fontsize=12)
    save_ax_posx.set_ylabel("X Position", fontsize=12)
    save_ax_posx.grid(True)
    save_ax_posx.plot(time_points_data, cursor_x_history_data, 'b-', linewidth=2, label="Actual") # USE ..._data
    save_ax_posx.plot(time_points_data, desired_x_history_data, 'g--', linewidth=1.5, label="Desired") # USE ..._data
    save_ax_posx.legend()

    # 3. Y POSITION OVER TIME (Now uses save_ax_posy)
    save_ax_posy.set_title("Y Position vs Time", fontsize=14)
    save_ax_posy.set_xlabel("Time Step", fontsize=12)
    save_ax_posy.set_ylabel("Y Position", fontsize=12)
    save_ax_posy.grid(True)
    save_ax_posy.plot(time_points_data, cursor_y_history_data, 'b-', linewidth=2, label="Actual") # USE ..._data
    save_ax_posy.plot(time_points_data, desired_y_history_data, 'g--', linewidth=1.5, label="Desired") # USE ..._data
    save_ax_posy.legend()

    # 4. X VELOCITY OVER TIME (Now uses save_ax_velx)
    save_ax_velx.set_title("X Velocity vs Time", fontsize=14)
    save_ax_velx.set_xlabel("Time Step", fontsize=12)
    save_ax_velx.set_ylabel("X Velocity", fontsize=12)
    save_ax_velx.grid(True)
    save_ax_velx.plot(time_points_data, pred_vx_history_data, 'm-', linewidth=2, label="Predicted") # USE ..._data
    save_ax_velx.plot(time_points_data, desired_vx_history_data, 'g--', linewidth=1.5, label="Desired") # USE ..._data
    save_ax_velx.legend()

    # 5. Y VELOCITY OVER TIME (Now uses save_ax_vely)
    save_ax_vely.set_title("Y Velocity vs Time", fontsize=14)
    save_ax_vely.set_xlabel("Time Step", fontsize=12)
    save_ax_vely.set_ylabel("Y Velocity", fontsize=12)
    save_ax_vely.grid(True)
    save_ax_vely.plot(time_points_data, pred_vy_history_data, 'm-', linewidth=2, label="Predicted") # USE ..._data
    save_ax_vely.plot(time_points_data, desired_vy_history_data, 'g--', linewidth=1.5, label="Desired") # USE ..._data
    save_ax_vely.legend()

    # 6. SLIDING WINDOW CORRELATION METRICS (Now uses save_ax_corr)
    save_ax_corr.set_title(f"Sliding Window Correlation (N={corr_calc_window_val}, Lag={lag_val})", fontsize=14) # USE ..._val
    save_ax_corr.set_xlabel("Time Step", fontsize=12)
    save_ax_corr.set_ylabel("Correlation", fontsize=12)
    save_ax_corr.set_ylim(-1.1, 1.1)
    save_ax_corr.grid(True)
    save_ax_corr.plot(time_points_data, sliding_x_corr_history_data, 'r-', linewidth=2, label="X Corr (Sliding)") # USE ..._data
    save_ax_corr.plot(time_points_data, sliding_y_corr_history_data, 'b-', linewidth=2, label="Y Corr (Sliding)") # USE ..._data
    save_ax_corr.legend()
    
    # ---> REMOVE Plotting logic for Time-to-Target from this figure < ---
    # save_ax_reach_time.set_title("Time to Reach Target", fontsize=14)
    # ... [rest of the old plotting logic removed] ...
    # ---> END REMOVE < ---
    
    # Add phase transition markers on all time plots
    # ... (Phase transition logic uses save_ax_... variables) ...

    plt.tight_layout()

    # Save figure
    if not os.path.exists(current_results_dir): # USE current_results_dir
        os.makedirs(current_results_dir)

    # ---> INSERT MISSING FILENAME DEFINITION <---    
    plot_filename = os.path.join(current_results_dir, f'trajectory_plots_{timestamp}.png') # USE current_results_dir
    # ---> END INSERT <---    
    save_fig.savefig(plot_filename, dpi=150, bbox_inches='tight')
    plt.close(save_fig)

    print(f"Saved trajectory plots to {plot_filename}")

    # ---> ADD SEPARATE PLOT FOR SMOOTHED REACH TIME < ---
    if target_reach_steps_data and target_reach_durations_data: # USE ..._data
        reach_fig, ax_reach = plt.subplots(figsize=(12, 6), dpi=100)
        
        ax_reach.set_title("Smoothed Time to Reach Target (Smoothing Reset per Phase)", fontsize=14)
        ax_reach.set_xlabel("Reach Completion Trial Index", fontsize=12) # Ensures correct X-LABEL
        ax_reach.set_ylabel("Time (seconds)", fontsize=12)
        ax_reach.grid(True, alpha=0.5)
        
        window_size = 10 
        durations_series = pd.Series(target_reach_durations_data)
        num_total_reaches = len(target_reach_durations_data)
        # This generates [1, 2, 3, ..., num_reaches] for the x-axis
        overall_reach_indices = np.arange(1, num_total_reaches + 1) 

        transition_trial_count_phase1 = -1 
        if initial_learning_completion_step_data is not None and target_reach_steps_data:
            count = 0
            for step_val in target_reach_steps_data: # target_reach_steps_data contains sim steps of reach completion
                if step_val < initial_learning_completion_step_data:
                    count += 1 # count becomes the number of trials in phase 1
                else:
                    break
            transition_trial_count_phase1 = count
        
        phase1_name = phase_names_val.get(PHASE_INITIAL_LEARNING, "Phase 1")
        phase2_name = phase_names_val.get(PHASE_ADAPT_TO_DISRUPTION, "Phase 2")

        if transition_trial_count_phase1 > 0 and transition_trial_count_phase1 < num_total_reaches:
            # Plot Phase 1
            durations_phase1 = durations_series.iloc[:transition_trial_count_phase1]
            # Uses the first 'transition_trial_count_phase1' trial indices for x-axis
            trials_phase1 = overall_reach_indices[:transition_trial_count_phase1] 
            if len(durations_phase1) >= 1:
                try:
                    smoothed_phase1 = durations_phase1.rolling(window=window_size, min_periods=1).mean()
                    ax_reach.plot(trials_phase1, smoothed_phase1, 'r-', linewidth=2, label=f'{phase1_name} Smoothed (N={window_size})')
                except Exception as e:
                    print(f"Error smoothing/plotting Phase 1: {e}")

            # Plot Phase 2
            durations_phase2 = durations_series.iloc[transition_trial_count_phase1:]
            # Uses trial indices from 'transition_trial_count_phase1' onwards for x-axis
            trials_phase2 = overall_reach_indices[transition_trial_count_phase1:] 
            if len(durations_phase2) >= 1:
                try:
                    smoothed_phase2 = durations_phase2.rolling(window=window_size, min_periods=1).mean()
                    ax_reach.plot(trials_phase2, smoothed_phase2, 'b-', linewidth=2, label=f'{phase2_name} Smoothed (N={window_size})')
                except Exception as e:
                    print(f"Error smoothing/plotting Phase 2: {e}")

            ax_reach.axvline(x=transition_trial_count_phase1 + 0.5, color='black', linestyle=':', alpha=0.7, 
                             label=f'{phase1_name} End / {phase2_name} Start')
            ax_reach.legend()
        else: # Plot as a single phase if no valid transition point
            if num_total_reaches >= 1:
                try:
                    moving_avg = durations_series.rolling(window=window_size, min_periods=1).mean()
                    # Uses all trial indices for x-axis
                    ax_reach.plot(overall_reach_indices, moving_avg, 'r-', linewidth=2, label=f'Smoothed Reach Duration (N={window_size})')
                    ax_reach.legend()
                except Exception as e:
                     print(f"Error smoothing/plotting single phase: {e}")
            else:
                ax_reach.text(0.5, 0.5, 'Need at least 1 reach for smoothing', horizontalalignment='center', verticalalignment='center', transform=ax_reach.transAxes)
        
        reach_plot_filename = os.path.join(current_results_dir, f'reach_time_smoothed_{timestamp}.png')
        try:
            reach_fig.savefig(reach_plot_filename, dpi=150, bbox_inches='tight')
            print(f"Saved smoothed reach time plot to {reach_plot_filename}")
        except Exception as e:
            print(f"Error saving smoothed reach time plot: {e}")
        plt.close(reach_fig)
    else:
        print("No targets reached or durations recorded, skipping smoothed reach time plot.")
    # ---> END ADD < ---

# ---> END save_trajectory_plots DEFINITION < ---

# ---------------------------------- #
# PRE-TRAINING PHASE FUNCTION        #
# ---------------------------------- #
NUM_PRETRAIN_TARGETS = 400  # 5 epochs * 400 reaches per epoch equivalent
MAX_REACH_TIME_SECONDS = 10.0 # 10 seconds
def run_pretraining_phase(model, spike_gen, rl_updater, device, num_reaches, sim_lr_scale_val):
    print(f"\n========== STARTING OFFLINE PRE-TRAINING PHASE ({num_reaches} reaches) ==========")
    pretrain_cursor_pos = CENTER_POS.copy()
    # Initial peripheral target for the very first pre-train reach
    pretrain_target_pos = get_target_for_phase(PHASE_INITIAL_LEARNING, 0, pretrain_cursor_pos) 

    import gc # Import garbage collector

    original_spike_gen_training_state = None
    if hasattr(spike_gen, 'training'):
        original_spike_gen_training_state = spike_gen.training
        spike_gen.training = True # Ensure noise etc. is active for robust pre-training
        print("Spike generator set to training mode for pre-training.")

    if hasattr(model, 'reset_states'):
        print("Resetting SNN model states before pre-training starts.")
        model.reset_states()

    total_pretrain_steps_simulated = 0
    for reach_idx in range(num_reaches):
        current_target_start_time = time.time()
        if (reach_idx % 201 == 0):
            model.reset_states()
        # Loop for steps within this single reach attempt
        step_in_current_reach = 0
        while True:
            total_pretrain_steps_simulated += 1
            step_in_current_reach += 1

            elapsed_reach_time = time.time() - current_target_start_time
            if elapsed_reach_time > MAX_REACH_TIME_SECONDS:
                # print(f"  Pre-train Reach {reach_idx + 1}/{num_reaches} TIMED OUT (> {MAX_REACH_TIME_SECONDS:.1f}s, {step_in_current_reach} steps)")
                break # End this reach attempt

            desired_vec = pretrain_target_pos - pretrain_cursor_pos
            distance = np.linalg.norm(desired_vec)

            if distance < target_radius:
                # print(f"  Pre-train Reach {reach_idx + 1}/{num_reaches} SUCCESSFUL in {elapsed_reach_time:.2f}s ({step_in_current_reach} steps)")
                break # End this reach attempt
            
            if distance > 1e-6:
                desired_vel_np_pt = desired_vec / distance
                desired_vel_np_pt = desired_vel_np_pt * min(1.0, distance / 200.0)
            else:
                desired_vel_np_pt = np.zeros_like(desired_vec)
            
            desired_vel_tensor_pt = torch.tensor(desired_vel_np_pt, dtype=torch.float32, device=device).unsqueeze(0)
            
            spikes_pt = spike_gen.generate_spikes(desired_vel_tensor_pt.squeeze(0), sequence_length=1)
            
            current_spikes_pt = spikes_pt.squeeze(0)
            # Ensure padding matches SNNRegression input_size (e.g., 182 for 180 neurons + 2 padding)
            # Assuming spike_gen.num_neurons is 180 and model.fc1.in_features is 182
            padding_size = model.fc1.in_features - spike_gen.num_neurons
            if padding_size < 0: padding_size = 0 # Should not happen with current setup
            padding_pt = torch.zeros((current_spikes_pt.size(0), padding_size), device=device)
            current_spikes_padded_pt = torch.cat([current_spikes_pt, padding_pt], dim=1)
            input_tensor_pt = current_spikes_padded_pt.unsqueeze(0)

            pred_output_sequence_snn_pt = model(input_tensor_pt)
            pred_velocity_snn_tensor_pt = pred_output_sequence_snn_pt[:, -1, :]
            
            loss_val_pt, _, _ = rl_updater.update(input_tensor_pt, pred_velocity_snn_tensor_pt, desired_vel_tensor_pt)
            
            # if total_pretrain_steps_simulated % 500 == 0: # Log loss less frequently during pre-training
            #     print(f"    Pre-train step {total_pretrain_steps_simulated}, Reach {reach_idx+1}, Loss: {loss_val_pt:.4f}")
                
            pred_velocity_snn_np_pt = pred_velocity_snn_tensor_pt.detach().cpu().numpy().squeeze()
            pretrain_cursor_pos = pretrain_cursor_pos + pred_velocity_snn_np_pt * sim_lr_scale_val
            
            pretrain_cursor_pos[0] = np.clip(pretrain_cursor_pos[0], 0, screen_width)
            pretrain_cursor_pos[1] = np.clip(pretrain_cursor_pos[1], 0, screen_height)
            
            # time.sleep(0.001) # Optional: if pre-training is too fast for time.time() resolution or to yield CPU

        # After reach attempt (success or timeout), reset cursor to center for next pre-train reach
        pretrain_cursor_pos = CENTER_POS.copy()
        pretrain_target_pos = get_target_for_phase(PHASE_INITIAL_LEARNING, reach_idx + 1, pretrain_cursor_pos)

        # Log status per pre-train reach completion (success or timeout)
        if (reach_idx + 1) % 1 == 0: # Log every reach attempt outcome
            status_msg = "SUCCESS" if distance < target_radius else "TIMEOUT"
            print(f"  Pre-train Reach {reach_idx + 1}/{num_reaches} - {status_msg}. Steps in reach: {step_in_current_reach}. Last Loss: {loss_val_pt:.4f}")

        # Periodically try to free up memory and print summary
        if (reach_idx + 1) % 50 == 0: # Summary every 50 reaches
            print(f"    Pre-train: Completed {reach_idx + 1} reaches. Performing GC.")
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    if hasattr(spike_gen, 'training') and original_spike_gen_training_state is not None:
        spike_gen.training = original_spike_gen_training_state # Restore original state
        print(f"Spike generator restored to training={original_spike_gen_training_state}.")

    print(f"========== OFFLINE PRE-TRAINING COMPLETE ({total_pretrain_steps_simulated} total steps simulated) ==========")
    if hasattr(model, 'reset_states'):
        print("Resetting SNN model states after pre-training before main simulation starts.")
        model.reset_states() # Reset for the main simulation sequence

# # ---------------------------------- #
# # MAIN SIMULATION LOGIC              #
# # ---------------------------------- #
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# num_neurons = 96
# print(f"========== SIMULATION SETUP ==========")
# print(f"Device: {device}")
# print(f"Number of neurons: {num_neurons}")
# print(f"Screen dimensions: {screen_width}x{screen_height}")

# spike_gen = Synthetic_Neuron(num_neurons=num_neurons, noise_level=0.02)
# model = SNNRegression(input_size=98, hidden_size=512, output_size=2).to(device)

# # Save original neural properties for resetting between experiments
# original_properties = {
#     'noise_level': spike_gen.noise_level,
#     'preferred_directions': spike_gen.preferred_directions.clone(),
#     'max_firing_rate': spike_gen.max_firing_rate,
#     'min_firing_rate': spike_gen.min_firing_rate # Added min rate
# }


# print("Model architecture:")
# print("  Input size: 182")
# print("  Hidden size: 512")
# print("  Output size: 2")

# # Initialize the UPDATED updater
# rl_updater = TwoScaleMetaRLWeightUpdaterFull(
#     model, 
#     base_fast_lr=1e-4,   # Using values from user change
#     base_slow_lr=1e-3,   # Using values from user change
#     window_size=10,      # Using values from user change
#     meta_lr=0.1,         # Using values from mc_maze online_adaptation
# )
# print("RL Updater configuration (using mc_maze parameters):")
# print(f"  Base Fast learning rate: {rl_updater.base_fast_lr}")
# print(f"  Base Slow learning rate: {rl_updater.base_slow_lr}")
# print(f"  Meta Parameters: {rl_updater.meta_params}")
# print(f"  Window size: {rl_updater.window_size}")

# cursor_pos = np.array([screen_width/2, screen_height/2], dtype=np.float32)
# target_pos = new_target()
# print(f"Initial cursor position: ({cursor_pos[0]:.2f}, {cursor_pos[1]:.2f})")
# print(f"Initial target position: ({target_pos[0]:.2f}, {target_pos[1]:.2f})")
# initial_distance = np.linalg.norm(target_pos - cursor_pos)
# print(f"Initial distance to target: {initial_distance:.2f}")

# # For tracking experiment phases
# current_phase = PHASE_INITIAL_LEARNING
# phase_step = 0 # Uncomment and keep for logging and benign pass to get_target_for_phase
# total_steps = 0 
# targets_in_current_phase = 0 
# initial_learning_completion_step_global = None # To store when the first phase ends

# # --- ADD for center-out logic ---
# evaluation_target_count = 0 # Counter for evaluation phase targets
# # --- END ADD ---

# # Initialize first target properly for center-out
# cursor_pos = CENTER_POS.copy() # Start cursor at center

# if current_phase == PHASE_EVALUATION: # Unlikely to start in EVAL, but handle it
#     target_pos = get_target_for_phase(current_phase, evaluation_target_count, cursor_pos)
# else: # e.g., PHASE_INITIAL_LEARNING
#     target_pos = get_target_for_phase(current_phase, phase_step, cursor_pos) # phase_step is 0 initially

# print(f"Initial cursor position (center-out): ({cursor_pos[0]:.2f}, {cursor_pos[1]:.2f}) (Center)")
# print(f"Initial peripheral target position (center-out): ({target_pos[0]:.2f}, {target_pos[1]:.2f})")
# initial_distance = np.linalg.norm(target_pos - cursor_pos)
# print(f"Initial distance to peripheral target: {initial_distance:.2f}")

# # Ensure results directory exists right at the start
# os.makedirs(RESULTS_DIR, exist_ok=True)
# print(f"Created results directory: {os.path.abspath(RESULTS_DIR)}")

# # Write experiment info file
# try:
#     with open(os.path.join(RESULTS_DIR, "experiment_info.txt"), "w") as f:
#         f.write(f"Experiment started at: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
#         f.write(f"Disruption type: {DISRUPTION_TYPE}\n")
#         f.write(f"Disruption intensity: {DISRUPTION_INTENSITY}\n")
#         f.write(f"Phase targets: INITIAL_LEARNING={INITIAL_LEARNING_TARGETS}, ADAPT_TO_DISRUPTION={ADAPT_TO_DISRUPTION_TARGETS}\n") # Updated to reflect target counts
#         f.write(f"Updater: TwoScaleMetaRLWeightUpdaterFull\n")
#         f.write(f"  Fast LR: {rl_updater.base_fast_lr}, Slow LR: {rl_updater.base_slow_lr}, Window: {rl_updater.window_size}, Meta LR: {rl_updater.meta_lr}\n")
#     print(f"Successfully created experiment info file in {RESULTS_DIR}")
# except Exception as e:
#     print(f"WARNING: Could not write to results directory: {e}")
#     print(f"Graphs may not be saved correctly!")

# # For measuring performance in each phase
# phase_metrics = {
#     PHASE_INITIAL_LEARNING: {'mse': [], 'distance': []},
#     PHASE_ADAPT_TO_DISRUPTION: {'mse': [], 'distance': []}
# }

# # --- Global variables for simulation state & history, defined before pre-training ---
# time_points = []
# cursor_x_history = []
# cursor_y_history = []
# desired_x_history = [] 
# desired_y_history = [] 
# pred_vx_history = []
# pred_vy_history = []
# desired_vx_history = [] 
# desired_vy_history = [] 

# total_targets_reached = 0
# slow_updates_applied = 0 # Tracked by updater
# weight_norms = {'fc1': [], 'fc2': [], 'fc3': []} # Keep for monitoring
# simulation_lr_scale = 25

# # Lists for Time-to-Target Tracking
# target_reach_steps = []
# target_reach_durations = []
# MAX_REACH_TIME_SECONDS = 10.0  # Maximum time allowed per reach attempt

# print(f"Movement scaling factor: {simulation_lr_scale}")

# # ---> CALL TO PRE-TRAINING PHASE HERE < ---
# # run_pretraining_phase(model, spike_gen, rl_updater, device, NUM_PRETRAIN_TARGETS, simulation_lr_scale)
# # --- END CALL ---

# print("\n========== STARTING SIMULATION (with SNN and KF) ==========\n")

# # --- MOVED 6-PANEL PLOT SETUP HERE (AFTER PRE-TRAINING) ---
# # plt.ion() # COMMENTED OUT FOR NO LIVE PLOTTING
# fig = plt.figure(figsize=(18, 12))
# gs = plt.GridSpec(3, 3, figure=fig)
# # 1. TRAJECTORY PLOT (TOP LEFT - LARGER)
# ax_traj = fig.add_subplot(gs[0:2, 0:2])
# ax_traj.set_title("Real-time Cursor Trajectory", fontsize=14)
# ax_traj.set_xlabel("X Position", fontsize=12)
# ax_traj.set_ylabel("Y Position", fontsize=12)
# ax_traj.set_xlim(0, screen_width)
# ax_traj.set_ylim(0, screen_height)
# ax_traj.set_aspect('equal')
# ax_traj.grid(True)
# # Initialize trajectory visualization elements
# # target_pos is initialized before pre-training, so it's available here
# cursor_plot, = ax_traj.plot([], [], 'bo', markersize=10, label="Current Position")
# target_plot = plt.Circle((target_pos[0], target_pos[1]), target_radius, fill=False, color='r', linewidth=2, label="Target")
# ax_traj.add_patch(target_plot)
# cursor_trajectory, = ax_traj.plot([], [], 'b-', alpha=0.6, linewidth=2, label="Cursor Path")
# desired_trajectory, = ax_traj.plot([], [], 'g--', alpha=0.5, linewidth=1.5, label="Desired Path")
# phase_indicator = ax_traj.text(10, 10, "PHASE: INITIAL_LEARNING", fontsize=12, color='blue', # Start with INITIAL_LEARNING
#                              bbox=dict(facecolor='white', alpha=0.7))
# ax_traj.legend(loc='upper right')
# # 2. X POSITION OVER TIME (TOP RIGHT)
# ax_posx = fig.add_subplot(gs[0, 2])
# ax_posx.set_title("X Position vs Time", fontsize=14)
# ax_posx.set_xlabel("Time Step", fontsize=12)
# ax_posx.set_ylabel("X Position", fontsize=12)
# ax_posx.grid(True)
# pos_x_actual, = ax_posx.plot([], [], 'b-', linewidth=2, label="Actual")
# pos_x_desired, = ax_posx.plot([], [], 'g--', linewidth=1.5, label="Desired")
# ax_posx.legend()
# # 3. Y POSITION OVER TIME (MIDDLE RIGHT)
# ax_posy = fig.add_subplot(gs[1, 2])
# ax_posy.set_title("Y Position vs Time", fontsize=14)
# ax_posy.set_xlabel("Time Step", fontsize=12)
# ax_posy.set_ylabel("Y Position", fontsize=12)
# ax_posy.grid(True)
# pos_y_actual, = ax_posy.plot([], [], 'b-', linewidth=2, label="Actual")
# pos_y_desired, = ax_posy.plot([], [], 'g--', linewidth=1.5, label="Desired")
# ax_posy.legend()
# # 4. X VELOCITY OVER TIME (BOTTOM LEFT)
# ax_velx = fig.add_subplot(gs[2, 0])
# ax_velx.set_title("X Velocity vs Time", fontsize=14)
# ax_velx.set_xlabel("Time Step", fontsize=12)
# ax_velx.set_ylabel("X Velocity", fontsize=12)
# ax_velx.grid(True)
# vel_x_pred, = ax_velx.plot([], [], 'm-', linewidth=2, label="Predicted")
# vel_x_desired, = ax_velx.plot([], [], 'g--', linewidth=1.5, label="Desired")
# ax_velx.legend()
# # 5. Y VELOCITY OVER TIME (BOTTOM MIDDLE)
# ax_vely = fig.add_subplot(gs[2, 1])
# ax_vely.set_title("Y Velocity vs Time", fontsize=14)
# ax_vely.set_xlabel("Time Step", fontsize=12)
# ax_vely.set_ylabel("Y Velocity", fontsize=12)
# ax_vely.grid(True)
# vel_y_pred, = ax_vely.plot([], [], 'm-', linewidth=2, label="Predicted")
# vel_y_desired, = ax_vely.plot([], [], 'g--', linewidth=1.5, label="Desired")
# ax_vely.legend()
# # 6. SLIDING WINDOW CORRELATION METRICS (BOTTOM RIGHT) - MODIFIED
# ax_corr = fig.add_subplot(gs[2, 2])
# ax_corr.set_title("Sliding Window Velocity Correlation (N=50)", fontsize=14) # MODIFIED TITLE
# ax_corr.set_xlabel("Time Step", fontsize=12)
# ax_corr.set_ylabel("Correlation", fontsize=12)
# ax_corr.set_ylim(-1.1, 1.1) # MODIFIED Y-LIMITS for correlation range
# ax_corr.grid(True)
# # Use different variable names for sliding window plots
# sliding_corr_x_plot, = ax_corr.plot([], [], 'r-', linewidth=2, label="X Corr (Sliding)") # MODIFIED LABEL
# sliding_corr_y_plot, = ax_corr.plot([], [], 'b-', linewidth=2, label="Y Corr (Sliding)") # MODIFIED LABEL
# ax_corr.legend()

# plt.tight_layout()
# # plt.show() # COMMENTED OUT FOR NO LIVE PLOTTING (figures saved at end)
# print("Visualization setup (deferred to end of simulation)")
# # --- END MOVED 6-PANEL PLOT SETUP ---

# start_time = time.time()
# target_start_time = time.time()

# # Map phase IDs to names for printing/plotting
# phase_names = {
#     PHASE_INITIAL_LEARNING: "INITIAL_LEARNING", 
#     PHASE_ADAPT_TO_DISRUPTION: "ADAPT_TO_DISRUPTION"
# }

# print("\n========== PRE-TRAINING KALMAN FILTER ==========")
# # Generate a batch of training data for the Kalman Filter
# kf_train_steps = 500  # Number of timesteps for KF training
# kf_train_spikes_list = []
# kf_train_vel_list = []
# temp_cursor_pos = np.array([screen_width/2, screen_height/2], dtype=np.float32)
# temp_target_pos = new_target() # Get an initial random target

# for _ in range(kf_train_steps):
#     temp_desired_vec = temp_target_pos - temp_cursor_pos
#     temp_distance = np.linalg.norm(temp_desired_vec)
#     if temp_distance > 1e-6:
#         temp_desired_vel_np = temp_desired_vec / temp_distance
#         temp_desired_vel_np = temp_desired_vel_np * min(1.0, temp_distance / 200.0)
#     else:
#         temp_desired_vel_np = np.zeros_like(temp_desired_vec)
    
#     temp_desired_vel_tensor = torch.tensor(temp_desired_vel_np, dtype=torch.float32, device=device).unsqueeze(0)
    
#     # Generate spikes (sequence_length=1 for single step)
#     # The synthetic neuron expects a 1D velocity tensor if sequence_length=1
#     temp_spikes_seq = spike_gen.generate_spikes(temp_desired_vel_tensor.squeeze(0), sequence_length=1) 
    
#     # Ensure temp_spikes is [num_neurons] for KF training data
#     # generate_spikes returns [1, sequence_length, num_neurons]
#     # For KF, train_rates expects [timesteps, num_neurons]
#     current_spikes_for_kf = temp_spikes_seq.squeeze(0).squeeze(0).cpu().numpy() # Shape [num_neurons]
#                                                                                 # Squeeze batch and time
    
#     kf_train_spikes_list.append(current_spikes_for_kf)
#     kf_train_vel_list.append(temp_desired_vel_np) # Store the 2D numpy velocity

#     # Simple cursor and target update for generating varied data
#     temp_cursor_pos = temp_cursor_pos + temp_desired_vel_np * simulation_lr_scale 
#     temp_cursor_pos[0] = np.clip(temp_cursor_pos[0], 0, screen_width)
#     temp_cursor_pos[1] = np.clip(temp_cursor_pos[1], 0, screen_height)
#     if _ % 50 == 0: # Get new target periodically
#         temp_target_pos = new_target()

# kf_train_rates = np.array(kf_train_spikes_list)
# kf_train_velocities = np.array(kf_train_vel_list)

# if kf_train_rates.size == 0 or kf_train_velocities.size == 0:
#     print("ERROR: No data generated for KF training. Skipping KF setup.")
#     kf_decoder = None
# else:
#     try:
#         # --- MODIFIED CALL: We only need the kf_decoder object ---
#         # Pass None for test_rates and test_vel as we only need the trained KF
#         # Ensure your train_test_kalman_filter can handle None for test_rates/vel
#         # or provide minimal dummy data if it cannot.
        
#         # Let's ensure dummy_test_rates and dummy_test_vel are defined
#         # even if kf_train_rates is small.
#         dummy_n_test_samples = min(10, kf_train_rates.shape[0])
#         if dummy_n_test_samples > 0:
#              dummy_test_rates_kf = kf_train_rates[:dummy_n_test_samples] 
#              dummy_test_vel_kf = kf_train_velocities[:dummy_n_test_samples]
#         else: # Handle edge case if kf_train_rates itself is too small
#              dummy_test_rates_kf = np.array([]) 
#              dummy_test_vel_kf = np.array([])

#         # --- FIX: Unpack the tuple return value ---
#         _, kf_decoder = train_test_kalman_filter( # Use _ to discard the first element (offline preds)
#             kf_train_rates, kf_train_velocities, 
#             dummy_test_rates_kf, dummy_test_vel_kf 
#         )
#         # --- END FIX ---
        
#         if kf_decoder is not None:
#             print("Kalman Filter trained successfully.")
#         else:
#             print("ERROR: Kalman Filter training failed, kf_decoder is None (returned from function).")
#     except Exception as e:
#         print(f"ERROR training Kalman Filter: {e}")
#         import traceback
#         traceback.print_exc() # More detailed error
#         kf_decoder = None 

# # Store KF predictions history
# kf_pred_vx_history = []
# kf_pred_vy_history = []

# # --- ADD KF Performance Metric History ---
# kf_sliding_x_corr_history = []
# kf_sliding_y_corr_history = []
# kf_pred_vx_buffer = deque(maxlen=pred_buffer_size) # Reuse existing buffer size
# kf_pred_vy_buffer = deque(maxlen=pred_buffer_size)
# kf_last_calc_x_corr = 0.0 # Initialize
# kf_last_calc_y_corr = 0.0 # Initialize

# # ---> DEFINE max_history HERE <---
# max_history = 1000 # Or whatever value you prefer for plotting history length
# # ---> END DEFINE max_history <---

# try:
#     # --- MODIFIED MAIN LOOP WITH PHASES (NOW TARGET-BASED) ---
#     while True:
#         total_steps += 1
#         phase_step += 1 # Increment phase_step
#         step_start_time = time.time()
        
#         current_phase_name = phase_names.get(current_phase, "UNKNOWN")

#         # Check for phase transitions (NOW BASED ON total_targets_reached for initial, and targets_in_current_phase for adaptive)
#         if current_phase == PHASE_INITIAL_LEARNING and total_targets_reached >= INITIAL_LEARNING_TARGETS:
#             current_phase = PHASE_ADAPT_TO_DISRUPTION
#             targets_in_current_phase = 0 
#             current_phase_name = phase_names.get(current_phase, "UNKNOWN") 
#             initial_learning_completion_step_global = total_steps 
#             phase_step = 1 # Reset phase_step for the new phase
#             print(f"\n===== {total_targets_reached} TARGETS REACHED. TRANSITION TO {current_phase_name} (Step {total_steps}) =====")
#             apply_disruption(spike_gen, DISRUPTION_TYPE, DISRUPTION_INTENSITY, original_properties, device, num_neurons)
            
#             # Save graphs after initial learning phase
#             print(f"Saving interim graphs after {phase_names[PHASE_INITIAL_LEARNING]} phase...")
#             plot_experiment_results(phase_metrics, disruption_type=DISRUPTION_TYPE, current_results_dir=RESULTS_DIR)
            
#         elif current_phase == PHASE_ADAPT_TO_DISRUPTION and targets_in_current_phase >= ADAPT_TO_DISRUPTION_TARGETS:
#              print(f"\n========== {targets_in_current_phase} TARGETS REACHED IN {current_phase_name} PHASE. ==========")
#              print("\n========== SIMULATION COMPLETE (Trial-Based Procedure) ==========")
#              print(f"Total simulation time: {(time.time() - start_time):.2f} seconds")
#              print(f"Total targets reached overall: {total_targets_reached}")

#              save_trajectory_plots(current_results_dir=RESULTS_DIR,
#                                   time_points_data=time_points,
#                                   cursor_x_history_data=cursor_x_history,
#                                   cursor_y_history_data=cursor_y_history,
#                                   desired_x_history_data=desired_x_history,
#                                   desired_y_history_data=desired_y_history,
#                                   pred_vx_history_data=pred_vx_history,
#                                   pred_vy_history_data=pred_vy_history,
#                                   desired_vx_history_data=desired_vx_history,
#                                   desired_vy_history_data=desired_vy_history,
#                                   sliding_x_corr_history_data=sliding_x_corr_history,
#                                   sliding_y_corr_history_data=sliding_y_corr_history,
#                                   target_reach_steps_data=target_reach_steps,
#                                   target_reach_durations_data=target_reach_durations,
#                                   initial_learning_completion_step_data=initial_learning_completion_step_global,
#                                   phase_metrics_data=phase_metrics,
#                                   corr_calc_window_val=corr_calc_window,
#                                   lag_val=lag,
#                                   screen_width_val=screen_width,
#                                   screen_height_val=screen_height,
#                                   phase_names_val=phase_names)

#              print("\nExperiment completed successfully!")
#              break 
        
#         # Update phase indicator text on plot
#         if current_phase == PHASE_INITIAL_LEARNING:
#             phase_progress = f"(Target {total_targets_reached}/{INITIAL_LEARNING_TARGETS})"
#         elif current_phase == PHASE_ADAPT_TO_DISRUPTION:
#             phase_progress = f"(Target {targets_in_current_phase}/{ADAPT_TO_DISRUPTION_TARGETS})"
#         else:
#             phase_progress = ""
        
#         phase_indicator.set_text(f"PHASE: {current_phase_name} {phase_progress} Steps: {total_steps}")
        
#         # --- Step Header (less frequent or conditional) ---
#         if total_steps % 50 == 0 : # Log header less frequently
#             print(f"\n----- Current Step: {total_steps} (Phase: {current_phase_name}) -----")
        
#         desired_vec = target_pos - cursor_pos
#         distance = np.linalg.norm(desired_vec)
#         # print(f"Cursor: ({cursor_pos[0]:.2f}, {cursor_pos[1]:.2f}) | Target: ({target_pos[0]:.2f}, {target_pos[1]:.2f})") # Less frequent
#         # print(f"Distance to target: {distance:.2f} pixels") # Less frequent
        
#         if distance > 1e-6: # Use epsilon
#             desired_vel = desired_vec / distance
#             desired_vel = desired_vel * min(1.0, distance / 200.0)
#         else:
#             desired_vel = np.zeros_like(desired_vec)
#         # print(f"Desired velocity vector: ({desired_vel[0]:.4f}, {desired_vel[1]:.4f}) | Magnitude: {np.linalg.norm(desired_vel):.4f}") # Commented out
#         desired_vel_tensor = torch.tensor(desired_vel, dtype=torch.float32, device=device).unsqueeze(0)
#         desired_vel_np = desired_vel
        
#         # print("Generating spike train (seq_len=1)...") # Commented out
#         spikes = spike_gen.generate_spikes(desired_vel_tensor.squeeze(0), sequence_length=1)
        
#         # Apply controlled dropout for ADAPT_TO_DISRUPTION phase if needed
#         # NOTE: This assumes DISRUPTION_TYPE is 'dropout' and intensity is set
#         if current_phase == PHASE_ADAPT_TO_DISRUPTION and DISRUPTION_TYPE == "dropout":
#             dropout_mask = torch.bernoulli(torch.ones_like(spikes) * (1.0 - DISRUPTION_INTENSITY)).to(device)
#             spikes = spikes * dropout_mask
#             print(f"Applied neural dropout with rate {DISRUPTION_INTENSITY:.2f}")
        
#         # Prepare input tensor (T=1)
#         current_spikes = spikes.squeeze(0)
#         padding = torch.zeros((current_spikes.size(0), 2), device=device)
#         current_spikes_padded = torch.cat([current_spikes, padding], dim=1)
#         input_tensor = current_spikes_padded.unsqueeze(0) # [1, 1, 182]

#         # --- SNN Prediction ---
#         # print("Forward pass through SNN...") # Commented out
#         pred_output_sequence_snn = model(input_tensor) 
#         pred_velocity_snn_tensor = pred_output_sequence_snn[:, -1, :] 
#         pred_velocity_snn_np = pred_velocity_snn_tensor.detach().cpu().numpy().squeeze()
#         # print(f"SNN Predicted velocity: ({pred_velocity_snn_np[0]:.4f}, {pred_velocity_snn_np[1]:.4f})") # Commented out

#         # --- Kalman Filter Prediction ---
#         kf_pred_velocity_np = np.zeros(2) 
#         if kf_decoder is not None:
#             # KF expects spike_rates as [num_neurons] or [num_neurons, 1]
#             # current_spikes is [seq_len=1, num_neurons] from spike_gen
#             # For KF, we need to ensure it's the raw spike vector for the current step
#             # The kf_train_rates was [timesteps, num_neurons]
#             # spikes from spike_gen is [1, 1, num_neurons]
#             # current_spikes_for_kf should be [num_neurons]
            
#             # Use the same spikes that were prepared for KF training (before padding for SNN)
#             # Assuming 'spikes' from spike_gen.generate_spikes is [1,1,num_neurons]
#             kf_input_spikes_current_step = spikes.squeeze(0).squeeze(0).cpu().numpy() # Shape [num_neurons]
            
#             try:
#                 kf_decoder.predict()
#                 kf_decoder.update(kf_input_spikes_current_step.reshape(-1, 1)) 
#                 kf_pred_velocity_np = kf_decoder.x.copy()
#                 # print(f"KF Predicted velocity:  ({kf_pred_velocity_np[0]:.4f}, {kf_pred_velocity_np[1]:.4f})") # Commented out
#             except Exception as e:
#                 print(f"ERROR during KF predict/update: {e}")
#                 # kf_pred_velocity_np remains zeros
                
#         # Now you have:
#         # pred_velocity_snn_np (from your SNN)
#         # kf_pred_velocity_np (from Kalman Filter)

#         # For cursor movement, you need to decide which prediction to use,
#         # or if you want to run them in parallel for comparison without affecting the main cursor.
#         # For now, let's assume the SNN still drives the cursor.
#         # The SNN learning part uses 'pred_velocity_snn_tensor' and 'desired_vel_tensor'
        
#         # ... (SNN learning update using rl_updater.update(input_tensor, pred_velocity_snn_tensor, desired_vel_tensor)) ...
#         # step_loss, x_corr_snn, y_corr_snn = rl_updater.update(...)

#         # ---> Store SNN predictions (already done as pred_velocity_np) <---
#         # pred_velocity_np is the SNN prediction used for cursor update and SNN metrics

#         # ---> Store KF predictions (for plotting and KF-specific metrics) <---
#         kf_pred_vx_history.append(kf_pred_velocity_np[0])
#         kf_pred_vy_history.append(kf_pred_velocity_np[1])

#         # ---> UPDATE KF SLIDING WINDOW BUFFERS & CORRELATION <---
#         kf_pred_vx_buffer.append(kf_pred_velocity_np[0])
#         kf_pred_vy_buffer.append(kf_pred_velocity_np[1])
#         # desired_vel_np is already available from earlier in the loop and added to des_vx_buffer

#         kf_current_lagged_x_corr = kf_last_calc_x_corr # Default
#         kf_current_lagged_y_corr = kf_last_calc_y_corr # Default

#         if total_steps % 10 == 0 and len(des_vx_buffer) == des_buffer_size: # Assuming des_vx_buffer is for SNN's target
#             kf_pred_window_x = list(kf_pred_vx_buffer)
#             kf_pred_window_y = list(kf_pred_vy_buffer)
            
#             des_lagged_window_x_for_kf = list(des_vx_buffer)[0:corr_calc_window] # Use the same desired for comparison
#             des_lagged_window_y_for_kf = list(des_vy_buffer)[0:corr_calc_window]
                
#             if len(kf_pred_window_x) == corr_calc_window and len(des_lagged_window_x_for_kf) == corr_calc_window:
#                 kf_current_lagged_x_corr = compute_correlation(np.array(kf_pred_window_x), np.array(des_lagged_window_x_for_kf))
#                 kf_current_lagged_y_corr = compute_correlation(np.array(kf_pred_window_y), np.array(des_lagged_window_y_for_kf))
            
#             kf_last_calc_x_corr = kf_current_lagged_x_corr
#             kf_last_calc_y_corr = kf_current_lagged_y_corr
        
#         kf_sliding_x_corr_history.append(kf_last_calc_x_corr)
#         kf_sliding_y_corr_history.append(kf_last_calc_y_corr)

#         if total_steps % 10 == 0 and len(des_vx_buffer) == des_buffer_size:
#             print(f"KF Lagged Sliding Window Corr (N={corr_calc_window}, Lag={lag}, Step {total_steps}): X={kf_last_calc_x_corr:.4f}, Y={kf_last_calc_y_corr:.4f}")

#         # ... (rest of your loop: cursor update, target check, SNN plot updates) ...

#         # Limit KF history length (inside the history limiting block)
#         if len(time_points) > max_history: # Use the defined max_history
#             # ... (pop from history lists) ...
#             kf_sliding_x_corr_history.pop(0) # Add these
#             kf_sliding_y_corr_history.pop(0) # Add these
#             if kf_pred_vx_history: kf_pred_vx_history.pop(0) # Check if not empty
#             if kf_pred_vy_history: kf_pred_vy_history.pop(0)

#         # Initialize step_loss before the if/else block as a safeguard
#         step_loss = 0.0 

#         # Gate learning based on phase (Simplified: learning is always ON for these two phases)
#         if current_phase in [PHASE_INITIAL_LEARNING, PHASE_ADAPT_TO_DISRUPTION]:
#             # This print statement can be removed if too verbose, or kept for clarity
#             # print(f"Applying Advanced Hebbian + Meta weight updates (learning ON in {current_phase_name})...")
#             loss_value, x_corr, y_corr = rl_updater.update(input_tensor, pred_velocity_snn_tensor, desired_vel_tensor) 
#             step_loss = loss_value
#             if total_steps % 50 == 0: # Log loss less frequently
#                  print(f"  Step {total_steps} Learning ON - Loss: {step_loss:.4f}, SNN Corr X/Y: {x_corr:.3f}/{y_corr:.3f}")
#         else:
#             # This else block should ideally not be reached if logic is correct
#             print(f"WARNING: Unexpected phase {current_phase_name} encountered for learning gate. Defaulting to no learning.")
#             with torch.no_grad():
#                 step_loss = F.mse_loss(pred_velocity_snn_tensor, desired_vel_tensor).item()
        
#         pred_velocity_np = pred_velocity_snn_np

#         # Store metrics for the current phase
#         phase_metrics[current_phase]['mse'].append(step_loss) # Use step_loss
#         phase_metrics[current_phase]['distance'].append(distance)

#         # ---> UPDATE SLIDING WINDOW BUFFERS <---
#         pred_vx_buffer.append(pred_velocity_snn_np[0]) # Use SNN prediction
#         pred_vy_buffer.append(pred_velocity_snn_np[1]) # Use SNN prediction
#         des_vx_buffer.append(desired_vel_np[0])
#         des_vy_buffer.append(desired_vel_np[1])

#         # ---> CALCULATE SLIDING WINDOW CORRELATION (Lagged and Less Frequent) <---
#         current_lagged_x_corr = 0.0 # Default value
#         current_lagged_y_corr = 0.0 # Default value

#         # Only calculate every 10 steps and if buffers are full enough
#         if total_steps % 10 == 0 and len(des_vx_buffer) == des_buffer_size:
#             # Extract the latest prediction window (last W elements)
#             pred_window_x = list(pred_vx_buffer)
#             pred_window_y = list(pred_vy_buffer)
            
#             # Extract the desired window from L steps ago (elements 0 to W-1)
#             des_lagged_window_x = list(des_vx_buffer)[0:corr_calc_window]
#             des_lagged_window_y = list(des_vy_buffer)[0:corr_calc_window]
            
#             # Compute correlation if windows are valid
#             if len(pred_window_x) == corr_calc_window and len(des_lagged_window_x) == corr_calc_window:
#                 current_lagged_x_corr = compute_correlation(np.array(pred_window_x), np.array(des_lagged_window_x))
#                 current_lagged_y_corr = compute_correlation(np.array(pred_window_y), np.array(des_lagged_window_y))
#             else:
#                  # Should not happen if buffer logic is correct, but safety check
#                  current_lagged_x_corr = last_calc_x_corr 
#                  current_lagged_y_corr = last_calc_y_corr

#             # Update the last calculated values
#             last_calc_x_corr = current_lagged_x_corr
#             last_calc_y_corr = current_lagged_y_corr
#         else:
#              # On steps where we don't calculate, use the previous value
#              current_lagged_x_corr = last_calc_x_corr
#              current_lagged_y_corr = last_calc_y_corr
        
#         # Append the current value (either newly calculated or the last held value) to history
#         sliding_x_corr_history.append(last_calc_x_corr)
#         sliding_y_corr_history.append(last_calc_y_corr)
        
#         # Print the *calculated* value when it happens
#         if total_steps % 10 == 0 and len(des_vx_buffer) == des_buffer_size:
#              print(f"Lagged Sliding Window Corr (N={corr_calc_window}, Lag={lag}, Step {total_steps}): X={last_calc_x_corr:.4f}, Y={last_calc_y_corr:.4f}")
#         # ---> END LAGGED/INFREQUENT CORRELATION CALCULATION <---

#         # Updater status logging
#         if rl_updater.step_counter == 0 and total_steps > rl_updater.window_size: # Check if slow update was just applied
#             # print(f"📢 SLOW UPDATE APPLIED! Meta params: Plasticity={rl_updater.meta_params['plasticity']:.3f}, Sensitivity={rl_updater.meta_params['sensitivity']:.3f}") # Moved to periodic log
#             pass # Actual print moved

#         # print(f"Current update window: {rl_updater.step_counter}/{rl_updater.window_size} ({rl_updater.step_counter/rl_updater.window_size*100:.1f}%)") # Commented out

#         # Monitor weight norms (less frequent)
#         # if total_steps % 50 == 0: # Log less frequently
#         #    with torch.no_grad():
#         #        fc1_norm = torch.norm(model.fc1.weight.data).item()
#         #        fc2_norm = torch.norm(model.fc2.weight.data).item()
#         #        fc3_norm = torch.norm(model.fc3.weight.data).item()
#         #        weight_norms['fc1'].append(fc1_norm)
#         #        weight_norms['fc2'].append(fc2_norm)
#         #        weight_norms['fc3'].append(fc3_norm)
#         #        print(f"  Weight norms - FC1: {fc1_norm:.4f} | FC2: {fc2_norm:.4f} | FC3: {fc3_norm:.4f}")


#         prev_pos = cursor_pos.copy()
#         cursor_pos = cursor_pos + pred_velocity_snn_np * simulation_lr_scale
#         # movement = np.linalg.norm(cursor_pos - prev_pos) # Commented out
#         # print(f"Cursor moved: {movement:.2f} pixels") # Commented out
        
#         # Boundary check
#         if cursor_pos[0] < 0 or cursor_pos[0] > screen_width or cursor_pos[1] < 0 or cursor_pos[1] > screen_height:
#             # print("WARNING: Cursor went out of bounds! Correcting position.")
#             cursor_pos[0] = np.clip(cursor_pos[0], 0, screen_width)
#             cursor_pos[1] = np.clip(cursor_pos[1], 0, screen_height)
        
#         # --- Target Check (MODIFIED FOR INSTANTANEOUS CENTER RESET + TIME LIMIT) ---
#         current_reach_time = time.time() - target_start_time
#         timed_out = current_reach_time > MAX_REACH_TIME_SECONDS

#         if distance < target_radius or timed_out: 
#             if not timed_out: 
#                 total_targets_reached += 1 
#                 targets_reached_in_phase_current_run = 0
#                 if current_phase == PHASE_INITIAL_LEARNING:
#                     targets_reached_in_phase_current_run = total_targets_reached
#                 elif current_phase == PHASE_ADAPT_TO_DISRUPTION:
#                     targets_in_current_phase += 1 
#                     targets_reached_in_phase_current_run = targets_in_current_phase
                
#                 target_reach_durations.append(current_reach_time)
#                 print(f"\n✅ TARGET REACHED! (Overall: {total_targets_reached}, In Phase '{current_phase_name}': {targets_reached_in_phase_current_run}) Time: {current_reach_time:.2f}s. Step: {total_steps}")
#             else: 
#                 print(f"\n❌ TARGET TIMED OUT! (> {MAX_REACH_TIME_SECONDS:.1f}s). Step: {total_steps}. (Overall Reaches: {total_targets_reached})")
#                 target_reach_durations.append(MAX_REACH_TIME_SECONDS)

#             target_reach_steps.append(total_steps)
            
#             # INSTANTANEOUSLY RESET CURSOR TO CENTER
#             cursor_pos = CENTER_POS.copy()
#             print(f"Cursor RESET to center: ({cursor_pos[0]:.2f}, {cursor_pos[1]:.2f})")

#             # Set a new peripheral target starting from the center
#             if current_phase == PHASE_EVALUATION:
#                 print(f"Setting new EVALUATION target from center, index: {evaluation_target_count}")
#                 target_pos = get_target_for_phase(current_phase, evaluation_target_count, cursor_pos) # cursor_pos is CENTER_POS
#                 evaluation_target_count += 1
#             else: # Baseline, Disruption, Recovery
#                 target_pos = get_target_for_phase(current_phase, phase_step, cursor_pos) # cursor_pos is CENTER_POS
            
#             target_start_time = time.time() # Reset timer for the new reach from center
#             new_peripheral_dist = np.linalg.norm(cursor_pos - target_pos) # Dist from center to new target
#             print(f"New peripheral target from center: ({target_pos[0]:.2f}, {target_pos[1]:.2f}). Dist: {new_peripheral_dist:.2f}")

#             # Update target visualization for the new peripheral target
#             target_plot.center = (target_pos[0], target_pos[1])
#             target_plot.set_radius(target_radius) # Ensure peripheral target radius is used
        
#         # --- ADD 6-PANEL PLOT UPDATE LOGIC ---
#         # Update History
#         time_points.append(total_steps)
#         cursor_x_history.append(cursor_pos[0])
#         cursor_y_history.append(cursor_pos[1])
#         desired_x_history.append(target_pos[0])
#         desired_y_history.append(target_pos[1])
#         pred_vx_history.append(pred_velocity_np[0])
#         pred_vy_history.append(pred_velocity_np[1])
#         desired_vx_history.append(desired_vel_np[0])
#         desired_vy_history.append(desired_vel_np[1])

#         # Limit history length
#         max_history = 1000
#         if len(time_points) > max_history:
#             time_points.pop(0)
#             cursor_x_history.pop(0)
#             cursor_y_history.pop(0)
#             desired_x_history.pop(0)
#             desired_y_history.pop(0)
#             pred_vx_history.pop(0)
#             pred_vy_history.pop(0)
#             desired_vx_history.pop(0)
#             desired_vy_history.pop(0)
#             sliding_x_corr_history.pop(0)
#             sliding_y_corr_history.pop(0)

#         # Update Visualization
#         cursor_plot.set_data([cursor_pos[0]], [cursor_pos[1]])
#         target_plot.center = (target_pos[0], target_pos[1])
#         cursor_trajectory.set_data(cursor_x_history, cursor_y_history)
        
#         # Update desired path line
#         desired_path_x = np.linspace(cursor_pos[0], target_pos[0], 20)
#         desired_path_y = np.linspace(cursor_pos[1], target_pos[1], 20)
#         desired_trajectory.set_data(desired_path_x, desired_path_y)
        
#         # Update position plots
#         pos_x_actual.set_data(time_points, cursor_x_history)
#         pos_x_desired.set_data(time_points, desired_x_history)
#         pos_y_actual.set_data(time_points, cursor_y_history)
#         pos_y_desired.set_data(time_points, desired_y_history)
        
#         # Update velocity plots
#         vel_x_pred.set_data(time_points, pred_vx_history)
#         vel_x_desired.set_data(time_points, desired_vx_history)
#         vel_y_pred.set_data(time_points, pred_vy_history)
#         vel_y_desired.set_data(time_points, desired_vy_history)
        
#         # Update correlation plots (NOW SLIDING WINDOW)
#         sliding_corr_x_plot.set_data(time_points, sliding_x_corr_history)
#         sliding_corr_y_plot.set_data(time_points, sliding_y_corr_history)
        
#         # Clear and redraw velocity arrows on trajectory plot
#         for collection in ax_traj.collections[:]:
#             if isinstance(collection, matplotlib.quiver.Quiver): # Check type correctly
#                 collection.remove()
#         ax_traj.quiver([cursor_pos[0]], [cursor_pos[1]], [pred_velocity_np[0]], [pred_velocity_np[1]],
#                       color='m', scale=20, width=0.008, headwidth=5, headlength=7, label='Predicted Vel' if total_steps==1 else "")
#         if total_steps == 1: ax_traj.legend(loc='upper right') # Add legend only once
                
#         # Auto-adjust axes limits for time series plots
#         for ax in [ax_posx, ax_posy, ax_velx, ax_vely, ax_corr]:
#             ax.relim()
#             ax.autoscale_view()
#         # ---> Ensure ax_corr maintains [-1.1, 1.1] y-limits <---
#         ax_corr.set_ylim(-1.1, 1.1) 
        
#         # Update plot title
#         update_status = "SLOW" if rl_updater.step_counter == 0 else f"FAST ({rl_updater.step_counter}/{rl_updater.window_size})"
#         ax_traj.set_title(f"Targets: {total_targets_reached} | Steps: {total_steps} | Loss: {step_loss:.4f} | {update_status}")
        
#         # Redraw the figure
#         # fig.canvas.draw() # COMMENTED OUT FOR NO LIVE PLOTTING
#         # fig.canvas.flush_events() # COMMENTED OUT FOR NO LIVE PLOTTING
#         # --- END 6-PANEL PLOT UPDATE LOGIC ---
        
#         step_time = time.time() - step_start_time
#         elapsed_time = time.time() - start_time
#         # print(f"Step processing time: {step_time*1000:.1f}ms | Total elapsed time: {elapsed_time:.1f}s")
#         # print(f"Simulation speed: {1.0/(step_time+1e-9):.1f} steps/second") # Add epsilon
#         # print("---------------------------------")
        
#         time.sleep(0.01) # Reduced sleep

#         # --- PERIODIC SUMMARY LOG ---
#         # Log every, say, 20 successful reaches OR every 500 steps, whichever comes first and is not too frequent
#         if len(target_reach_steps) > 0 and (target_reach_steps[-1] == total_steps and total_targets_reached % 20 == 0) or (total_steps % 500 == 0) :
#             print(f"\n--- Summary @ Step {total_steps} (Phase: {current_phase_name}) ---")
#             print(f"  Total Targets Reached (Overall): {total_targets_reached}")
#             if current_phase == PHASE_INITIAL_LEARNING:
#                  print(f"  Targets for current phase ({current_phase_name}): {total_targets_reached} / {INITIAL_LEARNING_TARGETS}")
#             elif current_phase == PHASE_ADAPT_TO_DISRUPTION:
#                  print(f"  Targets for current phase ({current_phase_name}): {targets_in_current_phase} / {ADAPT_TO_DISRUPTION_TARGETS}")

#             if len(target_reach_durations) > 5: # Need a few reaches for a meaningful average
#                 avg_reach_time_recent = np.mean(target_reach_durations[-10:]) # Avg of last 10
#                 print(f"  Avg Reach Time (last 10 actual): {avg_reach_time_recent:.2f}s")
            
#             print(f"  SNN Sliding Corr (X/Y): {last_calc_x_corr:.3f} / {last_calc_y_corr:.3f}")
#             if kf_decoder is not None:
#                 print(f"  KF Sliding Corr  (X/Y): {kf_last_calc_x_corr:.3f} / {kf_last_calc_y_corr:.3f}")

#             print(f"  RL Updater Meta: Plasticity={rl_updater.meta_params['plasticity']:.3f}, Sensitivity={rl_updater.meta_params['sensitivity']:.3f}")
#             if rl_updater.step_counter == 0 and total_steps > rl_updater.window_size : print("    (Slow update was just applied)")
#             print(f"------------------------------------")


#         step_time = time.time() - step_start_time
#         # elapsed_time = time.time() - start_time # Can be part of periodic log
#         # print(f"Step processing time: {step_time*1000:.1f}ms | Total elapsed time: {elapsed_time:.1f}s") # Commented
#         # print(f"Simulation speed: {1.0/(step_time+1e-9):.1f} steps/second") # Commented
#         # print("---------------------------------") # Commented out, part of new summary
        
#         # Brief sleep if not plotting, to yield CPU slightly and allow interruption
#         if total_steps % 100 == 0: # Sleep very infrequently
#              time.sleep(0.001) 

# except (KeyboardInterrupt, Exception) as e:
#     print("\n\n========== SIMULATION ENDED ==========")
#     if isinstance(e, KeyboardInterrupt):
#         print("Interrupted by user.")
#     else:
#         print(f"Exception: {type(e).__name__}: {e}")
#         import traceback
#         traceback.print_exc()

#     print("\nFINAL STATISTICS:")
#     print(f"Total steps: {total_steps}")
#     print(f"Total targets reached: {total_targets_reached}")
#     # Avg error calculation might need adjustment depending on how loss is tracked/defined now
#     # avg_error = rl_updater.cumulative_error / max(1, rl_updater.step_counter) # Example using updater state
#     # print(f"Average error in last window: {avg_error:.6f}") # Be specific about what error means
#     print(f"Final Meta parameters: Plasticity={rl_updater.meta_params['plasticity']:.3f}, Sensitivity={rl_updater.meta_params['sensitivity']:.3f}")
#     # ---> Add final sliding correlation print <---
#     if sliding_x_corr_history:
#         print(f"Final Sliding Window Correlation - X: {sliding_x_corr_history[-1]:.4f} | Y: {sliding_y_corr_history[-1]:.4f}")
#     # ---> END Add final sliding correlation print <---
#     if weight_norms['fc1']: # Check if list is not empty
#         print(f"Final weight norms - FC1: {weight_norms['fc1'][-1]:.4f} | FC2: {weight_norms['fc2'][-1]:.4f} | FC3: {weight_norms['fc3'][-1]:.4f}")
#     print(f"Total runtime: {time.time() - start_time:.1f} seconds")
    
#     # Generate plots of experimental results
#     print("Generating final experiment plots...")
#     plot_experiment_results(phase_metrics, disruption_type=DISRUPTION_TYPE, current_results_dir=RESULTS_DIR)
#     # Save trajectory plots
#     save_trajectory_plots(current_results_dir=RESULTS_DIR,
#                          time_points_data=time_points,
#                          cursor_x_history_data=cursor_x_history,
#                          cursor_y_history_data=cursor_y_history,
#                          desired_x_history_data=desired_x_history,
#                          desired_y_history_data=desired_y_history,
#                          pred_vx_history_data=pred_vx_history,
#                          pred_vy_history_data=pred_vy_history,
#                          desired_vx_history_data=desired_vx_history,
#                          desired_vy_history_data=desired_vy_history,
#                          sliding_x_corr_history_data=sliding_x_corr_history,
#                          sliding_y_corr_history_data=sliding_y_corr_history,
#                          target_reach_steps_data=target_reach_steps,
#                          target_reach_durations_data=target_reach_durations,
#                          initial_learning_completion_step_data=initial_learning_completion_step_global,
#                          phase_metrics_data=phase_metrics,
#                          corr_calc_window_val=corr_calc_window,
#                          lag_val=lag,
#                          screen_width_val=screen_width,
#                          screen_height_val=screen_height,
#                          phase_names_val=phase_names)

# plt.ioff()
# # plt.show() # COMMENTED OUT FOR NO LIVE PLOTTING (figures saved at end)


# # Removed the __main__ block as the simulation runs directly now