import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import numpy as np
import pandas as pd
import time
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from collections import deque, defaultdict
import random
from filterpy.kalman import KalmanFilter as FilterPyKalmanFilter
from scipy.linalg import inv

# --- Attempt to import from local modules (excluding train_single_gpu) ---
try:
    from bci_simulation import (
        Synthetic_Neuron,
        LSTMRegression,
        train_test_kalman_filter,
        train_lstm_model,
        compute_correlation,
        
    )
    from bci_simulation import TwoScaleMetaRLWeightUpdaterFull
    print("Successfully imported components from bci_simulation.py")
except ImportError as e:
    print(f"Warning: Could not import from bci_simulation.py: {e}")
    # Define dummy classes/functions if import fails
    class Synthetic_Neuron: pass
    class LSTMRegression(nn.Module): pass
    def train_test_kalman_filter(*args, **kwargs): return None, None
    def train_lstm_model(*args, **kwargs): return None, [], []
    def compute_correlation(*args, **kwargs): return 0.0
    class TwoScaleMetaRLWeightUpdaterFull: pass

# --- Import ZENODO classes with renamed imports to avoid conflicts ---
try:
    from ZENODO_SCRIPT2 import (
        SNNRegression as SNNRegressionZenodo,
        TwoScaleMetaRLWeightUpdaterFull as TwoScaleMetaRLWeightUpdaterFullZenodo
    )
    print("Successfully imported ZENODO components from ZENODO_SCRIPT2.py")
except ImportError as e:
    print(f"Warning: Could not import from ZENODO_SCRIPT2.py: {e}")
    # Define dummy classes if import fails
    class SNNRegressionZenodo: pass
    class TwoScaleMetaRLWeightUpdaterFullZenodo: pass

import math

# --- SNNRegression class (copied from train_dist_old_zenodo.py or bci_simulation.py) ---
import snntorch as snn
from snntorch import surrogate

