#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import json
import logging
import math
import os
import platform
import sys
import time
from collections import OrderedDict, defaultdict
from contextlib import suppress
from datetime import datetime
from typing import Any, Optional

import yaml

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils

from spikingjelly.clock_driven import functional

from timm.models import safe_model_name, resume_checkpoint, model_parameters
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import (
    get_outdir, CheckpointSaver, update_summary,
    AverageMeter, random_seed, ModelEmaV2, distribute_bn, reduce_tensor
)

from spike_htr import SpikeHTR
from ocr_datasets import build_ctc_dataloaders

try:
    from apex import amp  # type: ignore
    has_apex = True
except Exception:
    has_apex = False

try:
    import wandb  # type: ignore
    has_wandb = True
except Exception:
    has_wandb = False

_logger = logging.getLogger('train')
torch.backends.cudnn.benchmark = True

# ---------------------------
# PyTorch 2.6 checkpoint compat
# ---------------------------
# Allowlist argparse.Namespace for weights_only unpickler (PyTorch>=2.6)
try:
    import argparse as _argparse
    torch.serialization.add_safe_globals([_argparse.Namespace])
except Exception:
    pass


def torch_load_compat(path: str, map_location="cpu"):
    """
    Try safe weights_only load first (PyTorch 2.6 default).
    Fallback to weights_only=False for legacy checkpoints if needed.
    Only do fallback if you trust the checkpoint source.
    """
    try:
        return torch.load(path, map_location=map_location)
    except Exception as e:
        # Fallback (trust required)
        try:
            return torch.load(path, map_location=map_location, weights_only=False)
        except TypeError:
            # older torch without weights_only arg
            return torch.load(path, map_location=map_location)

# ============================================================
# Logging / Distributed
# ============================================================

def setup_logging_clean(output_dir: Optional[str], args):
    """Configure logging to stdout (+ train.log for main process) without duplicates."""
    root = logging.getLogger()
    for h in list(root.handlers):
        root.removeHandler(h)

    lg = logging.getLogger('train')
    for h in list(lg.handlers):
        lg.removeHandler(h)

    is_main = _is_main_process(args)
    level = logging.INFO if is_main else logging.WARNING

    lg.setLevel(level)
    lg.propagate = False

    fmt = logging.Formatter(
        fmt="%(asctime)s | %(levelname)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    sh = logging.StreamHandler(stream=sys.stdout)
    sh.setLevel(level)
    sh.setFormatter(fmt)
    lg.addHandler(sh)

    if is_main and output_dir:
        os.makedirs(output_dir, exist_ok=True)
        fh = logging.FileHandler(os.path.join(output_dir, "train.log"), mode="a", encoding="utf-8")
        fh.setLevel(logging.INFO)
        fh.setFormatter(fmt)
        lg.addHandler(fh)

    root.setLevel(level)

    global _logger
    _logger = lg


def _dump_json(obj: Any, path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)


def _dump_text(text: str, path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)


def _count_params(model: nn.Module):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen = total - trainable
    buffers = sum(b.numel() for b in model.buffers())
    return {"total": int(total), "trainable": int(trainable), "frozen": int(frozen), "buffers": int(buffers)}


def _human_int(n: int) -> str:
    return f"{int(n):,}"


def _module_param_breakdown(model: nn.Module, depth: int = 1):
    agg = defaultdict(int)
    for name, p in model.named_parameters():
        key = ".".join(name.split(".")[:depth]) if depth > 0 else name
        agg[key] += p.numel()
    return sorted(agg.items(), key=lambda x: x[1], reverse=True)


def log_env_and_args(args, output_dir: Optional[str]):
    if not _is_main_process(args):
        return

    env = {
        "python": sys.version.replace(os.linesep, " "),
        "platform": platform.platform(),
        "pytorch": torch.__version__,
        "cuda_available": torch.cuda.is_available(),
        "cuda_runtime": torch.version.cuda if hasattr(torch.version, "cuda") else None,
        "cudnn": torch.backends.cudnn.version() if hasattr(torch.backends, "cudnn") else None,
        "gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
        "gpu_name_0": torch.cuda.get_device_name(0) if torch.cuda.is_available() and torch.cuda.device_count() > 0 else None,
    }

    _logger.info("===== Environment =====")
    for k, v in env.items():
        _logger.info(f"{k}: {v}")

    try:
        args_dict = vars(args)
    except Exception:
        args_dict = {"args": str(args)}

    _logger.info("===== Args =====")
    _logger.info(json.dumps(args_dict, ensure_ascii=False, indent=2))

    if output_dir:
        _dump_json(env, os.path.join(output_dir, "env.json"))
        _dump_json(args_dict, os.path.join(output_dir, "args.json"))


def log_model_and_params(args, model: nn.Module, output_dir: Optional[str]):
    if not _is_main_process(args):
        return

    model_u = unwrap_model(model)
    arch = str(model_u)

    _logger.info("===== Model Architecture =====")
    _logger.info(arch)

    stats = _count_params(model_u)
    _logger.info("===== Parameter Stats =====")
    _logger.info(
        f"params_total={_human_int(stats['total'])} | "
        f"trainable={_human_int(stats['trainable'])} | "
        f"frozen={_human_int(stats['frozen'])} | "
        f"buffers={_human_int(stats['buffers'])}"
    )

    total = max(stats["total"], 1)
    for depth in (1, 2):
        bd = _module_param_breakdown(model_u, depth=depth)
        top = bd[:30]
        _logger.info(f"===== Param Breakdown (prefix depth={depth}) Top-{len(top)} =====")
        for k, v in top:
            _logger.info(f"{k:<40} {_human_int(v):>15} ({v/total*100:6.2f}%)")

    if output_dir:
        _dump_text(arch + "\n", os.path.join(output_dir, "model_arch.txt"))
        _dump_json(stats, os.path.join(output_dir, "param_stats.json"))
        _dump_json(
            [{"module": k, "numel": int(v)} for k, v in _module_param_breakdown(model_u, depth=1)],
            os.path.join(output_dir, "param_breakdown_depth1.json"),
        )
        _dump_json(
            [{"module": k, "numel": int(v)} for k, v in _module_param_breakdown(model_u, depth=2)],
            os.path.join(output_dir, "param_breakdown_depth2.json"),
        )


def _is_dist_avail_and_initialized() -> bool:
    return torch.distributed.is_available() and torch.distributed.is_initialized()


def _is_main_process(args) -> bool:
    return int(getattr(args, "rank", 0)) == 0


def init_distributed_and_device(args):
    args.rank = 0
    args.world_size = 1
    args.local_rank = 0
    args.distributed = False

    if "WORLD_SIZE" in os.environ and "RANK" in os.environ:
        args.world_size = int(os.environ["WORLD_SIZE"])
        args.rank = int(os.environ["RANK"])
        args.local_rank = int(os.environ.get("LOCAL_RANK", 0))
        args.distributed = args.world_size > 1

    if torch.cuda.is_available():
        torch.cuda.set_device(args.local_rank)
        args.device = f"cuda:{args.local_rank}"
    else:
        args.device = "cpu"

    if args.distributed:
        backend = "nccl" if torch.cuda.is_available() else "gloo"
        torch.distributed.init_process_group(backend=backend, init_method="env://")
        if torch.cuda.is_available():
            torch.distributed.barrier(device_ids=[args.local_rank])
        else:
            torch.distributed.barrier()

    if not _is_main_process(args):
        logging.getLogger().setLevel(logging.WARNING)


def wrap_ddp(model: nn.Module, args) -> nn.Module:
    if not args.distributed:
        return model
    if args.device.startswith("cuda"):
        return torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=bool(getattr(args, "find_unused_parameters", False)),
        )
    return torch.nn.parallel.DistributedDataParallel(
        model,
        broadcast_buffers=False,
        find_unused_parameters=bool(getattr(args, "find_unused_parameters", False)),
    )


def unwrap_model(m: nn.Module) -> nn.Module:
    return m.module if hasattr(m, "module") else m


# ============================================================
# AMP scalers
# ============================================================

class NativeScalerWithGradAccum:
    """GradScaler (native AMP) + grad accumulation."""
    state_dict_key = "amp_scaler"

    def __init__(self, enabled: bool = True):
        self.enabled = bool(enabled)
        try:
            from torch.amp import GradScaler
            self._scaler = GradScaler('cuda', enabled=self.enabled)
        except Exception:
            self._scaler = torch.cuda.amp.GradScaler(enabled=self.enabled)

    def __call__(
        self,
        loss,
        optimizer,
        clip_grad=None,
        clip_mode='norm',
        parameters=None,
        create_graph=False,
        need_update: bool = True,
    ):
        if not self.enabled:
            loss.backward(create_graph=create_graph)
            if need_update:
                if clip_grad is not None:
                    assert parameters is not None
                    from timm.utils import dispatch_clip_grad
                    dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
                optimizer.step()
            return None

        self._scaler.scale(loss).backward(create_graph=create_graph)

        if not need_update:
            return None

        if clip_grad is not None:
            assert parameters is not None
            self._scaler.unscale_(optimizer)
            from timm.utils import dispatch_clip_grad
            dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)

        self._scaler.step(optimizer)
        self._scaler.update()
        return None

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


