# -*- coding: utf-8 -*-
"""
UltraLIF Training - Full benchmark on MNIST/Fashion/CIFAR10

Usage:
    python train.py --model ultralif --epochs 100
    python train.py --model all --epochs 50 --dataset mnist
    python train.py --model all --dataset mnist --track-spikes  # with spike counting
"""
import sys
import argparse
import logging
from pathlib import Path
from datetime import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm


# =============================================================================
# LOGGING SETUP
# =============================================================================

class TeeLogger:
    """Write to both console and file."""
    def __init__(self, log_path):
        self.terminal = sys.stdout
        self.log = open(log_path, 'w', encoding='utf-8')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.log.close()

sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from datasets import rate_encode, set_seed

# TPU support
try:
    import torch_xla.core.xla_model as xm
    TPU_AVAILABLE = True
except ImportError:
    TPU_AVAILABLE = False


# =============================================================================
# SPIKE COUNTING UTILITIES
# =============================================================================

def count_spikes_epoch(model, loader, device, timesteps, neuromorphic=False, dtype=torch.float32):
    """
    Count average spikes per neuron per timestep over a dataset.

    Works with SNN, DeepSNN, and ConvSNN by using the model's last_spike_rate
    computed during forward pass.

    Returns:
        spike_rate: Average spike rate (spikes per neuron per timestep)
        total_spikes: Total spikes across all samples (estimated)
    """
    model.train(False)
    spike_rates = []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)
            # Run forward pass - model computes last_spike_rate internally
            _ = model(x)
            if model.last_spike_rate is not None:
                # last_spike_rate is average spike rate for this batch
                rate_val = model.last_spike_rate
                if hasattr(rate_val, 'item'):
                    spike_rates.append(rate_val.item())
                else:
                    spike_rates.append(float(rate_val))

    spike_rate = sum(spike_rates) / len(spike_rates) if spike_rates else 0
    # Estimate total spikes (approximate)
    total_spikes = spike_rate * len(loader.dataset) * timesteps
    return spike_rate, total_spikes


def compute_energy_proxy(spike_rate, num_neurons, timesteps, ops_per_spike=1.0):
    """
    Compute energy proxy based on spike count.

    Energy ∝ #spikes × ops_per_spike
    Normalized relative to a baseline rate of 0.5
    """
    baseline_rate = 0.5
    energy = spike_rate / baseline_rate if baseline_rate > 0 else spike_rate
    return energy

class LIF(nn.Module):
    """
    Leaky Integrate-and-Fire with surrogate gradient (Neftci et al., 2019).

    Membrane dynamics:  v(t+1) = tau * v(t) + I(t)
    Spike:              z(t)   = sigma(beta * (v(t) - theta))
    Reset:              v(t)  *= (1 - z(t))   [soft reset]
    """
    def __init__(self, dim, tau=0.9, thresh=0.5, beta=10.0):
        super().__init__()
        self.dim, self.tau, self.thresh, self.beta = dim, tau, thresh, beta
        self.v = None
    def reset(self, b, d):
        self.v = torch.zeros(b, self.dim, device=d)
    def forward(self, x):
        self.v = self.tau * self.v + x                          # leak + integrate
        spike = torch.sigmoid(self.beta * (self.v - self.thresh))  # surrogate spike
        self.v = self.v * (1 - spike)                           # soft reset
        return spike

class PLIF(nn.Module):
    """
    Parametric LIF with learnable membrane time constant (Fang et al., ICCV 2021).

    Same as LIF but tau is a learnable parameter constrained to (0, 1) via sigmoid.
    """
    def __init__(self, dim, init_tau=0.9, thresh=0.5, beta=10.0):
        super().__init__()
        self.dim, self.thresh, self.beta = dim, thresh, beta
        self._tau = nn.Parameter(torch.tensor(init_tau))  # learnable tau
        self.v = None
    @property
    def tau(self):
        return torch.sigmoid(self._tau)
    def reset(self, b, d):
        self.v = torch.zeros(b, self.dim, device=d)
    def forward(self, x):
        self.v = self.tau * self.v + x                          # leak + integrate
        spike = torch.sigmoid(self.beta * (self.v - self.thresh))  # surrogate spike
        self.v = self.v * (1 - spike)                           # soft reset
        return spike

class AdaLIF(nn.Module):
    """
    Adaptive LIF (Bellec et al. NeurIPS 2018 - LSNN).

    The threshold DYNAMICALLY increases after each spike and decays back:
        B(t) = b0 + beta * b(t)
        b(t+1) = rho * b(t) + (1 - rho) * z(t)

    Where rho = exp(-dt/tau_adapt) controls adaptation decay.
    """
    def __init__(self, dim, tau=0.9, base_thresh=0.5, beta_adapt=0.1, tau_adapt=0.9, surrogate_beta=10.0):
        super().__init__()
        self.dim = dim
        self.tau = tau  # membrane time constant
        self.base_thresh = base_thresh  # b0 in paper
        self.beta_adapt = beta_adapt  # beta in paper (adaptation strength)
        self.rho = tau_adapt  # rho = exp(-dt/tau_a), controls adaptation decay
        self.surrogate_beta = surrogate_beta
        self.v = None
        self.b = None  # adaptation variable
    def reset(self, batch, device):
        self.v = torch.zeros(batch, self.dim, device=device)
        self.b = torch.zeros(batch, self.dim, device=device)
    def forward(self, x):
        # Membrane dynamics
        self.v = self.tau * self.v + x
        # Adaptive threshold: B(t) = b0 + beta * b(t)
        thresh = self.base_thresh + self.beta_adapt * self.b
        # Spike with surrogate gradient
        spike = torch.sigmoid(self.surrogate_beta * (self.v - thresh))
        # Update adaptation: b(t+1) = rho * b(t) + (1 - rho) * z(t)
        self.b = self.rho * self.b + (1 - self.rho) * spike
        # Soft reset
        self.v = self.v * (1 - spike)
        return spike

