import torch
import torch.nn as nn
import snntorch as snn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from snntorch import surrogate
import os
from torch.cuda.amp import autocast, GradScaler
import random
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from scipy.signal import butter, sosfiltfilt
import numpy as np
import h5py
import warnings
import pandas as pd
import time

def compute_correlation(pred, target):
    """
    Compute Pearson correlation between predicted and target values.
    Handles both tensor and numpy inputs. Flattens sequences.
    """
    if torch.is_tensor(pred):
        pred = pred.detach().cpu().numpy()
    if torch.is_tensor(target):
        target = target.detach().cpu().numpy()

    # Flatten sequences: (batch, seq_len, features) -> (batch * seq_len, features)
    if pred.ndim == 3:
        pred = pred.reshape(-1, pred.shape[-1])
    if target.ndim == 3:
        target = target.reshape(-1, target.shape[-1])

    if pred.shape[0] < 2:
        return np.zeros(pred.shape[1])

    # Calculate correlation per dimension
    num_dims = pred.shape[1]
    corrs = []
    for dim in range(num_dims):
        pred_dim = pred[:, dim]
        target_dim = target[:, dim]

        if np.std(pred_dim) < 1e-6 or np.std(target_dim) < 1e-6:
            corrs.append(0.0)
            continue
        
        corr_matrix = np.corrcoef(pred_dim, target_dim)
        corrs.append(corr_matrix[0, 1])

    return np.array(corrs)


def set_seed(seed):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def load_data(
    mat_file_path,
    bin_width_s=0.064,
    spike_processing='binary'
):
    """
    Loads and processes data from a single Indy .mat file.
    - Correctly filters cursor position and calculates velocity.
    - Bins spikes and velocity.
    - Returns processed tensors for use in within-session evaluation.
    """
    with h5py.File(mat_file_path, 'r') as f:
        t_vec = f['t'][()]
        if t_vec.ndim == 1: t_vec = t_vec[:, None]
        elif t_vec.shape[0] < t_vec.shape[1]: t_vec = t_vec.T
        
        if not np.all(np.diff(t_vec.squeeze()) > 0):
            raise ValueError("Time vector 't' is not monotonically increasing.")

        fs = 1.0 / np.mean(np.diff(t_vec.squeeze()))

        ds = f['spikes']
        if h5py.check_dtype(ref=ds.dtype) is not None:
            refs = ds[()]
            if refs.ndim == 2 and refs.shape[1] != 1: n_units, get_ref = refs.shape[1], lambda i: refs[0,i]
            elif refs.ndim == 2 and refs.shape[0] != 1: n_units, get_ref = refs.shape[0], lambda i: refs[i,0]
            else: refs = refs.ravel(); n_units, get_ref = refs.shape[0], lambda i: refs[i]
            T_orig = t_vec.shape[0]
            spikes = np.zeros((T_orig, n_units), dtype=np.float32)
            for i in range(n_units):
                r = get_ref(i)
                if isinstance(r, h5py.Reference):
                    times = f[r][()].squeeze()
                    idx = np.searchsorted(t_vec.squeeze(), times, 'left')
                    idx = idx[idx < T_orig]
                    spikes[np.unique(idx), i] = 1.0
        else:
            spikes = ds[()]
            if spikes.shape[0] < spikes.shape[1]: spikes = spikes.T
            spikes = spikes.astype(np.float32)

        cp = f['cursor_pos'][()]
        if cp.shape[0] < cp.shape[1]: cp = cp.T
        
        # Correctly filter and calculate velocity
        sos = butter(4, 10, 'low', fs=fs, output='sos')
        cp_filtered = sosfiltfilt(sos, cp, axis=0)
        vel = np.gradient(cp_filtered, t_vec.squeeze(), axis=0)

    bin_width = int(round(bin_width_s * fs))
    actual_bin_width_s = bin_width / fs
    stride = bin_width # non-overlapping
    num_bins = (spikes.shape[0] - bin_width) // stride + 1

    if num_bins <= 1:
        warnings.warn(f"Insufficient data for binning (num_bins={num_bins}) in {mat_file_path}", RuntimeWarning)
        return {'error': 'Insufficient data for binning'}

    # More efficient binning using reshaping
    spike_windows = np.lib.stride_tricks.as_strided(
        spikes,
        shape=(num_bins, bin_width, spikes.shape[1]),
        strides=(spikes.strides[0] * stride, spikes.strides[0], spikes.strides[1])
    )
    spikes_b_raw = spike_windows.sum(axis=1).astype(np.float32)

    vel_windows = np.lib.stride_tricks.as_strided(
        vel,
        shape=(num_bins, bin_width, vel.shape[1]),
        strides=(vel.strides[0] * stride, vel.strides[0], vel.strides[1])
    )
    vel_b_raw = vel_windows.mean(axis=1).astype(np.float32)

    if spike_processing == 'binary':
        if not np.all(np.isin(spikes_b_raw, [0, 1])):
            warnings.warn("Spike data is not binary, but 'binary' processing was selected. Coercing to binary.", UserWarning)
        X_processed = (spikes_b_raw > 0).astype(np.float32)
    elif spike_processing == 'count':
        X_processed = spikes_b_raw
    elif spike_processing == 'rate':
        X_processed = (spikes_b_raw / actual_bin_width_s).astype(np.float32)
    else:
        raise ValueError(f"Unknown spike_processing: {spike_processing}")

    return {
        'X_processed': torch.tensor(X_processed, dtype=torch.float32),
        'y_raw': torch.tensor(vel_b_raw, dtype=torch.float32),
        'bin_width_s': bin_width_s
    }