class ApexScalerWithGradAccum:
    """APEX AMP + grad accumulation."""
    state_dict_key = "amp_scaler"

    def __init__(self, enabled: bool = True):
        self.enabled = bool(enabled)

    def __call__(
        self,
        loss,
        optimizer,
        clip_grad=None,
        clip_mode='norm',
        parameters=None,
        create_graph=False,
        need_update: bool = True,
    ):
        if not self.enabled:
            loss.backward(create_graph=create_graph)
            if need_update:
                if clip_grad is not None:
                    from timm.utils import dispatch_clip_grad
                    dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
                optimizer.step()
            return None

        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward(create_graph=create_graph)

        if not need_update:
            return None

        if clip_grad is not None:
            from timm.utils import dispatch_clip_grad
            dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)

        optimizer.step()
        return None

    def state_dict(self):
        return amp.state_dict()

    def load_state_dict(self, state_dict):
        try:
            amp.load_state_dict(state_dict)
        except Exception:
            pass


def _resolve_autocast_dtype(args) -> torch.dtype:
    want = str(getattr(args, "amp_dtype", "fp16")).lower()
    if want in ("bf16", "bfloat16"):
        if torch.cuda.is_available():
            try:
                ok = torch.cuda.is_bf16_supported()
            except Exception:
                ok = False
            if not ok:
                _logger.warning("amp_dtype=bf16 requested but not supported. Fallback to fp16.")
                return torch.float16
        return torch.bfloat16
    return torch.float16


def create_amp(args, model, optimizer):
    """
    return (amp_autocast_ctx_factory, loss_scaler, model, optimizer, use_amp_name)
    """
    use_amp = None
    if args.amp:
        args.native_amp = True

    if args.apex_amp and has_apex:
        use_amp = "apex"
    elif args.native_amp:
        use_amp = "native"
    elif args.apex_amp and not has_apex:
        _logger.warning("APEX AMP requested but apex not found. Fallback to float32.")

    amp_autocast = suppress
    loss_scaler = None

    if use_amp == "apex":
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        loss_scaler = ApexScalerWithGradAccum(enabled=True)
        _logger.info("Using NVIDIA APEX AMP.")
        amp_autocast = suppress

    elif use_amp == "native":
        dtype = _resolve_autocast_dtype(args)
        try:
            from torch.amp import autocast as torch_amp_autocast
            def amp_autocast():
                return torch_amp_autocast("cuda", enabled=torch.cuda.is_available(), dtype=dtype)
        except Exception:
            def amp_autocast():
                return torch.cuda.amp.autocast(enabled=torch.cuda.is_available())

        loss_scaler = NativeScalerWithGradAccum(enabled=torch.cuda.is_available())
        _logger.info(f"Using native Torch AMP, dtype={str(getattr(args, 'amp_dtype', 'fp16'))}.")
    else:
        _logger.info("AMP not enabled. Training in float32.")

    return amp_autocast, loss_scaler, model, optimizer, use_amp


def autocast_off():
    if not torch.cuda.is_available():
        return suppress()
    try:
        from torch.amp import autocast as torch_amp_autocast
        return torch_amp_autocast("cuda", enabled=False)
    except Exception:
        return torch.cuda.amp.autocast(enabled=False)


# ============================================================
# Debug dump
# ============================================================

def dump_non_finite_batch(
    output_dir: str,
    epoch: int,
    batch_idx: int,
    images: torch.Tensor,
    targets_concat: torch.Tensor,
    target_lengths: torch.Tensor,
    input_lengths: torch.Tensor,
    texts=None,
    extra=None,
    logits: Optional[torch.Tensor] = None,
):
    try:
        dbg_dir = os.path.join(output_dir, "nan_debug")
        os.makedirs(dbg_dir, exist_ok=True)
        stem = os.path.join(dbg_dir, f"epoch{epoch:03d}_iter{batch_idx:06d}")
        payload = {
            "epoch": int(epoch),
            "batch_idx": int(batch_idx),
            "images": images.detach().cpu(),
            "targets_concat": targets_concat.detach().cpu(),
            "target_lengths": target_lengths.detach().cpu(),
            "input_lengths": input_lengths.detach().cpu(),
            "texts": texts,
            "extra": extra,
        }
        if logits is not None:
            with autocast_off():
                lf = logits.detach().float()
                payload.update({
                    "logits_min": float(lf.min().item()),
                    "logits_max": float(lf.max().item()),
                    "logits_mean": float(lf.mean().item()),
                    "logits_isfinite": bool(torch.isfinite(lf).all().item()),
                })
        torch.save(payload, stem + ".pt")
        try:
            torchvision.utils.save_image(
                images.detach().cpu(),
                stem + ".png",
                padding=0,
                normalize=True
            )
        except Exception:
            pass
    except Exception:
        pass


# ============================================================
# Utils: reduce SNN memory
# ============================================================

def force_disable_store_v(model: nn.Module):
    """Best-effort: set store_v_seq/store_v=False for all spiking neurons."""
    for m in model.modules():
        for attr in ("store_v_seq", "store_v"):
            if hasattr(m, attr):
                try:
                    setattr(m, attr, False)
                except Exception:
                    pass

def _set_requires_grad(module: nn.Module, flag: bool):
    for p in module.parameters():
        p.requires_grad = bool(flag)


# ============================================================
# CTC decode & metrics
# ============================================================

def _levenshtein_seq(a, b):
    n, m = len(a), len(b)
    if n == 0:
        return m
    if m == 0:
        return n
    prev = list(range(m + 1))
    cur = [0] * (m + 1)
    for i in range(1, n + 1):
        cur[0] = i
        ai = a[i - 1]
        for j in range(1, m + 1):
            cost = 0 if ai == b[j - 1] else 1
            cur[j] = min(prev[j] + 1, cur[j - 1] + 1, prev[j - 1] + cost)
        prev, cur = cur, prev
    return prev[m]


def compute_cer(ref: str, hyp: str):
    ref = ref or ""
    hyp = hyp or ""
    dist = _levenshtein_seq(list(ref), list(hyp))
    return dist, len(ref)


def compute_wer(ref: str, hyp: str):
    ref_tokens = (ref or "").split()
    hyp_tokens = (hyp or "").split()
    dist = _levenshtein_seq(ref_tokens, hyp_tokens)
    return dist, max(len(ref_tokens), 1)


def ctc_greedy_decode(
    logits: torch.Tensor,
    blank_id: int = 0,
    input_lengths: Optional[torch.Tensor] = None,
):
    """
    logits: [T_seq, B, V]
    input_lengths: [B]
    """
    pred_ids = logits.argmax(2).detach().cpu().numpy()  # (T_seq, B)
    T_seq, B = pred_ids.shape
    results = []
    if input_lengths is not None:
        il = input_lengths.detach().cpu().tolist()
        il = [max(0, min(int(x), T_seq)) for x in il]
    else:
        il = [T_seq] * B

    for b in range(B):
        prev = None
        seq = []
        t_end = il[b]
        for t in range(t_end):
            idx = int(pred_ids[t, b])
            if idx == blank_id:
                prev = None
                continue
            if idx == prev:
                continue
            seq.append(idx)
            prev = idx
        results.append(seq)
    return results


def ids_to_text(token_ids, charset):
    out = []
    for i in token_ids:
        if i <= 0:
            continue
        j = i - 1
        if 0 <= j < len(charset):
            out.append(charset[j])
    return "".join(out)


# ============================================================
# STRICT input_lengths inference
# ============================================================

def _to_1d_long_tensor(x, device) -> Optional[torch.Tensor]:
    if x is None:
        return None
    if torch.is_tensor(x):
        t = x.to(device=device, dtype=torch.long)
        if t.ndim == 0:
            return t.view(1)
        return t.view(-1)
    if isinstance(x, (list, tuple)) and len(x) > 0 and all(isinstance(v, (int, float)) for v in x):
        return torch.tensor(x, device=device, dtype=torch.long).view(-1)
    return None


def _to_1d_float_tensor(x, device) -> Optional[torch.Tensor]:
    if x is None:
        return None
    if torch.is_tensor(x):
        t = x.to(device=device, dtype=torch.float32)
        if t.ndim == 0:
            return t.view(1)
        return t.view(-1)
    if isinstance(x, (list, tuple)) and len(x) > 0 and all(isinstance(v, (int, float)) for v in x):
        return torch.tensor(x, device=device, dtype=torch.float32).view(-1)
    return None


