# filename: train_and_export_spiker_vhd_updated.py
# usage (example):
#   python train_and_export_spiker_vhd_updated.py \
#       --data_path ./propdata/MNIST --num_steps 8 --num_epochs 5 \
#       --Frag True --Fault False --ECOC False --Soft True --Routing False \
#       --Astro False --Falvolt False --LIFA False \
#       --bias True --spiker_root ./spiker_public --vhdl_out ./vhdl_out \
#       --weights_bw 6 --neurons_bw 8 --fp_dec 6 --tau 2.0 --dt 1.0

import os, sys, math, argparse
from pathlib import Path
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

import numpy as np

# ========= (0) Optional: speed / device =========
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

def get_device(gpu_num: int):
    if torch.cuda.is_available():
        os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
        os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(gpu_num))
        return torch.device("cuda")
    return torch.device("cpu")

# ========= (1) External modules (optional) =========
# benchmarks: ECOC, SoftSNN, Router, Astro, FalVolt, LIFA hooks
try:
    from benchmarks import (
        ECOCHead,
        install_softsnn,
        install_router_from_mask, autoroute_with_mask, attach_slot_activity_tracker,
        install_astro_auto, install_falvolt_auto, install_lifa_auto,
    )
except Exception:
    ECOCHead = None
    def install_softsnn(*a, **k): return None
    def install_router_from_mask(*a, **k): return None
    def autoroute_with_mask(*a, **k): return None
    def attach_slot_activity_tracker(*a, **k): return (None, None)
    def install_astro_auto(*a, **k): return None
    def install_falvolt_auto(*a, **k): return None
    def install_lifa_auto(*a, **k): return None

# fault manager
try:
    from fault_injection import build_fault_manager, get_fault_map
except Exception:
    def build_fault_manager(*a, **k): return None
    def get_fault_map(*a, **k): return None

# fragmentation: updated DynFrag + confidence-weighted aggregation + FragNorm
from fragmentation import batch_dynamic_fragments, agg_conf_logits, FragNorm

# spikingjelly
from spikingjelly.activation_based import surrogate, neuron, functional, encoding

# ========= (2) Spiker+ import helper =========
def import_spiker(spiker_root: Optional[str]):
    try:
        import spiker.spikerplus as spikerplus
        from spiker.spikerplus.vhdl.vhdl import write_file_all
        return spikerplus, write_file_all
    except Exception:
        if spiker_root is None:
            raise ImportError(
                "Failed to import spikerplus. Provide --spiker_root to your Spiker public repo or install it with `pip install -e .`."
            )
        root = Path(spiker_root).expanduser().resolve()
        sys.path.insert(0, str(root))
        sys.path.insert(0, str(root / "spiker"))
        import spiker.spikerplus as spikerplus  # type: ignore
        from spiker.spikerplus.vhdl.vhdl import write_file_all  # type: ignore
        return spikerplus, write_file_all

# ========= (3) LIF tau↔beta (quantized) =========
import math as _math

def tau_to_beta(tau: float, dt: float = 1.0, method: str = "exp") -> float:
    if method == "exp":
        return _math.exp(-float(dt) / max(float(tau), 1e-12))
    elif method == "euler":
        return max(0.0, 1.0 - float(dt) / max(float(tau), 1e-12))
    else:
        raise ValueError("method must be 'exp' or 'euler'")

def quantize_beta(beta: float, fp_dec: int) -> float:
    step = 1 << int(fp_dec)
    return round(float(beta) * step) / float(step)

# ========= (4) Model (spikingjelly) =========
class Net(nn.Module):
    """784-1024-512-128-10 MLP + LIF(tau=2.0 default). FragNorm (updated) is optional."""
    def __init__(self, in_dim=28*28, num_classes=10, bias: bool = True, use_fragnorm: bool = True):
        super().__init__()
        functional.set_backend(self, backend='cupy')
        self.use_fragnorm = use_fragnorm
        if use_fragnorm:
            self.fn = FragNorm(num_features=in_dim, time_aggregate=False, affine=True,
                               track_running_stats=True, momentum=0.1, eps=1e-5)
        self.fc1  = nn.Linear(in_dim, 1024, bias=bias)
        self.lif1 = neuron.LIFNode(tau=2.0, v_threshold=1.0, v_reset=0.0, surrogate_function=surrogate.ATan())
        self.fc2  = nn.Linear(1024, 512, bias=bias)
        self.lif2 = neuron.LIFNode(tau=2.0, v_threshold=1.0, v_reset=0.0, surrogate_function=surrogate.ATan())
        self.fc3  = nn.Linear(512, 128, bias=bias)
        self.lif3 = neuron.LIFNode(tau=2.0, v_threshold=1.0, v_reset=0.0, surrogate_function=surrogate.ATan())
        self.fc4  = nn.Linear(128, num_classes, bias=bias)
        self.lif4 = neuron.LIFNode(tau=2.0, v_threshold=1.0, v_reset=0.0, surrogate_function=surrogate.ATan())

    def forward(self, x):
        if self.use_fragnorm:
            x = self.fn(x)
        x = self.fc1(x); x = self.lif1(x)
        x = self.fc2(x); x = self.lif2(x)
        x = self.fc3(x); x = self.lif3(x)
        x = self.fc4(x); x = self.lif4(x)
        return x

