import torch
import torch.nn as nn
import snntorch as snn
from torch.utils.data import DataLoader
from snntorch import surrogate
import os
from torch.cuda.amp import autocast, GradScaler

from nlb_tools.nwb_interface import NWBDataset
import torch
from torch.utils.data import Dataset
import numpy as np
import os


def compute_correlation(pred, target):
    """
    Compute Pearson correlation between predicted and target values.
    Handles both tensor and numpy inputs. Flattens sequences.
    """
    if torch.is_tensor(pred):
        pred = pred.detach().cpu().numpy()
    if torch.is_tensor(target):
        target = target.detach().cpu().numpy()

    # Flatten sequences: (batch, seq_len, features) -> (batch * seq_len, features)
    if pred.ndim == 3:
        pred = pred.reshape(-1, pred.shape[-1])
    if target.ndim == 3:
        target = target.reshape(-1, target.shape[-1])

    # Handle case where input might already be flat (e.g., single sequence)
    if pred.ndim == 1:
        pred = pred.reshape(-1, 1)
    if target.ndim == 1:
        target = target.reshape(-1, 1)
        
    # Ensure target has the same number of features if pred has more than 1
    if pred.shape[1] > 1 and target.shape[1] == 1:
       # This case might indicate an issue, but handle gracefully for now
       # Assuming target should match pred's feature count if multi-dimensional
       # This might need adjustment based on specific use case (e.g., if target is always 2D velocity)
       # For now, we will assume target should be comparable feature-wise if pred is multi-feature
       # If target is meant to be compared against a specific dimension of pred, that needs explicit handling.
       # Let's calculate correlation for each dimension separately if pred has multiple features.
       pass # Proceed to calculate per dimension below

    # Check for sufficient data points
    if pred.shape[0] < 2:
        return 0.0 if pred.shape[1] <= 1 else np.zeros(pred.shape[1])


    # Calculate correlation per dimension
    num_dims = pred.shape[1]
    corrs = []
    for dim in range(num_dims):
        pred_dim = pred[:, dim]
        # Adjust target_dim selection based on target shape
        target_dim = target[:, dim] if target.shape[1] == num_dims else target[:, 0] # Default to first dim if target is single-dim

        # Check for zero variance in this dimension
        if np.std(pred_dim) < 1e-6 or np.std(target_dim) < 1e-6:
            corrs.append(0.0)
            continue

        try:
            corr_matrix = np.corrcoef(pred_dim, target_dim)
            if corr_matrix.size > 1:
                corrs.append(corr_matrix[0, 1])
            else:
                corrs.append(1.0 if np.allclose(pred_dim, target_dim) else 0.0) # Handle scalar case / identical arrays
        except (IndexError, ValueError):
            corrs.append(0.0)

    # Return the array/list of correlations for each dimension
    return np.array(corrs) if corrs else np.array([0.0])


from nlb_tools.nwb_interface import NWBDataset
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler


dataset = NWBDataset("~/scratch/000128/sub-Jenkins/", "*train", split_heldout=False)
dataset.resample(10)


dataset.smooth_spk(50, name='smth_50')


# %%
DATASET = dataset

# %%
from nlb_tools.nwb_interface import NWBDataset
import pandas as pd

from torch.utils.data import Dataset
import numpy as np
import os


# --- MODIFIED Signature: Add window/stride, remove resample_ms ---
def load_data(dataset_name="mc_maze", use_cached=True, 
              bin_width_ms=100, stride_ms=10, 
              spike_processing='count'):