def infer_input_lengths_strict(extra: Any, T_seq: int, B: int, device: torch.device) -> torch.Tensor:
    if extra is None:
        raise ValueError(
            "CTC strict lengths: extra is None. DataLoader must return (.., extra dict with lengths)."
        )
    if not isinstance(extra, dict):
        raise TypeError(f"CTC strict lengths: extra must be dict, got {type(extra)}")

    for k in ("feat_lengths", "input_lengths", "feature_lengths", "ctc_input_lengths"):
        if k in extra:
            t = _to_1d_long_tensor(extra.get(k), device)
            if t is None or t.numel() != B:
                raise ValueError(f"CTC strict lengths: extra['{k}'] invalid shape, expect B={B}.")
            return t.clamp(min=1, max=int(T_seq))

    if ("valid_w" in extra) and ("target_w" in extra):
        vw = _to_1d_float_tensor(extra.get("valid_w"), device)
        if vw is None or vw.numel() != B:
            raise ValueError(f"CTC strict lengths: extra['valid_w'] invalid shape, expect B={B}.")

        tw = extra.get("target_w")
        denom_scalar = None
        denom_vec = None

        if isinstance(tw, (int, float)):
            denom_scalar = float(tw)
        else:
            denom_vec = _to_1d_float_tensor(tw, device)
            if denom_vec is None:
                raise ValueError("CTC strict lengths: target_w must be int/float or tensor/list.")
            if denom_vec.numel() == 1:
                denom_scalar = float(denom_vec.item())
                denom_vec = None
            elif denom_vec.numel() != B:
                raise ValueError(f"CTC strict lengths: target_w invalid shape, expect 1 or B={B}.")

        vw_f = vw.to(device=device, dtype=torch.float32).clamp(min=1.0)

        if denom_vec is None:
            if denom_scalar is None or denom_scalar <= 1:
                raise ValueError(f"CTC strict lengths: invalid target_w={tw}.")
            feat = torch.ceil(vw_f / float(denom_scalar) * float(T_seq)).to(dtype=torch.long)
        else:
            denom_f = denom_vec.to(device=device, dtype=torch.float32).clamp(min=1.0)
            feat = torch.ceil(vw_f / denom_f * float(T_seq)).to(dtype=torch.long)

        return feat.clamp(min=1, max=int(T_seq))

    raise ValueError(
        "CTC strict lengths: extra missing keys. Need feat_lengths/input_lengths OR (valid_w + target_w). "
        f"Got keys={list(extra.keys())}"
    )


# ============================================================
# argparse (SAME as train2.py, just note the model name)
# ============================================================

config_parser = argparse.ArgumentParser(description='Training Config', add_help=False)
config_parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
                           help='YAML config file specifying default arguments')

parser = argparse.ArgumentParser(description='PyTorch Line-CTC Training (Spike-HTR OPTIMIZED)')

# Model
parser.add_argument('--model', default='snn_ocr_optimized', type=str, metavar='MODEL')
parser.add_argument('--norm', type=str, default=None, choices=['bn', 'gn'],
                    help='Override norm: bn|gn (None => follow model defaults).')
parser.add_argument('--gn-max-groups', dest='gn_max_groups', type=int, default=32)

parser.add_argument('--gn-nocast', dest='gn_nocast', action='store_true', default=None)
parser.add_argument('--no-gn-nocast', dest='gn_nocast', action='store_false')

parser.add_argument('--height-pool-mode', dest='height_pool_mode', type=str, default='sigmoid',
                    choices=['mean', 'max', 'sigmoid', 'softmax', 'attn', 'gate', 'attn_gate'])
parser.add_argument('--height-pool-mix', dest='height_pool_mix', type=float, default=0.65)
# MaxFormer-inspired anti-lowpass knobs
parser.add_argument('--dual-res-down-mode', dest='dual_res_down_mode', type=str, default=None,
                    choices=['avg', 'max', 'mix'],
                    help='DualResFusion downsample: avg|max|mix (None => model default).')
parser.add_argument('--dual-res-down-mix-init', dest='dual_res_down_mix_init', type=float, default=None,
                    help='Init for down_mode=mix; w in avg + w*(max-avg). (None => model default)')

g_mem = parser.add_mutually_exclusive_group()
g_mem.add_argument('--use-mem-residual', dest='use_mem_residual', action='store_true',
                   help='Enable analog bypass in spiking blocks (high-frequency friendly)')
g_mem.add_argument('--no-mem-residual', dest='use_mem_residual', action='store_false',
                   help='Disable analog bypass in spiking blocks')
parser.set_defaults(use_mem_residual=None)
parser.add_argument('--mem-residual-init', dest='mem_residual_init', type=float, default=None,
                    help='Init alpha for analog bypass (0 keeps original behavior).')

parser.add_argument('--use-abs-pos', dest='use_abs_pos', action='store_true', default=True)
parser.add_argument('--no-abs-pos', dest='use_abs_pos', action='store_false')
parser.add_argument('--use-conv-pos', dest='use_conv_pos', action='store_true', default=True)
parser.add_argument('--no-conv-pos', dest='use_conv_pos', action='store_false')
parser.add_argument('--conv-pos-k', dest='conv_pos_k', type=int, default=7)

parser.add_argument('--pos-allow-interp', dest='pos_allow_interp', action='store_true', default=True,
                    help='Allow interpolating abs pos embedding when W > max_seq_len.')
parser.add_argument('--no-pos-allow-interp', dest='pos_allow_interp', action='store_false')

parser.add_argument('--temporal-max-t', dest='temporal_max_T', type=int, default=16,
                    help='TemporalFusion MAX_T (must be >= time_step).')
parser.add_argument('--temporal-fp32-reduce', dest='temporal_fp32_reduce', action='store_true', default=True,
                    help='Temporal fusion reduce ops in fp32 for AMP stability.')
parser.add_argument('--no-temporal-fp32-reduce', dest='temporal_fp32_reduce', action='store_false')

parser.add_argument('--use-checkpoint', dest='use_checkpoint', action='store_true', default=False,
                    help='Enable activation checkpointing inside model (saves VRAM, slower).')

# Model detail
parser.add_argument('-T', '--time-step', dest='time_step', type=int, default=4)
parser.add_argument('-L', '--layer', dest='layer', type=int, default=4)
parser.add_argument('--dim', dest='dim', type=int, default=384)
parser.add_argument('--mlp-ratio', dest='mlp_ratio', type=float, default=4.0)


# 1D mixer layout (LinearAttn / QK-style / Spiking-SSM)
parser.add_argument('--seq-block-layout', dest='seq_block_layout', type=str, default='auto',
                    help="Seq mixer layout: auto|linear|qk|ssm|comma-list (len=1 or seq_layers).")
parser.add_argument('--ssm-kernel', dest='ssm_kernel', type=int, default=31,
                    help='Kernel size for SpikingSSM blocks (odd preferred).')
parser.add_argument('--ssm-expand-ratio', dest='ssm_expand_ratio', type=float, default=2.0,
                    help='Channel expansion ratio for SpikingSSM blocks.')

# Temporal fusion
parser.add_argument('--temporal-fuse', dest='temporal_fuse', type=str, default='mean',
                    choices=['mean', 'wavg', 'gate', 'wg', 'max', 'last'])
parser.add_argument('--temporal-gate', dest='temporal_gate', type=str, default='scalar',
                    choices=['scalar', 'channel', 'vector'])
parser.add_argument('--temporal-eps', dest='temporal_eps', type=float, default=1e-6)

parser.add_argument('--temporal-fuse-pre', dest='temporal_fuse_pre', type=str, default='none',
                    choices=['none', 'mean', 'wavg', 'gate', 'wg', 'max', 'last'])
parser.add_argument('--temporal-fuse-final', dest='temporal_fuse_final', type=str, default='none',
                    choices=['none', 'mean', 'wavg', 'gate', 'wg', 'max', 'last'])

# Aux deep supervision
parser.add_argument('--use-aux-ctc', dest='use_aux_ctc', action='store_true', default=False)
parser.add_argument('--aux-ctc-weight', dest='aux_ctc_weight', type=float, default=0.25)
parser.add_argument('--aux-temporal-fuse', dest='aux_temporal_fuse', type=str, default='none',
                    choices=['none', 'mean', 'wavg', 'gate', 'wg', 'max', 'last'])

parser.add_argument('--consistency-kl-weight', type=float, default=0.0)
parser.add_argument('--consistency-kl-temp', type=float, default=1.0)
parser.add_argument('--consistency-kl-symmetric', action='store_true', default=True)
parser.add_argument('--consistency-kl-asymmetric', dest='consistency_kl_symmetric', action='store_false')

# Optional module knobs (tri-state override)
def _add_tri_state_bool(name: str, help_str: str):
    g = parser.add_mutually_exclusive_group()
    g.add_argument(f'--{name}', dest=name.replace('-', '_'), action='store_true', help=help_str + ' (enable)')
    g.add_argument(f'--no-{name}', dest=name.replace('-', '_'), action='store_false', help=help_str + ' (disable)')
    parser.set_defaults(**{name.replace('-', '_'): None})

_add_tri_state_bool('use-temporal-coding', 'Enable Temporal Coding block')
_add_tri_state_bool('use-dual-res-fusion', 'Enable Dual-res Fusion (S2+S3)')
_add_tri_state_bool('use-token-merge', 'Enable Token Merge (blank pruning)')

# InkCoder / Temporal Coding (only used when `use_temporal_coding` is enabled)
ink_g = parser.add_argument_group('InkCoder / Temporal Coding')
ink_g.add_argument('--ink-int-blur-ks', dest='ink_int_blur_ks', type=int, default=None)
ink_g.add_argument('--ink-edge-blur-ks', dest='ink_edge_blur_ks', type=int, default=None)
ink_g.add_argument('--ink-q-low', dest='ink_q_low', type=float, default=None)
ink_g.add_argument('--ink-q-high', dest='ink_q_high', type=float, default=None)
ink_g.add_argument('--ink-q-edge', dest='ink_q_edge', type=float, default=None)