class FullPLIF(nn.Module):
    """
    Fully Parametric LIF with learnable tau and threshold.

    Extends PLIF (Fang et al., 2021) by also making the firing threshold
    a learnable parameter, both constrained to (0, 1) via sigmoid.
    """
    def __init__(self, dim, init_tau=0.9, init_thresh=0.5, beta=10.0):
        super().__init__()
        self.dim, self.beta = dim, beta
        self._tau = nn.Parameter(torch.tensor(init_tau))       # learnable tau
        self._thresh = nn.Parameter(torch.tensor(init_thresh)) # learnable threshold
        self.v = None
    @property
    def tau(self):
        return torch.sigmoid(self._tau)
    @property
    def thresh(self):
        return torch.sigmoid(self._thresh)
    def reset(self, b, d):
        self.v = torch.zeros(b, self.dim, device=d)
    def forward(self, x):
        self.v = self.tau * self.v + x                          # leak + integrate
        spike = torch.sigmoid(self.beta * (self.v - self.thresh))  # surrogate spike
        self.v = self.v * (1 - spike)                           # soft reset
        return spike

class UltraLIF(nn.Module):
    """UltraDLIF (spatial, 3-term LSE from diffusion PDE). Fixed tau."""
    def __init__(self, dim, tau=0.9, init_eps=1.0):
        super().__init__()
        self.dim, self.tau = dim, tau
        self._log_eps = nn.Parameter(torch.tensor(float(init_eps)).log())
        self.v = None
    @property
    def eps(self):
        return self._log_eps.exp().clamp(0.1, 20.0)
    def reset(self, b, d):
        self.v = torch.zeros(b, self.dim, device=d)
    def forward(self, x):
        self.v = self.tau * self.v + x
        # Spatial neighbors (diffusion PDE discretization)
        v_l, v_r = torch.roll(self.v, 1, -1), torch.roll(self.v, -1, -1)
        # 3-term LSE: (left neighbor, center, right neighbor)
        stack = torch.stack([v_l, self.v, v_r], -1)
        eps = self.eps
        m = stack.max(-1, keepdim=True).values
        v_max = m.squeeze(-1) + eps * torch.logsumexp((stack - m) / eps, -1)
        spike = torch.sigmoid(v_max / eps)
        self.v = self.v * (1 - spike)
        return spike

class UltraPLIF(nn.Module):
    """UltraDPLIF (spatial, 3-term LSE from diffusion PDE). Learnable tau."""
    def __init__(self, dim, init_tau=0.9, init_eps=1.0):
        super().__init__()
        self.dim = dim
        self._tau = nn.Parameter(torch.tensor(init_tau))
        self._log_eps = nn.Parameter(torch.tensor(float(init_eps)).log())
        self.v = None
    @property
    def tau(self):
        return torch.sigmoid(self._tau)
    @property
    def eps(self):
        return self._log_eps.exp().clamp(0.1, 20.0)
    def reset(self, b, d):
        self.v = torch.zeros(b, self.dim, device=d)
    def forward(self, x):
        self.v = self.tau * self.v + x
        v_l, v_r = torch.roll(self.v, 1, -1), torch.roll(self.v, -1, -1)
        stack = torch.stack([v_l, self.v, v_r], -1)
        eps = self.eps
        m = stack.max(-1, keepdim=True).values
        v_max = m.squeeze(-1) + eps * torch.logsumexp((stack - m) / eps, -1)
        spike = torch.sigmoid(v_max / eps)
        self.v = self.v * (1 - spike)
        return spike


class UltraTLIF(nn.Module):
    """
    UltraLIF (temporal, 2-term LSE from LIF ODE). Fixed tau.

    From ultradiscretization of Euler-discretized LIF:
        v^(t+1) = τ·v^(t) + I^(t)

    Ultradiscretization gives:
        V^(t+1) = LSE_ε(V^(t) + log τ, I^(t))

    LSE over 2 terms: (decayed membrane, input)
    """
    def __init__(self, dim, tau=0.9, init_eps=1.0):
        super().__init__()
        self.dim = dim
        self.tau = tau
        self.log_tau = float(torch.tensor(tau).log())
        self._log_eps = nn.Parameter(torch.tensor(float(init_eps)).log())
        self.v = None

    @property
    def eps(self):
        return self._log_eps.exp().clamp(0.1, 20.0)

    def reset(self, b, d):
        self.v = torch.zeros(b, self.dim, device=d)

    def forward(self, x):
        eps = self.eps

        # Stack the two terms: (V^(t) + log τ, I^(t))
        term1 = self.v + self.log_tau    # decayed membrane in log-space
        term2 = x                         # input
        stack = torch.stack([term1, term2], dim=-1)

        # LSE over the 2 terms (last dim)
        m = stack.max(-1, keepdim=True).values
        self.v = m.squeeze(-1) + eps * torch.logsumexp((stack - m) / eps, dim=-1)

        # Spike probability
        spike = torch.sigmoid(self.v / eps)

        # Soft reset
        self.v = self.v * (1 - spike)

        return spike