class SNNRegression(nn.Module):
    def __init__(self, input_size=182, hidden_size=128, output_size=2, input_beta=0.9, use_mem=False):
        self.use_mem = use_mem
        super(SNNRegression, self).__init__()
        spike_grad = surrogate.fast_sigmoid()

        self.lif_input = snn.Leaky(
            beta=input_beta,
            spike_grad=spike_grad,
            init_hidden=False
        )
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.lif1 = snn.Leaky(beta=0.9, spike_grad=spike_grad, init_hidden=False)
        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.lif2 = snn.Leaky(beta=0.9, spike_grad=spike_grad, init_hidden=False)
        self.fc3 = nn.Linear(hidden_size // 2, output_size)
        self.lif3 = snn.Leaky(
            beta=0.9,
            spike_grad=spike_grad,
            threshold=float('inf'),
            reset_mechanism="none",
            init_hidden=False
        )
        self.apply(self._init_weights)
        self.mem_in = None
        self.mem1 = None
        self.mem2 = None
        self.mem3 = None

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def reset_states(self):
        self.mem_in = self.lif_input.reset_mem()
        self.mem1   = self.lif1.reset_mem()
        self.mem2   = self.lif2.reset_mem()
        self.mem3   = self.lif3.reset_mem()

    def forward(self, x):
        batch_size, T, _ = x.size()
        device = x.device
        if self.mem_in is None or self.mem_in.device != device:
            self.reset_states()
            self.mem_in = self.mem_in.to(device)
            self.mem1   = self.mem1.to(device)
            self.mem2   = self.mem2.to(device)
            self.mem3   = self.mem3.to(device)

        outputs = []
        for t in range(T):
            inp = x[:, t, :].float().to(device)
            spk_in, self.mem_in = self.lif_input(inp, self.mem_in)
            if self.use_mem:
                cur1 = self.fc1(self.mem_in)
            else:
                cur1 = self.fc1(spk_in)
            spk1, self.mem1 = self.lif1(cur1, self.mem1)
            cur2 = self.fc2(spk1)
            spk2, self.mem2 = self.lif2(cur2, self.mem2)
            cur3 = self.fc3(spk2)
            _, self.mem3 = self.lif3(cur3, self.mem3)
            outputs.append(self.mem3)
        return torch.stack(outputs, dim=1)

# --- evaluate_model function (copied from train_dist_old_zenodo.py) ---
def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_outputs = []
    all_targets = []
    with torch.no_grad():
        for batch_spike, batch_target in data_loader:
            batch_spike, batch_target = batch_spike.to(device), batch_target.to(device)
            
            batch_size_dyn, seq_len_dyn, num_features_dyn = batch_spike.shape
            # This assumes model.fc1 exists, which is true for SNNRegression
            if not hasattr(model, 'fc1') and not hasattr(model, 'lstm') : # Add check for LSTM which has model.lstm.input_size
                print("Warning: evaluate_model cannot determine expected input size for this model type.")
                expected_input_size = num_features_dyn # Fallback: assume no padding needed
            elif hasattr(model, 'fc1'):
                expected_input_size = model.fc1.in_features
            elif hasattr(model, 'lstm'): # For LSTM model
                expected_input_size = model.lstm.input_size
            else: # Should not be reached if above is comprehensive
                expected_input_size = num_features_dyn

            padding_size = expected_input_size - num_features_dyn
            if padding_size > 0:
                padding_tensor = torch.zeros(batch_size_dyn, seq_len_dyn, padding_size, device=device)
                batch_spike_padded = torch.cat([batch_spike, padding_tensor], dim=2)
            elif padding_size < 0:
                print(f"Warning in evaluate_model: Input features ({num_features_dyn}) > model expected ({expected_input_size}). Truncating.")
                batch_spike_padded = batch_spike[:, :, :expected_input_size]
            else:
                batch_spike_padded = batch_spike

            if hasattr(model, 'reset_states'): model.reset_states()
            outputs = model(batch_spike_padded)
            loss = criterion(outputs[:, -1, :], batch_target) # Assumes target is for the last step
            total_loss += loss.item()
            all_outputs.append(outputs.cpu().numpy())
            all_targets.append(batch_target.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    all_outputs = np.concatenate(all_outputs, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    if all_outputs.ndim == 3:
        all_outputs = all_outputs[:, -1, :] 
    if all_targets.ndim == 3:
        all_targets = all_targets[:, -1, :]

    corr_x = compute_correlation(all_outputs[:, 0], all_targets[:, 0])
    corr_y = compute_correlation(all_outputs[:, 1], all_targets[:, 1])
    return avg_loss, corr_x, corr_y

from torch.cuda.amp import GradScaler, autocast

# --- Experiment Phase Definitions (New for this script) ---
PHASE_ONLINE_TRAINING_COLLECTION = "ONLINE_TRAINING_COLLECTION"
PHASE_POST_TRAINING_EVALUATION = "POST_TRAINING_EVALUATION"


# Disruption functions are not used in this "no pretrain" script.

def train_bptt_snn_local(train_dataset, val_dataset, 
                        input_size, hidden_size, output_size,
                        epochs=200, batch_size=32768, device=torch.device("cpu"),
                        lr=1e-3,
                        checkpoint_path="snn_bptt_local_temp.pth",
                        patience=5, model_to_train=None):
    print(f"Starting BPTT SNN training (for online data) on {device}...")
    
    if not train_dataset or len(train_dataset) == 0:
        print("  No training data provided for BPTT SNN. Returning initial model state.")
        if model_to_train: return model_to_train.state_dict(), pd.DataFrame()
        return None, pd.DataFrame()

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=min(4, os.cpu_count() if os.cpu_count() else 1), pin_memory=True, 
        persistent_workers=bool(min(4, os.cpu_count() if os.cpu_count() else 1) > 0), 
        prefetch_factor=2 if bool(min(4, os.cpu_count() if os.cpu_count() else 1) > 0) else None
    )
    val_loader = None
    if val_dataset and len(val_dataset) > 0 :
        val_loader = DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False,
            num_workers=min(4, os.cpu_count() if os.cpu_count() else 1), pin_memory=True,
            persistent_workers=bool(min(4, os.cpu_count() if os.cpu_count() else 1) > 0),
            prefetch_factor=2 if bool(min(4, os.cpu_count() if os.cpu_count() else 1) > 0) else None
        )

    model = model_to_train.to(device) if model_to_train else SNNRegression(input_size=input_size, hidden_size=hidden_size, output_size=output_size).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scaler = GradScaler()
    criterion = nn.MSELoss()

    history = []
    best_val_loss = float('inf')
    # Start with current model state as best, especially if it's pre-initialized randomly
    best_model_state_dict = model.state_dict() 
    epochs_no_improve = 0
    best_epoch = 0

    print(f"  Model (SNN_BPTT for online data): input={input_size}, hidden={hidden_size}, output={output_size}")

    try:
        for epoch in range(epochs):
            model.train()
            epoch_train_loss = 0.0
            for batch_idx, (batch_spike, batch_target) in enumerate(train_loader):
                batch_spike, batch_target = batch_spike.to(device), batch_target.to(device)
                
                batch_size_dyn, seq_len_dyn, num_features_dyn = batch_spike.shape
                expected_input_size = model.fc1.in_features
                padding_size = expected_input_size - num_features_dyn
                if padding_size > 0:
                    padding_tensor = torch.zeros(batch_size_dyn, seq_len_dyn, padding_size, device=device)
                    batch_spike_padded = torch.cat([batch_spike, padding_tensor], dim=2)
                elif padding_size < 0:
                     batch_spike_padded = batch_spike[:, :, :expected_input_size]
                else:
                    batch_spike_padded = batch_spike
                
                if hasattr(model, 'reset_states'): model.reset_states()
                optimizer.zero_grad()
                with autocast():
                    outputs = model(batch_spike_padded)
                    loss = criterion(outputs[:, -1, :], batch_target)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                epoch_train_loss += loss.item()
            
            avg_train_loss = epoch_train_loss / max(1, len(train_loader))
            
            current_val_loss = avg_train_loss 
            val_x_corr, val_y_corr = 0.0, 0.0

            if val_loader and len(val_loader.dataset) > 0:
                current_val_loss, val_x_corr, val_y_corr = evaluate_model(model, val_loader, criterion, device)
                log_msg = f"  Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {current_val_loss:.4f}, Val Corr X: {val_x_corr:.4f}, Val Corr Y: {val_y_corr:.4f}"
            else:
                log_msg = f"  Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f} (No validation)"
            print(log_msg)

            history.append({
                'epoch': epoch + 1,
                'train_loss': avg_train_loss,
                'val_loss': current_val_loss if val_loader else None,
                'val_corr_x': val_x_corr if val_loader else None,
                'val_corr_y': val_y_corr if val_loader else None
            })

            if val_loader: 
                if current_val_loss < best_val_loss:
                    best_val_loss = current_val_loss
                    best_model_state_dict = model.state_dict()
                    epochs_no_improve = 0
                    best_epoch = epoch + 1
                    # print(f"    New best val_loss: {best_val_loss:.4f}. Saving model state.") # Less verbose
                else:
                    epochs_no_improve += 1
                    if epochs_no_improve >= patience:
                        print(f"Early stopping triggered after {epoch + 1} epochs.")
                        break
            else: 
                best_model_state_dict = model.state_dict()
                best_epoch = epoch + 1

    except KeyboardInterrupt:
        print("\\nTraining interrupted by user.")
    finally:
        print("BPTT SNN training (for online data) finished.")
        if best_epoch > 0:
            print(f"Using model state from epoch {best_epoch}.")

    history_df = pd.DataFrame(history)
    return best_model_state_dict, history_df


class SpikeVelDataset(Dataset):
    def __init__(self, spike_data_list_of_arrays, vel_data_list_of_arrays, sequence_length, num_neurons_expected):
        if not spike_data_list_of_arrays or not vel_data_list_of_arrays:
            self.spike_data = torch.empty(0, num_neurons_expected, dtype=torch.float32)
            self.vel_data = torch.empty(0, 2, dtype=torch.float32)
            self.num_samples = 0
            self.sequence_length = sequence_length
            # print("Warning: SpikeVelDataset received empty data lists.")
            return

        # Filter out empty arrays and ensure correct shapes before concatenation
        filtered_spikes = []
        for arr in spike_data_list_of_arrays:
            if arr is not None and arr.ndim == 2 and arr.shape[0] > 0 and arr.shape[1] == num_neurons_expected:
                filtered_spikes.append(arr)
            elif arr is not None and arr.ndim == 1 and arr.shape[0] == num_neurons_expected: # single step recorded
                filtered_spikes.append(arr.reshape(1, -1))
        
        filtered_vels = []
        for arr in vel_data_list_of_arrays:
            if arr is not None and arr.ndim == 2 and arr.shape[0] > 0 and arr.shape[1] == 2:
                filtered_vels.append(arr)
            elif arr is not None and arr.ndim == 1 and arr.shape[0] == 2: # single step recorded
                filtered_vels.append(arr.reshape(1, -1))
        
        if not filtered_spikes or not filtered_vels or len(filtered_spikes) != len(filtered_vels):
            # print("Warning: SpikeVelDataset - Mismatch or empty after filtering. Shapes:")
            # for arr in spike_data_list_of_arrays: print(arr.shape if hasattr(arr, 'shape') else type(arr))
            # for arr in vel_data_list_of_arrays: print(arr.shape if hasattr(arr, 'shape') else type(arr))
            self.spike_data = torch.empty(0, num_neurons_expected, dtype=torch.float32)
            self.vel_data = torch.empty(0, 2, dtype=torch.float32)
            self.num_samples = 0
            return

        try:
            concatenated_spikes = np.concatenate(filtered_spikes, axis=0)
            concatenated_vels = np.concatenate(filtered_vels, axis=0)
        except ValueError as e:
            print(f"Error concatenating data for SpikeVelDataset: {e}")
            self.spike_data = torch.empty(0, num_neurons_expected, dtype=torch.float32)
            self.vel_data = torch.empty(0, 2, dtype=torch.float32)
            self.num_samples = 0
            return

        self.spike_data = torch.tensor(concatenated_spikes, dtype=torch.float32)
        self.vel_data = torch.tensor(concatenated_vels, dtype=torch.float32)
        self.sequence_length = sequence_length
        self.num_samples = max(0, len(self.spike_data) - self.sequence_length + 1)
        if self.num_samples == 0 and len(self.spike_data) > 0:
            # print(f"Warning: SpikeVelDataset created with 0 samples. Total steps: {len(self.spike_data)}, Seq Len: {self.sequence_length}")
            pass # Less verbose

    def __len__(self):
        return self.num_samples 

    def __getitem__(self, idx):
        if idx >= self.num_samples:
            raise IndexError("Index out of bounds")
        spikes_seq = self.spike_data[idx : idx + self.sequence_length]
        target_vel = self.vel_data[idx + self.sequence_length - 1] 
        return spikes_seq, target_vel


# --- Simulation Parameters for Reach Task (GLOBAL CONSTANTS) ---
SCREEN_WIDTH, SCREEN_HEIGHT = 800, 600
TARGET_RADIUS = 50
MOVEMENT_SCALE = 5 
CENTER_POS = np.array([SCREEN_WIDTH/2, SCREEN_HEIGHT/2], dtype=np.float32)
MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL = 300 # 3 seconds limit (300 steps * 10ms/step)

# --- Function to Simulate a Single Reach Attempt ---
def simulate_single_reach_attempt(
    decoder_name, model_object_or_tuple, 
    spike_generator, device,
    target_pos_abs, initial_cursor_pos_abs,
    movement_scale_abs,
    max_steps_this_reach,
    input_size_snn_val, num_neurons_val,
    is_snn_online_learning_active,
    collect_trajectory_data=False 
):
    trajectory_spikes = []
    trajectory_ideal_velocities = [] 

    current_model_obj = model_object_or_tuple[0] if isinstance(model_object_or_tuple, tuple) else model_object_or_tuple
    
    if not (decoder_name == 'SNN_Online' and is_snn_online_learning_active):
        if hasattr(current_model_obj, 'eval'): current_model_obj.eval()
    
    if decoder_name != "SNN_Online":
        if hasattr(current_model_obj, 'reset_states'):
            current_model_obj.reset_states()
    elif decoder_name == "SNN_Online": 
        # For ZENODO SNN_Online, reset persistent states before each reach attempt
        if hasattr(current_model_obj, '_persistent_states'):
            batch_size = 1
            current_model_obj._persistent_states = (
                torch.zeros(batch_size, current_model_obj.fc1.out_features, device=device),  # spk1_rec
                torch.zeros(batch_size, current_model_obj.fc1.out_features, device=device),  # mem1
                torch.zeros(batch_size, current_model_obj.fc2.out_features, device=device),  # mem2
                torch.zeros(batch_size, current_model_obj.fc3.out_features, device=device)   # mem3
            )


    cursor_pos = initial_cursor_pos_abs.copy()
    reach_succeeded = False
    steps_taken_final = max_steps_this_reach 

    for step_num in range(max_steps_this_reach):
        desired_vec_np = target_pos_abs - cursor_pos
        distance = np.linalg.norm(desired_vec_np)
        
        if distance > 1e-6:
            current_ideal_vel_np = desired_vec_np / distance * min(1.0, distance / 200.0) 
        else:
            current_ideal_vel_np = np.zeros(2)
        
        current_ideal_vel_tensor = torch.tensor(current_ideal_vel_np, dtype=torch.float32).unsqueeze(0).to(device)
        # Spike generator now always in training mode for this experiment for noisy interaction
        dynamically_generated_spikes = spike_generator.generate_spikes(current_ideal_vel_tensor.squeeze(0), sequence_length=1)

        if collect_trajectory_data:
            trajectory_spikes.append(dynamically_generated_spikes.squeeze().cpu().numpy())
            trajectory_ideal_velocities.append(current_ideal_vel_np)

        pred_vel_for_cursor = np.zeros(2)
        current_spikes_for_kf_np = dynamically_generated_spikes.squeeze().cpu().numpy()

        if decoder_name == 'KF':
            kf_filter = current_model_obj
            if kf_filter is not None and hasattr(kf_filter, 'H') and kf_filter.H is not None: 
                try:
                    z = current_spikes_for_kf_np.reshape(-1, 1)
                    kf_filter.predict()
                    kf_filter.update(z)
                    pred_vel_for_cursor = kf_filter.x.copy() 
                except Exception:
                    pred_vel_for_cursor = np.random.randn(2) * 0.1 
            else: 
                pred_vel_for_cursor = np.random.randn(2) * 0.1 

        elif decoder_name == 'LSTM':
            lstm_model = current_model_obj
            lstm_input_tensor = dynamically_generated_spikes.to(device) 
            # Pad if num_neurons_val (from spike_gen) != lstm_model.lstm.input_size
            if lstm_input_tensor.shape[2] != lstm_model.lstm.input_size:
                padding_lstm = torch.zeros(lstm_input_tensor.shape[0], lstm_input_tensor.shape[1], lstm_model.lstm.input_size - lstm_input_tensor.shape[2], device=device)
                lstm_input_tensor = torch.cat([lstm_input_tensor, padding_lstm], dim=2)

            with torch.no_grad(): 
                lstm_output = lstm_model(lstm_input_tensor)
                pred_vel_for_cursor = lstm_output.squeeze().detach().cpu().numpy()

        elif decoder_name == 'SNN_BPTT':
            snn_bptt_model = current_model_obj
            snn_input_step_bptt = dynamically_generated_spikes.squeeze(0).to(device) 
            padding_bptt = torch.zeros((snn_input_step_bptt.size(0), input_size_snn_val - num_neurons_val), device=device)
            snn_input_padded_bptt = torch.cat([snn_input_step_bptt, padding_bptt], dim=1)
            snn_input_tensor_bptt = snn_input_padded_bptt.unsqueeze(0)
            with torch.no_grad(): 
                # No need to reset states again here, done at start of function
                snn_bptt_output_seq = snn_bptt_model(snn_input_tensor_bptt)
                pred_vel_for_cursor = snn_bptt_output_seq[:, -1, :].squeeze().detach().cpu().numpy()
        
        elif decoder_name == 'SNN_Online':
            online_model, online_updater = model_object_or_tuple # unpack
            snn_input_step_online = dynamically_generated_spikes.squeeze(0).to(device) 
            padding_online = torch.zeros((snn_input_step_online.size(0), input_size_snn_val - num_neurons_val), device=device)
            snn_input_padded_online = torch.cat([snn_input_step_online, padding_online], dim=1)
            snn_input_tensor_online = snn_input_padded_online.unsqueeze(0)

            target_tensor_online = torch.tensor(current_ideal_vel_np, dtype=torch.float32).unsqueeze(0).to(device)

            # Initialize states for ZENODO SNN if not already present
            if not hasattr(online_model, '_persistent_states'):
                batch_size = 1
                online_model._persistent_states = (
                    torch.zeros(batch_size, online_model.fc1.out_features, device=device),  # spk1_rec
                    torch.zeros(batch_size, online_model.fc1.out_features, device=device),  # mem1
                    torch.zeros(batch_size, online_model.fc2.out_features, device=device),  # mem2
                    torch.zeros(batch_size, online_model.fc3.out_features, device=device)   # mem3
                )

            if is_snn_online_learning_active:
                with torch.enable_grad():
                    online_model.train()
                    # Use ZENODO forward signature with states
                    pred_output_sequence_online, online_model._persistent_states = online_model(
                        snn_input_tensor_online, *online_model._persistent_states, need_traces=False
                    )
                    pred_velocity_online_tensor = pred_output_sequence_online[:, -1, :]
                    pred_vel_for_cursor = pred_velocity_online_tensor.squeeze().detach().cpu().numpy()
                    
                    # Use ZENODO updater's single timestep update method
                    loss, online_model._persistent_states = online_updater.update_single_timestep(
                        snn_input_padded_online.squeeze(0), target_tensor_online.squeeze(0), 
                        online_model._persistent_states
                    )
                online_model.eval()
            else: 
                with torch.no_grad():
                    pred_output_sequence_online, online_model._persistent_states = online_model(
                        snn_input_tensor_online, *online_model._persistent_states, need_traces=False
                    )
                    pred_vel_for_cursor = pred_output_sequence_online[:, -1, :].squeeze().detach().cpu().numpy()

        pred_vel_for_cursor = np.clip(pred_vel_for_cursor, -1.0, 1.0) 
        cursor_pos += pred_vel_for_cursor * movement_scale_abs 
        cursor_pos[0] = np.clip(cursor_pos[0], 0, SCREEN_WIDTH)
        cursor_pos[1] = np.clip(cursor_pos[1], 0, SCREEN_HEIGHT)

        if distance < TARGET_RADIUS:
            reach_succeeded = True
            steps_taken_final = step_num + 1
            break 
    
    if collect_trajectory_data:
        return steps_taken_final, reach_succeeded, trajectory_spikes, trajectory_ideal_velocities
    else:
        return steps_taken_final, reach_succeeded

# --- Plotting function for reach times ---
def plot_phased_reach_time_comparison(phased_results_single_run, base_filename="nopretrain_reach_time.png", smoothing_window=5, num_training_reaches=30):
    plt.ioff()
    if not phased_results_single_run: return

    fig, ax = plt.subplots(figsize=(15, 8))
    colors = {'KF': 'blue', 'LSTM': 'green', 'SNN_BPTT': 'red', 'SNN_Online': 'purple'}
    phase_names_ordered = [PHASE_ONLINE_TRAINING_COLLECTION, PHASE_POST_TRAINING_EVALUATION]
    max_total_reaches_plot = 0

    for decoder_name, phase_data_for_decoder in phased_results_single_run.items():
        current_total_reach_idx_plot = 0
        decoder_color = colors.get(decoder_name, 'gray')
        for phase_idx, phase_name in enumerate(phase_names_ordered):
            if phase_name not in phase_data_for_decoder: continue
            reach_steps_this_phase = phase_data_for_decoder[phase_name]
            if not reach_steps_this_phase: continue

            reach_times_s = np.array(reach_steps_this_phase) * 0.01 
            max_time_cap = MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL * 0.01
            reach_times_s = np.clip(reach_times_s, 0, max_time_cap)
            smoothed_times = pd.Series(reach_times_s).rolling(window=smoothing_window, min_periods=1).mean().to_numpy()
            
            num_reaches_in_phase = len(reach_times_s)
            trial_indices_this_segment = np.arange(current_total_reach_idx_plot, current_total_reach_idx_plot + num_reaches_in_phase)
            
            linestyle = '-' # Plot both phases with solid lines but manage legend by decoder
            label_text = f'{decoder_name}'
            
            ax.plot(trial_indices_this_segment, smoothed_times, 
                    color=decoder_color, linestyle=linestyle,
                    linewidth=2, label=label_text if phase_idx == 0 else None) 
            current_total_reach_idx_plot += num_reaches_in_phase
        if current_total_reach_idx_plot > max_total_reaches_plot: max_total_reaches_plot = current_total_reach_idx_plot
            
    if num_training_reaches > 0 and num_training_reaches < max_total_reaches_plot :
        ax.axvline(x=num_training_reaches - 0.5, color='black', linestyle=':', alpha=0.7, 
                   label=f'End Online Training ({num_training_reaches} reaches)')

    ax.set_xlabel("Reach Trial Index")
    ax.set_ylabel(f"Smoothed Time to Reach Target (seconds, N={smoothing_window})")
    ax.set_title("Decoder Performance: Online Learning from Scratch")
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    ax.set_ylim(bottom=0, top=MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL * 0.01 * 1.1)
    if max_total_reaches_plot > 0 : ax.set_xlim(left=-1, right=max_total_reaches_plot)

    plt.tight_layout()
    try:
        fig.savefig(base_filename, dpi=150)
        print(f"Saved phased reach time comparison plot to {base_filename}")
    except Exception as e: print(f"Error saving plot: {e}")
    plt.close(fig)

def run_single_comparison(seed_value):
    torch.manual_seed(seed_value)
    np.random.seed(seed_value)
    random.seed(seed_value)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed_value)
    
    print(f"--- Running 'No Pretrain' Comparison with SEED: {seed_value} ---")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    NUM_TARGETS_ONLINE_TRAINING = 30
    NUM_TARGETS_POST_TRAINING_EVAL = 70
    NUM_NEURONS = 96
    INPUT_SIZE_SNN = NUM_NEURONS + 2 
    INPUT_SIZE_LSTM = NUM_NEURONS # LSTM input size is just num_neurons
    OUTPUT_SIZE = 2
    NN_HIDDEN_SIZE = 256
    LSTM_HIDDEN_SIZE = NN_HIDDEN_SIZE 
    LSTM_NUM_LAYERS = 2
    LSTM_DROPOUT = 0.2
    ONLINE_DATA_TRAIN_EPOCHS = 30 # Reduced from 50, as 30 reaches give limited data
    ONLINE_DATA_BATCH_SIZE = 64   # Further reduced batch size
    ONLINE_DATA_LR = 1e-3
    ONLINE_DATA_SEQ_LEN = 5 # Use short sequences from online data

    ONLINE_SNN_FAST_LR = 1e-4
    ONLINE_SNN_SLOW_LR = 1e-3
    ONLINE_SNN_META_LR = 0.1
    ONLINE_SNN_WINDOW = 10

    spike_gen = Synthetic_Neuron(num_neurons=NUM_NEURONS, noise_level=0.02)
    if hasattr(spike_gen, 'to'): spike_gen.to(device)
    spike_gen.training = True 

    trained_decoders = {}
    kf_model = FilterPyKalmanFilter(dim_x=OUTPUT_SIZE, dim_z=NUM_NEURONS)
    kf_model.F = np.eye(OUTPUT_SIZE); kf_model.H = np.random.rand(NUM_NEURONS, OUTPUT_SIZE) * 0.01
    kf_model.Q = np.eye(OUTPUT_SIZE) * 0.1; kf_model.R = np.eye(NUM_NEURONS) * 1.0
    kf_model.x = np.zeros(OUTPUT_SIZE); kf_model.P = np.eye(OUTPUT_SIZE) * 100.0
    trained_decoders['KF'] = kf_model 
    print("Kalman Filter initialized (untrained).")

    lstm_model = LSTMRegression(input_size=INPUT_SIZE_LSTM, hidden_size=LSTM_HIDDEN_SIZE, output_size=OUTPUT_SIZE, num_layers=LSTM_NUM_LAYERS, dropout=LSTM_DROPOUT).to(device)
    trained_decoders['LSTM'] = lstm_model
    print("LSTM initialized with random weights.")

    snn_bptt_model = SNNRegression(input_size=INPUT_SIZE_SNN, hidden_size=NN_HIDDEN_SIZE, output_size=OUTPUT_SIZE).to(device)
    trained_decoders['SNN_BPTT'] = snn_bptt_model
    print("SNN_BPTT initialized with random weights.")
    
    # --- Create hardware-friendly lookup table for ZENODO algorithm ---
    vals = torch.linspace(0.25, 8.0, 256, device=device)
    inv_sqrt_LUT = 1.0 / torch.sqrt(vals)
    
    snn_online_model = SNNRegressionZenodo(input_size=INPUT_SIZE_SNN, hidden_size=NN_HIDDEN_SIZE, output_size=OUTPUT_SIZE).to(device)
    online_updater = TwoScaleMetaRLWeightUpdaterFullZenodo(
        snn_online_model, base_fast_lr=ONLINE_SNN_FAST_LR, base_slow_lr=ONLINE_SNN_SLOW_LR, 
        window_size=ONLINE_SNN_WINDOW, meta_lr=ONLINE_SNN_META_LR,
        online_mode=True, inv_sqrt_LUT=inv_sqrt_LUT
    )
    trained_decoders['SNN_Online'] = (snn_online_model, online_updater)
    print("SNN_Online initialized with random weights (will learn online).")

    phased_reach_results = defaultdict(lambda: defaultdict(list))
    
    # Data collection lists for the first 30 reaches, specific to each decoder
    # {decoder_name: {'spikes': [], 'vels': []}}
    phase1_collected_data_for_training = defaultdict(lambda: {'spikes': [], 'vels': []})

    print(f"\n===== Starting Phase: {PHASE_ONLINE_TRAINING_COLLECTION} ({NUM_TARGETS_ONLINE_TRAINING} reaches) =====")
    for target_idx in range(NUM_TARGETS_ONLINE_TRAINING):
        current_cursor_pos_sim = CENTER_POS.copy()
        target_x = np.random.uniform(TARGET_RADIUS*1.5, SCREEN_WIDTH - TARGET_RADIUS*1.5)
        target_y = np.random.uniform(TARGET_RADIUS*1.5, SCREEN_HEIGHT - TARGET_RADIUS*1.5)
        current_target_pos_sim = np.array([target_x, target_y], dtype=np.float32)
        if np.linalg.norm(current_target_pos_sim - current_cursor_pos_sim) <= TARGET_RADIUS*3: # Ensure target is far enough
            target_x = SCREEN_WIDTH - target_x # Simple heuristic to move it if too close
            current_target_pos_sim = np.array([target_x, target_y], dtype=np.float32)

        for decoder_name, model_or_tuple_val in trained_decoders.items():
            steps_taken, success, traj_spikes, traj_ideal_vels = simulate_single_reach_attempt(
                decoder_name=decoder_name, model_object_or_tuple=model_or_tuple_val,
                spike_generator=spike_gen, device=device,
                target_pos_abs=current_target_pos_sim, initial_cursor_pos_abs=current_cursor_pos_sim,
                movement_scale_abs=MOVEMENT_SCALE,
                max_steps_this_reach=MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL,
                input_size_snn_val=INPUT_SIZE_SNN, num_neurons_val=NUM_NEURONS,
                is_snn_online_learning_active=(decoder_name == 'SNN_Online'),
                collect_trajectory_data=True 
            )
            phased_reach_results[decoder_name][PHASE_ONLINE_TRAINING_COLLECTION].append(steps_taken)
            
            # Collect data from this decoder's interaction for its later training
            phase1_collected_data_for_training[decoder_name]['spikes'].extend(traj_spikes)
            phase1_collected_data_for_training[decoder_name]['vels'].extend(traj_ideal_vels)
            
            if (target_idx + 1) % 5 == 0 or target_idx == NUM_TARGETS_ONLINE_TRAINING -1 :
                 print(f"    {decoder_name}, Phase1 Target {target_idx+1}/{NUM_TARGETS_ONLINE_TRAINING}: {'Success' if success else 'Timeout'} in {steps_taken} steps.")

    print("\n===== Training KF, LSTM, SNN_BPTT on their respective collected online data =====")
    for decoder_name in ['KF', 'LSTM', 'SNN_BPTT']:
        collected_spikes = phase1_collected_data_for_training[decoder_name]['spikes']
        collected_vels = phase1_collected_data_for_training[decoder_name]['vels']

        if not collected_spikes or not collected_vels:
            print(f"  No data collected for {decoder_name}. Skipping its training.")
            continue
        
        print(f"  Training {decoder_name} on {len(collected_spikes)} collected steps...")
        current_model_to_train = trained_decoders[decoder_name]

        if decoder_name == 'KF':
            kf_train_spikes_arr = np.array(collected_spikes)
            kf_train_vels_arr = np.array(collected_vels)
            if kf_train_spikes_arr.shape[0] > 1 and kf_train_vels_arr.shape[0] > 1:
                _, fitted_kf = train_test_kalman_filter(kf_train_spikes_arr, kf_train_vels_arr, kf_train_spikes_arr[:min(10, len(kf_train_spikes_arr))], kf_train_vels_arr[:min(10, len(kf_train_vels_arr))])
                if fitted_kf: trained_decoders['KF'] = fitted_kf; print("  KF fitted.")
                else: print("  KF fitting failed.")
            else: print("  Not enough data for KF.")
        
        elif decoder_name == 'LSTM':
            # SpikeVelDataset expects list of arrays, here we have list of step_data
            # Create a single "reach" from all collected steps for this decoder
            lstm_train_dataset = SpikeVelDataset([np.array(collected_spikes)], [np.array(collected_vels)], sequence_length=ONLINE_DATA_SEQ_LEN, num_neurons_expected=NUM_NEURONS)
            if len(lstm_train_dataset) > 0:
                lstm_train_loader = DataLoader(lstm_train_dataset, batch_size=ONLINE_DATA_BATCH_SIZE, shuffle=True)
                criterion_lstm = nn.MSELoss()
                optimizer_lstm = optim.Adam(current_model_to_train.parameters(), lr=ONLINE_DATA_LR)
                trained_lstm, _, _ = train_lstm_model(current_model_to_train, lstm_train_loader, None, criterion_lstm, optimizer_lstm, device, num_epochs=ONLINE_DATA_TRAIN_EPOCHS, patience=ONLINE_DATA_TRAIN_EPOCHS)
                trained_decoders['LSTM'] = trained_lstm
                print("  LSTM trained.")
            else: print(f"  Not enough samples for LSTM training ({len(lstm_train_dataset)} samples from {len(collected_spikes)} steps).")

        elif decoder_name == 'SNN_BPTT':
            snn_bptt_train_dataset = SpikeVelDataset([np.array(collected_spikes)], [np.array(collected_vels)], sequence_length=ONLINE_DATA_SEQ_LEN, num_neurons_expected=NUM_NEURONS)
            if len(snn_bptt_train_dataset) > 0:
                snn_bptt_state_dict, _ = train_bptt_snn_local(snn_bptt_train_dataset, None, INPUT_SIZE_SNN, NN_HIDDEN_SIZE, OUTPUT_SIZE, ONLINE_DATA_TRAIN_EPOCHS, ONLINE_DATA_BATCH_SIZE, device, ONLINE_DATA_LR, ONLINE_DATA_TRAIN_EPOCHS, current_model_to_train)
                if snn_bptt_state_dict:
                    current_model_to_train.load_state_dict(snn_bptt_state_dict)
                    trained_decoders['SNN_BPTT'] = current_model_to_train # Ensure it's the updated model object
                    print("  SNN_BPTT trained.")
                else: print("  SNN_BPTT training failed to return state.")
            else: print(f"  Not enough samples for SNN_BPTT training ({len(snn_bptt_train_dataset)} samples from {len(collected_spikes)} steps).")


    print(f"\n===== Starting Phase: {PHASE_POST_TRAINING_EVALUATION} ({NUM_TARGETS_POST_TRAINING_EVAL} reaches) =====")
    for target_idx in range(NUM_TARGETS_POST_TRAINING_EVAL):
        current_cursor_pos_sim = CENTER_POS.copy()
        target_x = np.random.uniform(TARGET_RADIUS*1.5, SCREEN_WIDTH - TARGET_RADIUS*1.5)
        target_y = np.random.uniform(TARGET_RADIUS*1.5, SCREEN_HEIGHT - TARGET_RADIUS*1.5)
        current_target_pos_sim = np.array([target_x, target_y], dtype=np.float32)
        if np.linalg.norm(current_target_pos_sim - current_cursor_pos_sim) <= TARGET_RADIUS*3:
            target_x = SCREEN_WIDTH - target_x 
            current_target_pos_sim = np.array([target_x, target_y], dtype=np.float32)

        for decoder_name, model_or_tuple_val in trained_decoders.items():
            steps_taken, success = simulate_single_reach_attempt(
                decoder_name=decoder_name, model_object_or_tuple=model_or_tuple_val,
                spike_generator=spike_gen, device=device,
                target_pos_abs=current_target_pos_sim, initial_cursor_pos_abs=current_cursor_pos_sim,
                movement_scale_abs=MOVEMENT_SCALE,
                max_steps_this_reach=MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL,
                input_size_snn_val=INPUT_SIZE_SNN, num_neurons_val=NUM_NEURONS,
                is_snn_online_learning_active=(decoder_name == 'SNN_Online'),
                collect_trajectory_data=False
            )
            phased_reach_results[decoder_name][PHASE_POST_TRAINING_EVALUATION].append(steps_taken)
            if (target_idx + 1) % 10 == 0 or target_idx == NUM_TARGETS_POST_TRAINING_EVAL -1:
                 print(f"    {decoder_name}, Phase2 Target {target_idx+1}/{NUM_TARGETS_POST_TRAINING_EVAL}: {'Success' if success else 'Timeout'} in {steps_taken} steps.")
    
    print("\n--- 'No Pretrain' Experiment Simulation Complete --- ")
    summary_for_run_dict = {}
    for decoder_name, phase_data in phased_reach_results.items():
        summary_for_run_dict[decoder_name] = {}
        for phase_name_key, steps_list_val in phase_data.items():
            if steps_list_val:
                successful_reaches = [s for s in steps_list_val if s < MAX_REACH_STEPS_PER_ATTEMPT_GLOBAL]
                num_total_attempted = len(steps_list_val)
                s_rate_val = (len(successful_reaches) / num_total_attempted * 100) if num_total_attempted > 0 else 0.0
                mean_s_steps_val = np.mean(successful_reaches) if successful_reaches else float('nan')
                summary_for_run_dict[decoder_name][phase_name_key] = f"{mean_s_steps_val:.1f} steps / {s_rate_val:.1f}% success"
                print(f"  {decoder_name} - {phase_name_key}: Avg Steps/Success={mean_s_steps_val:.1f}, Success Rate={s_rate_val:.1f}% ({len(successful_reaches)}/{num_total_attempted})")
            else:
                summary_for_run_dict[decoder_name][phase_name_key] = "N/A / 0.0%"
    
    print(f"--- 'No Pretrain' Comparison with SEED: {seed_value} Finished ---")
    return phased_reach_results, summary_for_run_dict

if __name__ == "__main__":
    default_seed = 42 
    print(f"Executing decoder_comparison_nopretrain.py as main script with seed: {default_seed}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # single_run_phased_reach_times, _ = run_single_comparison(seed_value=default_seed)
    
    # After getting results from run_single_comparison
    reach_times, summary_dict_for_run = run_single_comparison(seed_value=default_seed)
    print(f"DEBUG: Collection phase data present: {'ONLINE_TRAINING_COLLECTION' in reach_times.get('KF', {})}")
    print(f"DEBUG: Phase data: {[phase for decoder in reach_times for phase in reach_times[decoder]]}")
    
    if reach_times:
        plot_phased_reach_time_comparison(
            reach_times, 
            base_filename=f"nopretrain_reach_time_seed{default_seed}.png",
            num_training_reaches=30 
        )
    
    print("\n'No Pretrain' comparison script finished (standalone run).") 