_add_tri_state_bool('ink-edge-consistency', 'Enable edge-consistency gate inside InkCoder')
ink_g.add_argument('--ink-edge-cons-ks', dest='ink_edge_cons_ks', type=int, default=None)
ink_g.add_argument('--ink-edge-cons-kappa', dest='ink_edge_cons_kappa', type=float, default=None)
ink_g.add_argument('--ink-edge-cons-tau', dest='ink_edge_cons_tau', type=float, default=None)

_add_tri_state_bool('ink-edge-int-gate', 'Enable intensity-guided edge gate inside InkCoder')
ink_g.add_argument('--ink-edge-int-tau', dest='ink_edge_int_tau', type=float, default=None)
ink_g.add_argument('--ink-edge-int-kappa', dest='ink_edge_int_kappa', type=float, default=None)

_add_tri_state_bool('ink-d-speckle-suppress', 'Enable speckle suppression on fused evidence D')
ink_g.add_argument('--ink-d-cons-ks', dest='ink_d_cons_ks', type=int, default=None)
ink_g.add_argument('--ink-d-cons-kappa', dest='ink_d_cons_kappa', type=float, default=None)
ink_g.add_argument('--ink-d-cons-tau', dest='ink_d_cons_tau', type=float, default=None)
ink_g.add_argument('--ink-d-cons-power', dest='ink_d_cons_power', type=float, default=None)

_add_tri_state_bool('ink-use-multiscale-edge', 'Enable multi-scale edge magnitude in InkCoder')
ink_g.add_argument('--ink-edge-ms-down', dest='ink_edge_ms_down', type=int, default=None)

ink_g.add_argument('--ink-base-alpha', dest='ink_base_alpha', type=float, default=None)
ink_g.add_argument('--ink-alpha-decay', dest='ink_alpha_decay', type=float, default=None)

_add_tri_state_bool('ink-use-time-varying-fusion', 'Enable time-varying fusion of intensity/edge evidence')
ink_g.add_argument('--ink-fuse-bias', dest='ink_fuse_bias', type=float, default=None)
ink_g.add_argument('--ink-fuse-slope', dest='ink_fuse_slope', type=float, default=None)
_add_tri_state_bool('ink-force-aux-for-fusion', 'Force computing I_int/I_edge even when return_aux=False')

ink_g.add_argument('--ink-theta-min', dest='ink_theta_min', type=float, default=None)
ink_g.add_argument('--ink-theta-max', dest='ink_theta_max', type=float, default=None)
ink_g.add_argument('--ink-theta-gamma', dest='ink_theta_gamma', type=float, default=None)
ink_g.add_argument('--ink-eps', dest='ink_eps', type=float, default=None)

parser.add_argument('--token-min-keep-ratio', dest='token_min_keep_ratio', type=float, default=None)
parser.add_argument('--token-blank-thresh', dest='token_blank_thresh', type=float, default=None)
parser.add_argument('--token-merge-k', dest='token_merge_k', type=int, default=None)

parser.add_argument('--token-merge-keep-ratio', dest='token_min_keep_ratio', type=float, default=None)
parser.add_argument('--token-merge-blank-thresh', dest='token_blank_thresh', type=float, default=None)


# Dataset
parser.add_argument('--dataset', '-d', metavar='NAME', default='iam_line')
parser.add_argument('--data-path', dest='data_path', metavar='DIR', default="")
parser.add_argument('--train-split', dest='train_split', default='train')
parser.add_argument('--val-split', dest='val_split', default='validation')

parser.add_argument('--img-height', dest='img_height', type=int, default=64)
parser.add_argument('--img-max-width', dest='img_max_width', type=int, default=512)

parser.add_argument('--crop-pad', dest='crop_pad', type=int, default=2)
parser.add_argument('--allow-empty', dest='allow_empty', action='store_true', default=False)
parser.add_argument('--cache-size', dest='cache_size', type=int, default=8)

# Text normalization / charset
parser.add_argument('--text-norm', dest='text_norm', action='store_true', default=True)
parser.add_argument('--no-text-norm', dest='text_norm', action='store_false')
parser.add_argument('--text-norm-form', dest='text_norm_form', type=str, default='NFKC')
parser.add_argument('--text-collapse-ws', dest='text_collapse_ws', action='store_true', default=True)
parser.add_argument('--no-text-collapse-ws', dest='text_collapse_ws', action='store_false')
parser.add_argument('--text-drop-chars', dest='text_drop_chars', type=str, default="¬")
parser.add_argument('--max-label-len', dest='max_label_len', type=int, default=0)
parser.add_argument('--oov', dest='oov', type=str, default='error', choices=['error', 'unk', 'drop'])
parser.add_argument('--add-unk', dest='add_unk', action='store_true', default=False)
parser.add_argument('--charset-from', dest='charset_from', type=str, default='trainval',
                    choices=['train', 'trainval', 'all'])

# Image normalize / valid_w estimation / aug
parser.add_argument('--normalize', dest='normalize', type=str, default='half',
                    choices=['half', 'imagenet', 'none'])
parser.add_argument('--estimate-valid-w', dest='estimate_valid_w', action='store_true', default=True)
parser.add_argument('--no-estimate-valid-w', dest='estimate_valid_w', action='store_false')
parser.add_argument('--bg-thresh', dest='bg_thresh', type=int, default=250)

parser.add_argument('--aug-affine-p', dest='aug_affine_p', type=float, default=0.12)
parser.add_argument('--aug-degrees', dest='aug_degrees', type=float, default=2.0)
parser.add_argument('--aug-translate', dest='aug_translate', type=float, default=0.01)
parser.add_argument('--aug-shear', dest='aug_shear', type=float, default=2.0)

parser.add_argument('--aug-stroke-p', dest='aug_stroke_p', type=float, default=0.25)
parser.add_argument('--aug-stroke-kmin', dest='aug_stroke_kmin', type=int, default=3)
parser.add_argument('--aug-stroke-kmax', dest='aug_stroke_kmax', type=int, default=5)

parser.add_argument('--aug-sharpen-p', dest='aug_sharpen_p', type=float, default=0.15)
parser.add_argument('--aug-invert-p', dest='aug_invert_p', type=float, default=0.0)
parser.add_argument('--aug-noise-p', dest='aug_noise_p', type=float, default=0.15)
parser.add_argument('--aug-noise-std', dest='aug_noise_std', type=float, default=0.03)

parser.add_argument('--aug-wstretch-p', dest='aug_wstretch_p', type=float, default=0.25)
parser.add_argument('--aug-wstretch-min', dest='aug_wstretch_min', type=float, default=0.7)
parser.add_argument('--aug-wstretch-max', dest='aug_wstretch_max', type=float, default=1.3)

# Resume / init
parser.add_argument('--initial-checkpoint', dest='initial_checkpoint', default='', type=str)
parser.add_argument('--resume', dest='resume', default='', type=str)
parser.add_argument('--no-resume-opt', dest='no_resume_opt', action='store_true', default=False)

# Batch / workers
parser.add_argument('-b', '--batch-size', dest='batch_size', type=int, default=4)
parser.add_argument('-vb', '--val-batch-size', dest='val_batch_size', type=int, default=4)
parser.add_argument('-j', '--workers', dest='workers', type=int, default=4)

# Optimizer / scheduler
parser.add_argument('--opt', dest='opt', default='adamw', type=str)
parser.add_argument('--weight-decay', dest='weight_decay', type=float, default=0.05)
parser.add_argument('--momentum', dest='momentum', type=float, default=0.9)
parser.add_argument('--opt-eps', dest='opt_eps', type=float, default=None)
parser.add_argument('--opt-betas', dest='opt_betas', type=float, nargs=2, default=None)
parser.add_argument('--clip-grad', dest='clip_grad', type=float, default=1.0)
parser.add_argument('--clip-mode', dest='clip_mode', type=str, default='norm')

parser.add_argument('--sched', dest='sched', default='cosine', type=str)
parser.add_argument('--sched-on-updates', dest='sched_on_updates', action='store_true', default=None)
parser.add_argument('--sched-on-epochs', dest='sched_on_updates', action='store_false')
parser.add_argument('--lr', dest='lr', type=float, default=3e-4)
parser.add_argument('--warmup-lr', dest='warmup_lr', type=float, default=1e-5)
parser.add_argument('--min-lr', dest='min_lr', type=float, default=1e-6)
parser.add_argument('--epochs', dest='epochs', type=int, default=210)
parser.add_argument('--warmup-epochs', dest='warmup_epochs', type=int, default=5)
parser.add_argument('--cooldown-epochs', dest='cooldown_epochs', type=int, default=10)
parser.add_argument('--decay-epochs', dest='decay_epochs', type=float, default=30)
parser.add_argument('--decay-rate', '--dr', dest='decay_rate', type=float, default=0.1)