class UltraTPLIF(nn.Module):
    """
    UltraPLIF (temporal, 2-term LSE from LIF ODE). Learnable tau.

    From ultradiscretization of Euler-discretized LIF:
        V^(t+1) = LSE_ε(V^(t) + log τ, I^(t))
    """
    def __init__(self, dim, init_tau=0.9, init_eps=1.0):
        super().__init__()
        self.dim = dim
        self._log_tau = nn.Parameter(torch.tensor(float(init_tau)).log())
        self._log_eps = nn.Parameter(torch.tensor(float(init_eps)).log())
        self.v = None

    @property
    def tau(self):
        return self._log_tau.exp().clamp(0.01, 0.99)

    @property
    def eps(self):
        return self._log_eps.exp().clamp(0.1, 20.0)

    def reset(self, b, d):
        self.v = torch.zeros(b, self.dim, device=d)

    def forward(self, x):
        eps = self.eps
        log_tau = self._log_tau

        # Stack the two terms: (V^(t) + log τ, I^(t))
        term1 = self.v + log_tau    # decayed membrane in log-space
        term2 = x                    # input
        stack = torch.stack([term1, term2], dim=-1)

        # LSE over the 2 terms (last dim)
        m = stack.max(-1, keepdim=True).values
        self.v = m.squeeze(-1) + eps * torch.logsumexp((stack - m) / eps, dim=-1)

        # Spike probability
        spike = torch.sigmoid(self.v / eps)

        # Soft reset
        self.v = self.v * (1 - spike)

        return spike


def dspike_fn(x, b):
    """
    DSpike function (Li et al. NeurIPS 2021, Eq. 12).

    DSpike(x, b) = [tanh(b(x - 0.5)) + tanh(b/2)] / [2·tanh(b/2)],  for 0 ≤ x ≤ 1
    Clips to 0 for x < 0, clips to 1 for x > 1.

    Args:
        x: normalized membrane potential (ideally in [0, 1] range)
        b: temperature/sharpness parameter (learnable, b > 0)
    """
    # Compute tanh(b/2) for normalization
    tanh_b2 = torch.tanh(b / 2.0)
    # DSpike formula
    spike = (torch.tanh(b * (x - 0.5)) + tanh_b2) / (2.0 * tanh_b2 + 1e-8)
    # Clip to [0, 1]
    return spike.clamp(0.0, 1.0)


class DSpike(nn.Module):
    """
    Differentiable Spike (Li et al. NeurIPS 2021) baseline.

    Uses tanh-based smooth spike function with learnable temperature b.
    DSpike(x, b) = [tanh(b(x - 0.5)) + tanh(b/2)] / [2·tanh(b/2)]

    Key difference from UltraLIF:
    - DSpike: tanh-based smooth function, normalized to [0,1]
    - UltraLIF: sigmoid((v-θ)/ε) derived from ultradiscretization
    """
    def __init__(self, dim, tau=0.9, init_b=4.0, thresh=0.5):
        super().__init__()
        self.dim, self.tau, self.thresh = dim, tau, thresh
        # Learnable temperature (b > 0), using softplus to keep positive
        self._b = nn.Parameter(torch.tensor(float(init_b)))
        self.v = None

    @property
    def k(self):
        """Temperature parameter b, kept positive via softplus. Named 'k' for compatibility."""
        return torch.nn.functional.softplus(self._b).clamp(0.5, 50.0)

    def reset(self, b, d):
        self.v = torch.zeros(b, self.dim, device=d)

    def forward(self, x):
        self.v = self.tau * self.v + x
        # Normalize v to [0, 1] range for DSpike: x = (v - 0) / (2*thresh - 0) = v / (2*thresh)
        # This maps v ∈ [0, 2*thresh] to x ∈ [0, 1], with spike at x=0.5 (v=thresh)
        x_norm = self.v / (2.0 * self.thresh)
        spike = dspike_fn(x_norm, self.k)
        self.v = self.v * (1 - spike)  # Soft reset
        return spike


class DSpikePlus(nn.Module):
    """
    DSpike+: DSpike (Li et al., NeurIPS 2021) with learnable tau.

    Extends DSpike by making the membrane time constant learnable,
    providing a fair comparison to PLIF-family models.
    """
    def __init__(self, dim, init_tau=0.9, init_b=4.0, thresh=0.5):
        super().__init__()
        self.dim, self.thresh = dim, thresh
        self._tau = nn.Parameter(torch.tensor(init_tau))
        self._b = nn.Parameter(torch.tensor(float(init_b)))
        self.v = None

    @property
    def tau(self):
        return torch.sigmoid(self._tau)

    @property
    def k(self):
        """Temperature parameter b. Named 'k' for compatibility."""
        return torch.nn.functional.softplus(self._b).clamp(0.5, 50.0)

    def reset(self, b, d):
        self.v = torch.zeros(b, self.dim, device=d)

    def forward(self, x):
        self.v = self.tau * self.v + x
        x_norm = self.v / (2.0 * self.thresh)
        spike = dspike_fn(x_norm, self.k)
        self.v = self.v * (1 - spike)  # Soft reset
        return spike


class SNN(nn.Module):
    def __init__(self, neuron, in_dim, hid_dim, out_dim, timesteps=30, neuromorphic=False):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hid_dim)
        self.neuron = neuron
        self.fc2 = nn.Linear(hid_dim, out_dim)
        self.T = timesteps
        self.neuromorphic = neuromorphic
        self.last_spike_rate = None  # Track for sparsity penalty
    def forward(self, x):
        batch = x.shape[0]
        device, dtype = x.device, x.dtype
        if self.neuromorphic:
            # Neuromorphic: x is [batch, T, channels, H, W] or [batch, T, features]
            if x.dim() > 3:
                x = x.view(batch, x.shape[1], -1)  # [batch, T, features]
            spikes_in = x
            T = spikes_in.shape[1]
        else:
            # Static: x is [batch, C, H, W], need rate encoding
            x = x.view(batch, -1)
            spikes_in = rate_encode(x, self.T, gain=0.5)
            T = self.T
        self.neuron.reset(batch, device)
        out_sum = torch.zeros(batch, self.fc2.out_features, device=device, dtype=dtype)
        spike_sum = 0.0
        for t in range(T):
            h = self.fc1(spikes_in[:, t, :])
            spike = self.neuron(h)
            spike_sum = spike_sum + spike.mean()
            out_sum = out_sum + self.fc2(spike)  # Use + instead of += for torch.compile
        self.last_spike_rate = spike_sum / T
        return out_sum / T