# ========= (5) Utils =========
class ParameterClipper:
    def __init__(self, limit=1.0, clip_bn=True, bn_limit=1.0):
        self.limit = float(limit); self.clip_bn = bool(clip_bn); self.bn_limit = float(bn_limit)
    def __call__(self, module: nn.Module):
        if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
            if getattr(module, "weight", None) is not None:
                module.weight.data.clamp_(-self.limit, self.limit)
            if getattr(module, "bias", None) is not None:
                module.bias.data.clamp_(-self.limit, self.limit)
        elif self.clip_bn:
            name = module.__class__.__name__.lower()
            is_bn_like = isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) or\
                         ('batchnorm' in name) or\
                         (hasattr(module, 'running_mean') and hasattr(module, 'running_var'))
            if is_bn_like:
                for wname in ('weight', 'gamma'):
                    p = getattr(module, wname, None)
                    if p is not None:
                        p.data.clamp_(-self.bn_limit, self.bn_limit)
                for bname in ('bias', 'beta'):
                    p = getattr(module, bname, None)
                    if p is not None:
                        p.data.clamp_(-self.bn_limit, self.bn_limit)

# ========= (6) Training =========
class UCIHARDataset(Dataset):
    def __init__(self, data_path, split="train"):
        self.split = split
        base = os.path.join(data_path, split, "Inertial Signals")
        signal_names = ["total_acc_x", "total_acc_y", "total_acc_z",
                        "body_acc_x", "body_acc_y", "body_acc_z",
                        "body_gyro_x", "body_gyro_y", "body_gyro_z"]
        signals = []
        for name in signal_names:
            filename = os.path.join(base, f"{name}_{split}.txt")
            data = np.loadtxt(filename)
            signals.append(data)
        self.data = np.stack(signals, axis=2)  # [num_samples, 128, 9]
        self.data = torch.tensor(self.data, dtype=torch.float)
        label_path = os.path.join(data_path, split, f"y_{split}.txt")
        labels = np.loadtxt(label_path, dtype=int)
        self.targets = torch.tensor(labels - 1, dtype=torch.long)
    def __len__(self):
        return self.data.shape[0]
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

def time_roll_forward(net: nn.Module, x_flat: torch.Tensor, T: int) -> torch.Tensor:
    B = x_flat.shape[0]
    functional.reset_net(net)
    seq = []
    for _ in range(T):
        seq.append(net(x_flat))
    return torch.stack(seq, dim=0)  # [T,B,K]


def time_roll_forward_frag(net: nn.Module, x_imgs: torch.Tensor, T: int, *,
                           method: str = "combo", overlap: bool = True,
                           kernel_size: int = 15, overlap_iter: int = 3,
                           per_image: bool = False, power_norm: Optional[dict] = None,
                           weak_model: Optional[nn.Module] = None) -> torch.Tensor:
    masks = batch_dynamic_fragments(
        x_imgs, n_steps=T, n_angles=180, method=method, overlap=overlap,
        kernel_size=kernel_size, overlap_iter=overlap_iter,
        per_image=per_image, power_norm=power_norm, weak_model=weak_model
    )  # [B,T,1,H,W]
    B = x_imgs.shape[0]
    functional.reset_net(net)
    outs = []
    for t in range(T):
        xt = (x_imgs * masks[:, t])       # [B,1,H,W]
        outs.append(net(xt.view(B, -1)))  # [B,K]
    return torch.stack(outs, dim=0)       # [T,B,K]