# Regularization
parser.add_argument('--drop', '--drop-rate', dest='drop', type=float, default=0.1)
parser.add_argument('--drop-path', '--drop-path-rate', dest='drop_path', type=float, default=0.1)

# Sequence module
parser.add_argument('--seq-nhead', dest='seq_nhead', type=int, default=8)
parser.add_argument('--seq-layers', dest='seq_layers', type=int, default=2)
parser.add_argument('--max-seq-len', dest='max_seq_len', type=int, default=512)

# Grad accumulation
parser.add_argument('--accum-steps', dest='accum_steps', type=int, default=1)
parser.add_argument('--accum-iter', dest='accum_steps', type=int)

# AMP / perf
parser.add_argument('--amp', dest='amp', action='store_true', default=False)
parser.add_argument('--apex-amp', dest='apex_amp', action='store_true', default=False)
parser.add_argument('--native-amp', dest='native_amp', action='store_true', default=True)
parser.add_argument('--amp-dtype', dest='amp_dtype', type=str, default='bf16', choices=['fp16', 'bf16'])

parser.add_argument('--channels-last', dest='channels_last', action='store_true', default=False)
parser.add_argument('--pin-mem', dest='pin_mem', action='store_true', default=True)
parser.add_argument('--no-prefetcher', dest='no_prefetcher', action='store_true', default=False)

parser.add_argument('--tf32', dest='tf32', action='store_true', default=True)
parser.add_argument('--no-tf32', dest='tf32', action='store_false')
parser.add_argument('--matmul-precision', dest='matmul_precision', type=str, default='high',
                    choices=['highest', 'high', 'medium'])

# Logging / output
parser.add_argument('--seed', dest='seed', type=int, default=42)
parser.add_argument('--log-interval', dest='log_interval', type=int, default=50)
parser.add_argument('--checkpoint-hist', dest='checkpoint_hist', type=int, default=1)
parser.add_argument('--output', dest='output', default='./output/train', type=str)
parser.add_argument('--experiment', dest='experiment', default='', type=str)
parser.add_argument('--eval-metric', dest='eval_metric', default='cer', type=str,
                    help="loss / cer / wer / exact_acc (default: cer)")
parser.add_argument('--save-images', dest='save_images', action='store_true', default=False)
parser.add_argument('--log-wandb', dest='log_wandb', action='store_true', default=False)

# EMA / DDP
parser.add_argument('--model-ema', dest='model_ema', action='store_true', default=False)
parser.add_argument('--model-ema-force-cpu', dest='model_ema_force_cpu', action='store_true', default=False)
parser.add_argument('--model-ema-decay', dest='model_ema_decay', type=float, default=0.9998)
parser.add_argument('--find-unused-parameters', dest='find_unused_parameters', action='store_true', default=False)

parser.add_argument('--dist-bn', dest='dist_bn', type=str, default='',
                    choices=['', 'broadcast', 'reduce'])

parser.add_argument('--force-no-store-v', dest='force_no_store_v', action='store_true', default=False)

# strict ctc lengths
parser.add_argument('--no-strict-ctc-lengths', dest='strict_ctc_lengths', action='store_false')
parser.set_defaults(strict_ctc_lengths=True)

parser.add_argument('--ctc-len-violation', dest='ctc_len_violation', type=str, default='clamp',
                    choices=['raise', 'clamp', 'skip'])
parser.add_argument('--dump-on-ctc-violation', dest='dump_on_ctc_violation', action='store_true', default=False)
parser.add_argument('--logit-clip', dest='logit_clip', type=float, default=0.0,
                    help='If >0 clamp logits to [-c,c] before log_softmax (fp32). 0 disables.')




def _sanitize_argv(argv):
    out = []
    for a in argv:
        if isinstance(a, str) and a.startswith('--') and a.endswith(':'):
            a = a[:-1]
        out.append(a)
    return out


def _parse_args():
    args_config, remaining = config_parser.parse_known_args(_sanitize_argv(sys.argv[1:]))
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f) or {}
        parser.set_defaults(**cfg)

    args = parser.parse_args(remaining)
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False, sort_keys=False)
    return args, args_text


def ensure_optimizer_namespace_attrs(args):
    if not hasattr(args, 'momentum'):
        args.momentum = 0.9
    if not hasattr(args, 'opt_eps'):
        args.opt_eps = None
    if not hasattr(args, 'opt_betas'):
        args.opt_betas = None
    return args


def apply_yaml_compat_patch(args):
    if hasattr(args, "accum_iter"):
        try:
            v = int(getattr(args, "accum_iter"))
            if v > 0:
                args.accum_steps = v
        except Exception:
            pass

    default_drop = parser.get_default("drop")
    default_dpr = parser.get_default("drop_path")
    if hasattr(args, "drop_rate"):
        if float(getattr(args, "drop", default_drop)) == float(default_drop):
            try:
                args.drop = float(getattr(args, "drop_rate"))
            except Exception:
                pass
    if hasattr(args, "drop_path_rate"):
        if float(getattr(args, "drop_path", default_dpr)) == float(default_dpr):
            try:
                args.drop_path = float(getattr(args, "drop_path_rate"))
            except Exception:
                pass

    if str(getattr(args, "temporal_gate", "scalar")).lower() == "vector":
        args.temporal_gate = "channel"

    if getattr(args, "token_min_keep_ratio", None) is None and hasattr(args, "token_merge_keep_ratio"):
        args.token_min_keep_ratio = getattr(args, "token_merge_keep_ratio")
    if getattr(args, "token_blank_thresh", None) is None and hasattr(args, "token_merge_blank_thresh"):
        args.token_blank_thresh = getattr(args, "token_merge_blank_thresh")

    # Backward compatibility: older configs used `trocr_layers` / `trocr_nhead`.
    # We renamed them to `seq_layers` / `seq_nhead` to reflect that this is now a
    # generic 1D sequence module (not tied to TrOCR).
    try:
        default_seq_layers = parser.get_default("seq_layers")
        default_seq_nhead = parser.get_default("seq_nhead")
        if hasattr(args, "trocr_layers") and int(getattr(args, "seq_layers", default_seq_layers)) == int(default_seq_layers):
            args.seq_layers = int(getattr(args, "trocr_layers"))
        if hasattr(args, "trocr_nhead") and int(getattr(args, "seq_nhead", default_seq_nhead)) == int(default_seq_nhead):
            args.seq_nhead = int(getattr(args, "trocr_nhead"))
    except Exception:
        pass

    if isinstance(getattr(args, "norm", None), str) and getattr(args, "norm").lower() in ("none", "null"):
        args.norm = None

    norm = str(getattr(args, "normalize", "half")).lower()
    if norm in ("half", "0.5", "mean0.5", "std0.5"):
        args.normalize = "half"
    elif norm in ("imagenet", "in"):
        args.normalize = "imagenet"
    elif norm in ("none", "no", "off"):
        args.normalize = "none"

    return args


def _zero_grad(optimizer):
    try:
        optimizer.zero_grad(set_to_none=True)
    except TypeError:
        optimizer.zero_grad()


def _load_initial_checkpoint(model: nn.Module, ckpt_path: str):
    if not ckpt_path:
        return
    if not os.path.isfile(ckpt_path):
        _logger.warning(f"--initial-checkpoint not found: {ckpt_path}")
        return

    checkpoint = torch_load_compat(ckpt_path, map_location='cpu')
    if isinstance(checkpoint, dict):
        state_dict = checkpoint.get('state_dict', None) or checkpoint.get('model', None) or checkpoint.get('model_ema', None)
        if state_dict is None:
            state_dict = {k: v for k, v in checkpoint.items() if torch.is_tensor(v)}
    else:
        state_dict = checkpoint

    if not isinstance(state_dict, dict) or len(state_dict) == 0:
        _logger.warning(f"initial checkpoint format not recognized: {ckpt_path}")
        return

    new_state = OrderedDict()
    for k, v in state_dict.items():
        nk = k[7:] if k.startswith('module.') else k
        new_state[nk] = v

    # model's overridden load_state_dict will auto-resize pos/temporal weights
    incompatible = model.load_state_dict(new_state, strict=False)
    try:
        missing, unexpected = incompatible.missing_keys, incompatible.unexpected_keys
    except Exception:
        missing, unexpected = [], []

    _logger.info(f"Loaded initial checkpoint: {ckpt_path}")
    if missing:
        _logger.info(f"  Missing keys: {len(missing)}")
    if unexpected:
        _logger.info(f"  Unexpected keys: {len(unexpected)}")


def _log_resolved_model(args, model_kwargs: dict, model=None):
    if not _is_main_process(args):
        return

    _logger.info("Model init kwargs:")
    for k in sorted(model_kwargs.keys()):
        _logger.info(f"  - {k}: {model_kwargs[k]}")

    if model is None:
        return
    cfg = getattr(model, "cfg", None)
    if cfg is None:
        return

    try:
        from dataclasses import asdict, is_dataclass
        cfg_dict = asdict(cfg) if is_dataclass(cfg) else dict(vars(cfg))
    except Exception:
        cfg_dict = dict(vars(cfg))

    _logger.info("Resolved model config (model.cfg):")
    for k in sorted(cfg_dict.keys()):
        _logger.info(f"  - {k}: {cfg_dict[k]}")