class DeepSNN(nn.Module):
    """
    2-layer Spiking Neural Network for deeper feature learning.

    Architecture:
    Input -> fc1 -> neuron1 -> fc2 -> neuron2 -> fc3 -> Output

    Each layer has its own spiking neuron with independent learnable params.
    """
    def __init__(self, neuron1, neuron2, in_dim, hid_dim, out_dim, timesteps=30, neuromorphic=False):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hid_dim)
        self.neuron1 = neuron1
        self.fc2 = nn.Linear(hid_dim, hid_dim)
        self.neuron2 = neuron2
        self.fc3 = nn.Linear(hid_dim, out_dim)
        self.T = timesteps
        self.neuromorphic = neuromorphic
        self.last_spike_rate = None

    def forward(self, x):
        batch = x.shape[0]
        device, dtype = x.device, x.dtype
        if self.neuromorphic:
            if x.dim() > 3:
                x = x.view(batch, x.shape[1], -1)
            spikes_in = x
            T = spikes_in.shape[1]
        else:
            x = x.view(batch, -1)
            spikes_in = rate_encode(x, self.T, gain=0.5)
            T = self.T

        self.neuron1.reset(batch, device)
        self.neuron2.reset(batch, device)
        out_sum = torch.zeros(batch, self.fc3.out_features, device=device, dtype=dtype)
        spike_sum = 0.0

        for t in range(T):
            # Layer 1
            h1 = self.fc1(spikes_in[:, t, :])
            spike1 = self.neuron1(h1)

            # Layer 2
            h2 = self.fc2(spike1)
            spike2 = self.neuron2(h2)

            # Track spikes from both layers
            spike_sum = spike_sum + (spike1.mean() + spike2.mean()) / 2
            out_sum = out_sum + self.fc3(spike2)

        self.last_spike_rate = spike_sum / T
        return out_sum / T


