#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Evaluate Spike-GTR on IAM-line/LAM/READ2016.

Example:
  python test.py \
    -c ./configs/IAM.yaml \
    --checkpoint ./output/iam_snn_ocr/EXP/model_best.pth.tar \
    --data-path /workspace/data/OCR/IAM-line/data \
    --val-split validation \
    --export-csv ./output/iam_snn_ocr/EXP/pred_test.csv  \
    --export-csv-with-ids \

Notes:
- If your training enabled model_ema and used EMA for evaluation, keep --use-ema (default).
- If your dataset has a dedicated test split, set --val-split to that split name.
"""

import argparse
import csv
import json
import logging
import os
import platform
import sys
import time
from contextlib import suppress
from typing import Any, Optional

import yaml

import torch
import torch.nn as nn

from spikingjelly.clock_driven import functional


from spike_htr import SpikeHTR
from ocr_datasets import build_ctc_dataloaders

_logger = logging.getLogger("test")

# ---------------------------
# PyTorch 2.6 checkpoint compat
# ---------------------------
try:
    import argparse as _argparse
    torch.serialization.add_safe_globals([_argparse.Namespace])
except Exception:
    pass


def torch_load_compat(path: str, map_location="cpu"):
    """Try safe load first (PyTorch 2.6 default). Fallback if needed."""
    try:
        return torch.load(path, map_location=map_location)
    except Exception:
        try:
            return torch.load(path, map_location=map_location, weights_only=False)
        except TypeError:
            return torch.load(path, map_location=map_location)


# ---------------------------
# CTC decode & metrics
# ---------------------------

def _levenshtein_seq(a, b):
    n, m = len(a), len(b)
    if n == 0:
        return m
    if m == 0:
        return n
    prev = list(range(m + 1))
    cur = [0] * (m + 1)
    for i in range(1, n + 1):
        cur[0] = i
        ai = a[i - 1]
        for j in range(1, m + 1):
            cost = 0 if ai == b[j - 1] else 1
            cur[j] = min(prev[j] + 1, cur[j - 1] + 1, prev[j - 1] + cost)
        prev, cur = cur, prev
    return prev[m]


def compute_cer(ref: str, hyp: str):
    ref = ref or ""
    hyp = hyp or ""
    dist = _levenshtein_seq(list(ref), list(hyp))
    return dist, len(ref)


def compute_wer(ref: str, hyp: str):
    ref_tokens = (ref or "").split()
    hyp_tokens = (hyp or "").split()
    dist = _levenshtein_seq(ref_tokens, hyp_tokens)
    return dist, max(len(ref_tokens), 1)


def ctc_greedy_decode(logits: torch.Tensor, blank_id: int = 0, input_lengths: Optional[torch.Tensor] = None):
    """logits: [T, B, V]"""
    pred_ids = logits.argmax(2).detach().cpu().numpy()  # (T, B)
    T_seq, B = pred_ids.shape

    if input_lengths is not None:
        il = input_lengths.detach().cpu().tolist()
        il = [max(0, min(int(x), T_seq)) for x in il]
    else:
        il = [T_seq] * B

    results = []
    for b in range(B):
        prev = None
        seq = []
        for t in range(il[b]):
            idx = int(pred_ids[t, b])
            if idx == blank_id:
                prev = None
                continue
            if idx == prev:
                continue
            seq.append(idx)
            prev = idx
        results.append(seq)
    return results


def ids_to_text(token_ids, charset):
    out = []
    for i in token_ids:
        if i <= 0:
            continue
        j = i - 1
        if 0 <= j < len(charset):
            out.append(charset[j])
    return "".join(out)


# ---------------------------
# STRICT input_lengths inference
# ---------------------------

def _to_1d_long_tensor(x, device) -> Optional[torch.Tensor]:
    if x is None:
        return None
    if torch.is_tensor(x):
        t = x.to(device=device, dtype=torch.long)
        if t.ndim == 0:
            return t.view(1)
        return t.view(-1)
    if isinstance(x, (list, tuple)) and len(x) > 0 and all(isinstance(v, (int, float)) for v in x):
        return torch.tensor(x, device=device, dtype=torch.long).view(-1)
    return None


def _to_1d_float_tensor(x, device) -> Optional[torch.Tensor]:
    if x is None:
        return None
    if torch.is_tensor(x):
        t = x.to(device=device, dtype=torch.float32)
        if t.ndim == 0:
            return t.view(1)
        return t.view(-1)
    if isinstance(x, (list, tuple)) and len(x) > 0 and all(isinstance(v, (int, float)) for v in x):
        return torch.tensor(x, device=device, dtype=torch.float32).view(-1)
    return None


def infer_input_lengths_strict(extra: Any, T_seq: int, B: int, device: torch.device) -> torch.Tensor:
    if extra is None:
        raise ValueError("CTC strict lengths: extra is None.")
    if not isinstance(extra, dict):
        raise TypeError(f"CTC strict lengths: extra must be dict, got {type(extra)}")

    for k in ("feat_lengths", "input_lengths", "feature_lengths", "ctc_input_lengths"):
        if k in extra:
            t = _to_1d_long_tensor(extra.get(k), device)
            if t is None or t.numel() != B:
                raise ValueError(f"CTC strict lengths: extra['{k}'] invalid shape, expect B={B}.")
            return t.clamp(min=1, max=int(T_seq))

    if ("valid_w" in extra) and ("target_w" in extra):
        vw = _to_1d_float_tensor(extra.get("valid_w"), device)
        if vw is None or vw.numel() != B:
            raise ValueError(f"CTC strict lengths: extra['valid_w'] invalid shape, expect B={B}.")

        tw = extra.get("target_w")
        denom_scalar = None
        denom_vec = None
        if isinstance(tw, (int, float)):
            denom_scalar = float(tw)
        else:
            denom_vec = _to_1d_float_tensor(tw, device)
            if denom_vec is None:
                raise ValueError("CTC strict lengths: target_w must be int/float or tensor/list.")
            if denom_vec.numel() == 1:
                denom_scalar = float(denom_vec.item())
                denom_vec = None
            elif denom_vec.numel() != B:
                raise ValueError(f"CTC strict lengths: target_w invalid shape, expect 1 or B={B}.")

        vw_f = vw.to(device=device, dtype=torch.float32).clamp(min=1.0)
        if denom_vec is None:
            if denom_scalar is None or denom_scalar <= 1:
                raise ValueError(f"CTC strict lengths: invalid target_w={tw}.")
            feat = torch.ceil(vw_f / float(denom_scalar) * float(T_seq)).to(dtype=torch.long)
        else:
            denom_f = denom_vec.to(device=device, dtype=torch.float32).clamp(min=1.0)
            feat = torch.ceil(vw_f / denom_f * float(T_seq)).to(dtype=torch.long)
        return feat.clamp(min=1, max=int(T_seq))

    raise ValueError(
        "CTC strict lengths: extra missing keys. Need feat_lengths/input_lengths OR (valid_w + target_w). "
        f"Got keys={list(extra.keys())}"
    )


def setup_logging():
    root = logging.getLogger()
    for h in list(root.handlers):
        root.removeHandler(h)

    level = logging.INFO
    _logger.setLevel(level)
    _logger.propagate = False

    fmt = logging.Formatter(fmt="%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
    sh = logging.StreamHandler(stream=sys.stdout)
    sh.setLevel(level)
    sh.setFormatter(fmt)
    _logger.addHandler(sh)

    root.setLevel(level)


def ensure_parent_dir(filepath: str) -> None:
    d = os.path.dirname(os.path.abspath(filepath))
    if d and not os.path.exists(d):
        os.makedirs(d, exist_ok=True)


def _extra_get(extra: Any, key: str, i: int) -> str:
    """Best-effort extract a per-sample string field from `extra` dict."""
    if not isinstance(extra, dict):
        return ""
    if key not in extra:
        return ""
    v = extra.get(key)
    if v is None:
        return ""
    # common patterns: list/tuple of strings, tensor-like, single string
    if isinstance(v, (list, tuple)):
        if 0 <= i < len(v):
            return str(v[i])
        return ""
    if isinstance(v, str):
        return v
    # torch tensors won't be strings; skip
    try:
        return str(v)
    except Exception:
        return ""


def extract_sample_id_and_path(extra: Any, i: int) -> tuple[str, str]:
    """Try common key names used by OCR datasets."""
    id_keys = ("id", "ids", "uid", "uids", "sample_id", "sample_ids", "name", "names")
    path_keys = (
        "path",
        "paths",
        "img_path",
        "img_paths",
        "image_path",
        "image_paths",
        "fname",
        "fnames",
        "file",
        "files",
    )

    sid = ""
    for k in id_keys:
        sid = _extra_get(extra, k, i)
        if sid:
            break

    spath = ""
    for k in path_keys:
        spath = _extra_get(extra, k, i)
        if spath:
            break

    return sid, spath



def extract_sample_meta(extra: Any, i: int) -> dict[str, Any]:
    """Return a small, stable meta dict for CSV export.

    The exact keys inside `extra` differ across datasets / dataloaders.
    We keep this function deliberately conservative: best-effort `sample_id`
    and `path` (if available).
    """
    sid, spath = extract_sample_id_and_path(extra, i)
    return {"sample_id": sid, "path": spath}
def apply_yaml_compat_patch(args):
    """Keep consistent with train.py for old key aliases."""
    # drop/drop_path aliases
    try:
        default_drop = getattr(args, "drop", 0.1)
        if hasattr(args, "drop_rate") and getattr(args, "drop", default_drop) == default_drop:
            args.drop = float(getattr(args, "drop_rate"))
    except Exception:
        pass

    try:
        default_dpr = getattr(args, "drop_path", 0.1)
        if hasattr(args, "drop_path_rate") and getattr(args, "drop_path", default_dpr) == default_dpr:
            args.drop_path = float(getattr(args, "drop_path_rate"))
    except Exception:
        pass

    # temporal_gate legacy naming
    if str(getattr(args, "temporal_gate", "scalar")).lower() == "vector":
        args.temporal_gate = "channel"

    # token merge aliases
    if getattr(args, "token_min_keep_ratio", None) is None and hasattr(args, "token_merge_keep_ratio"):
        args.token_min_keep_ratio = getattr(args, "token_merge_keep_ratio")
    if getattr(args, "token_blank_thresh", None) is None and hasattr(args, "token_merge_blank_thresh"):
        args.token_blank_thresh = getattr(args, "token_merge_blank_thresh")


    # trocr legacy aliases
    try:
        if hasattr(args, "trocr_layers") and not hasattr(args, "seq_layers"):
            args.seq_layers = int(getattr(args, "trocr_layers"))
        if hasattr(args, "trocr_nhead") and not hasattr(args, "seq_nhead"):
            args.seq_nhead = int(getattr(args, "trocr_nhead"))
    except Exception:
        pass

    # normalize canonical
    norm = str(getattr(args, "normalize", "half")).lower()
    if norm in ("half", "0.5", "mean0.5", "std0.5"):
        args.normalize = "half"
    elif norm in ("imagenet", "in"):
        args.normalize = "imagenet"
    elif norm in ("none", "no", "off"):
        args.normalize = "none"

    return args


def _resolve_autocast_dtype(args) -> torch.dtype:
    want = str(getattr(args, "amp_dtype", "fp16")).lower()
    if want in ("bf16", "bfloat16"):
        if torch.cuda.is_available():
            try:
                ok = torch.cuda.is_bf16_supported()
            except Exception:
                ok = False
            if not ok:
                _logger.warning("amp_dtype=bf16 requested but not supported. Fallback to fp16.")
                return torch.float16
        return torch.bfloat16
    return torch.float16


def amp_autocast_ctx(args):
    """Inference autocast context (native AMP)."""
    if not bool(getattr(args, "native_amp", True)):
        return suppress()
    if not torch.cuda.is_available():
        return suppress()
    dtype = _resolve_autocast_dtype(args)
    try:
        from torch.amp import autocast as torch_amp_autocast

        def _ctx():
            return torch_amp_autocast("cuda", enabled=True, dtype=dtype)

        return _ctx()
    except Exception:
        return torch.cuda.amp.autocast(enabled=True)


def load_best_state_dict(ckpt: dict, prefer_ema: bool = True) -> dict:
    """Robustly extract model weights from various checkpoint formats."""
    if not isinstance(ckpt, dict):
        raise TypeError(f"Checkpoint must be a dict, got {type(ckpt)}")

    # timm CheckpointSaver common keys
    candidates = []
    if prefer_ema:
        candidates += [
            "state_dict_ema",
            "model_ema",
            "ema_state_dict",
        ]
    candidates += [
        "state_dict",
        "model",
        "model_state_dict",
    ]

    for k in candidates:
        sd = ckpt.get(k, None)
        if isinstance(sd, dict) and len(sd) > 0:
            return sd

    # fallback: treat any tensor-key dict as weights
    tensor_items = {k: v for k, v in ckpt.items() if torch.is_tensor(v)}
    if tensor_items:
        return tensor_items

    raise KeyError(f"Cannot find state_dict in checkpoint. keys={list(ckpt.keys())}")


def strip_module_prefix(state_dict: dict) -> dict:
    out = {}
    for k, v in state_dict.items():
        nk = k[7:] if isinstance(k, str) and k.startswith("module.") else k
        out[nk] = v
    return out


@torch.no_grad()
def validate(model, loader, loss_fn, args, charset, blank_id, device, *, csv_writer=None, split_name="validation"):
    batch_time_m = 0.0
    n_batches = 0
    loss_sum = 0.0
    n_samples = 0

    total_char_errs = 0
    total_chars = 0
    total_word_errs = 0
    total_words = 0
    total_exact = 0

    exported = 0
    export_max = int(getattr(args, "export_csv_max_samples", 0) or 0)
    export_with_ids = bool(getattr(args, "export_csv_with_ids", False))

    model.eval()
    end = time.time()

    for batch in loader:
        if not isinstance(batch, (list, tuple)) or len(batch) != 5:
            raise RuntimeError("DataLoader must return 5 items (images, targets_concat, target_lengths, texts, extra).")

        images, targets_concat, target_lengths, texts, extra = batch
        images = images.to(device, non_blocking=True)
        targets_concat = targets_concat.to(device, non_blocking=True)
        target_lengths = target_lengths.to(device, non_blocking=True)
        B = images.size(0)

        with amp_autocast_ctx(args):
            out = model(images)
            feat_lengths = None
            if isinstance(out, (tuple, list)) and len(out) == 2:
                logits, feat_lengths = out
            elif isinstance(out, dict):
                logits = out.get("logits", None)
                feat_lengths = out.get("feat_lengths", out.get("feature_lengths", None))
                if logits is None:
                    raise RuntimeError(f"Model returned dict without 'logits'. keys={list(out.keys())}")
            else:
                logits = out

        if feat_lengths is not None and isinstance(extra, dict):
            extra = dict(extra)
            extra["feat_lengths"] = feat_lengths

        T_seq = int(logits.size(0))
        if bool(getattr(args, "strict_ctc_lengths", True)):
            input_lengths = infer_input_lengths_strict(extra, T_seq=T_seq, B=B, device=device)
        else:
            input_lengths = torch.full((B,), int(T_seq), dtype=torch.long, device=device)

        if torch.any(target_lengths > input_lengths):
            input_lengths = torch.maximum(input_lengths, target_lengths).clamp(min=1, max=int(T_seq))

        logits_fp32 = logits.float()
        log_probs = logits_fp32.log_softmax(2)
        loss = loss_fn(log_probs, targets_concat, input_lengths, target_lengths)

        functional.reset_net(model)

        loss_sum += float(loss.item()) * B
        n_samples += B

        pred_seqs = ctc_greedy_decode(logits, blank_id=blank_id, input_lengths=input_lengths)

        # unpack targets
        target_seqs = []
        offset = 0
        tlen_list = target_lengths.detach().cpu().tolist()
        for L in tlen_list:
            L = int(L)
            target_seqs.append(targets_concat[offset : offset + L].detach().cpu().tolist())
            offset += L

        for bi, (pred_ids, tgt_ids) in enumerate(zip(pred_seqs, target_seqs)):
            # Prefer dataset-provided normalized text (if present) to match training logs.
            if isinstance(texts, (list, tuple)) and bi < len(texts):
                ref_text = str(texts[bi])
            else:
                ref_text = ids_to_text(tgt_ids, charset)
            pred_text = ids_to_text(pred_ids, charset)

            cerr, clen = compute_cer(ref_text, pred_text)
            werr, wlen = compute_wer(ref_text, pred_text)

            total_char_errs += int(cerr)
            total_chars += int(clen)
            total_word_errs += int(werr)
            total_words += int(wlen)
            if pred_text == ref_text:
                total_exact += 1

            if csv_writer is not None and (export_max <= 0 or exported < export_max):
                meta = extract_sample_meta(extra, bi)
                row = {
                    "index": int(n_samples - B + bi),
                    "split": str(split_name),
                    "sample_id": meta.get("sample_id", ""),
                    "path": meta.get("path", ""),
                    "gt": ref_text,
                    "pred": pred_text,
                    "char_err": int(cerr),
                    "char_len": int(clen),
                    "cer": float(cerr) / max(int(clen), 1),
                    "word_err": int(werr),
                    "word_len": int(wlen),
                    "wer": float(werr) / max(int(wlen), 1),
                    "exact": int(pred_text == ref_text),
                    "input_len": int(input_lengths[bi].item()) if bi < input_lengths.numel() else "",
                    "target_len": int(target_lengths[bi].item()) if bi < target_lengths.numel() else "",
                    "T_seq": int(T_seq),
                }
                if export_with_ids:
                    row["gt_ids"] = " ".join(str(int(x)) for x in tgt_ids)
                    row["pred_ids"] = " ".join(str(int(x)) for x in pred_ids)
                csv_writer.writerow(row)
                exported += 1

        batch_time_m += (time.time() - end)
        n_batches += 1
        end = time.time()

    avg_loss = loss_sum / max(n_samples, 1)
    cer = total_char_errs / max(total_chars, 1)
    wer = total_word_errs / max(total_words, 1)
    exact_acc = 100.0 * total_exact / max(n_samples, 1)

    return {
        "loss": float(avg_loss),
        "cer": float(cer),
        "wer": float(wer),
        "exact_acc": float(exact_acc),
        "time_per_batch_s": float(batch_time_m / max(n_batches, 1)),
    }


def build_model_kwargs(args, num_ctc_classes: int, blank_id: int) -> dict:
    """Construct model kwargs."""
    mk = dict(
        num_classes=num_ctc_classes,
        blank_id=blank_id,

        norm=getattr(args, "norm", None),
        gn_max_groups=getattr(args, "gn_max_groups", 32),
        gn_nocast=getattr(args, "gn_nocast", None),
        height_pool_mode=getattr(args, "height_pool_mode", "sigmoid"),
        height_pool_mix=getattr(args, "height_pool_mix", 0.65),

        use_abs_pos=getattr(args, "use_abs_pos", True),
        use_conv_pos=getattr(args, "use_conv_pos", True),
        conv_pos_k=getattr(args, "conv_pos_k", 7),
        pos_allow_interp=getattr(args, "pos_allow_interp", True),

        time_step=getattr(args, "time_step", 4),
        layer=getattr(args, "layer", 4),
        dim=getattr(args, "dim", 384),
        mlp_ratio=getattr(args, "mlp_ratio", 4.0),
        max_seq_len=getattr(args, "max_seq_len", 512),

        temporal_fuse=getattr(args, "temporal_fuse", "mean"),
        temporal_gate=getattr(args, "temporal_gate", "scalar"),
        temporal_eps=getattr(args, "temporal_eps", 1e-6),
        temporal_fuse_pre=getattr(args, "temporal_fuse_pre", "none"),
        temporal_fuse_final=getattr(args, "temporal_fuse_final", "none"),
        temporal_max_T=getattr(args, "temporal_max_T", 16),
        temporal_fp32_reduce=getattr(args, "temporal_fp32_reduce", True),

        seq_layers=getattr(args, "seq_layers", 2),
        seq_nhead=getattr(args, "seq_nhead", 8),

        seq_block_layout=getattr(args, "seq_block_layout", "auto"),
        ssm_kernel=getattr(args, "ssm_kernel", 31),
        ssm_expand_ratio=getattr(args, "ssm_expand_ratio", 2.0),

        drop_rate=getattr(args, "drop", 0.1),
        drop_path_rate=getattr(args, "drop_path", 0.1),

        use_aux_ctc=getattr(args, "use_aux_ctc", False),
        aux_ctc_weight=getattr(args, "aux_ctc_weight", 0.25),
        aux_temporal_fuse=getattr(args, "aux_temporal_fuse", "none"),

        use_checkpoint=getattr(args, "use_checkpoint", False),

        # optional overrides (None => keep model defaults)
        use_temporal_coding=getattr(args, "use_temporal_coding", None),
        use_dual_res_fusion=getattr(args, "use_dual_res_fusion", None),
        use_token_merge=getattr(args, "use_token_merge", None),
        token_min_keep_ratio=getattr(args, "token_min_keep_ratio", None),
        token_blank_thresh=getattr(args, "token_blank_thresh", None),
        token_merge_k=getattr(args, "token_merge_k", None),
        dual_res_down_mode=getattr(args, "dual_res_down_mode", None),
        dual_res_down_mix_init=getattr(args, "dual_res_down_mix_init", None),
        use_mem_residual=getattr(args, "use_mem_residual", None),
        mem_residual_init=getattr(args, "mem_residual_init", None),

        # InkCoder / Temporal Coding hyperparameters
        ink_int_blur_ks=getattr(args, "ink_int_blur_ks", None),
        ink_edge_blur_ks=getattr(args, "ink_edge_blur_ks", None),
        ink_q_low=getattr(args, "ink_q_low", None),
        ink_q_high=getattr(args, "ink_q_high", None),
        ink_q_edge=getattr(args, "ink_q_edge", None),
        ink_edge_consistency=getattr(args, "ink_edge_consistency", None),
        ink_edge_cons_ks=getattr(args, "ink_edge_cons_ks", None),
        ink_edge_cons_kappa=getattr(args, "ink_edge_cons_kappa", None),
        ink_edge_cons_tau=getattr(args, "ink_edge_cons_tau", None),
        ink_edge_int_gate=getattr(args, "ink_edge_int_gate", None),
        ink_edge_int_tau=getattr(args, "ink_edge_int_tau", None),
        ink_edge_int_kappa=getattr(args, "ink_edge_int_kappa", None),
        ink_d_speckle_suppress=getattr(args, "ink_d_speckle_suppress", None),
        ink_d_cons_ks=getattr(args, "ink_d_cons_ks", None),
        ink_d_cons_kappa=getattr(args, "ink_d_cons_kappa", None),
        ink_d_cons_tau=getattr(args, "ink_d_cons_tau", None),
        ink_d_cons_power=getattr(args, "ink_d_cons_power", None),
        ink_use_multiscale_edge=getattr(args, "ink_use_multiscale_edge", None),
        ink_edge_ms_down=getattr(args, "ink_edge_ms_down", None),
        ink_base_alpha=getattr(args, "ink_base_alpha", None),
        ink_alpha_decay=getattr(args, "ink_alpha_decay", None),
        ink_use_time_varying_fusion=getattr(args, "ink_use_time_varying_fusion", None),
        ink_fuse_bias=getattr(args, "ink_fuse_bias", None),
        ink_fuse_slope=getattr(args, "ink_fuse_slope", None),
        ink_force_aux_for_fusion=getattr(args, "ink_force_aux_for_fusion", None),
        ink_theta_min=getattr(args, "ink_theta_min", None),
        ink_theta_max=getattr(args, "ink_theta_max", None),
        ink_theta_gamma=getattr(args, "ink_theta_gamma", None),
        ink_eps=getattr(args, "ink_eps", None),
    )

    return {k: v for k, v in mk.items() if v is not None}


def parse_args(argv=None):
    config_parser = argparse.ArgumentParser(add_help=False)
    config_parser.add_argument("-c", "--config", default="", type=str, metavar="FILE", help="YAML config file")

    parser = argparse.ArgumentParser("SNN-OCR evaluation (CTC)")

    # required
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to model_best.pth.tar")

    # common runtime
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")

    # EMA selection (prefer EMA weights if checkpoint provides them)
    ema_group = parser.add_mutually_exclusive_group()
    ema_group.add_argument("--use-ema", dest="use_ema", action="store_true",
                           help="Prefer EMA weights if checkpoint provides them.")
    ema_group.add_argument("--no-use-ema", dest="use_ema", action="store_false",
                           help="Do not use EMA weights even if present.")
    # aliases for train script compatibility
    ema_group.add_argument("--model-ema", dest="use_ema", action="store_true",
                           help="Alias of --use-ema (train script compatibility).")
    ema_group.add_argument("--no-model-ema", dest="use_ema", action="store_false",
                           help="Alias of --no-use-ema.")
    parser.set_defaults(use_ema=True)

    # minimal dataset knobs (others can still come from YAML defaults)
    parser.add_argument("--dataset", type=str, default="iam_line")
    parser.add_argument("--data-path", dest="data_path", type=str, default="")
    parser.add_argument("--train-split", dest="train_split", type=str, default="train")
    parser.add_argument("--val-split", dest="val_split", type=str, default="validation")
    parser.add_argument("--val-batch-size", dest="val_batch_size", type=int, default=8)
    parser.add_argument("--workers", type=int, default=4)

    # optional per-sample dump
    parser.add_argument(
        "--export-csv",
        default="",
        type=str,
        help="If set, write per-sample GT/pred (and CER/WER) to this CSV file.",
    )
    parser.add_argument(
        "--export-csv-max-samples",
        default=0,
        type=int,
        help="If >0, export at most this many samples to CSV (metrics still run on full split).",
    )
    parser.add_argument(
        "--export-csv-with-ids",
        action="store_true",
        default=False,
        help="Also export raw token id sequences (gt_ids/pred_ids) to CSV.",
    )

    # key model knobs (optional to override)
    parser.add_argument("--time-step", dest="time_step", type=int, default=None)
    parser.add_argument("--temporal-max-t", dest="temporal_max_T", type=int, default=None)
    parser.add_argument("--use-temporal-coding", dest="use_temporal_coding", action="store_true", default=None)
    parser.add_argument("--no-use-temporal-coding", dest="use_temporal_coding", action="store_false")

    # ---- model architecture toggles / overrides (tri-state) ----

    dual_group = parser.add_mutually_exclusive_group()
    dual_group.add_argument("--use-dual-res-fusion", dest="use_dual_res_fusion", action="store_true",
                            help="Enable dual-res fusion (override).")
    dual_group.add_argument("--no-use-dual-res-fusion", dest="use_dual_res_fusion", action="store_false",
                            help="Disable dual-res fusion (override).")
    parser.set_defaults(use_dual_res_fusion=None)

    parser.add_argument("--dual-res-down-mode", dest="dual_res_down_mode", type=str, default=None,
                        choices=["avg", "max", "mix"],
                        help="Dual-res downsample mode (avg|max|mix).")
    parser.add_argument("--dual-res-down-mix-init", dest="dual_res_down_mix_init", type=float, default=None,
                        help="Initial mix ratio for dual-res downsample when mode=mix.")

    tmerge_group = parser.add_mutually_exclusive_group()
    tmerge_group.add_argument("--use-token-merge", dest="use_token_merge", action="store_true",
                              help="Enable blank-driven token merge (override).")
    tmerge_group.add_argument("--no-use-token-merge", dest="use_token_merge", action="store_false",
                              help="Disable blank-driven token merge (override).")
    parser.set_defaults(use_token_merge=None)

    parser.add_argument("--token-min-keep-ratio", dest="token_min_keep_ratio", type=float, default=None)
    parser.add_argument("--token-blank-thresh", dest="token_blank_thresh", type=float, default=None)
    parser.add_argument("--token-merge-k", dest="token_merge_k", type=int, default=None)

    # perf knobs
    parser.add_argument("--native-amp", dest="native_amp", action="store_true", default=True)
    parser.add_argument("--no-native-amp", dest="native_amp", action="store_false")
    parser.add_argument("--amp-dtype", dest="amp_dtype", type=str, default="fp16", choices=["fp16", "bf16"])
    parser.add_argument("--tf32", dest="tf32", action="store_true", default=True)
    parser.add_argument("--no-tf32", dest="tf32", action="store_false")
    parser.add_argument("--matmul-precision", dest="matmul_precision", type=str, default="high", choices=["highest", "high", "medium"])

    # strict CTC lengths (recommended: keep on)
    parser.add_argument("--no-strict-ctc-lengths", dest="strict_ctc_lengths", action="store_false")
    parser.set_defaults(strict_ctc_lengths=True)

    # parse -c first
    args_config, remaining = config_parser.parse_known_args(argv)
    if args_config.config:
        with open(args_config.config, "r", encoding="utf-8") as f:
            cfg = yaml.safe_load(f) or {}
        parser.set_defaults(**cfg)

    args, unknown = parser.parse_known_args(remaining)
    if unknown:
        _logger.warning(f"Ignored unknown CLI args: {unknown}")

    return args


def main(argv=None):
    setup_logging()
    args = parse_args(argv)
    args = apply_yaml_compat_patch(args)

    # perf knobs
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = bool(getattr(args, "tf32", True))
        torch.backends.cudnn.allow_tf32 = bool(getattr(args, "tf32", True))
        torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision(str(getattr(args, "matmul_precision", "high")))
    except Exception:
        pass

    device = torch.device(args.device)

    _logger.info("===== Environment =====")
    _logger.info(f"python: {sys.version.replace(os.linesep, ' ')}")
    _logger.info(f"platform: {platform.platform()}")
    _logger.info(f"pytorch: {torch.__version__}")
    _logger.info(f"cuda_available: {torch.cuda.is_available()}")

    _logger.info("===== Args (effective) =====")
    _logger.info(json.dumps(vars(args), ensure_ascii=False, indent=2))

    # build loaders
    train_loader, val_loader, charset, blank_id, num_ctc_classes = build_ctc_dataloaders(args, _logger)
    _logger.info(f"CTC vocab size (incl blank) = {num_ctc_classes}, blank_id = {blank_id}")

    # build model
    model_kwargs = build_model_kwargs(args, num_ctc_classes=num_ctc_classes, blank_id=blank_id)
    _logger.info("Model init kwargs (non-None):")
    for k in sorted(model_kwargs.keys()):
        _logger.info(f"  - {k}: {model_kwargs[k]}")

    model = SpikeHTR(**model_kwargs).to(device)

    # load checkpoint
    ckpt_path = args.checkpoint
    if not os.path.isfile(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    ckpt = torch_load_compat(ckpt_path, map_location="cpu")
    state_dict = load_best_state_dict(ckpt, prefer_ema=bool(getattr(args, "use_ema", True)))
    state_dict = strip_module_prefix(state_dict)

    incompatible = model.load_state_dict(state_dict, strict=False)
    try:
        missing, unexpected = incompatible.missing_keys, incompatible.unexpected_keys
    except Exception:
        missing, unexpected = [], []

    _logger.info(f"Loaded checkpoint: {ckpt_path}")
    if missing:
        _logger.info(f"  Missing keys: {len(missing)} (show up to 20): {missing[:20]}")
    if unexpected:
        _logger.info(f"  Unexpected keys: {len(unexpected)} (show up to 20): {unexpected[:20]}")

    # eval (+ optional per-sample CSV export)
    loss_fn = nn.CTCLoss(blank=blank_id, zero_infinity=True).to(device)

    csv_f = None
    csv_writer = None
    if getattr(args, "export_csv", ""):
        out_csv = str(args.export_csv)
        ensure_parent_dir(out_csv)
        csv_f = open(out_csv, "w", encoding="utf-8", newline="")

        fieldnames = [
            "index",
            "split",
            "sample_id",
            "path",
            "gt",
            "pred",
            "char_err",
            "char_len",
            "cer",
            "word_err",
            "word_len",
            "wer",
            "exact",
            "input_len",
            "target_len",
            "T_seq",
        ]
        if bool(getattr(args, "export_csv_with_ids", False)):
            fieldnames += ["gt_ids", "pred_ids"]

        csv_writer = csv.DictWriter(csv_f, fieldnames=fieldnames)
        csv_writer.writeheader()
        _logger.info(f"Exporting per-sample predictions to CSV: {out_csv}")

    try:
        metrics = validate(
            model,
            val_loader,
            loss_fn,
            args,
            charset,
            blank_id,
            device,
            csv_writer=csv_writer,
            split_name=str(getattr(args, "val_split", "validation")),
        )
    finally:
        if csv_f is not None:
            csv_f.close()

    _logger.info("===== Eval Results =====")
    _logger.info(f"split={getattr(args, 'val_split', 'validation')}")
    _logger.info(f"loss={metrics['loss']:.6f}")
    _logger.info(f"CER={metrics['cer']:.6f}")
    _logger.info(f"WER={metrics['wer']:.6f}")
    _logger.info(f"ExactAcc={metrics['exact_acc']:.2f}%")
    _logger.info(f"time_per_batch_s={metrics['time_per_batch_s']:.4f}")


if __name__ == "__main__":
    main()