class SNNRegression(nn.Module):
    def __init__(self, input_size, hidden_size=128, output_size=2):
        super(SNNRegression, self).__init__()
        spike_grad = surrogate.fast_sigmoid()

        # Feedforward & Recurrent Layers - EXACT match with ZENODO_SCRIPT2.py
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc_rec = nn.Linear(hidden_size, hidden_size, bias=False)
        self.lif1 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)  # FIXED: 0.9 → 0.7

        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.lif2 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)  # FIXED: 0.9 → 0.7

        self.fc3 = nn.Linear(hidden_size // 2, output_size)
        # FIXED: Add missing lif3 layer to match online version exactly
        self.lif3 = snn.Leaky(
            beta=0.5,
            spike_grad=spike_grad,
            init_hidden=False,
            threshold=1.0,
            reset_mechanism="none"
        )
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        # FIXED: Match exact weight initialization from ZENODO_SCRIPT2.py
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                # give fc1 / fc2 a small negative bias; keep others at 0
                if module in (self.fc1, self.fc2):
                    nn.init.constant_(module.bias, -0.1)  # FIXED: 0.0 → -0.1
                else:
                    nn.init.zeros_(module.bias)
        
    def forward(self, x):
        """
        FIXED: Forward pass that matches ZENODO_SCRIPT2.py exactly
        """
        batch_size = x.size(0)
        # Initialize hidden states at the beginning of each sequence
        spk1_rec = torch.zeros(batch_size, self.fc1.out_features, device=x.device)
        mem1 = torch.zeros(batch_size, self.fc1.out_features, device=x.device)
        mem2 = torch.zeros(batch_size, self.fc2.out_features, device=x.device)
        mem3 = torch.zeros(batch_size, self.fc3.out_features, device=x.device)  # FIXED: Add mem3

        outputs = []
        for t in range(x.size(1)):
            inp = x[:, t, :]
            
            # FIXED: Add state detachment to match online version exactly
            mem1 = mem1.detach()
            mem2 = mem2.detach()
            spk1_rec = spk1_rec.detach()
            mem3 = mem3.detach()
            
            # Combine feedforward and recurrent input for the first layer
            cur1 = self.fc1(inp) + self.fc_rec(spk1_rec)
            spk1, mem1 = self.lif1(cur1, mem1)

            # Update the recurrent state for the next time step
            spk1_rec = spk1
            
            # Second hidden layer
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            # FIXED: Output layer - add lif3 layer like online version
            cur3 = self.fc3(spk2)
            out, mem3 = self.lif3(cur3, mem3)  # FIXED: Process through lif3
            outputs.append(mem3)  # FIXED: Append mem3 instead of direct out

        return torch.stack(outputs, dim=1)



class SpikeDataset(Dataset):
    def __init__(self, spike_data, targets, sequence_length=100):
        self.spike_data = spike_data
        self.targets = targets
        self.sequence_length = sequence_length
        self.num_sequences = spike_data.shape[0] - sequence_length + 1

    def __len__(self):
        return self.num_sequences

    def __getitem__(self, idx):
        return (
            self.spike_data[idx:idx + self.sequence_length],
            self.targets[idx:idx + self.sequence_length],
        )

def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_outputs = []
    all_targets = []
    
    with torch.no_grad():
        for batch_spike, batch_target in data_loader:
            batch_spike, batch_target = batch_spike.to(device), batch_target.to(device)
            with autocast(enabled=(device.type == 'cuda')):
                outputs = model(batch_spike)
                loss = criterion(outputs, batch_target)

            total_loss += loss.item()
            all_outputs.append(outputs.cpu())
            all_targets.append(batch_target.cpu())

    avg_loss = total_loss / len(data_loader)
    all_outputs_tensor = torch.cat(all_outputs, dim=0)
    all_targets_tensor = torch.cat(all_targets, dim=0)
    
    correlations = compute_correlation(all_outputs_tensor, all_targets_tensor)
    x_corr = correlations[0] if correlations.size > 0 else np.nan
    y_corr = correlations[1] if correlations.size > 1 else np.nan

    return avg_loss, x_corr, y_corr

def train_snn_bptt(train_dataset, val_dataset, input_size, epochs=50, batch_size=2048, checkpoint_path="snn_bptt_model.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on device: {device}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    model = SNNRegression(input_size=input_size, hidden_size=256).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3, verbose=True)
    criterion = nn.MSELoss()

    scaler = GradScaler(enabled=(device.type == 'cuda'))

    best_val_corr = -1.0
    epochs_no_improve = 0
    patience = 7

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for batch_spike, batch_target in train_loader:
            batch_spike, batch_target = batch_spike.to(device), batch_target.to(device)
            optimizer.zero_grad()
            
            with autocast(enabled=(device.type == 'cuda')):
                outputs = model(batch_spike)
                loss = criterion(outputs, batch_target)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += loss.item()

        avg_train_loss = epoch_loss / len(train_loader)
        val_loss, val_x_corr, val_y_corr = evaluate_model(model, val_loader, criterion, device)
        avg_val_corr = (val_x_corr + val_y_corr) / 2.0
        
        scheduler.step(val_loss)

        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Avg Corr: {avg_val_corr:.4f}")

        if avg_val_corr > best_val_corr:
            best_val_corr = avg_val_corr
            torch.save(model.state_dict(), checkpoint_path)
            print(f"  New best model saved with Val Avg Corr: {best_val_corr:.4f}")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {patience} epochs with no improvement.")
            break

def run_within_session_bptt_evaluation(num_sessions=5, snn_epochs=50, sequence_length=25, batch_size=512):
    print(f"\n===== Running Within-Session BPTT SNN Evaluation ({num_sessions} Sessions) =====")
    basepath = os.path.expanduser('~/scratch/zenodo_dataset')
    train_split_ratio = 0.8
    val_split_ratio = 0.15 # as a fraction of the *non-test* data
    neurons_to_use = 96

    all_results = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    try:
        all_mat_files = [f for f in os.listdir(basepath) if f.endswith('.mat')]
        if len(all_mat_files) < num_sessions:
            print(f"Error: Found only {len(all_mat_files)} files, need at least {num_sessions}.")
            return
        selected_files = random.sample(all_mat_files, num_sessions)
    except FileNotFoundError:
        print(f"Error: Basepath not found: {basepath}")
        return
    except Exception as e:
        print(f"Error listing files: {e}")
        return

    for session_idx, filename in enumerate(selected_files):
        print(f"\n--- Session {session_idx + 1}/{num_sessions} ({filename}) ---")
        
        try:
            data = load_data(
                mat_file_path=os.path.join(basepath, filename),
                spike_processing='count',
            )
            if 'error' in data:
                print(f"  Skipping {filename} due to load error: {data['error']}")
                continue
            
            X_full_spikes = data.get('X_processed')[:, :neurons_to_use].numpy()
            y_full_raw = data.get('y_raw').numpy()
            input_size = X_full_spikes.shape[1]
            
        except Exception as e:
            print(f"  Error loading session file {filename}: {e}")
            continue

        lag_s = 0.1
        bin_width_s = data['bin_width_s']
        lag_bins = int(round(lag_s / bin_width_s))
        
        X_lagged_spikes = X_full_spikes[:-lag_bins]
        y_lagged = y_full_raw[lag_bins:]

        # Split into train/test sets first
        X_train_val_raw, X_test_raw, y_train_val_raw, y_test_raw = train_test_split(
            X_lagged_spikes, y_lagged, test_size=(1 - train_split_ratio), random_state=42, shuffle=False
        )

        # Split train_val into train and validation sets
        X_train_raw, X_val_raw, y_train_raw, y_val_raw = train_test_split(
            X_train_val_raw, y_train_val_raw, test_size=val_split_ratio, random_state=42, shuffle=False
        )
        
        # Prune neurons with zero variance in the training set
        variance = np.var(X_train_raw, axis=0)
        keep = variance > 1e-6  # This threshold is fine for counts too
        if np.sum(~keep) > 0:
            print(f"  Pruning {np.sum(~keep)}/{len(keep)} neurons with zero variance.")
        
        # FIXED: Use spike COUNTS directly (no scaling applied to match online)
        X_train_final = X_train_raw[:, keep]
        X_val_final = X_val_raw[:, keep]
        X_test_final = X_test_raw[:, keep]
        input_size = X_train_final.shape[1]

        # Normalize outputs (velocities) based on the training set ONLY
        velocity_scaler = StandardScaler().fit(y_train_raw)
        y_train_scaled = velocity_scaler.transform(y_train_raw)
        y_val_scaled = velocity_scaler.transform(y_val_raw)
        y_test_scaled = velocity_scaler.transform(y_test_raw)

        train_dataset = SpikeDataset(X_train_final, y_train_scaled, sequence_length)
        val_dataset = SpikeDataset(X_val_final, y_val_scaled, sequence_length)
        test_dataset = SpikeDataset(X_test_final, y_test_scaled, sequence_length)
        
        print(f"  Datasets created: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")
        if len(test_dataset) == 0:
            print("  Skipping session due to empty test set.")
            continue
            
        checkpoint_path_session = f"snn_bptt_session_{filename.replace('.mat', '')}.pth"
        train_snn_bptt(
            train_dataset, val_dataset,
            input_size=input_size,
            epochs=snn_epochs,
            batch_size=batch_size,
            checkpoint_path=checkpoint_path_session
        )
        
        print(f"\n--- Evaluating on Test Set for Session {session_idx+1} ---")
        model = SNNRegression(input_size=input_size, hidden_size=256).to(device)
        test_loss, test_x_corr, test_y_corr = np.nan, np.nan, np.nan
        
        if os.path.exists(checkpoint_path_session):
            print(f"  Loading best model from {checkpoint_path_session}")
            try:
                model.load_state_dict(torch.load(checkpoint_path_session, map_location=device))
                test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
                criterion = nn.MSELoss()
                test_loss, test_x_corr, test_y_corr = evaluate_model(model, test_loader, criterion, device)
                print(f"  Test Set Results: Loss={test_loss:.4f}, CorrX={test_x_corr:.4f}, CorrY={test_y_corr:.4f}")
            except Exception as e:
                print(f"  ERROR evaluating model: {e}")
        else:
            print(f"  ERROR: Checkpoint {checkpoint_path_session} not found for evaluation!")

        all_results.append({
            'session': filename,
            'test_loss': test_loss,
            'test_corr_x': test_x_corr,
            'test_corr_y': test_y_corr,
        })
    
    print("\n\n===== FINAL SNN BPTT RESULTS (WITHIN-SESSION) =====")
    if all_results:
        results_df = pd.DataFrame(all_results)
        results_df.to_csv("results_snn_bptt_within_session.csv", index=False)
        print("Full results saved to results_snn_bptt_within_session.csv")
        
        print("\n--- Results per Session ---")
        print(results_df)

        print("\n--- Average Performance ---")
        print(f"  Average Test Loss:        {results_df['test_loss'].mean():.4f} +/- {results_df['test_loss'].std():.4f}")
        print(f"  Average Test Correlation X: {results_df['test_corr_x'].mean():.4f} +/- {results_df['test_corr_x'].std():.4f}")
        print(f"  Average Test Correlation Y: {results_df['test_corr_y'].mean():.4f} +/- {results_df['test_corr_y'].std():.4f}")
    else:
        print("No sessions completed successfully.")
    print("======================================================")

if __name__ == "__main__":
    set_seed(42)
    run_within_session_bptt_evaluation(num_sessions=10, snn_epochs=50)