class ConvSNN(nn.Module):
    """
    Convolutional Spiking Neural Network for image classification.

    Architecture (for 32x32 input like CIFAR10):
    Input [C,32,32] → Conv1(32,3x3) → Spike → Pool → Conv2(64,3x3) → Spike → Pool → FC → Output

    For MNIST/Fashion (28x28): same but different FC input size
    """
    def __init__(self, neuron_cls, in_channels, out_dim, timesteps=30, input_size=32):
        super().__init__()
        self.T = timesteps
        self.input_size = input_size

        # Conv layers
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, padding=1)
        self.pool1 = nn.AvgPool2d(2)  # 32x32 -> 16x16 or 28x28 -> 14x14
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.AvgPool2d(2)  # 16x16 -> 8x8 or 14x14 -> 7x7

        # Calculate FC input size
        if input_size == 32:  # CIFAR10
            fc_in = 64 * 8 * 8  # 4096
        elif input_size == 28:  # MNIST/Fashion
            fc_in = 64 * 7 * 7  # 3136
        else:
            fc_in = 64 * (input_size // 4) ** 2

        self.fc = nn.Linear(fc_in, out_dim)

        # Spiking neurons for each conv layer (flattened for simplicity)
        self.neuron1 = neuron_cls(32 * (input_size // 2) ** 2)
        self.neuron2 = neuron_cls(fc_in)

        self.last_spike_rate = None

    def forward(self, x):
        batch = x.shape[0]
        device = x.device

        # For static images, create temporal copies
        # x: [batch, C, H, W]
        if x.dim() == 4:
            x = x.unsqueeze(1).repeat(1, self.T, 1, 1, 1)  # [batch, T, C, H, W]

        T = x.shape[1]
        self.neuron1.reset(batch, device)
        self.neuron2.reset(batch, device)

        out_sum = torch.zeros(batch, self.fc.out_features, device=device, dtype=x.dtype)
        spike_sum = 0.0

        for t in range(T):
            xt = x[:, t]  # [batch, C, H, W]

            # Conv1 + Spike
            h1 = self.pool1(self.conv1(xt))  # [batch, 32, H/2, W/2]
            h1_flat = h1.view(batch, -1)
            spike1 = self.neuron1(h1_flat)
            spike1 = spike1.view(batch, 32, self.input_size // 2, self.input_size // 2)

            # Conv2 + Spike
            h2 = self.pool2(self.conv2(spike1))  # [batch, 64, H/4, W/4]
            h2_flat = h2.view(batch, -1)
            spike2 = self.neuron2(h2_flat)

            spike_sum = spike_sum + (spike1.mean() + spike2.mean()) / 2
            out_sum = out_sum + self.fc(spike2)

        self.last_spike_rate = spike_sum / T
        return out_sum / T


def get_dataset(name, batch_size=128, timesteps=30, num_workers=1, pin_memory=True):
    """Load dataset - full train/test sets with optimized DataLoader settings."""
    data_dir = Path(__file__).parent.parent / "data"

    # DataLoader kwargs for speed
    # num_workers=1 because we run 2-3 parallel training processes per T4 VM
    loader_kwargs = {
        'num_workers': num_workers,
        'pin_memory': pin_memory,
        'persistent_workers': num_workers > 0,  # Keep workers alive between epochs
    }

    # Static datasets (torchvision)
    if name == "mnist":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        train = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
        test = datasets.MNIST(data_dir, train=False, download=True, transform=transform)
        return DataLoader(train, batch_size, shuffle=True, **loader_kwargs), DataLoader(test, batch_size, **loader_kwargs), 784, 10
    elif name == "fashion":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))])
        train = datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform)
        test = datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform)
        return DataLoader(train, batch_size, shuffle=True, **loader_kwargs), DataLoader(test, batch_size, **loader_kwargs), 784, 10
    elif name == "cifar10":
        tr_t = transforms.Compose([transforms.RandomCrop(32, 4), transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465), (0.247,0.243,0.262))])
        te_t = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465), (0.247,0.243,0.262))])
        train = datasets.CIFAR10(data_dir, train=True, download=True, transform=tr_t)
        test = datasets.CIFAR10(data_dir, train=False, download=True, transform=te_t)
        return DataLoader(train, batch_size, shuffle=True, **loader_kwargs), DataLoader(test, batch_size, **loader_kwargs), 3072, 10

    # Neuromorphic datasets (tonic)
    try:
        import tonic
        import tonic.transforms as T
    except ImportError:
        raise ImportError("tonic required for neuromorphic datasets: pip install tonic")

    # Neuromorphic loader kwargs (tonic collate_fn + our optimizations)
    def neuro_loader(dataset, shuffle=False):
        return DataLoader(dataset, batch_size, shuffle=shuffle,
                          collate_fn=tonic.collation.PadTensors(batch_first=True),
                          num_workers=num_workers, pin_memory=pin_memory,
                          persistent_workers=num_workers > 0)

    if name == "nmnist":
        sensor_size = tonic.datasets.NMNIST.sensor_size
        frame_transform = T.Compose([
            T.ToFrame(sensor_size=sensor_size, n_time_bins=timesteps),
            torch.from_numpy,
        ])
        train = tonic.datasets.NMNIST(save_to=str(data_dir), train=True, transform=frame_transform)
        test = tonic.datasets.NMNIST(save_to=str(data_dir), train=False, transform=frame_transform)
        return neuro_loader(train, shuffle=True), neuro_loader(test), 2*34*34, 10

    elif name == "dvs_gesture":
        sensor_size = tonic.datasets.DVSGesture.sensor_size
        frame_transform = T.Compose([
            T.ToFrame(sensor_size=sensor_size, n_time_bins=timesteps),
            torch.from_numpy,
        ])
        train = tonic.datasets.DVSGesture(save_to=str(data_dir), train=True, transform=frame_transform)
        test = tonic.datasets.DVSGesture(save_to=str(data_dir), train=False, transform=frame_transform)
        return neuro_loader(train, shuffle=True), neuro_loader(test), 2*128*128, 11

    elif name == "cifar10_dvs":
        sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
        frame_transform = T.Compose([
            T.ToFrame(sensor_size=sensor_size, n_time_bins=timesteps),
            torch.from_numpy,
        ])
        train = tonic.datasets.CIFAR10DVS(save_to=str(data_dir), transform=frame_transform)
        # CIFAR10-DVS has no train/test split - we split manually 80/20
        n_train = int(0.8 * len(train))
        n_test = len(train) - n_train
        train_set, test_set = torch.utils.data.random_split(train, [n_train, n_test],
            generator=torch.Generator().manual_seed(42))
        return neuro_loader(train_set, shuffle=True), neuro_loader(test_set), 2*128*128, 10

    elif name == "shd":
        frame_transform = T.Compose([
            T.ToFrame(sensor_size=(700,1,1), n_time_bins=timesteps),
            torch.from_numpy,
        ])
        train = tonic.datasets.SHD(save_to=str(data_dir), train=True, transform=frame_transform)
        test = tonic.datasets.SHD(save_to=str(data_dir), train=False, transform=frame_transform)
        return neuro_loader(train, shuffle=True), neuro_loader(test), 700, 20

    elif name == "ssc":
        frame_transform = T.Compose([
            T.ToFrame(sensor_size=(700,1,1), n_time_bins=timesteps),
            torch.from_numpy,
        ])
        train = tonic.datasets.SSC(save_to=str(data_dir), split="train", transform=frame_transform)
        test = tonic.datasets.SSC(save_to=str(data_dir), split="test", transform=frame_transform)
        return neuro_loader(train, shuffle=True), neuro_loader(test), 700, 35

    raise ValueError(f"Unknown dataset: {name}. Options: mnist, fashion, cifar10, nmnist, dvs_gesture, cifar10_dvs, shd, ssc")