# --- END MODIFIED ---
    """Load data from MC_MAZE, MC_RTT, or Area2_Bump dataset"""
    spike_rates = None
    hand_vel = None
    hand_vel_mean = None
    hand_vel_std = None

    # Define cache file based on dataset name
    cache_file = f"{dataset_name}_cached_data.pt"

    dataset = DATASET
    if dataset_name == "mc_maze":
        print("Processing MC_MAZE data from scratch...")

        # FIXED: Use IDENTICAL align ranges for both spike and velocity data
        trial_data = dataset.make_trial_data(
            align_field="move_onset_time",
            align_range=(-130, 370),  # FIXED: Same as online version
            allow_nans=False
        )
        lagged_trial_data = dataset.make_trial_data(
            align_field="move_onset_time",
            align_range=(-130, 370),  # FIXED: Was (-50, 450) - 80ms misalignment!
            allow_nans=False
        )

        # Intersect trial IDs to avoid mismatches
        common_ids = sorted(set(trial_data.trial_id.unique()) & set(lagged_trial_data.trial_id.unique()))

        spikes_list = []
        vel_list = []
        windowed_trial_id_list = [] # ADDED: List to store trial IDs for windowed data

        for tid in common_ids:
            td = trial_data[trial_data.trial_id == tid]
            ld = lagged_trial_data[lagged_trial_data.trial_id == tid]

            spikes = td.spikes.to_numpy()
            vel = ld.hand_vel.to_numpy()

            # --- ADDED: Sliding Window Logic (within trial) ---
            # Ensure minimum length for at least one window
            min_len_s = min(spikes.shape[0], vel.shape[0])
            if min_len_s * dataset.bin_width < bin_width_ms:
                # print(f"Skipping trial {tid}: too short for window.") # Optional: Verbose skip message
                continue # Skip trial if too short for even one window
                
            # Calculate window width and stride in steps based on original bin width
            bin_width_steps = int(round(bin_width_ms / dataset.bin_width))
            stride_steps = int(round(stride_ms / dataset.bin_width))
            
            # Determine number of bins for this trial
            T_trial = min_len_s
            num_bins_trial = (T_trial - bin_width_steps) // stride_steps + 1
            
            if num_bins_trial <= 0:
                # print(f"Skipping trial {tid}: not enough steps for bins.") # Optional: Verbose skip message
                continue # Skip if not enough steps for any bins
                
            # Apply sliding window and aggregate
            spike_windows_trial = np.stack([spikes[i*stride_steps : i*stride_steps+bin_width_steps]
                                            for i in range(num_bins_trial)])
            vel_windows_trial = np.stack([vel[i*stride_steps : i*stride_steps+bin_width_steps]
                                          for i in range(num_bins_trial)])
            
            spikes_b_trial = spike_windows_trial.sum(axis=1) # Sum spikes in window
            vel_b_trial = vel_windows_trial.mean(axis=1)    # Average velocity in window
            
            # Append windowed data (instead of raw trial data)
            spikes_list.append(spikes_b_trial)
            vel_list.append(vel_b_trial)
            # ADDED: Create and append trial ID array matching window count
            ids_for_trial_windows = np.full(num_bins_trial, tid, dtype=np.int64)
            windowed_trial_id_list.append(ids_for_trial_windows)
            # --- END ADDED ---
            # --- END Sliding Window Logic ---

        spike_rates = np.vstack(spikes_list).astype(np.float32)
        hand_vel = np.vstack(vel_list).astype(np.float32)
        # --- ADDED: Stack windowed trial IDs --- 
        windowed_trial_ids = np.concatenate(windowed_trial_id_list) # Use concatenate for 1D array
        # --- END ADDED ---

        # FIXED: Apply spike processing correctly
        print(f"Applying spike processing: '{spike_processing}'")
        
        # SANITY CHECK: Print before processing
        print(f"BEFORE processing - spike_rates sample (first 3 timesteps, first 10 neurons):")
        print(spike_rates[:3, :10])
        
        if spike_processing == 'binary':
            spike_rates = (spike_rates > 0).astype(np.float32)
            print("Applied BINARY processing: converted to 0s and 1s")
        elif spike_processing == 'count':
            # Keep raw counts - DO NOT CONVERT TO RATES!
            pass  # spike_rates already contains raw spike counts
            print("Applied COUNT processing: keeping raw spike counts")
        elif spike_processing == 'rate':
            # Convert counts to rates
            bin_width_s = bin_width_ms / 1000.0
            spike_rates = (spike_rates / bin_width_s).astype(np.float32)
            print(f"Applied RATE processing: converted to rates using bin_width={bin_width_s}s")
        else:
            print(f"Warning: Unknown spike_processing '{spike_processing}'. Using raw counts.")
            
        # SANITY CHECK: Print after processing  
        print(f"AFTER processing - spike_rates sample (first 3 timesteps, first 10 neurons):")
        print(spike_rates[:3, :10])
        print(f"spike_processing mode: {spike_processing}")
        print(f"Data type: {spike_rates.dtype}, Min: {spike_rates.min():.2f}, Max: {spike_rates.max():.2f}")
        
        # PRINT 5x10 SAMPLE LIKE MCMAZE_SCRIPT.py
        print(f"Spike COUNTS (for SNN) sample 5x10:")
        print(f" {spike_rates[:5, :10]}")
        
        # Also create a "rates" version for comparison (even though we don't use it for SNN)
        if spike_processing == 'count':
            rates_sample = spike_rates[:5, :10] / (bin_width_ms / 1000.0)
            print(f"Spike RATES (for KF/LSTM) sample 5x10:")
            print(f" {rates_sample}")
        else:
            print(f"Spike data (current mode: {spike_processing}) sample 5x10:")
            print(f" {spike_rates[:5, :10]}")

    elif dataset_name == "mc_rtt":
        print("Processing MC_RTT data from scratch...")
        # Load RTT dataset
        dataset_path = "000129/sub-Indy/"
        dataset = NWBDataset(dataset_path, "*train", split_heldout=False)
        
        # Smooth spikes with 50ms window
        dataset.smooth_spk(50, name="smth_50", ignore_nans=True)
        
        # Lag velocity by 120 ms relative to neural data (as in notebook)
        lag = 120
        lag_bins = int(round(lag / dataset.bin_width))
        
        # Get data and handle NaNs
        nans = dataset.data.finger_vel.x.isna().reset_index(drop=True)
        rates = dataset.data.spikes_smth_50[~nans.to_numpy() & ~nans.shift(-lag_bins, fill_value=True).to_numpy()]
        vel = dataset.data.finger_vel[~nans.to_numpy() & ~nans.shift(lag_bins, fill_value=True).to_numpy()]
        
        # Take only first 1000 timepoints for memory efficiency
        spike_rates = rates.iloc[:1000].to_numpy().astype(np.float32)
        hand_vel = vel.iloc[:1000].to_numpy().astype(np.float32)

    elif dataset_name == "area2_bump":
        print("Processing Area2_Bump data from scratch...")
        # Load Area2_Bump dataset
        dataset_path = "000127/sub-Han/"
        dataset = NWBDataset(dataset_path, "*train", split_heldout=False)
        
        # Optional resampling to reduce memory usage
        dataset.resample(5)
        
        # Smooth spikes with 50ms window
        dataset.smooth_spk(50, name="smth_50", ignore_nans=True)
        
        # All 16 conditions, in the format (ctr_hold_bump, cond_dir)
        unique_conditions = [(False, 0.0), (False, 45.0), (False, 90.0), (False, 135.0),
                           (False, 180.0), (False, 225.0), (False, 270.0), (False, 315.0),
                           (True, 0.0), (True, 45.0), (True, 90.0), (True, 135.0),
                           (True, 180.0), (True, 225.0), (True, 270.0), (True, 315.0)]
        
        # Loop through conditions and average
        rate_list = []
        vel_list = []
        
        for cond in unique_conditions:
            # Create boolean mask for trials in this condition
            cond_mask = (np.all(dataset.trial_info[['ctr_hold_bump', 'cond_dir']] == cond, axis=1)) & \
                       (dataset.trial_info.split != 'none')
            
            # Get trial data aligned to movement onset
            trial_data = dataset.make_trial_data(
                align_field='move_onset_time',
                align_range=(-100, 400),
                allow_nans=False,
                ignored_trials=~cond_mask
            )
            
            # Average across trials
            rate = trial_data.groupby('align_time').mean().spikes_smth_50.to_numpy()
            vel = trial_data.groupby('align_time').mean().hand_vel.to_numpy()
            
            rate_list.append(rate)
            vel_list.append(vel)
        
        # Stack all conditions
        spike_rates = np.vstack(rate_list).astype(np.float32)
        hand_vel = np.vstack(vel_list).astype(np.float32)

    elif dataset_name == "dmfc_rsg":
        print("Processing DMFC_RSG data from scratch...")
        dataset = NWBDataset("000130/sub-Haydn/", "*train", split_heldout=False)
        
        # # Optional resampling to reduce memory usage
        # dataset.resample(5)
        
        # Smooth spikes with 50ms window
        dataset.smooth_spk(50, name="smth_50", ignore_nans=True)
        
        # The 10 timing conditions, in the format (is_short, ts)
        unique_conditions = [(True, 480.0), (True, 560.0), (True, 640.0), (True, 720.0), (True, 800.0),  # short prior
                            (False, 800.0), (False, 900.0), (False, 1000.0), (False, 1100.0), (False, 1200.0)]  # long prior
        
        # Loop through conditions, averaging smoothed spikes
        rate_list = []
        target_list = []
        
        for cond in unique_conditions:
            # Create boolean mask for trials in this condition
            cond_mask = (np.all(dataset.trial_info[['is_short', 'ts']] == cond, axis=1)) & \
                        (dataset.trial_info.split != 'none') & (~dataset.trial_info.is_outlier)
            
            # Get trial data
            trial_data = dataset.make_trial_data(
                start_field='ready_time',
                end_field='set_time',
                allow_nans=True,
                ignored_trials=~cond_mask
            )
            
            # Average across trials
            rate = trial_data.groupby('align_time').mean().spikes_smth_50.to_numpy()
            rate_list.append(rate)
            
            # Create ramping signals that increase from 0 to the target value
            # This is more compatible with velocity prediction models
            trial_length = len(rate)
            time_value = cond[1]
            
            # Create a ramp from 0 to time_value for X dimension (increasing)
            ramp_x = np.linspace(0, time_value, trial_length)
            
            # Create a ramp from time_value to 0 for Y dimension (decreasing)
            ramp_y = np.linspace(time_value, 0, trial_length)
            
            # Combine both ramps into target values
            ramps = np.column_stack((ramp_x, ramp_y))
            target_list.append(ramps)
        
        # Stack all conditions
        spike_rates = np.vstack(rate_list).astype(np.float32)
        hand_vel = np.vstack(target_list).astype(np.float32)

    print(f"Spike rates shape: {spike_rates.shape}")
    print(f"Hand velocity shape: {hand_vel.shape}")

    # Cache the RAW data before normalization
    # if use_cached and dataset_name != "dmfc_rsg":
    #     torch.save({
    #         'spike_rates': spike_rates,
    #         'hand_vel': hand_vel
    #     }, cache_file)
    #     print(f"Cached raw data to {cache_file}")

    # Normalize spike rates
    # spike_rates = (spike_rates - spike_rates.mean(axis=0)) / (spike_rates.std(axis=0) + 1e-6)
    
    # # Normalize all velocity data consistently
    # hand_vel_mean = hand_vel.mean(axis=0)
    # hand_vel_std = hand_vel.std(axis=0) + 1e-6
    # hand_vel = (hand_vel - hand_vel_mean) / hand_vel_std

    # --- MODIFIED: Return windowed_trial_ids --- 
    return spike_rates, hand_vel, windowed_trial_ids, hand_vel_mean, hand_vel_std, trial_data
    # --- END MODIFIED ---