def train_and_eval(args):
    device = get_device(args.gpu_num)
    torch.backends.cudnn.benchmark = True

    if args.dataset in ("sequential"):
        # UCI-HAR: [B, 128, 9], label ∈ {0..5}
        ds_train = UCIHARDataset(args.data_path, split="train")
        ds_test = UCIHARDataset(args.data_path, split="test")
        input_dim = 128 * 9
        num_classes = 6
    else:
        # MNIST / FMNIST (28x28, 10 classes)
        tfm = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,)),
        ])
        if args.data_path.endswith("FMNIST"):
            ds_train = datasets.FashionMNIST(args.data_path, train=True, download=True, transform=tfm)
            ds_test = datasets.FashionMNIST(args.data_path, train=False, download=True, transform=tfm)
        else:
            ds_train = datasets.MNIST(args.data_path, train=True, download=True, transform=tfm)
            ds_test = datasets.MNIST(args.data_path, train=False, download=True, transform=tfm)
        input_dim = 28 * 28
        num_classes = 10

    train_loader = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, drop_last=True)
    test_loader  = DataLoader(ds_test,  batch_size=args.batch_size, shuffle=False, drop_last=True)

    # Model
    net = Net(in_dim=input_dim, num_classes=num_classes, bias=args.bias, use_fragnorm=args.Frag).to(device)

    # ECOC head (optional)
    ecoc = None
    if args.ECOC and ECOCHead is not None:
        ecoc = ECOCHead(num_classes=10, bit_values=(0.0, 1.0))
        net = ecoc.patch_last_linear(net)

    # Other hooks (optional)
    bounder = install_softsnn(net, mode="bnp2", per="channel", symmetric=True) if args.Soft else None

    fault_mgr = None
    router = pairs = stuck_map = None
    if args.Fault:
        fault_mgr = build_fault_manager(
            net,
            ratio=args.fault_ratio,
            fault_type=args.fault_type,
            distribution=args.fault_dist,
            stuck_at=args.limit,
            noise_std=args.noise_std,
            limit=args.limit,
            include_bias=True,
        )
        if args.Routing:
            stuck_map = get_fault_map(fault_mgr, include_bias=False)
            router, pairs = install_router_from_mask(net, stuck_map, arch="mlp")
            attach_slot_activity_tracker(net, beta=0.9)

    astro   = install_astro_auto(net, fault_mgr, start_epoch=args.fault_start_epoch) if (args.Fault and args.Astro) else None
    falvolt = install_falvolt_auto(net, fault_mgr, start_epoch=args.fault_start_epoch,
                                   clamp=(0.3, 2.0), include_bias=False, verbose=True) if (args.Fault and args.Falvolt) else None
    lifa    = install_lifa_auto(net, fault_mgr, start_epoch=args.fault_start_epoch, arch="mlp") if (args.Fault and args.LIFA) else None

    # Optimizer & loss
    opt = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    sch = torch.optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.75)
    ce  = nn.CrossEntropyLoss()
    mse = nn.MSELoss()
    encoder = encoding.PoissonEncoder()

    # DynFrag power normalizer config
    power_cfg = dict(mode="rms", target=1.0, per_channel=False, use_mask=True,
                     max_gain=6.0, detach_stats=True) if args.Frag else None

    # Train
    for epoch in range(args.num_epochs):
        net.train()
        if epoch >= 1:
            sch.step()
        for xb, yb in train_loader:
            xb = xb.to(device)

            if args.dataset in ("sequential"):
                # [B,128,9] → per-sample min-max → [0,1] → [B,1152]
                dmin = xb.amin(dim=(1, 2), keepdim=True)
                dmax = xb.amax(dim=(1, 2), keepdim=True)
                xb = (xb - dmin) / (dmax - dmin + 1e-8)
                if args.Frag:
                    xb = xb.view(args.batch_size, 1, 36, 32)

            yb = yb.to(device)

            if args.Frag:
                # vectorized fragments + auto-weakness (method='combo')
                masks = batch_dynamic_fragments(
                    xb, n_steps=args.num_steps, n_angles=180, method="combo",
                    overlap=True, kernel_size=15, overlap_iter=3,
                    per_image=False, power_norm=power_cfg, weak_model=net
                )  # [B,T,1,28,28]
                T = masks.size(1)
                logits = []
                for t in range(T):
                    xt = (xb * masks[:, t]).view(xb.size(0), -1)
                    spikes = encoder(xt).float()
                    logits.append(net(spikes))
                logits_t = torch.stack(logits, dim=0)  # [T,B,K]
                logits_seq = agg_conf_logits(logits_t, tau=2.0, time_major=True)
            else:
                logits_seq = 0
                x_flat = xb.view(xb.size(0), -1)
                for _ in range(args.num_steps):
                    spikes = encoder(x_flat).float()
                    logits_seq += net(spikes)
                logits_seq = logits_seq / args.num_steps

            if ecoc is None:
                loss = ce(logits_seq, yb)
            else:
                loss = ecoc.loss_ce(logits_seq, yb, metric="euclidean", temp=1.0, squared=True)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            functional.reset_net(net)

            with torch.no_grad():
                if args.Fault and epoch >= args.fault_start_epoch and fault_mgr is not None:
                    fault_mgr.apply_(net)
                    if args.Routing and router is not None and epoch == args.fault_start_epoch:
                        autoroute_with_mask(router, pairs, stuck_map)
                    if epoch == args.fault_start_epoch:
                        opt = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
                net.apply(ParameterClipper(limit=args.limit, clip_bn=True, bn_limit=args.limit))

        # quick eval
        net.eval()
        correct = total = 0
        with torch.no_grad():
            for xb, yb in test_loader:
                xb = xb.to(device)

                if args.dataset in ("sequential"):
                    # [B,128,9] → per-sample min-max → [0,1] → [B,1152]
                    dmin = xb.amin(dim=(1, 2), keepdim=True)
                    dmax = xb.amax(dim=(1, 2), keepdim=True)
                    xb = (xb - dmin) / (dmax - dmin + 1e-8)
                    if args.Frag:
                        xb = xb.view(args.batch_size, 1, 36, 32)

                yb = yb.to(device)

                if args.Frag:
                    logits_t = time_roll_forward_frag(net, xb, args.num_steps,
                                                      method="combo", overlap=True,
                                                      kernel_size=15, overlap_iter=3,
                                                      per_image=False, power_norm=power_cfg,
                                                      weak_model=net)
                else:
                    logits_t = time_roll_forward(net, xb.view(xb.size(0), -1), args.num_steps)
                logits_seq = agg_conf_logits(logits_t, tau=2.0, time_major=True)
                pred = logits_seq.argmax(dim=1) if ecoc is None else ecoc.decode(logits_seq)
                total += yb.size(0)
                correct += (pred == yb).sum().item()
        acc = 100.0 * correct / max(1, total)
        print(f"[Epoch {epoch}] test acc = {acc:.2f}%")

    return net, ecoc