def train_model(model, train_loader, test_loader, epochs, lr, device, verbose=True,
                use_tpu=False, save_path=None, track_spikes=False, neuromorphic=False, timesteps=30,
                sparsity_lambda=0.0, dtype=torch.float32):
    """
    Train with progress logging and optional checkpoint saving.

    Args:
        track_spikes: If True, compute spike rate after each epoch
        neuromorphic: Whether dataset is neuromorphic (for spike counting)
        timesteps: Number of timesteps (for spike counting)
        sparsity_lambda: Penalty coefficient for spike rate (0.0 = disabled)
                         Loss = CE + sparsity_lambda * spike_rate

    Returns:
        best_acc: Best test accuracy
        history: List of dicts with epoch stats including eps/k/tau history
    """
    model = model.to(device=device, dtype=dtype)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, epochs)
    crit = nn.CrossEntropyLoss()
    best = 0
    best_state = None
    history = []

    # Track learnable parameters per epoch
    eps_history = []
    k_history = []
    tau_history = []
    spike_rate_history = []

    # Progress bar for epochs
    epoch_pbar = tqdm(range(epochs), desc="Training", unit="epoch", disable=not verbose)

    for ep in epoch_pbar:
        model.train()
        train_loss = 0
        n_batches = 0

        # Progress bar for batches
        batch_pbar = tqdm(train_loader, desc=f"Epoch {ep+1}/{epochs}",
                         leave=False, disable=not verbose)
        for x, y in batch_pbar:
            x, y = x.to(device=device, dtype=dtype), y.to(device)
            opt.zero_grad()
            out = model(x)
            ce_loss = crit(out, y)
            # Sparsity penalty (optional)
            if sparsity_lambda > 0 and model.last_spike_rate is not None:
                loss = ce_loss + sparsity_lambda * model.last_spike_rate
            else:
                loss = ce_loss
            loss.backward()
            opt.step()
            if use_tpu: xm.mark_step()  # TPU sync

            train_loss += loss.item()
            n_batches += 1
            batch_pbar.set_postfix(loss=f"{loss.item():.4f}")

        scheduler.step()
        avg_loss = train_loss / n_batches

        # Track learnable params at end of epoch
        # Handle different model architectures (SNN has .neuron, DeepSNN has .neuron1/.neuron2)
        if hasattr(model, 'neuron'):
            neurons = [model.neuron]
        elif hasattr(model, 'neuron1'):
            neurons = [model.neuron1, model.neuron2]
        else:
            neurons = []

        # Track params from first neuron (or average if multiple)
        eps_vals, k_vals, tau_vals = [], [], []
        for neuron in neurons:
            if hasattr(neuron, 'eps'):
                eps_vals.append(neuron.eps.item())
            if hasattr(neuron, 'k'):
                k_vals.append(neuron.k.item())
            if hasattr(neuron, 'tau') and isinstance(neuron.tau, torch.Tensor):
                tau_vals.append(neuron.tau.item())

        if eps_vals:
            eps_history.append(sum(eps_vals) / len(eps_vals))  # Average if multiple
        if k_vals:
            k_history.append(sum(k_vals) / len(k_vals))
        if tau_vals:
            tau_history.append(sum(tau_vals) / len(tau_vals))

        # Evaluate
        model.train(False)
        correct = total = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device=device, dtype=dtype), y.to(device)
                correct += (model(x).argmax(1) == y).sum().item()
                total += y.size(0)
        acc = correct / total

        # Track spike rate (optional, adds compute)
        spike_rate = None
        if track_spikes and (ep + 1) % 10 == 0:  # Every 10 epochs to save time
            spike_rate, _ = count_spikes_epoch(model, test_loader, device, timesteps, neuromorphic, dtype)
            spike_rate_history.append({"epoch": ep+1, "rate": spike_rate})

        # Build epoch record
        epoch_record = {"epoch": ep+1, "acc": acc, "loss": avg_loss}
        if eps_history:
            epoch_record["eps"] = eps_history[-1]
        if k_history:
            epoch_record["k"] = k_history[-1]
        if tau_history:
            epoch_record["tau"] = tau_history[-1]
        if spike_rate is not None:
            epoch_record["spike_rate"] = spike_rate
        history.append(epoch_record)

        if acc > best:
            best = acc
            if save_path:
                # Compute spike rate for checkpoint if tracking enabled
                ckpt_spike_rate = spike_rate
                if track_spikes and ckpt_spike_rate is None:
                    ckpt_spike_rate, _ = count_spikes_epoch(model, test_loader, device, timesteps, neuromorphic, dtype)
                best_state = {
                    "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
                    "best_acc": acc,
                    "epoch": ep + 1,
                    "spike_rate": ckpt_spike_rate,
                    "energy_proxy": ckpt_spike_rate * timesteps if ckpt_spike_rate else None,
                    "eps": eps_history[-1] if eps_history else None,
                    "tau": tau_history[-1] if tau_history else None,
                }

        # Update epoch progress bar
        postfix = {"acc": f"{acc:.2%}", "best": f"{best:.2%}", "loss": f"{avg_loss:.4f}"}
        if eps_history:
            postfix["eps"] = f"{eps_history[-1]:.2f}"
        if k_history:
            postfix["k"] = f"{k_history[-1]:.2f}"
        epoch_pbar.set_postfix(postfix)

    # Save best checkpoint
    if save_path and best_state:
        torch.save(best_state, save_path)
        print(f"  Saved: {save_path}")

    # Attach full histories to final result
    final_info = {
        "eps_history": eps_history,
        "k_history": k_history,
        "tau_history": tau_history,
        "spike_rate_history": spike_rate_history
    }

    return best, history, final_info

NEURONS = {
    "lif": ("LIF", LIF),
    "plif": ("PLIF", PLIF),
    "adalif": ("AdaLIF", AdaLIF),
    "fullplif": ("FullPLIF", FullPLIF),
    "ultralif": ("UltraLIF", UltraLIF),
    "ultraplif": ("UltraPLIF", UltraPLIF),
    "ultratlif": ("UltraTLIF", UltraTLIF),
    "ultratplif": ("UltraTPLIF", UltraTPLIF),
    "dspike": ("DSpike", DSpike),
    "dspike+": ("DSpike+", DSpikePlus),
}

# 2-layer deep variants (FC)
DEEP_NEURONS = {
    "ultralif2": ("UltraLIF-2L", UltraLIF),
    "ultraplif2": ("UltraPLIF-2L", UltraPLIF),
    "ultratlif2": ("UltraTLIF-2L", UltraTLIF),
    "ultratplif2": ("UltraTPLIF-2L", UltraTPLIF),
    "lif2": ("LIF-2L", LIF),
}

# Convolutional variants
CONV_NEURONS = {
    "ultralif_conv": ("UltraLIF-Conv", UltraLIF),
    "ultraplif_conv": ("UltraPLIF-Conv", UltraPLIF),
    "ultratlif_conv": ("UltraTLIF-Conv", UltraTLIF),
    "ultratplif_conv": ("UltraTPLIF-Conv", UltraTPLIF),
    "lif_conv": ("LIF-Conv", LIF),
}