# ============================================================
# main
# ============================================================

def main():
    args, args_text = _parse_args()
    ensure_optimizer_namespace_attrs(args)
    apply_yaml_compat_patch(args)

    init_distributed_and_device(args)
    args.prefetcher = not bool(getattr(args, "no_prefetcher", False))
    random_seed(args.seed, int(args.rank))

    # perf knobs
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = bool(getattr(args, "tf32", True))
        torch.backends.cudnn.allow_tf32 = bool(getattr(args, "tf32", True))
    try:
        torch.set_float32_matmul_precision(str(getattr(args, "matmul_precision", "high")))
    except Exception:
        pass

    device = torch.device(args.device)

    # =======================================================
    # Output dir & logging
    # =======================================================
    exp_name = args.experiment or '-'.join([
        datetime.now().strftime("%Y%m%d-%H%M%S"),
        safe_model_name(args.model),
        f"{args.dataset}_ctc",
    ])

    output_dir: Optional[str] = None
    if _is_main_process(args):
        output_dir = get_outdir(args.output, exp_name)
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, 'args.yaml'), 'w', encoding='utf-8') as f:
            f.write(args_text)

    if args.distributed and _is_dist_avail_and_initialized():
        obj_list = [output_dir if output_dir else ""]
        torch.distributed.broadcast_object_list(obj_list, src=0)
        output_dir = obj_list[0] if obj_list[0] else None

    setup_logging_clean(output_dir, args)
    if output_dir and _is_main_process(args):
        _logger.info(f"Output dir: {output_dir}")

    _logger.info(f"device={args.device} distributed={args.distributed} rank={args.rank} world_size={args.world_size}")
    log_env_and_args(args, output_dir)

    # =======================================================
    # 1) Data
    # =======================================================
    train_loader, val_loader, charset, blank_id, num_ctc_classes = build_ctc_dataloaders(args, _logger)
    _logger.info(f"CTC vocab size (incl blank) = {num_ctc_classes}, blank_id = {blank_id}")

    # =======================================================
    # 2) Model - OPTIMIZED
    # =======================================================
    model_kwargs = dict(
        num_classes=num_ctc_classes,
        blank_id=blank_id,

        norm=args.norm,
        gn_max_groups=args.gn_max_groups,
        gn_nocast=args.gn_nocast,
        height_pool_mode=args.height_pool_mode,
        height_pool_mix=args.height_pool_mix,

        use_abs_pos=args.use_abs_pos,
        use_conv_pos=args.use_conv_pos,
        conv_pos_k=args.conv_pos_k,
        pos_allow_interp=args.pos_allow_interp,

        time_step=args.time_step,
        layer=args.layer,
        dim=args.dim,
        mlp_ratio=args.mlp_ratio,
        max_seq_len=args.max_seq_len,

        temporal_fuse=args.temporal_fuse,
        temporal_gate=args.temporal_gate,
        temporal_eps=args.temporal_eps,
        temporal_fuse_pre=args.temporal_fuse_pre,
        temporal_fuse_final=args.temporal_fuse_final,
        temporal_max_T=args.temporal_max_T,
        temporal_fp32_reduce=args.temporal_fp32_reduce,

        seq_layers=args.seq_layers,
        seq_nhead=args.seq_nhead,

        seq_block_layout=args.seq_block_layout,
        ssm_kernel=args.ssm_kernel,
        ssm_expand_ratio=args.ssm_expand_ratio,

        drop_rate=args.drop,
        drop_path_rate=args.drop_path,

        use_aux_ctc=args.use_aux_ctc,
        aux_ctc_weight=args.aux_ctc_weight,
        aux_temporal_fuse=args.aux_temporal_fuse,

        use_checkpoint=args.use_checkpoint,

        # optional overrides (None => keep model defaults)
        use_temporal_coding=getattr(args, 'use_temporal_coding', None),
        use_dual_res_fusion=getattr(args, 'use_dual_res_fusion', None),
        use_token_merge=getattr(args, 'use_token_merge', None),
        token_min_keep_ratio=getattr(args, 'token_min_keep_ratio', None),
        token_blank_thresh=getattr(args, 'token_blank_thresh', None),
        token_merge_k=getattr(args, 'token_merge_k', None),
        dual_res_down_mode=getattr(args, 'dual_res_down_mode', None),
        dual_res_down_mix_init=getattr(args, 'dual_res_down_mix_init', None),
        use_mem_residual=getattr(args, 'use_mem_residual', None),
        mem_residual_init=getattr(args, 'mem_residual_init', None),

        # InkCoder / Temporal Coding hyperparameters (applied only when provided)
        ink_int_blur_ks=getattr(args, 'ink_int_blur_ks', None),
        ink_edge_blur_ks=getattr(args, 'ink_edge_blur_ks', None),
        ink_q_low=getattr(args, 'ink_q_low', None),
        ink_q_high=getattr(args, 'ink_q_high', None),
        ink_q_edge=getattr(args, 'ink_q_edge', None),
        ink_edge_consistency=getattr(args, 'ink_edge_consistency', None),
        ink_edge_cons_ks=getattr(args, 'ink_edge_cons_ks', None),
        ink_edge_cons_kappa=getattr(args, 'ink_edge_cons_kappa', None),
        ink_edge_cons_tau=getattr(args, 'ink_edge_cons_tau', None),
        ink_edge_int_gate=getattr(args, 'ink_edge_int_gate', None),
        ink_edge_int_tau=getattr(args, 'ink_edge_int_tau', None),
        ink_edge_int_kappa=getattr(args, 'ink_edge_int_kappa', None),
        ink_d_speckle_suppress=getattr(args, 'ink_d_speckle_suppress', None),
        ink_d_cons_ks=getattr(args, 'ink_d_cons_ks', None),
        ink_d_cons_kappa=getattr(args, 'ink_d_cons_kappa', None),
        ink_d_cons_tau=getattr(args, 'ink_d_cons_tau', None),
        ink_d_cons_power=getattr(args, 'ink_d_cons_power', None),
        ink_use_multiscale_edge=getattr(args, 'ink_use_multiscale_edge', None),
        ink_edge_ms_down=getattr(args, 'ink_edge_ms_down', None),
        ink_base_alpha=getattr(args, 'ink_base_alpha', None),
        ink_alpha_decay=getattr(args, 'ink_alpha_decay', None),
        ink_use_time_varying_fusion=getattr(args, 'ink_use_time_varying_fusion', None),
        ink_fuse_bias=getattr(args, 'ink_fuse_bias', None),
        ink_fuse_slope=getattr(args, 'ink_fuse_slope', None),
        ink_force_aux_for_fusion=getattr(args, 'ink_force_aux_for_fusion', None),
        ink_theta_min=getattr(args, 'ink_theta_min', None),
        ink_theta_max=getattr(args, 'ink_theta_max', None),
        ink_theta_gamma=getattr(args, 'ink_theta_gamma', None),
        ink_eps=getattr(args, 'ink_eps', None),
    )
    model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}

    # ============ CREATE OPTIMIZED MODEL ============
    model = SpikeHTR(**model_kwargs)
    _log_resolved_model(args, model_kwargs, model=model)
    _logger.info(f"Model {safe_model_name(args.model)} OPTIMIZED created, param count: {sum(p.numel() for p in model.parameters())}")

    if args.force_no_store_v:
        force_disable_store_v(model)
        _logger.info("Applied --force-no-store-v (best-effort).")

    model = model.to(device)
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.initial_checkpoint and not args.resume:
        _load_initial_checkpoint(model, args.initial_checkpoint)

    log_model_and_params(args, model, output_dir)

    # =======================================================
    # 3) Optimizer / AMP / Loss
    # =======================================================
    optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
    amp_autocast, loss_scaler, model, optimizer, _ = create_amp(args, model, optimizer)
    ctc_loss_fn = nn.CTCLoss(blank=blank_id, zero_infinity=True).to(device)


    # resume BEFORE DDP wrap
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model, args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=_is_main_process(args)
        )

    # wrap DDP now
    model = wrap_ddp(model, args)
    model_for_saving = unwrap_model(model)

    # EMA
    model_ema = None
    if args.model_ema:
        model_ema = ModelEmaV2(
            model_for_saving,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else None
        )

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = int(resume_epoch) if resume_epoch is not None else 0

    # Scheduler stepping policy
    sched_on_updates = None if getattr(args, 'sched_on_updates', None) is None else bool(getattr(args, 'sched_on_updates'))
    if lr_scheduler is not None and sched_on_updates is None:
        if hasattr(lr_scheduler, 't_in_epochs'):
            sched_on_updates = (not bool(getattr(lr_scheduler, 't_in_epochs')))
        else:
            sched_on_updates = hasattr(lr_scheduler, 'step_update')
    if sched_on_updates is None:
        sched_on_updates = False
    args._sched_on_updates = bool(sched_on_updates)

    if lr_scheduler is not None and start_epoch > 0 and (not args._sched_on_updates):
        lr_scheduler.step(start_epoch)

    _logger.info(f"Scheduled epochs: {num_epochs}")

    saver = None
    if _is_main_process(args):
        decreasing = args.eval_metric in ('loss', 'cer', 'wer')
        saver = CheckpointSaver(
            model=model_for_saving,
            optimizer=optimizer,
            args=args,
            model_ema=model_ema,
            amp_scaler=loss_scaler,
            checkpoint_dir=output_dir,
            recovery_dir=output_dir,
            decreasing=decreasing,
            max_history=args.checkpoint_hist
        )

        if args.log_wandb and has_wandb:
            wandb.init(project=args.dataset, name=exp_name, config=vars(args))

    # =======================================================
    # 4) Train loop
    # =======================================================
    best_metric = None
    best_epoch = None
    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed and hasattr(train_loader, 'sampler') and hasattr(train_loader.sampler, 'set_epoch'):
                train_loader.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(
                epoch, model, train_loader, optimizer, ctc_loss_fn, args,
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                amp_autocast=amp_autocast, loss_scaler=loss_scaler,
                model_ema=model_ema,
                device=device,
            )

            # BN distribute only makes sense if using BN
            if args.distributed and (args.norm == 'bn'):
                if args.dist_bn in ('broadcast', 'reduce'):
                    distribute_bn(model_for_saving, args.world_size, args.dist_bn == 'reduce')
                if model_ema is not None and args.dist_bn in ('broadcast', 'reduce'):
                    distribute_bn(model_ema.module, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(
                model, val_loader, ctc_loss_fn, args,
                charset, blank_id,
                amp_autocast=amp_autocast,
                device=device,
            )

            if model_ema is not None:
                ema_eval_metrics = validate(
                    model_ema.module, val_loader, ctc_loss_fn, args,
                    charset, blank_id,
                    amp_autocast=amp_autocast, log_suffix=' (EMA)',
                    device=device,
                )
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None and (not bool(getattr(args, '_sched_on_updates', False))):
                metric_for_scheduler = eval_metrics.get(args.eval_metric, eval_metrics.get('loss', None))
                lr_scheduler.step(epoch + 1, metric_for_scheduler)

            if output_dir is not None and _is_main_process(args):
                update_summary(
                    epoch, train_metrics, eval_metrics,
                    os.path.join(output_dir, 'summary.csv'),
                    write_header=best_metric is None,
                    log_wandb=False
                )
                if args.log_wandb and has_wandb:
                    wandb.log({
                        "train/loss": train_metrics['loss'],
                        "val/loss": eval_metrics['loss'],
                        "val/cer": eval_metrics['cer'],
                        "val/wer": eval_metrics['wer'],
                        "val/exact_acc": eval_metrics['exact_acc'],
                        "lr": float(optimizer.param_groups[0]['lr']),
                    }, step=epoch)


            if saver is not None:
                save_metric = eval_metrics.get(args.eval_metric, eval_metrics.get('loss', None))
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
                _logger.info(f"*** Best metric: {best_metric} (epoch {best_epoch})")

    except KeyboardInterrupt:
        _logger.info("Interrupted by user.")

    if args.distributed and _is_dist_avail_and_initialized():
        torch.distributed.barrier()

    if best_metric is not None and _is_main_process(args):
        _logger.info(f"*** Best metric: {best_metric} (epoch {best_epoch})")


# ============================================================
# train_one_epoch & validate
# ============================================================

def train_one_epoch(
    epoch,
    model,
    loader,
    optimizer,
    loss_fn,
    args,
    lr_scheduler=None,
    saver=None,
    output_dir=None,
    amp_autocast=suppress,
    loss_scaler=None,
    model_ema=None,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    accum_steps = max(1, int(getattr(args, 'accum_steps', 1)))
    num_updates_per_epoch = max(1, math.ceil(len(loader) / accum_steps))
    num_updates = epoch * num_updates_per_epoch

    model.train()
    _zero_grad(optimizer)

    end = time.time()
    last_idx = len(loader) - 1

    for batch_idx, batch in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)

        if not isinstance(batch, (list, tuple)) or len(batch) != 5:
            raise RuntimeError(
                "DataLoader must return 5 items: (images, targets_concat, target_lengths, texts, extra). "
                f"Got type={type(batch)} len={len(batch) if isinstance(batch,(list,tuple)) else 'NA'}"
            )

        images, targets_concat, target_lengths, texts, extra = batch

        images = images.to(device, non_blocking=True)
        targets_concat = targets_concat.to(device, non_blocking=True)
        target_lengths = target_lengths.to(device, non_blocking=True)
        B = images.size(0)

        if args.channels_last:
            images = images.contiguous(memory_format=torch.channels_last)

        need_update = ((batch_idx + 1) % accum_steps == 0) or last_batch

        sync_ctx = suppress()
        if args.distributed and accum_steps > 1 and (not need_update) and hasattr(model, "no_sync"):
            sync_ctx = model.no_sync()

        with sync_ctx:
            with amp_autocast():
                out = model(images)
                aux_logits = None
                feat_lengths = None
                if isinstance(out, (tuple, list)) and len(out) == 2:
                    logits, feat_lengths = out
                elif isinstance(out, dict):
                    logits = out.get("logits", None)
                    aux_logits = out.get("aux_logits", None)
                    feat_lengths = out.get("feat_lengths", out.get("feature_lengths", None))
                    if logits is None:
                        raise RuntimeError(f"Model returned dict without 'logits' key. keys={list(out.keys())}")
                else:
                    logits = out

            if feat_lengths is not None and isinstance(extra, dict):
                extra = dict(extra)
                extra["feat_lengths"] = feat_lengths

            T_seq = int(logits.size(0))
            if args.strict_ctc_lengths:
                input_lengths = infer_input_lengths_strict(extra, T_seq=T_seq, B=B, device=device)
            else:
                input_lengths = torch.full((B,), int(T_seq), dtype=torch.long, device=device)

            if torch.any(target_lengths > input_lengths):
                policy = getattr(args, "ctc_len_violation", "clamp")
                msg = (
                    f"[CTC] target_lengths > input_lengths at epoch={epoch}, iter={batch_idx}. "
                    f"tgt_max={int(target_lengths.max().item())}, in_max={int(input_lengths.max().item())}, T_seq={T_seq}. "
                    "Check extra['feat_lengths'] or valid_w/target_w estimation."
                )

                if getattr(args, "dump_on_ctc_violation", False) and _is_main_process(args) and output_dir:
                    dump_non_finite_batch(
                        output_dir, epoch, batch_idx,
                        images, targets_concat, target_lengths, input_lengths,
                        texts=texts, extra=extra, logits=logits
                    )

                if policy == "raise":
                    raise RuntimeError(msg)
                elif policy == "skip":
                    if _is_main_process(args):
                        _logger.warning(msg + " -> skip batch")
                    functional.reset_net(unwrap_model(model))
                    _zero_grad(optimizer)
                    end = time.time()
                    continue
                else:
                    if _is_main_process(args) and batch_idx == 0:
                        _logger.warning(msg + " -> clamp input_lengths to target_lengths")
                    input_lengths = torch.maximum(input_lengths, target_lengths).clamp(min=1, max=int(T_seq))

            with autocast_off():
                logits_fp32 = logits.float()
                if float(getattr(args, 'logit_clip', 0.0)) > 0:
                    c = float(getattr(args, 'logit_clip', 0.0))
                    logits_fp32 = torch.clamp(logits_fp32, -c, c)
                log_probs = logits_fp32.log_softmax(2)
                loss = loss_fn(log_probs, targets_concat, input_lengths, target_lengths)

                if (aux_logits is not None) and float(getattr(args, "aux_ctc_weight", 0.0)) > 0:
                    T_aux = int(aux_logits.size(0))
                    if args.strict_ctc_lengths:
                        aux_input_lengths = infer_input_lengths_strict(extra, T_seq=T_aux, B=B, device=device)
                    else:
                        aux_input_lengths = torch.full((B,), int(T_aux), dtype=torch.long, device=device)

                    if torch.any(target_lengths > aux_input_lengths):
                        aux_input_lengths = torch.maximum(aux_input_lengths, target_lengths)
                    aux_fp32 = aux_logits.float()
                    if float(getattr(args, 'logit_clip', 0.0)) > 0:
                        c = float(getattr(args, 'logit_clip', 0.0))
                        aux_fp32 = torch.clamp(aux_fp32, -c, c)
                    aux_log_probs = aux_fp32.log_softmax(2)
                    aux_loss = loss_fn(aux_log_probs, targets_concat, aux_input_lengths, target_lengths)

                    loss = loss + float(args.aux_ctc_weight) * aux_loss
                    if float(getattr(args, 'consistency_kl_weight', 0.0)) > 0:
                        temp = max(float(getattr(args, 'consistency_kl_temp', 1.0)), 1e-6)
                        main_lp = (logits_fp32 / temp).log_softmax(dim=2)
                        aux_lp2 = (aux_fp32 / temp).log_softmax(dim=2)

                        Tmin = min(int(main_lp.size(0)), int(aux_lp2.size(0)))
                        main_lp = main_lp[:Tmin]
                        aux_lp2 = aux_lp2[:Tmin]
                        main_p = main_lp.exp()
                        aux_p = aux_lp2.exp()

                        kl_am = F.kl_div(main_lp, aux_p, reduction='none', log_target=False).sum(dim=2)
                        kl_ma = F.kl_div(aux_lp2, main_p, reduction='none', log_target=False).sum(dim=2)

                        if bool(getattr(args, 'consistency_kl_symmetric', True)):
                            kl = 0.5 * (kl_am + kl_ma)
                        else:
                            kl = kl_am

                        lengths = input_lengths
                        try:
                            lengths = torch.minimum(lengths, aux_input_lengths)
                        except Exception:
                            pass

                        t = torch.arange(Tmin, device=device).unsqueeze(1)
                        mask = (t < lengths.unsqueeze(0)).to(dtype=kl.dtype)
                        denom = mask.sum().clamp(min=1.0)
                        kl_loss = (kl * mask).sum() / denom

                        loss = loss + float(getattr(args, 'consistency_kl_weight', 0.0)) * kl_loss

            if not torch.isfinite(loss):
                if _is_main_process(args):
                    _logger.error(
                        f"[NaN/Inf] Non-finite loss at epoch={epoch}, iter={batch_idx}. "
                        f"loss={loss.detach().cpu().item()}"
                    )
                    if output_dir:
                        dump_non_finite_batch(
                            output_dir, epoch, batch_idx,
                            images, targets_concat, target_lengths, input_lengths,
                            texts=texts, extra=extra, logits=logits
                        )
                functional.reset_net(unwrap_model(model))
                _zero_grad(optimizer)
                end = time.time()
                continue

            loss_to_backward = loss / accum_steps
            losses_m.update(loss.item(), B)

            if loss_scaler is not None:
                loss_scaler(
                    loss_to_backward, optimizer,
                    clip_grad=args.clip_grad, clip_mode=args.clip_mode,
                    parameters=model_parameters(unwrap_model(model), exclude_head='agc' in args.clip_mode),
                    create_graph=second_order,
                    need_update=need_update,
                )
            else:
                loss_to_backward.backward(create_graph=second_order)
                if need_update:
                    if args.clip_grad is not None:
                        from timm.utils import dispatch_clip_grad
                        dispatch_clip_grad(
                            model_parameters(unwrap_model(model), exclude_head='agc' in args.clip_mode),
                            value=args.clip_grad, mode=args.clip_mode
                        )
                    optimizer.step()

        if need_update:
            _zero_grad(optimizer)

        functional.reset_net(unwrap_model(model))

        if model_ema is not None and need_update:
            model_ema.update(unwrap_model(model))

        if need_update:
            num_updates += 1
            if lr_scheduler is not None and bool(getattr(args, '_sched_on_updates', False)):
                lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

        do_log = last_batch or (batch_idx % args.log_interval == 0)
        if do_log and device.type == "cuda":
            torch.cuda.synchronize()

        batch_time_m.update(time.time() - end)

        if do_log:
            loss_for_log = loss.detach()
            if args.distributed:
                loss_for_log = reduce_tensor(loss_for_log, args.world_size)

            lr = float(sum(pg['lr'] for pg in optimizer.param_groups) / max(len(optimizer.param_groups), 1))
            if _is_main_process(args):
                extra_xr = ""

                _logger.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss_val:>9.6f} ({loss_avg:>6.4f})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx, len(loader),
                        100. * batch_idx / max(1, last_idx),
                        loss_val=float(loss_for_log.item()),
                        loss_avg=float(losses_m.avg),
                        batch_time=batch_time_m,
                        rate=B * args.world_size / max(batch_time_m.val, 1e-8),
                        rate_avg=B * args.world_size / max(batch_time_m.avg, 1e-8),
                        lr=lr,
                        data_time=data_time_m
                    ) + extra_xr
                )

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        images,
                        os.path.join(output_dir, f'train-batch-{batch_idx}.jpg'),
                        padding=0,
                        normalize=True
                    )

        end = time.time()

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()
    out = OrderedDict([('loss', float(losses_m.avg))])
    return out