# ========= (7) Spiker+ weight mapping & VHDL export =========

def to_spiker_and_export(trained_net: nn.Module, num_steps: int, ecoc: Optional[object],
                         spiker_root: Optional[str], out_dir: str,
                         w_bw: int = 6, n_bw: int = 8, fp_dec: int = 6,
                         tau: float = 2.0, dt: float = 1.0):
    spikerplus, write_file_all = import_spiker(spiker_root)
    NetBuilder = spikerplus.NetBuilder
    VhdlGenerator = spikerplus.VhdlGenerator

    # final output width (after ECOC patch if present)
    with torch.no_grad():
        last_out = None
        for m in trained_net.modules():
            if isinstance(m, nn.Linear):
                last_out = m.out_features
        if last_out is None:
            raise RuntimeError("Could not find the last nn.Linear layer.")

    # beta from tau, quantized to fp_dec fractional bits
    beta_cont = tau_to_beta(tau, dt, method="exp")
    beta_q    = quantize_beta(beta_cont, fp_dec)
    print(f"[LIF] tau={tau} dt={dt} -> beta_cont={beta_cont:.6f} -> beta_q(fp{fp_dec})={beta_q:.6f}")

    net_dict = {
        "n_cycles": int(num_steps),
        "n_inputs": 28 * 28,
        "layer_0": {
            "neuron_model": "lif", "n_neurons": 1024,
            "alpha": None, "learn_alpha": False,
            "beta": float(beta_q), "learn_beta": False,
            "threshold": 1.0, "learn_threshold": False,
            "reset_mechanism": "zero",
        },
        "layer_1": {
            "neuron_model": "lif", "n_neurons": 512,
            "alpha": None, "learn_alpha": False,
            "beta": float(beta_q), "learn_beta": False,
            "threshold": 1.0, "learn_threshold": False,
            "reset_mechanism": "zero",
        },
        "layer_2": {
            "neuron_model": "lif", "n_neurons": 128,
            "alpha": None, "learn_alpha": False,
            "beta": float(beta_q), "learn_beta": False,
            "threshold": 1.0, "learn_threshold": False,
            "reset_mechanism": "zero",
        },
        "layer_3": {
            "neuron_model": "lif", "n_neurons": int(last_out),
            "alpha": None, "learn_alpha": False,
            "beta": float(beta_q), "learn_beta": False,
            "threshold": 1.0, "learn_threshold": False,
            "reset_mechanism": "zero",
        },
        "readout": "mem",
    }

    # build spiker net
    spk_net = NetBuilder(net_dict).build()

    # weight copy mapping: (fc0<-fc1), (fc1<-fc2), (fc2<-fc3), (fc3<-fc4)
    mapping = [
        ("fc0", "fc1"),
        ("fc1", "fc2"),
        ("fc2", "fc3"),
        ("fc3", "fc4"),
    ]
    with torch.no_grad():
        for spk_name, sj_name in mapping:
            w_sj = getattr(trained_net, sj_name).weight.detach().cpu()
            getattr(spk_net.layers, spk_name).weight.data.copy_(w_sj)
        # thresholds are 1.0; beta already set in net_dict

    optim_params = {"weights_bw": int(w_bw), "neurons_bw": int(n_bw), "fp_dec": int(fp_dec)}
    vhdl_gen = VhdlGenerator(spk_net, optim_params)
    vhdl_top = vhdl_gen.generate(interface = True, functional = False)

    out_path = Path(out_dir)
    out_path.mkdir(parents=True, exist_ok=True)
    write_file_all(vhdl_top, output_dir=str(out_path))
    print(f"[OK] VHDL written to: {out_path.resolve()}")