class SpikeDataset(Dataset):
    def __init__(self, spike_rates, targets, sequence_length=100):
        self.spike_rates = torch.tensor(spike_rates, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.sequence_length = sequence_length
        self.num_sequences = len(self.spike_rates) - sequence_length + 1

    def __len__(self):
        return self.num_sequences

    def __getitem__(self, idx):
        return (
            self.spike_rates[idx:idx + self.sequence_length],
            self.targets[idx:idx + self.sequence_length],
        )

def load_dmfc_rsg():
    """Load DMFC_RSG dataset following the notebook exactly"""
    dataset = NWBDataset("000130/sub-Haydn/", "*train", split_heldout=False)
    
    # Optional resampling to reduce memory usage
    dataset.resample(5)
    
    # Smooth spikes with 50ms window
    dataset.smooth_spk(50, name="smth_50", ignore_nans=True)
    
    # The 10 timing conditions, in the format (is_short, ts)
    unique_conditions = [(True, 480.0), (True, 560.0), (True, 640.0), (True, 720.0), (True, 800.0),  # short prior
                        (False, 800.0), (False, 900.0), (False, 1000.0), (False, 1100.0), (False, 1200.0)]  # long prior
    
    # Loop through conditions, averaging smoothed spikes
    rate_list = []
    target_list = []
    for cond in unique_conditions:
        # Create boolean mask for trials in this condition
        cond_mask = (np.all(dataset.trial_info[['is_short', 'ts']] == cond, axis=1)) & \
                    (dataset.trial_info.split != 'none') & (~dataset.trial_info.is_outlier)
        
        # Get trial data
        trial_data = dataset.make_trial_data(
            start_field='ready_time',
            end_field='set_time',
            allow_nans=True,
            ignored_trials=~cond_mask
        )
        
        # Average across trials
        rate = trial_data.groupby('align_time').mean().spikes_smth_50.to_numpy()
        rate_list.append(rate)
        target_list.extend([[cond[1], cond[1]] for _ in range(len(rate))])
    
    # Stack all conditions
    spike_rates = np.vstack(rate_list).astype(np.float32)
    target = np.array(target_list).astype(np.float32)
    
    # Normalize spike rates
    spike_rates = (spike_rates - spike_rates.mean(axis=0)) / (spike_rates.std(axis=0) + 1e-6)
    
    # Normalize targets
    target_mean = target.mean(axis=0)
    target_std = target.std(axis=0) + 1e-6
    target = (target - target_mean) / target_std
    
    return spike_rates, target, target_mean, target_std


def evaluate_model(model, data_loader, criterion, device):
    """Evaluates the model on the given data loader.

    Args:
        model: The SNN model to evaluate.
        data_loader: DataLoader for the validation or test set.
        criterion: The loss function (e.g., nn.MSELoss).
        device: The device to run evaluation on ('cuda' or 'cpu').

    Returns:
        Tuple: (average_loss, x_corr, y_corr)
    """
    model.eval()  # Set model to evaluation mode
    total_loss = 0.0
    all_outputs = []
    all_targets = []
    
    with torch.no_grad():
        for batch_spike, batch_target in data_loader:
            batch_spike, batch_target = batch_spike.to(device), batch_target.to(device)
            
            # No autocast needed for evaluation if not memory constrained, but can be added
            outputs = model(batch_spike)
            loss = criterion(outputs, batch_target)
            total_loss += loss.item()
            
            all_outputs.append(outputs.cpu()) # Store outputs on CPU
            all_targets.append(batch_target.cpu()) # Store targets on CPU

    avg_loss = total_loss / len(data_loader)
    
    # Concatenate all outputs and targets from the batches
    all_outputs_tensor = torch.cat(all_outputs, dim=0)
    all_targets_tensor = torch.cat(all_targets, dim=0)
    
    # Calculate overall correlation
    correlations = compute_correlation(all_outputs_tensor, all_targets_tensor)

    # Extract X and Y correlations (assuming 2D output)
    x_corr = correlations[0] if correlations.size > 0 else np.nan
    y_corr = correlations[1] if correlations.size > 1 else np.nan

    return avg_loss, x_corr, y_corr


class SNNRegression(nn.Module):
    def __init__(self, input_size, hidden_size=1024, output_size=2):
        super(SNNRegression, self).__init__()
        spike_grad = surrogate.fast_sigmoid()

        # Feedforward & Recurrent Layers - EXACT match with MCMAZE_SCRIPT.py
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc_rec = nn.Linear(hidden_size, hidden_size, bias=False)
        self.lif1 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)

        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.lif2 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)

        self.fc3 = nn.Linear(hidden_size // 2, output_size)
        self.lif3 = snn.Leaky(
            beta=0.5,  # FIXED: Match beta=0.5 (was 0.7)
            spike_grad=spike_grad,
            init_hidden=False,
            threshold=1.0,  # FIXED: Match threshold=1.0 (was 10000)
            reset_mechanism="none"
        )
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        # FIXED: Match exact weight initialization from MCMAZE_SCRIPT.py
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                # give fc1 / fc2 a small negative bias; keep others at 0
                if module in (self.fc1, self.fc2):
                    nn.init.constant_(module.bias, -0.1)  # FIXED: Match -0.1 bias
                else:
                    nn.init.zeros_(module.bias)

    def forward(self, x):
        """
        BPTT-compatible forward pass that matches MCMAZE_SCRIPT.py exactly.
        """
        batch_size = x.size(0)
        device = x.device

        # Initialize hidden states at the beginning of each sequence
        spk1_rec = torch.zeros(batch_size, self.fc1.out_features, device=device)
        mem1 = torch.zeros(batch_size, self.fc1.out_features, device=device)
        mem2 = torch.zeros(batch_size, self.fc2.out_features, device=device)
        mem3 = torch.zeros(batch_size, self.fc3.out_features, device=device)

        outputs = []
        for t in range(x.size(1)):
            inp = x[:, t, :]
            
            # FIXED: Add state detachment to match online version exactly
            mem1 = mem1.detach()
            mem2 = mem2.detach()
            spk1_rec = spk1_rec.detach()
            mem3 = mem3.detach()
            
            # Combine feedforward and recurrent input for the first layer
            cur1 = self.fc1(inp) + self.fc_rec(spk1_rec)
            spk1, mem1 = self.lif1(cur1, mem1)

            # Update the recurrent state for the next time step
            spk1_rec = spk1
            
            # Second hidden layer
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            # Output layer - FIXED: Use same variable names and logic as online
            cur3 = self.fc3(spk2)
            out, mem3 = self.lif3(cur3, mem3)  # FIXED: Use 'out' variable like online
            outputs.append(mem3)  # Both versions append mem3

        return torch.stack(outputs, dim=1)


import os

def train_single_gpu(train_dataset, val_dataset, input_size, epochs=200, batch_size=32768, checkpoint_path="snn_single_gpu_results.pth"):
    # 1) Set up device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)



    # 2) Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )

    # 3) Instantiate Model, Loss, Optimizer
    model = SNNRegression(input_size=input_size, hidden_size=1024).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    scaler = GradScaler()  # Use GradScaler for mixed precision
    criterion = nn.MSELoss()

    # 4) Check if checkpoint exists and load state
    start_epoch = 0
    train_losses = []
    # if os.path.exists(checkpoint_path):
    #     print(f"Checkpoint found at {checkpoint_path}. Resuming training...")
    #     checkpoint = torch.load(checkpoint_path, map_location=device)
    #     model.load_state_dict(checkpoint["model_state_dict"])
    #     optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    #     train_losses = checkpoint["train_losses"]
    #     start_epoch = len(train_losses)  # Resume from the next epoch
    #     print(f"Resuming from epoch {start_epoch + 1}/{epochs}...")
    # else:
    #     print("No checkpoint found. Starting training from scratch.")

    # 5) Training Loop
    try:
        for epoch in range(start_epoch, epochs):
            model.train()
            epoch_loss = 0.0

            for batch_spike, batch_target in train_loader:
                # Move data to GPU
                batch_spike, batch_target = batch_spike.to(device), batch_target.to(device)


                optimizer.zero_grad()
                with autocast():
                    outputs = model(batch_spike)
                    loss = criterion(outputs, batch_target)
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                    epoch_loss += loss.item()

            avg_train_loss = epoch_loss / len(train_loader)
            train_losses.append(avg_train_loss)

            # Optionally evaluate on validation set
            model.eval()
            val_loss, val_x_corr, val_y_corr = evaluate_model(model, val_loader, criterion, device)

            print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Corr X: {val_x_corr:.4f}, Val Corr Y: {val_y_corr:.4f}")

    except KeyboardInterrupt:
        print("\nTraining interrupted by user. Saving progress...")

    finally:
        # Save Model
        torch.save({
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "train_losses": train_losses,
        }, checkpoint_path)
        print(f"Model saved at {checkpoint_path}. Exiting training loop.")


if __name__ == "__main__":
    if DATASET is None:
        print("Exiting because global dataset could not be loaded.")
        exit()

    # --- Load the giant dataset once ---
    # Define parameters for load_data
    bin_width_ms = 100     # Window size for aggregation (e.g., 100ms)
    stride_ms = 10       # Stride for window (e.g., 10ms)
    spike_proc = 'count'  # FIXED: Use 'count' spike processing to match online version exactly

    print(f"Loading data with bin_width={bin_width_ms}ms, stride={stride_ms}ms, spike_processing='{spike_proc}'")
    firing_rates, hand_vel, windowed_trial_ids, _, _, trial_data = load_data(
        "mc_maze",
        use_cached=False,
        bin_width_ms=bin_width_ms,
        stride_ms=stride_ms,
        spike_processing=spike_proc  # FIXED: Now uses 'count' like online
    )
    print("Raw data loaded and windowed.")

    # Get unique trial indices from trial_data (assumed to be a MultiIndex or similar)
    # Ensure trial_data has the expected structure from load_data
    try:
        # Adjust based on actual structure of trial_data if it's not MultiIndex
        if isinstance(trial_data.index, pd.MultiIndex):
             unique_trials = np.unique(trial_data.index.get_level_values(0))
        else:
             # Fallback if it's not multi-index, maybe trial_id column?
             unique_trials = np.unique(trial_data['trial_id']) # Adjust if needed
    except Exception as e:
        print(f"Error accessing trial IDs from trial_data: {e}")
        print("Ensure trial_data is returned correctly from load_data with trial identifiers.")
        exit()

    n_total = len(unique_trials)
    print(f"Total unique trials: {n_total}")

    # Compute the number of trials per split (70/15/15 of total)
    n_train = int(0.70 * n_total)
    n_val   = int(0.15 * n_total)
    n_test  = n_total - n_train - n_val  # Ensures the remainder is test
    # Handle cases where n_total is small, ensure at least 1 trial per set if possible
    if n_train == 0 and n_total > 0: n_train = 1
    if n_val == 0 and n_total > n_train: n_val = 1
    if n_test == 0 and n_total > (n_train + n_val): n_test = 1
    # Recalculate n_train if adjustments were made or if rounding caused mismatch
    n_train = n_total - n_val - n_test
    print(f"Splitting each random permutation into: {n_train} train, {n_val} validation, {n_test} test trials")

    # Prepare list to record results from each split
    test_results_list = []
    # Prepare list to record Kalman Filter results from each split (if applicable)
    # kf_results = []

    # --- LOOP OVER MULTIPLE RANDOM SPLITS ---
    num_splits = 10
    for split in range(num_splits):
        print(f"\n--- Split {split+1}/{num_splits} ---")
        # Create a new random permutation of the unique trial IDs
        permuted = np.random.permutation(unique_trials)
        train_trials = permuted[:n_train]
        val_trials   = permuted[n_train:n_train+n_val]
        test_trials  = permuted[n_train+n_val:]

        # --- Create Masks based on windowed trial IDs ---
        train_mask = np.isin(windowed_trial_ids, train_trials)
        val_mask   = np.isin(windowed_trial_ids, val_trials)
        test_mask  = np.isin(windowed_trial_ids, test_trials)

        # --- Apply Masks to Get Data Subsets (Unnormalized) ---
        train_firing_rates_raw = firing_rates[train_mask]
        train_targets_raw      = hand_vel[train_mask]
        val_firing_rates_raw   = firing_rates[val_mask]
        val_targets_raw        = hand_vel[val_mask]
        test_firing_rates_raw  = firing_rates[test_mask]
        test_targets_raw       = hand_vel[test_mask]

        # --- Calculate Normalization Statistics from Training Set ONLY ---
        print("Calculating normalization statistics from training set...")
        # Normalize velocity targets
        hand_vel_mean_local = train_targets_raw.mean(axis=0)
        hand_vel_std_local  = train_targets_raw.std(axis=0) + 1e-6
        print(f"Vel Mean (local): {hand_vel_mean_local}, Vel Std (local): {hand_vel_std_local}")

        # --- Normalize All Sets Using Training Stats ---
        # Keep firing rates as they are (e.g., binary {0,1})
        train_firing_rates = train_firing_rates_raw
        val_firing_rates   = val_firing_rates_raw
        test_firing_rates  = test_firing_rates_raw

        # Normalize targets
        train_targets      = (train_targets_raw - hand_vel_mean_local) / hand_vel_std_local
        val_targets        = (val_targets_raw - hand_vel_mean_local) / hand_vel_std_local
        test_targets       = (test_targets_raw - hand_vel_mean_local) / hand_vel_std_local

        # --- Report Dataset Sizes for this split ---
        print(f"\nSplit {split+1} Data Sizes:")
        print(f"  Train: {len(train_firing_rates)} samples")
        print(f"  Validation: {len(val_firing_rates)} samples")
        print(f"  Test: {len(test_firing_rates)} samples")

        # --- Create PyTorch Datasets for this split ---
        sequence_length = 50 # Define sequence length for SpikeDataset
        train_split_dataset = SpikeDataset(train_firing_rates, train_targets, sequence_length=sequence_length)
        val_split_dataset = SpikeDataset(val_firing_rates, val_targets, sequence_length=sequence_length)
        test_split_dataset = SpikeDataset(test_firing_rates, test_targets, sequence_length=sequence_length)

        # --- Train Model for this split ---
        print(f"\nTraining model for Split {split+1}...")
        checkpoint_path_split = f"snn_mcmaze_split_{split+1}.pth"
        train_single_gpu(
            train_split_dataset,
            val_split_dataset,
            input_size=train_firing_rates.shape[1],
            epochs=50,
            batch_size=32768,
            checkpoint_path=checkpoint_path_split
        )

        # --- Test Set Evaluation for this split ---
        print(f"\n--- Evaluating on Test Set for Split {split+1} ---")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load the trained model for this split
        model = SNNRegression(input_size=firing_rates.shape[1], hidden_size=1024).to(device)  # FIXED: Match hidden_size=1024
        if os.path.exists(checkpoint_path_split):
            print(f"Loading model from {checkpoint_path_split} for testing...")
            checkpoint = torch.load(checkpoint_path_split, map_location=device)
            model.load_state_dict(checkpoint["model_state_dict"])
        else:
            print(f"ERROR: Checkpoint {checkpoint_path_split} not found for testing Split {split+1}!")
            test_results_list.append({'split': split+1, 'loss': np.nan, 'corr_x': np.nan, 'corr_y': np.nan}) # Record failure
            continue # Skip to next split

        # Create Test DataLoader
        test_loader = DataLoader(
            test_split_dataset,
            batch_size=128, # Can use larger batch for evaluation
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=2
        )

        # Evaluate
        criterion = nn.MSELoss()
        test_loss, test_x_corr, test_y_corr = evaluate_model(model, test_loader, criterion, device)

        print(f"\nTest Set Results (Split {split+1}):")
        print(f"  Loss: {test_loss:.4f}")
        print(f"  Correlation X: {test_x_corr:.4f}")
        print(f"  Correlation Y: {test_y_corr:.4f}")
        print("--------------------------")

        # Store results
        test_results_list.append({
            'split': split+1,
            'loss': test_loss,
            'corr_x': test_x_corr,
            'corr_y': test_y_corr,
            'train_trials': len(train_trials),
            'val_trials': len(val_trials),
            'test_trials': len(test_trials),
            'train_samples': len(train_firing_rates),
            'val_samples': len(val_firing_rates),
            'test_samples': len(test_firing_rates)
        })

    # --- Aggregate and Report Final Results ---
    print("\n\n===== FINAL RESULTS ACROSS ALL SPLITS =====")
    if test_results_list:
        results_df = pd.DataFrame(test_results_list)
        print(results_df.round(4)) # Display results per split

        # Calculate and print averages, ignoring NaNs from failed splits
        avg_loss = results_df['loss'].mean(skipna=True)
        avg_corr_x = results_df['corr_x'].mean(skipna=True)
        avg_corr_y = results_df['corr_y'].mean(skipna=True)
        std_loss = results_df['loss'].std(skipna=True)
        std_corr_x = results_df['corr_x'].std(skipna=True)
        std_corr_y = results_df['corr_y'].std(skipna=True)

        print("\nAverage Performance:")
        print(f"  Average Test Loss: {avg_loss:.4f} +/- {std_loss:.4f}")
        print(f"  Average Test Correlation X: {avg_corr_x:.4f} +/- {std_corr_x:.4f}")
        print(f"  Average Test Correlation Y: {avg_corr_y:.4f} +/- {std_corr_y:.4f}")
    else:
        print("No splits completed successfully.")
    print("==========================================")
