# -*- coding: utf-8 -*-
"""
Epsilon Ablation Study (Appendix E.1)

Compares fixed vs learned epsilon across all four ultradiscretized neuron
variants on MNIST at T=1. The epsilon parameter controls the softness of
the log-sum-exp approximation to the max operator; as eps -> 0 the LSE
recovers the hard max (classical ultradiscrete limit).

Models tested (code name -> paper name):
    UltraTLIF  -> UltraLIF   (temporal, 2-term LSE from LIF ODE)
    UltraTPLIF -> UltraPLIF  (temporal, learnable tau)
    UltraLIF   -> UltraDLIF  (spatial, 3-term LSE from diffusion PDE)
    UltraPLIF  -> UltraDPLIF (spatial, learnable tau)

Configurations: Fixed eps in {0.5, 1.0, 2.0}, Learned eps (init=1.0)

Usage:
    python eps_ablation.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import json
from pathlib import Path
from datetime import datetime
import argparse


# =============================================================================
# UltraTLIF - Paper name: UltraLIF (temporal, 2-term LSE from LIF ODE)
# =============================================================================
class UltraTLIF(nn.Module):
    """
    UltraLIF (temporal). From Euler-discretized LIF ODE:
        v(t+1) = tau * v(t) + I(t)
    Ultradiscretization yields:
        V(t+1) = LSE_eps(V(t) + log(tau), I(t))
    """

    def __init__(self, dim, tau=0.9, init_eps=1.0, learn_eps=True):
        super().__init__()
        self.dim = dim
        self.tau = tau
        self.threshold = 1.0
        self.log_tau = np.log(tau)

        if learn_eps:
            self._log_eps = nn.Parameter(torch.tensor(float(init_eps)).log())
        else:
            self.register_buffer('_fixed_eps', torch.tensor(float(init_eps)))

        self.learn_eps = learn_eps
        self.v = None

    @property
    def eps(self):
        if self.learn_eps:
            return self._log_eps.exp().clamp(0.1, 20.0)
        return self._fixed_eps

    def reset(self, batch_size, device):
        self.v = torch.zeros(batch_size, self.dim, device=device)

    def forward(self, x):
        eps = self.eps

        # 2-term LSE: (decayed membrane, input)
        v_leak = self.v + self.log_tau        # V(t) + log(tau) in log-space
        stack = torch.stack([v_leak, x], dim=-1)
        m = stack.max(dim=-1, keepdim=True).values  # numerical stability
        self.v = m.squeeze(-1) + eps * torch.logsumexp((stack - m) / eps, dim=-1)

        spike = torch.sigmoid((self.v - self.threshold) / eps)  # soft spike
        self.v = self.v * (1 - spike)  # soft reset
        return spike


# =============================================================================
# UltraTPLIF - Paper name: UltraPLIF (temporal, 2-term LSE, learnable tau)
# =============================================================================
class UltraTPLIF(nn.Module):
    """
    UltraPLIF (temporal). Same as UltraLIF but with learnable tau
    (Fang et al., ICCV 2021 style parametric extension).
    """

    def __init__(self, dim, init_tau=0.9, init_eps=1.0, learn_eps=True):
        super().__init__()
        self.dim = dim
        self.threshold = 1.0

        self._log_tau = nn.Parameter(torch.tensor(float(init_tau)).log())

        if learn_eps:
            self._log_eps = nn.Parameter(torch.tensor(float(init_eps)).log())
        else:
            self.register_buffer('_fixed_eps', torch.tensor(float(init_eps)))

        self.learn_eps = learn_eps
        self.v = None

    @property
    def eps(self):
        if self.learn_eps:
            return self._log_eps.exp().clamp(0.1, 20.0)
        return self._fixed_eps

    @property
    def tau(self):
        return self._log_tau.exp().clamp(0.1, 0.99)

    def reset(self, batch_size, device):
        self.v = torch.zeros(batch_size, self.dim, device=device)

    def forward(self, x):
        eps = self.eps
        log_tau = self._log_tau.clamp(-2.3, -0.01)  # tau in [0.1, 0.99]

        # 2-term LSE: (decayed membrane, input)
        v_leak = self.v + log_tau
        stack = torch.stack([v_leak, x], dim=-1)
        m = stack.max(dim=-1, keepdim=True).values  # numerical stability
        self.v = m.squeeze(-1) + eps * torch.logsumexp((stack - m) / eps, dim=-1)

        spike = torch.sigmoid((self.v - self.threshold) / eps)  # soft spike
        self.v = self.v * (1 - spike)  # soft reset
        return spike


# =============================================================================
# UltraLIF - Paper name: UltraDLIF (spatial, 3-term LSE from diffusion PDE)
# =============================================================================
class UltraLIF(nn.Module):
    """
    UltraDLIF (spatial). From diffusion PDE discretization:
        v(t+1) = v_left + v_center + v_right + I
    Ultradiscretization yields:
        V(t+1) = LSE_eps(V_left, V_center, V_right) + I
    """

    def __init__(self, dim, tau=0.9, init_eps=1.0, learn_eps=True):
        super().__init__()
        self.dim = dim
        self.threshold = 1.0

        if learn_eps:
            self._log_eps = nn.Parameter(torch.tensor(float(init_eps)).log())
        else:
            self.register_buffer('_fixed_eps', torch.tensor(float(init_eps)))

        self.learn_eps = learn_eps
        self.v = None

    @property
    def eps(self):
        if self.learn_eps:
            return self._log_eps.exp().clamp(0.1, 20.0)
        return self._fixed_eps

    def reset(self, batch_size, device):
        self.v = torch.zeros(batch_size, self.dim, device=device)

    def forward(self, x):
        eps = self.eps

        # Spatial neighbors (circular boundary via roll)
        v_left = torch.roll(self.v, 1, dims=-1)
        v_right = torch.roll(self.v, -1, dims=-1)
        # 3-term LSE: (left neighbor, center, right neighbor)
        stack = torch.stack([v_left, self.v, v_right], dim=-1)
        m = stack.max(dim=-1, keepdim=True).values  # numerical stability
        self.v = m.squeeze(-1) + eps * torch.logsumexp((stack - m) / eps, dim=-1) + x

        spike = torch.sigmoid((self.v - self.threshold) / eps)  # soft spike
        self.v = self.v * (1 - spike)  # soft reset
        return spike


# =============================================================================
# UltraPLIF - Paper name: UltraDPLIF (spatial, 3-term LSE, learnable tau)
# =============================================================================
class UltraPLIF(nn.Module):
    """
    UltraDPLIF (spatial). Same as UltraDLIF but with learnable tau
    (Fang et al., ICCV 2021 style parametric extension).
    """

    def __init__(self, dim, init_tau=0.9, init_eps=1.0, learn_eps=True):
        super().__init__()
        self.dim = dim
        self.threshold = 1.0

        self._log_tau = nn.Parameter(torch.tensor(float(init_tau)).log())

        if learn_eps:
            self._log_eps = nn.Parameter(torch.tensor(float(init_eps)).log())
        else:
            self.register_buffer('_fixed_eps', torch.tensor(float(init_eps)))

        self.learn_eps = learn_eps
        self.v = None

    @property
    def eps(self):
        if self.learn_eps:
            return self._log_eps.exp().clamp(0.1, 20.0)
        return self._fixed_eps

    @property
    def tau(self):
        return self._log_tau.exp().clamp(0.1, 0.99)

    def reset(self, batch_size, device):
        self.v = torch.zeros(batch_size, self.dim, device=device)

    def forward(self, x):
        eps = self.eps
        tau = self.tau

        # Spatial neighbors (circular boundary via roll)
        v_left = torch.roll(self.v, 1, dims=-1)
        v_right = torch.roll(self.v, -1, dims=-1)
        # 3-term LSE: (left neighbor, tau-decayed center, right neighbor)
        stack = torch.stack([v_left, self.v * tau, v_right], dim=-1)
        m = stack.max(dim=-1, keepdim=True).values  # numerical stability
        self.v = m.squeeze(-1) + eps * torch.logsumexp((stack - m) / eps, dim=-1) + x

        spike = torch.sigmoid((self.v - self.threshold) / eps)  # soft spike
        self.v = self.v * (1 - spike)  # soft reset
        return spike


# =============================================================================
# Network
# =============================================================================
class SNNNetwork(nn.Module):
    """1-hidden-layer SNN: input -> fc1 -> neuron1 (hidden) -> fc2 -> neuron2 (readout)."""

    def __init__(self, neuron_class, input_dim=784, hidden_dim=64, output_dim=10,
                 init_eps=1.0, learn_eps=True):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.lif1 = neuron_class(hidden_dim, init_eps=init_eps, learn_eps=learn_eps)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.lif2 = neuron_class(output_dim, init_eps=init_eps, learn_eps=learn_eps)

    def reset(self, batch_size, device):
        self.lif1.reset(batch_size, device)
        self.lif2.reset(batch_size, device)

    def forward(self, x, timesteps=1):
        batch_size = x.size(0)
        device = x.device
        self.reset(batch_size, device)

        x = x.view(batch_size, -1)

        out_spikes = []
        for t in range(timesteps):
            h = self.fc1(x)
            s1 = self.lif1(h)
            h = self.fc2(s1)
            s2 = self.lif2(h)
            out_spikes.append(s2)

        return torch.stack(out_spikes).mean(0)

    def get_eps(self):
        return {'layer1': self.lif1.eps.item(), 'layer2': self.lif2.eps.item()}


# =============================================================================
# Training
# =============================================================================
def train_epoch(model, loader, optimizer, device, timesteps):
    model.train()
    correct = 0
    total = 0

    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data, timesteps)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        correct += output.argmax(1).eq(target).sum().item()
        total += target.size(0)

    return 100. * correct / total


def test_model(model, loader, device, timesteps):
    model.train(False)
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data, timesteps)
            correct += output.argmax(1).eq(target).sum().item()
            total += target.size(0)

    return 100. * correct / total


def compute_spike_rate(model, loader, device, timesteps):
    """Compute average spike rate."""
    model.train(False)
    total_spikes = 0
    total_neurons = 0

    with torch.no_grad():
        for data, _ in loader:
            data = data.to(device)
            batch_size = data.size(0)
            model.reset(batch_size, device)
            x = data.view(batch_size, -1)

            for t in range(timesteps):
                h = model.fc1(x)
                s1 = model.lif1(h)
                total_spikes += s1.sum().item()
                total_neurons += s1.numel()

                h = model.fc2(s1)
                s2 = model.lif2(h)
                total_spikes += s2.sum().item()
                total_neurons += s2.numel()

    return total_spikes / total_neurons if total_neurons > 0 else 0


def run_exp(neuron_class, model_name, config, train_loader, test_loader, device, epochs, timesteps, seed, ckpt_dir):
    torch.manual_seed(seed)
    np.random.seed(seed)

    model = SNNNetwork(
        neuron_class,
        init_eps=config['init_eps'],
        learn_eps=config['learn_eps']
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    best_acc = 0
    best_state = None
    eps_history = []
    spike_history = []

    for epoch in range(epochs):
        train_epoch(model, train_loader, optimizer, device, timesteps)
        test_acc = test_model(model, test_loader, device, timesteps)
        eps_history.append(model.get_eps()['layer1'])

        if test_acc > best_acc:
            best_acc = test_acc
            best_state = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'accuracy': best_acc,
                'eps': model.get_eps(),
                'config': config
            }

        if (epoch + 1) % 20 == 0:
            spike_rate = compute_spike_rate(model, test_loader, device, timesteps)
            spike_history.append(spike_rate)
            print(f"    Ep {epoch+1}: {test_acc:.2f}%, eps={model.get_eps()['layer1']:.3f}, spike={spike_rate:.3f}")

    # Save best checkpoint
    if best_state is not None:
        cfg_name = config['name'].replace(' ', '_').replace('=', '').replace('(', '').replace(')', '')
        ckpt_path = ckpt_dir / f"{model_name}_{cfg_name}_T{timesteps}.pt"
        torch.save(best_state, ckpt_path)
        print(f"    Saved: {ckpt_path.name}")

    final_spike = compute_spike_rate(model, test_loader, device, timesteps)
    return {
        'best_acc': best_acc,
        'final_eps': model.get_eps()['layer1'],
        'final_spike_rate': final_spike,
        'eps_history': eps_history,
        'spike_history': spike_history
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--timesteps', type=int, default=1)
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}, Epochs: {args.epochs}, T={args.timesteps}\n")

    # MNIST
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
    test_data = datasets.MNIST('data', train=False, transform=transform)
    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=args.batch_size)

    # Create checkpoint directory
    ckpt_dir = Path('ablations/checkpoints')
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    # Models (code name -> class; see paper name in each class docstring)
    # UltraTLIF  = UltraLIF   (temporal)
    # UltraTPLIF = UltraPLIF  (temporal, learnable tau)
    # UltraLIF   = UltraDLIF  (spatial)
    # UltraPLIF  = UltraDPLIF (spatial, learnable tau)
    models = {
        'UltraTLIF': UltraTLIF,
        'UltraTPLIF': UltraTPLIF,
        'UltraLIF': UltraLIF,
        'UltraPLIF': UltraPLIF,
    }

    # Eps configs
    eps_configs = [
        {'name': 'Fixed=0.5', 'init_eps': 0.5, 'learn_eps': False},
        {'name': 'Fixed=1.0', 'init_eps': 1.0, 'learn_eps': False},
        {'name': 'Fixed=2.0', 'init_eps': 2.0, 'learn_eps': False},
        {'name': 'Learned(1.0)', 'init_eps': 1.0, 'learn_eps': True},
    ]

    all_results = {}

    for model_name, neuron_class in models.items():
        print("=" * 60)
        print(f"MODEL: {model_name}")
        print("=" * 60)

        model_results = {}

        for cfg in eps_configs:
            print(f"\n  {cfg['name']}:")
            result = run_exp(neuron_class, model_name, cfg, train_loader, test_loader,
                           device, args.epochs, args.timesteps, args.seed, ckpt_dir)
            model_results[cfg['name']] = result
            print(f"    => Best: {result['best_acc']:.2f}%, eps: {result['final_eps']:.3f}, spike: {result['final_spike_rate']:.3f}")

        all_results[model_name] = model_results

    # Summary - Accuracy
    print("\n" + "=" * 70)
    print("ACCURACY: Epsilon Ablation on MNIST T=1")
    print("=" * 70)
    print(f"{'Model':<12} {'Fixed=0.5':>10} {'Fixed=1.0':>10} {'Fixed=2.0':>10} {'Learned':>10}")
    print("-" * 70)

    for model_name, results in all_results.items():
        row = f"{model_name:<12}"
        for cfg in eps_configs:
            acc = results[cfg['name']]['best_acc']
            row += f" {acc:>9.2f}%"
        print(row)

    # Summary - Spike Rate
    print("\n" + "=" * 70)
    print("SPIKE RATE: Epsilon Ablation on MNIST T=1")
    print("=" * 70)
    print(f"{'Model':<12} {'Fixed=0.5':>10} {'Fixed=1.0':>10} {'Fixed=2.0':>10} {'Learned':>10}")
    print("-" * 70)

    for model_name, results in all_results.items():
        row = f"{model_name:<12}"
        for cfg in eps_configs:
            spike = results[cfg['name']]['final_spike_rate']
            row += f" {spike:>10.3f}"
        print(row)

    # Summary - Final Eps
    print("\n" + "=" * 70)
    print("FINAL EPS: Epsilon Ablation on MNIST T=1")
    print("=" * 70)
    print(f"{'Model':<12} {'Fixed=0.5':>10} {'Fixed=1.0':>10} {'Fixed=2.0':>10} {'Learned':>10}")
    print("-" * 70)

    for model_name, results in all_results.items():
        row = f"{model_name:<12}"
        for cfg in eps_configs:
            eps = results[cfg['name']]['final_eps']
            row += f" {eps:>10.3f}"
        print(row)

    # Save
    output_dir = Path('ablations/results')
    output_dir.mkdir(parents=True, exist_ok=True)
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    outfile = output_dir / f'eps_ablation_T{args.timesteps}_{timestamp}.json'

    with open(outfile, 'w') as f:
        json.dump(all_results, f, indent=2)

    print(f"\nSaved: {outfile}")


if __name__ == '__main__':
    main()