# ========= (8) CLI =========

def parse_args():
    p = argparse.ArgumentParser()
    # data / train
    p.add_argument("--batch_size", type=int, default=100)
    p.add_argument("--dataset", type=str, default="sequential", choices=["image", "sequential"])
    p.add_argument("--data_path", type=str, default="propdata/MNIST")
    p.add_argument("--num_steps", type=int, default=2)
    p.add_argument("--num_epochs", type=int, default=10)
    p.add_argument("--learning_rate", type=float, default=1e-3)
    p.add_argument("--limit", type=float, default=1.0)
    p.add_argument("--bias", type=lambda s: s.lower() in ("true","1","yes"), default=True)
    p.add_argument("--gpu_num", type=int, default=0)
    # fragmentation
    p.add_argument("--Frag", type=lambda s: s.lower() in ("true","1","yes"), default=True)
    # faults & hooks
    p.add_argument("--Fault", type=lambda s: s.lower() in ("true","1","yes"), default=False)
    p.add_argument("--fault_type", default="stuck", choices=["stuck","random","connectivity"],)
    p.add_argument("--fault_dist", default="sporadic", choices=["sporadic","clustered"],)
    p.add_argument("--fault_ratio", type=float, default=0.1)
    p.add_argument("--noise_std", type=float, default=0.05)
    p.add_argument("--fault_start_epoch", type=int, default=5)
    p.add_argument("--ECOC",   type=lambda s: s.lower() in ("true","1","yes"), default=False)
    p.add_argument("--Soft",   type=lambda s: s.lower() in ("true","1","yes"), default=False)
    p.add_argument("--Routing",type=lambda s: s.lower() in ("true","1","yes"), default=False)
    p.add_argument("--Astro",  type=lambda s: s.lower() in ("true","1","yes"), default=False)
    p.add_argument("--Falvolt",type=lambda s: s.lower() in ("true","1","yes"), default=False)
    p.add_argument("--LIFA",   type=lambda s: s.lower() in ("true","1","yes"), default=False)
    # spiker export
    p.add_argument("--spiker_root", type=str, default=None)
    p.add_argument("--vhdl_out", type=str, default="./vhdl_out")
    p.add_argument("--weights_bw", type=int, default=6)
    p.add_argument("--neurons_bw", type=int, default=8)
    p.add_argument("--fp_dec", type=int, default=6)
    p.add_argument("--tau", type=float, default=2.0)
    p.add_argument("--dt", type=float, default=1.0)
    return p.parse_args()


def main():
    args = parse_args()
    net, ecoc = train_and_eval(args)
    to_spiker_and_export(
        net, args.num_steps, ecoc,
        spiker_root=args.spiker_root, out_dir=args.vhdl_out,
        w_bw=args.weights_bw, n_bw=args.neurons_bw, fp_dec=args.fp_dec,
        tau=args.tau, dt=args.dt,
    )

if __name__ == "__main__":
    main()