def validate(
    model,
    loader,
    loss_fn,
    args,
    charset,
    blank_id,
    amp_autocast=suppress,
    log_suffix='',
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()

    total_char_errs = 0
    total_chars = 0
    total_word_errs = 0
    total_words = 0
    total_exact = 0
    total_samples = 0

    model.eval()
    end = time.time()
    last_idx = len(loader) - 1

    with torch.no_grad():
        for batch_idx, batch in enumerate(loader):
            last_batch = batch_idx == last_idx

            if not isinstance(batch, (list, tuple)) or len(batch) != 5:
                raise RuntimeError("DataLoader must return 5 items (images, targets_concat, target_lengths, texts, extra).")

            images, targets_concat, target_lengths, texts, extra = batch

            images = images.to(device, non_blocking=True)
            targets_concat = targets_concat.to(device, non_blocking=True)
            target_lengths = target_lengths.to(device, non_blocking=True)
            B = images.size(0)

            if args.channels_last:
                images = images.contiguous(memory_format=torch.channels_last)

            with amp_autocast():
                out = model(images)
                feat_lengths = None
                if isinstance(out, (tuple, list)) and len(out) == 2:
                    logits, feat_lengths = out
                elif isinstance(out, dict):
                    logits = out.get("logits", None)
                    feat_lengths = out.get("feat_lengths", out.get("feature_lengths", None))
                    if logits is None:
                        raise RuntimeError(f"Model returned dict without 'logits' key. keys={list(out.keys())}")
                else:
                    logits = out

            if feat_lengths is not None and isinstance(extra, dict):
                extra = dict(extra)
                extra["feat_lengths"] = feat_lengths

            T_seq = int(logits.size(0))
            if args.strict_ctc_lengths:
                input_lengths = infer_input_lengths_strict(extra, T_seq=T_seq, B=B, device=device)
            else:
                input_lengths = torch.full((B,), int(T_seq), dtype=torch.long, device=device)

            if torch.any(target_lengths > input_lengths):
                input_lengths = torch.maximum(input_lengths, target_lengths).clamp(min=1, max=int(T_seq))

            with autocast_off():
                logits_fp32 = logits.float()
                if float(getattr(args, 'logit_clip', 0.0)) > 0:
                    c = float(getattr(args, 'logit_clip', 0.0))
                    logits_fp32 = torch.clamp(logits_fp32, -c, c)
                log_probs = logits_fp32.log_softmax(2)
                loss = loss_fn(log_probs, targets_concat, input_lengths, target_lengths)

            functional.reset_net(unwrap_model(model))

            reduced_loss = reduce_tensor(loss.detach(), args.world_size) if args.distributed else loss.detach()
            losses_m.update(float(reduced_loss.item()), B)

            pred_seqs = ctc_greedy_decode(logits, blank_id=blank_id, input_lengths=input_lengths)

            target_seqs = []
            offset = 0
            tlen_list = target_lengths.detach().cpu().tolist()
            for L in tlen_list:
                L = int(L)
                target_seqs.append(targets_concat[offset:offset + L].detach().cpu().tolist())
                offset += L

            for pred_ids, tgt_ids in zip(pred_seqs, target_seqs):
                ref_text = ids_to_text(tgt_ids, charset)
                pred_text = ids_to_text(pred_ids, charset)

                cerr, clen = compute_cer(ref_text, pred_text)
                werr, wlen = compute_wer(ref_text, pred_text)

                total_char_errs += int(cerr)
                total_chars += int(clen)
                total_word_errs += int(werr)
                total_words += int(wlen)
                total_samples += 1
                if pred_text == ref_text:
                    total_exact += 1

            batch_time_m.update(time.time() - end)
            end = time.time()

            if _is_main_process(args) and (last_batch or batch_idx % args.log_interval == 0):
                cer_val = total_char_errs / max(total_chars, 1)
                wer_val = total_word_errs / max(total_words, 1)
                exact_acc = 100.0 * total_exact / max(total_samples, 1)
                log_name = 'Val' + log_suffix
                _logger.info(
                    '{0}: [{1:>4d}/{2}]  '
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'CER: {cer:.4f}  WER: {wer:.4f}  '
                    'ExactAcc: {ea:.2f}%'.format(
                        log_name, batch_idx, last_idx,
                        batch_time=batch_time_m,
                        loss=losses_m,
                        cer=cer_val, wer=wer_val, ea=exact_acc
                    )
                )

    if args.distributed and _is_dist_avail_and_initialized():
        stats = torch.tensor(
            [total_char_errs, total_chars, total_word_errs, total_words, total_exact, total_samples],
            device=device, dtype=torch.long
        )
        torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.SUM)
        total_char_errs, total_chars, total_word_errs, total_words, total_exact, total_samples = stats.tolist()

    cer_final = total_char_errs / max(total_chars, 1)
    wer_final = total_word_errs / max(total_words, 1)
    exact_acc_final = 100.0 * total_exact / max(total_samples, 1)

    return OrderedDict([
        ('loss', float(losses_m.avg)),
        ('cer', float(cer_final)),
        ('wer', float(wer_final)),
        ('exact_acc', float(exact_acc_final)),
    ])


if __name__ == '__main__':
    main()