def main():
    import json
    from datetime import datetime

    parser = argparse.ArgumentParser(description="UltraLIF Training")
    all_models = list(NEURONS.keys()) + list(DEEP_NEURONS.keys()) + list(CONV_NEURONS.keys())
    parser.add_argument("--model", default="ultralif", choices=["all", "all-deep", "all-conv"] + all_models)
    parser.add_argument("--dataset", default="mnist", choices=["mnist", "fashion", "cifar10", "nmnist", "dvs_gesture", "cifar10_dvs", "shd", "ssc"])
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--hidden", type=int, default=256)
    parser.add_argument("--timesteps", type=int, default=30)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--tpu", action="store_true", help="Use TPU (requires torch_xla)")
    parser.add_argument("--no-save", action="store_true", help="Don't save checkpoints")
    parser.add_argument("--track-spikes", action="store_true", help="Track spike rates (adds compute)")
    parser.add_argument("--sparsity-lambda", type=float, default=0.0,
                        help="Sparsity penalty coefficient (default: 0.0 = disabled)")
    parser.add_argument("--no-log", action="store_true", help="Disable automatic logging to file")
    parser.add_argument("--dtype", default="float32", choices=["float32", "float16", "bfloat16"],
                        help="Data type for training (default: float32)")
    args = parser.parse_args()

    # Parse dtype
    dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
    args.dtype_torch = dtype_map[args.dtype]

    # Setup directories first (needed for logging)
    base_dir = Path(__file__).parent.parent
    logs_dir = base_dir / "logs"
    logs_dir.mkdir(parents=True, exist_ok=True)

    # Setup automatic logging to file
    tee_logger = None
    if not args.no_log:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_filename = f"{args.dataset}_{args.model}_seed{args.seed}_{timestamp}.log"
        log_path = logs_dir / log_filename
        tee_logger = TeeLogger(log_path)
        sys.stdout = tee_logger
        print(f"Logging to: {log_path}")

    # Device selection
    if args.tpu:
        if not TPU_AVAILABLE:
            print("ERROR: torch_xla not installed. Run: pip install torch_xla")
            sys.exit(1)
        device = xm.xla_device()
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print("="*70)
    print(f"UltraLIF Training - {args.dataset.upper()}")
    print("="*70)
    print(f"Device: {device}")

    # Setup more directories (base_dir already set above for logs)
    ckpt_dir = base_dir / "checkpoints" / args.dataset
    results_dir = base_dir / "results"
    figs_dir = base_dir / "figures" / args.dataset
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    results_dir.mkdir(parents=True, exist_ok=True)
    figs_dir.mkdir(parents=True, exist_ok=True)

    train_loader, test_loader, in_dim, n_classes = get_dataset(args.dataset, args.batch_size, args.timesteps)
    print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
    print(f"Config: {args.hidden} hidden, {args.timesteps} T, {args.epochs} epochs, lr={args.lr}, dtype={args.dtype}")
    if args.track_spikes:
        print("Spike tracking: ENABLED (every 10 epochs)")

    if args.model == "all":
        models_to_run = list(NEURONS.keys())
    elif args.model == "all-deep":
        models_to_run = list(DEEP_NEURONS.keys())
    elif args.model == "all-conv":
        models_to_run = list(CONV_NEURONS.keys())
    else:
        models_to_run = [args.model]
    is_neuromorphic = args.dataset in ["nmnist", "dvs_gesture", "cifar10_dvs", "shd", "ssc"]

    # Determine input size for conv models
    if args.dataset in ["mnist", "fashion"]:
        input_size, in_channels = 28, 1
    elif args.dataset in ["cifar10"]:
        input_size, in_channels = 32, 3
    else:
        input_size, in_channels = 32, 2  # DVS datasets

    results = {}

    for key in models_to_run:
        # Check model type
        is_deep = key in DEEP_NEURONS
        is_conv = key in CONV_NEURONS

        if is_conv:
            name, neuron_cls = CONV_NEURONS[key]
        elif is_deep:
            name, neuron_cls = DEEP_NEURONS[key]
        else:
            name, neuron_cls = NEURONS[key]

        print(f"\n--- {name} ---")
        set_seed(args.seed)

        if is_conv:
            # Conv model - neuron_cls is passed to ConvSNN which creates neurons internally
            model = ConvSNN(neuron_cls, in_channels, n_classes, args.timesteps, input_size)
            neuron = model.neuron1  # For logging learned params
        elif is_deep:
            # Create 2 separate neurons for 2-layer network
            neuron1 = neuron_cls(args.hidden)
            neuron2 = neuron_cls(args.hidden)
            model = DeepSNN(neuron1, neuron2, in_dim, args.hidden, n_classes, args.timesteps, neuromorphic=is_neuromorphic)
            neuron = neuron1  # For logging learned params
        else:
            neuron = neuron_cls(args.hidden)
            model = SNN(neuron, in_dim, args.hidden, n_classes, args.timesteps, neuromorphic=is_neuromorphic)

        # torch.compile for speed (PyTorch 2.0+, CUDA only, not TPU)
        if not args.tpu and torch.cuda.is_available() and hasattr(torch, 'compile'):
            try:
                model = torch.compile(model, mode='reduce-overhead')
                print("  torch.compile: enabled (reduce-overhead)")
            except Exception as e:
                print(f"  torch.compile: disabled ({e})")

        params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"  Params: {params:,}")

        # Only apply sparsity penalty to UltraLIF/UltraPLIF/UltraTLIF (including deep/conv variants)
        # Rationale: Only these models have learnable spike mechanism (eps) that can
        # respond to sparsity penalty. Traditional LIF variants lack learnable spike
        # control - applying sparsity would just hurt accuracy without mechanism to adapt.
        # DSpike's learnable k controls surrogate gradient sharpness, not spike rate.
        ultra_keys = ("ultralif", "ultraplif", "ultralif2", "ultraplif2", "ultralif_conv", "ultraplif_conv",
                      "ultratlif", "ultratplif", "ultratlif2", "ultratplif2", "ultratlif_conv", "ultratplif_conv")
        use_sparsity = args.sparsity_lambda if key in ultra_keys else 0.0
        if use_sparsity > 0:
            print(f"  Sparsity penalty: {use_sparsity}")

        # Build checkpoint filename with config
        sparse_suffix = f"_sp{use_sparsity}" if use_sparsity > 0 else ""
        save_path = None if args.no_save else ckpt_dir / f"{key}_T{args.timesteps}{sparse_suffix}_seed{args.seed}.pt"
        acc, history, final_info = train_model(
            model, train_loader, test_loader, args.epochs, args.lr, device,
            verbose=not args.quiet, use_tpu=args.tpu, save_path=save_path,
            track_spikes=args.track_spikes, neuromorphic=is_neuromorphic, timesteps=args.timesteps,
            sparsity_lambda=use_sparsity, dtype=args.dtype_torch
        )

        # Collect learned params (handle both single neuron and multi-neuron models)
        learned = {}
        if is_deep:
            neurons_to_check = [neuron1, neuron2]
            neuron_names = ["L1", "L2"]
        else:
            neurons_to_check = [neuron]
            neuron_names = [""]

        for i, (n, name_suffix) in enumerate(zip(neurons_to_check, neuron_names)):
            prefix = f"{name_suffix}_" if name_suffix else ""
            if hasattr(n, "eps"):
                key = f"{prefix}eps" if prefix else "eps"
                learned[key] = n.eps.item()
                print(f"  Learned {key}: {learned[key]:.3f}")
            if hasattr(n, "k"):
                key = f"{prefix}k" if prefix else "k"
                learned[key] = n.k.item()
                print(f"  Learned {key}: {learned[key]:.3f}")
            if hasattr(n, "tau") and isinstance(n.tau, torch.Tensor):
                key = f"{prefix}tau" if prefix else "tau"
                learned[key] = n.tau.item()
                print(f"  Learned {key}: {learned[key]:.3f}")

        # Final spike rate
        if args.track_spikes:
            final_spike_rate, _ = count_spikes_epoch(model, test_loader, device, args.timesteps, is_neuromorphic, args.dtype_torch)
            learned["final_spike_rate"] = final_spike_rate
            energy = compute_energy_proxy(final_spike_rate, args.hidden, args.timesteps)
            learned["energy_proxy"] = energy
            print(f"  Final spike rate: {final_spike_rate:.4f}")
            print(f"  Energy proxy: {energy:.3f}x baseline")

        print(f"  Best: {acc:.2%}")

        results[key] = {
            "name": name,
            "acc": acc,
            "params": params,
            "learned": learned,
            "history": history,
            "eps_history": final_info["eps_history"],
            "k_history": final_info["k_history"],
            "tau_history": final_info["tau_history"],
            "spike_rate_history": final_info["spike_rate_history"]
        }

    # Print summary
    print("\n" + "="*70 + "\nRESULTS\n" + "="*70)
    print(f"{'Model':12} {'Acc':>8} {'Params':>10} {'eps/k':>8} {'Spk Rate':>10} {'Energy':>8}")
    print("-"*60)
    for i, (k, r) in enumerate(sorted(results.items(), key=lambda x: -x[1]["acc"])):
        m = "*" if i == 0 else " "
        eps_k = r['learned'].get('eps', r['learned'].get('k', '-'))
        eps_k_str = f"{eps_k:.2f}" if isinstance(eps_k, float) else "-"
        spk = r['learned'].get('final_spike_rate', None)
        spk_str = f"{spk:.4f}" if spk else "-"
        energy = r['learned'].get('energy_proxy', None)
        energy_str = f"{energy:.2f}x" if energy else "-"
        print(f"{m} {r['name']:12} {r['acc']:>7.2%} {r['params']:>10,} {eps_k_str:>8} {spk_str:>10} {energy_str:>8}")

    # Save results to JSON
    results_file = results_dir / f"{args.dataset}_{args.model}_seed{args.seed}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    with open(results_file, "w") as f:
        json.dump({"args": vars(args), "results": results}, f, indent=2)
    print(f"\nResults saved: {results_file}")

    # Generate plots if we have histories
    try:
        from visualization.plots import plot_epsilon_evolution, plot_training_curves_comparison
        # Epsilon/k evolution for models that have it
        for key, r in results.items():
            if r["eps_history"]:
                plot_epsilon_evolution(
                    r["eps_history"],
                    title=f"{r['name']} - ε Evolution",
                    save_path=figs_dir / f"{key}_eps_evolution.png"
                )
            if r["k_history"]:
                plot_epsilon_evolution(
                    r["k_history"],
                    title=f"{r['name']} - k Evolution",
                    ylabel="k (sharpness)",
                    save_path=figs_dir / f"{key}_k_evolution.png"
                )
        print(f"Figures saved to: {figs_dir}")
    except ImportError:
        print("(Skipping plots - visualization module not found)")

    # Print summary of saved files
    print("\n" + "="*70)
    print("SAVED FILES")
    print("="*70)
    print(f"  Checkpoints: {ckpt_dir}/")
    print(f"  Results:     {results_dir}/")
    print(f"  Figures:     {figs_dir}/")
    if tee_logger:
        print(f"  Log:         {log_path}")
        sys.stdout = tee_logger.terminal  # Restore stdout
        tee_logger.close()
        print(f"Log saved to: {log_path}")

if __name__ == "__main__":
    main()
