#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
ocr_datasets.common

Contract:
- Every dataset returns: (image_tensor, target_ids_1d, text, meta_dict)
- meta_dict MUST include:
    - valid_w: effective content width in pixels AFTER augmentation+resize, BEFORE right-pad region
    - target_w: padded width in pixels (the transform's target width)
- Collate ALWAYS returns a 5-tuple:
    images, targets_concat, target_lengths, texts, extra
  where extra contains tensors needed to compute per-sample CTC input_lengths.

"""

from __future__ import annotations

import math
import random
import unicodedata
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import numpy as np
import torch
from PIL import Image, ImageFilter, ImageOps
from torch.utils.data import Sampler
from torchvision import transforms
from torchvision.transforms import functional as TF


# -----------------------------
# Text normalization (HTR-VT style)
# -----------------------------

def normalize_transcription(
    text: str,
    *,
    form: str = "NFKC",
    collapse_whitespace: bool = True,
    strip: bool = True,
    drop_chars: str = "¬",
    replace_map: Optional[Dict[str, str]] = None,
) -> str:
    if text is None:
        return ""
    s = str(text)

    if form:
        try:
            s = unicodedata.normalize(form, s)
        except Exception:
            pass

    if replace_map:
        for k, v in replace_map.items():
            s = s.replace(k, v)

    if drop_chars:
        for ch in drop_chars:
            s = s.replace(ch, "")

    # HTR-VT: ' '.join(lbl.split())
    if collapse_whitespace:
        s = " ".join(s.split())

    if strip:
        s = s.strip()

    return s


# -----------------------------
# Charset mapper (CTC blank_id = 0)
# -----------------------------

@dataclass
class CharsetMapper:
    charset: List[str]
    char2id: Dict[str, int]
    blank_id: int = 0
    unk_token: Optional[str] = None  # if set, allow mapping OOV -> unk_id

    @classmethod
    def from_texts(
        cls,
        texts: Sequence[str],
        *,
        blank_id: int = 0,
        sort: bool = True,
        add_unk: bool = False,
        unk_token: str = "�",
    ) -> "CharsetMapper":
        all_chars = "".join(texts)
        charset = sorted(set(all_chars)) if sort else list(dict.fromkeys(all_chars))
        if add_unk and unk_token not in charset:
            charset.append(unk_token)
        char2id = {ch: i + 1 for i, ch in enumerate(charset)}  # 1..V, 0 reserved for blank
        return cls(
            charset=charset,
            char2id=char2id,
            blank_id=blank_id,
            unk_token=unk_token if add_unk else None,
        )

    @property
    def num_classes(self) -> int:
        return len(self.charset) + 1  # + blank

    def encode(self, text: str, *, oov: str = "error") -> torch.Tensor:
        """
        oov:
          - "error": raise on unknown char (recommended for trustworthy metrics)
          - "unk": map unknown char to unk_token (requires add_unk=True)
          - "drop": drop unknown char
        """
        if text is None:
            raise ValueError("encode: text is None")
        if len(text) == 0:
            raise ValueError("encode: empty transcription is not allowed for CTC")

        ids: List[int] = []
        for ch in text:
            if ch in self.char2id:
                ids.append(self.char2id[ch])
                continue
            if oov == "drop":
                continue
            if oov == "unk":
                if self.unk_token is None:
                    raise ValueError("encode(oov='unk') requires add_unk=True when building CharsetMapper")
                ids.append(self.char2id[self.unk_token])
                continue
            raise KeyError(f"Unknown character: {repr(ch)}")

        if len(ids) == 0:
            raise ValueError(f"encode: after OOV handling, target is empty. text={repr(text)} oov={oov}")
        return torch.tensor(ids, dtype=torch.long)

    def decode_ids(self, ids: Sequence[int]) -> str:
        out = []
        for i in ids:
            if i == self.blank_id:
                continue
            j = int(i) - 1
            if 0 <= j < len(self.charset):
                out.append(self.charset[j])
        return "".join(out)

    def __len__(self) -> int:
        return len(self.charset)

    def __getitem__(self, idx: int) -> str:
        return self.charset[idx]


# -----------------------------
# Geometry: resize height, right-pad to target width
# -----------------------------

class LineResizePad:
    """
    Resize line image to fixed height, then pad/crop to fixed width.

    - Step 1: resize to target_height keeping aspect ratio => new_w
    - Step 2:
        - if new_w <= target_width: right pad to target_width (background=fill)
        - else: downscale to target_width (keeping aspect ratio), then pad vertically to target_height
    """

    def __init__(self, target_height: int = 128, target_width: int = 512, fill: int = 255):
        self.target_height = int(target_height)
        self.target_width = int(target_width)
        self.fill = int(fill)

    def __call__(self, img: Image.Image) -> Image.Image:
        if img.mode != "RGB":
            img = img.convert("RGB")

        w, h = img.size
        if w <= 0 or h <= 0:
            return Image.new("RGB", (self.target_width, self.target_height), (self.fill, self.fill, self.fill))

        # resize to fixed height
        scale = self.target_height / float(h)
        new_w = max(1, int(round(w * scale)))
        img = img.resize((new_w, self.target_height), resample=Image.BICUBIC)

        if new_w > self.target_width:
            # downscale to target width, then pad vertically
            scale2 = self.target_width / float(new_w)
            new_h2 = max(1, int(round(self.target_height * scale2)))
            img = img.resize((self.target_width, new_h2), resample=Image.BICUBIC)
            pad_top = (self.target_height - new_h2) // 2
            pad_bottom = self.target_height - new_h2 - pad_top
            img = TF.pad(img, (0, pad_top, 0, pad_bottom), fill=self.fill)
        else:
            # right pad
            pad_right = self.target_width - new_w
            img = TF.pad(img, (0, 0, pad_right, 0), fill=self.fill)

        return img


def compute_valid_width(raw_w: int, raw_h: int, target_h: int, target_w: int) -> int:
    """Legacy/cheap valid_w: only based on raw size ratio (no aug)."""
    raw_w = int(raw_w)
    raw_h = int(raw_h)
    target_h = int(target_h)
    target_w = int(target_w)
    if raw_w <= 0 or raw_h <= 0 or target_h <= 0 or target_w <= 0:
        return max(target_w, 1)
    scale = target_h / float(raw_h)
    new_w = max(1, int(round(raw_w * scale)))
    return int(min(new_w, target_w))


def estimate_valid_width_from_ink(
    img_rgb_padded: Image.Image,
    *,
    bg_thresh: int = 250,
    min_w: int = 1,
    max_w: Optional[int] = None,
) -> int:
    """
    Robust valid_w estimation: find rightmost column that contains 'ink' pixels.
    Assumes right padding is white-ish.

    Important: **never underestimate** valid_w in strict-CTC pipelines.
    If ink cannot be detected (very faint writing / compression artifacts),
    fall back to max_w (usually target_width) to avoid input_lengths collapsing to 1.

    bg_thresh:
      - >=0: fixed threshold, ink = gray < bg_thresh
      - <0 : auto threshold (Otsu) per image
    """
    g = img_rgb_padded.convert("L")
    arr = np.array(g, dtype=np.uint8)  # [H,W]

    # auto threshold (Otsu) when bg_thresh < 0
    thr = int(bg_thresh)
    if thr < 0:
        hist = np.bincount(arr.reshape(-1), minlength=256).astype(np.float64)
        total = hist.sum()
        if total > 0:
            prob = hist / total
            omega = np.cumsum(prob)
            mu = np.cumsum(prob * np.arange(256))
            mu_t = mu[-1]
            # between-class variance
            denom = omega * (1.0 - omega)
            denom[denom == 0] = np.nan
            sigma_b2 = (mu_t * omega - mu) ** 2 / denom
            k = int(np.nanargmax(sigma_b2))
            thr = k
        else:
            thr = 250

    ink = arr < thr
    col_has = ink.any(axis=0)
    if not col_has.any():
        vw = int(max_w) if max_w is not None else int(min_w)
    else:
        vw = int(np.where(col_has)[0].max() + 1)
        vw = max(vw, int(min_w))
    if max_w is not None:
        vw = min(vw, int(max_w))
    return vw


def merge_meta_keep_max_valid_w(meta: dict, meta_u: dict) -> None:
    """Merge meta updates but **never** allow valid_w to decrease."""
    if not isinstance(meta, dict) or not isinstance(meta_u, dict):
        return

    if "valid_w" in meta_u:
        try:
            base_vw = float(meta.get("valid_w", 1))
        except Exception:
            base_vw = 1.0
        try:
            upd_vw = float(meta_u.get("valid_w", 1))
        except Exception:
            upd_vw = 1.0
        meta["valid_w"] = int(max(base_vw, upd_vw))

    if "target_w" in meta_u:
        try:
            meta["target_w"] = int(meta_u.get("target_w"))
        except Exception:
            pass

    for k, v in meta_u.items():
        if k in ("valid_w", "target_w"):
            continue
        meta[k] = v

# -----------------------------
# HTR-VT-inspired augmentation pack
# -----------------------------

class RandomStrokeWidth:
    """Randomly apply erosion or dilation (PIL Min/Max filter) to mimic pen thickness variation."""
    def __init__(self, p: float = 0.25, k_min: int = 3, k_max: int = 5):
        self.p = float(p)
        self.k_min = int(k_min)
        self.k_max = int(k_max)

    def __call__(self, img: Image.Image) -> Image.Image:
        if self.p <= 0 or random.random() >= self.p:
            return img
        # kernel must be odd for PIL Min/Max filter
        k = random.choice([k for k in range(self.k_min, self.k_max + 1) if k % 2 == 1] or [3])
        if random.random() < 0.5:
            # erosion
            return img.filter(ImageFilter.MinFilter(size=k))
        # dilation
        return img.filter(ImageFilter.MaxFilter(size=k))


class RandomSharpenPIL:
    def __init__(self, p: float = 0.15, radius: float = 1.0, percent: int = 150, threshold: int = 3):
        self.p = float(p)
        self.radius = float(radius)
        self.percent = int(percent)
        self.threshold = int(threshold)

    def __call__(self, img: Image.Image) -> Image.Image:
        if self.p <= 0 or random.random() >= self.p:
            return img
        return img.filter(ImageFilter.UnsharpMask(radius=self.radius, percent=self.percent, threshold=self.threshold))


class RandomInvert:
    def __init__(self, p: float = 0.0):
        self.p = float(p)

    def __call__(self, img: Image.Image) -> Image.Image:
        if self.p <= 0 or random.random() >= self.p:
            return img
        return ImageOps.invert(img)

class RandomWidthStretchPIL:
    """Width-only stretch/compress around image center, keeping output size unchanged."""
    def __init__(self, p: float = 0.0, min_scale: float = 0.7, max_scale: float = 1.3, fill: int = 255):
        self.p = float(p)
        self.min_scale = float(min_scale)
        self.max_scale = float(max_scale)
        self.fill = int(fill)

    def __call__(self, img: Image.Image) -> Image.Image:
        if self.p <= 0 or random.random() >= self.p:
            return img
        w, h = img.size
        if w <= 1 or h <= 1:
            return img
        s = random.uniform(self.min_scale, self.max_scale)
        cx = (w - 1) * 0.5
        a = 1.0 / s
        c = cx - cx / s
        return img.transform(
            (w, h),
            Image.AFFINE,
            (a, 0.0, c, 0.0, 1.0, 0.0),
            resample=Image.BICUBIC,
            fillcolor=(self.fill, self.fill, self.fill),
        )

class RandomAffineLite:
    """
    Light geometry: small rotate/shear/translate.
    (Safer than full projective for strict CTC, and pad stays clean after resize+pad.)
    """
    def __init__(
        self,
        p: float = 0.15,
        degrees: float = 2.0,
        translate: float = 0.01,
        shear: float = 2.0,
        fill: int = 255,
    ):
        self.p = float(p)
        self.degrees = float(degrees)
        self.translate = float(translate)
        self.shear = float(shear)
        self.fill = int(fill)

        try:
            self.aff = transforms.RandomAffine(
                degrees=self.degrees,
                translate=(self.translate, self.translate),
                shear=(-self.shear, self.shear),
                fill=self.fill,
            )
        except TypeError:
            # older torchvision without 'fill' kw
            self.aff = transforms.RandomAffine(
                degrees=self.degrees,
                translate=(self.translate, self.translate),
                shear=(-self.shear, self.shear),
            )
            self._no_fill = True
        else:
            self._no_fill = False

    def __call__(self, img: Image.Image) -> Image.Image:
        if self.p <= 0 or random.random() >= self.p:
            return img
        out = self.aff(img)
        if self._no_fill:
            # If fill unsupported, edges may be black; we keep it low p by default.
            return out
        return out




class AddGaussianNoiseTensor:
    def __init__(self, p: float = 0.15, std: float = 0.03):
        self.p = float(p)
        self.std = float(std)

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        # x: [C,H,W] in [0,1] (before Normalize)
        if self.p <= 0 or self.std <= 0 or random.random() >= self.p:
            return x
        return (x + torch.randn_like(x) * self.std).clamp(0.0, 1.0)


# -----------------------------
# Unified line transform (returns tensor and meta update)
# -----------------------------

class LineCTCTransform:
    """
    A single callable that applies:
      pre-aug (optional) -> resize_pad -> stroke aug -> valid_w estimation -> ToTensor -> noise -> Normalize

    Returns:
      - tensor [C,H,W]
      - meta_update dict {"valid_w": ..., "target_w": ...}
    """

    def __init__(
        self,
        *,
        target_height: int,
        target_width: int,
        normalize: str = "half",
        # text/ink related
        estimate_valid_w: bool = True,
        bg_thresh: int = 250,
        # augmentation knobs
        aug_affine_p: float = 0.15,
        aug_wstretch_p: float = 0.0,
        aug_wstretch_min: float = 0.7,
        aug_wstretch_max: float = 1.3,
        aug_degrees: float = 2.0,
        aug_translate: float = 0.01,
        aug_shear: float = 2.0,
        aug_stroke_p: float = 0.25,
        aug_stroke_kmin: int = 3,
        aug_stroke_kmax: int = 5,
        aug_sharpen_p: float = 0.15,
        aug_invert_p: float = 0.0,
        aug_noise_p: float = 0.15,
        aug_noise_std: float = 0.03,
    ):
        self.target_height = int(target_height)
        self.target_width = int(target_width)
        self.bg_thresh = int(bg_thresh)
        self.estimate_valid_w = bool(estimate_valid_w)

        self.pre_affine = RandomAffineLite(
            p=aug_affine_p,
            degrees=aug_degrees,
            translate=aug_translate,
            shear=aug_shear,
            fill=255,
        )
        self.pre_wstretch = RandomWidthStretchPIL(
            p = aug_wstretch_p,
            min_scale = aug_wstretch_min,
            max_scale = aug_wstretch_max,
            fill = 255,
        )

        self.resize_pad = LineResizePad(target_height=self.target_height, target_width=self.target_width, fill=255)
        self.stroke = RandomStrokeWidth(p=aug_stroke_p, k_min=aug_stroke_kmin, k_max=aug_stroke_kmax)
        self.sharpen = RandomSharpenPIL(p=aug_sharpen_p)
        self.invert = RandomInvert(p=aug_invert_p)

        self.to_tensor = transforms.ToTensor()
        self.noise = AddGaussianNoiseTensor(p=aug_noise_p, std=aug_noise_std)

        norm = (normalize or "imagenet").lower()
        if norm in ("half", "0.5", "pm1", "minus1_1", "old"):
            self.norm = transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                             std=(0.5, 0.5, 0.5))
        elif norm in ("none", "no", "identity"):
            self.norm = None
        else:
            self.norm = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                             std=(0.229, 0.224, 0.225))

    def __call__(self, img: Image.Image):
        # pre-aug
        if img.mode != "RGB":
            img = img.convert("RGB")
        # img = self.pre_affine(img)
        img = self.pre_wstretch(img)
        img = self.pre_affine(img)
        # resize+pad
        img = self.resize_pad(img)

        # post pad-safe aug
        img = self.invert(img)
        img = self.stroke(img)
        img = self.sharpen(img)

        # valid_w estimation
        if self.estimate_valid_w:
            vw = estimate_valid_width_from_ink(img, bg_thresh=self.bg_thresh, min_w=1, max_w=self.target_width)
        else:
            vw = self.target_width

        x = self.to_tensor(img)  # [0,1]
        x = self.noise(x)
        if self.norm is not None:
            x = self.norm(x)

        meta_u = {"valid_w": int(vw), "target_w": int(self.target_width)}
        return x, meta_u


def create_line_transform(
    target_height: int = 128,
    target_width: int = 512,
    *,
    normalize: str = "half",
    # ink-based valid_w
    estimate_valid_w: bool = True,
    bg_thresh: int = 250,
    # augmentation
    aug_affine_p: float = 0.15,
    aug_wstretch_p: float = 0.0,
    aug_wstretch_min: float = 0.7,
    aug_wstretch_max: float = 1.3,
    aug_degrees: float = 2.0,
    aug_translate: float = 0.01,
    aug_shear: float = 2.0,
    aug_stroke_p: float = 0.25,
    aug_stroke_kmin: int = 3,
    aug_stroke_kmax: int = 5,
    aug_sharpen_p: float = 0.15,
    aug_invert_p: float = 0.0,
    aug_noise_p: float = 0.15,
    aug_noise_std: float = 0.03,
):
    """
    Returns a callable transform. It may return:
      - tensor, or
      - (tensor, meta_update)

    Our LineCTCTransform returns (tensor, meta_update) to keep valid_w accurate after aug.
    """
    return LineCTCTransform(
        target_height=target_height,
        target_width=target_width,
        normalize=normalize,
        estimate_valid_w=estimate_valid_w,
        bg_thresh=bg_thresh,
        aug_affine_p=aug_affine_p,
        aug_wstretch_p=aug_wstretch_p,
        aug_wstretch_min = aug_wstretch_min,
        aug_wstretch_max = aug_wstretch_max,
        aug_degrees=aug_degrees,
        aug_translate=aug_translate,
        aug_shear=aug_shear,
        aug_stroke_p=aug_stroke_p,
        aug_stroke_kmin=aug_stroke_kmin,
        aug_stroke_kmax=aug_stroke_kmax,
        aug_sharpen_p=aug_sharpen_p,
        aug_invert_p=aug_invert_p,
        aug_noise_p=aug_noise_p,
        aug_noise_std=aug_noise_std,
    )


# -----------------------------
# Strict CTC collate (always 5-tuple)
# -----------------------------

def ctc_collate(batch: Sequence[Tuple[torch.Tensor, torch.Tensor, str, Dict[str, Any]]]):
    if len(batch) == 0:
        raise ValueError("ctc_collate: empty batch")

    first = batch[0]
    if not (isinstance(first, (tuple, list)) and len(first) == 4):
        raise ValueError("ctc_collate: each item must be (img, target, text, meta)")

    images, targets, texts, metas = zip(*batch)

    images_t = torch.stack(images, dim=0)
    target_lengths = torch.tensor([t.numel() for t in targets], dtype=torch.long)
    targets_concat = torch.cat([t.to(dtype=torch.long) for t in targets], dim=0)

    valid_ws: List[int] = []
    target_ws: List[int] = []
    for i, m in enumerate(metas):
        if not isinstance(m, dict):
            raise TypeError(f"ctc_collate: meta must be dict, got {type(m)} at idx={i}")
        if "valid_w" not in m or "target_w" not in m:
            raise KeyError(f"ctc_collate: meta must include valid_w and target_w, idx={i}, keys={list(m.keys())}")
        vw = int(m["valid_w"])
        tw = int(m["target_w"])
        if vw <= 0 or tw <= 1:
            raise ValueError(f"ctc_collate: invalid valid_w/target_w at idx={i}: {vw}/{tw}")
        valid_ws.append(vw)
        target_ws.append(tw)

    extra: Dict[str, Any] = {"valid_w": torch.tensor(valid_ws, dtype=torch.long)}
    tw0 = target_ws[0]
    extra["target_w"] = int(tw0) if all(t == tw0 for t in target_ws) else torch.tensor(target_ws, dtype=torch.long)

    return images_t, targets_concat, target_lengths, list(texts), extra


# -----------------------------
# Distributed eval sampler (NO duplicates)
# -----------------------------

class DistributedShardSampler(Sampler[int]):
    def __init__(
        self,
        dataset,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = False,
        seed: int = 0,
    ):
        self.dataset = dataset
        if num_replicas is None or rank is None:
            if torch.distributed.is_available() and torch.distributed.is_initialized():
                num_replicas = torch.distributed.get_world_size()
                rank = torch.distributed.get_rank()
            else:
                num_replicas = 1
                rank = 0
        self.num_replicas = int(num_replicas)
        self.rank = int(rank)
        self.shuffle = bool(shuffle)
        self.seed = int(seed)

        self.indices = list(range(len(self.dataset)))
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.seed)
            self.indices = torch.randperm(len(self.dataset), generator=g).tolist()

        self.rank_indices = self.indices[self.rank::self.num_replicas]

    def __iter__(self):
        return iter(self.rank_indices)

    def __len__(self):
        return len(self.rank_indices)


# -----------------------------
# CTC input length inference (strict)
# -----------------------------

def infer_ctc_input_lengths(
    extra: Dict[str, Any],
    *,
    T_seq: int,
    target_lengths: Optional[torch.Tensor] = None,
    device: Optional[torch.device] = None,
    fix: str = "warn_clamp",
) -> torch.Tensor:
    if extra is None or not isinstance(extra, dict):
        raise TypeError(f"infer_ctc_input_lengths: extra must be dict, got {type(extra)}")

    B = int(target_lengths.numel()) if target_lengths is not None else None
    dev = device if device is not None else (target_lengths.device if target_lengths is not None else torch.device("cpu"))

    for k in ("input_lengths", "feat_lengths", "feature_lengths", "ctc_input_lengths"):
        if k in extra and extra[k] is not None:
            x = extra[k]
            if not torch.is_tensor(x):
                raise TypeError(f"infer_ctc_input_lengths: {k} must be tensor, got {type(x)}")
            x = x.to(device=dev, dtype=torch.long)
            if B is not None and x.numel() != B:
                raise ValueError(f"infer_ctc_input_lengths: {k} numel={x.numel()} != B={B}")
            return x.clamp(min=1, max=int(T_seq))

    if "valid_w" not in extra or "target_w" not in extra:
        raise KeyError(f"infer_ctc_input_lengths: need (valid_w,target_w) or input_lengths, got keys={list(extra.keys())}")

    vw = extra["valid_w"]
    tw = extra["target_w"]

    if not torch.is_tensor(vw):
        raise TypeError(f"infer_ctc_input_lengths: valid_w must be tensor, got {type(vw)}")
    vw = vw.to(device=dev, dtype=torch.float32).clamp(min=1.0)

    if torch.is_tensor(tw):
        tw_f = tw.to(device=dev, dtype=torch.float32).clamp(min=1.0)
        if B is not None and tw_f.numel() not in (1, B):
            raise ValueError(f"infer_ctc_input_lengths: target_w shape invalid numel={tw_f.numel()} (expect 1 or {B})")
        ratio = vw / (float(tw_f.item()) if tw_f.numel() == 1 else tw_f)
    else:
        tw_s = float(tw)
        if tw_s <= 1:
            raise ValueError(f"infer_ctc_input_lengths: invalid target_w scalar: {tw_s}")
        ratio = vw / tw_s

    inp = torch.ceil(ratio * float(T_seq)).to(dtype=torch.long).clamp(min=1, max=int(T_seq))

    if target_lengths is not None:
        tl = target_lengths.to(device=dev, dtype=torch.long)
        bad = inp < tl
        if bad.any():
            n_bad = int(bad.sum().item())
            if fix == "error":
                idx = int(torch.nonzero(bad, as_tuple=False)[0].item())
                raise ValueError(
                    f"CTC length violation: {n_bad}/{tl.numel()} samples have input_len < target_len. "
                    f"Example idx={idx}: input_len={int(inp[idx])}, target_len={int(tl[idx])}."
                )
            if fix in ("clamp", "warn_clamp"):
                if fix == "warn_clamp":
                    print(f"[WARN] CTC length violation: clamping {n_bad}/{tl.numel()} samples (input_len < target_len).")
                inp = torch.maximum(inp, tl)
            else:
                raise ValueError(f"Unknown fix policy: {fix}")

    return inp



# ============================
# Dynamic-width pipeline (v2)
# ============================

def trim_horizontal_whitespace(
    img: Image.Image,
    *,
    bg_thresh: int = 250,
    margin: int = 2,
    min_width: int = 8,
) -> Image.Image:
    """Trim left/right blank margins based on an ink threshold.

    Conservative by design: only trims horizontally and leaves a small margin
    to avoid cutting off strokes. If ink cannot be detected, returns original.
    """
    if img.mode != "RGB":
        img = img.convert("RGB")
    g = np.array(img.convert("L"))
    thr = int(bg_thresh)
    if thr <= 0:
        # percentile fallback (no extra deps)
        thr = int(np.percentile(g, 95))
        thr = max(32, min(250, thr))
    ink = g < thr
    col_has_ink = ink.any(axis=0)
    if not col_has_ink.any():
        return img
    left = int(np.argmax(col_has_ink))
    right = int(len(col_has_ink) - 1 - np.argmax(col_has_ink[::-1]))
    left = max(0, left - int(margin))
    right = min(g.shape[1] - 1, right + int(margin))
    if right - left + 1 < int(min_width):
        return img
    return img.crop((left, 0, right + 1, img.height))


class LineResizePadDynamic:
    """Resize to fixed height, cap width, and right-pad (optionally to a multiple).

    Unlike LineResizePad(target_width=...), this only downscales width when exceeding max_width,
    so long lines won't be crushed into a tiny width (the 'crowded' problem).
    """

    def __init__(
        self,
        target_height: int,
        max_width: int,
        *,
        pad_to_multiple: int = 32,
        right_pad_min: int = 8,
        fill: int = 255,
        interpolation: int = Image.BICUBIC,
    ) -> None:
        self.target_height = int(target_height)
        self.max_width = int(max_width)
        self.pad_to_multiple = int(pad_to_multiple) if pad_to_multiple else 0
        self.right_pad_min = int(right_pad_min)
        self.fill = int(fill)
        self.interpolation = interpolation

    def __call__(self, img: Image.Image) -> Image.Image:
        if img.mode != "RGB":
            img = img.convert("RGB")

        w, h = img.size
        if h <= 0:
            return Image.new("RGB", (self.max_width, self.target_height), color=(self.fill,) * 3)

        # 1) resize to fixed height (keep aspect)
        scale = self.target_height / float(h)
        new_w = max(1, int(round(w * scale)))
        img = img.resize((new_w, self.target_height), resample=self.interpolation)

        # 2) ensure a little blank margin on the right
        content_w = min(self.max_width, max(1, new_w + self.right_pad_min))

        # 3) cap width by downscaling if still too wide
        if content_w > self.max_width:
            content_w = self.max_width

        if img.width > self.max_width:
            # downscale to max_width; height shrinks, then pad vertically back to target_height
            scale2 = self.max_width / float(img.width)
            new_h2 = max(1, int(round(self.target_height * scale2)))
            img2 = img.resize((self.max_width, new_h2), resample=self.interpolation)
            if new_h2 < self.target_height:
                pad_top = (self.target_height - new_h2) // 2
                base = Image.new("RGB", (self.max_width, self.target_height), color=(self.fill,) * 3)
                base.paste(img2, (0, pad_top))
                img = base
            else:
                img = img2
            content_w = self.max_width

        # 4) decide padded width (per-sample)
        pad_w = int(content_w)
        if self.pad_to_multiple and pad_w % self.pad_to_multiple != 0:
            pad_w = int(math.ceil(pad_w / self.pad_to_multiple) * self.pad_to_multiple)
        pad_w = min(self.max_width, max(1, pad_w))

        # 5) right-pad to pad_w
        if img.width < pad_w:
            base = Image.new("RGB", (pad_w, img.height), color=(self.fill,) * 3)
            base.paste(img, (0, 0))
            img = base
        elif img.width > pad_w:
            img = img.crop((0, 0, pad_w, img.height))

        return img




def _get_normalize(normalize: str):
    """Return a torchvision Normalize transform (or None) given normalize option.

    Supported values (case-insensitive):
      - 'half': mean=0.5 std=0.5 (maps [0,1] -> [-1,1])
      - 'imagenet': ImageNet mean/std
      - 'none': no normalization
    """
    norm = (normalize or "imagenet").lower()
    if norm in ("half", "0.5", "pm1", "minus1_1", "old"):
        return transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                   std=(0.5, 0.5, 0.5))
    if norm in ("none", "no", "false", "0", "identity"):
        return None
    # default: imagenet
    return transforms.Normalize(mean=(0.485, 0.456, 0.406),
                               std=(0.229, 0.224, 0.225))




# -----------------------------
# lightweight augmentations
# -----------------------------

def random_width_stretch(img: Image.Image, min_factor: float = 0.7, max_factor: float = 1.3) -> Image.Image:
    """Randomly stretch/shrink width only, preserving height."""
    if min_factor <= 0 or max_factor <= 0:
        return img
    if max_factor < min_factor:
        min_factor, max_factor = max_factor, min_factor
    f = random.uniform(float(min_factor), float(max_factor))
    w, h = img.size
    new_w = max(2, int(round(w * f)))
    if new_w == w:
        return img
    return img.resize((new_w, h), resample=Image.BICUBIC)


def random_affine(
    img: Image.Image,
    degrees: float = 2.0,
    translate: float = 0.01,
    shear: float = 2.0,
) -> Image.Image:
    """Apply a mild random affine transform on PIL image (white filled)."""
    w, h = img.size
    angle = random.uniform(-float(degrees), float(degrees)) if degrees and degrees > 0 else 0.0

    max_dx = float(translate) * w
    max_dy = float(translate) * h
    tx = int(round(random.uniform(-max_dx, max_dx))) if translate and translate > 0 else 0
    ty = int(round(random.uniform(-max_dy, max_dy))) if translate and translate > 0 else 0

    sh = random.uniform(-float(shear), float(shear)) if shear and shear > 0 else 0.0

    fill = (255, 255, 255)
    # torchvision expects shear as sequence [sx, sy]
    return TF.affine(
        img,
        angle=angle,
        translate=[tx, ty],
        scale=1.0,
        shear=[sh, 0.0],
        interpolation=transforms.InterpolationMode.BILINEAR,
        fill=fill,
    )


def random_stroke(
    img: Image.Image,
    k: int | None = None,
    kmin: int = 3,
    kmax: int = 5,
    **_kwargs,
) -> Image.Image:
    """Randomly thicken or thin strokes using morphological max/min filters.

    Compatibility:
      - supports calling as random_stroke(img, k=<odd_int>)
      - supports calling as random_stroke(img, kmin=..., kmax=...)
    """
    kmin = int(kmin)
    kmax = int(kmax)
    if kmax < kmin:
        kmin, kmax = kmax, kmin

    if k is None:
        if kmax < 3:
            return img
        k = random.randint(max(3, kmin), max(3, kmax))
    else:
        k = int(k)

    # PIL Max/MinFilter require odd kernel size >= 3
    if k < 3:
        return img
    if k % 2 == 0:
        k += 1

    # Randomly choose to thicken or thin
    if random.random() < 0.5:
        return img.filter(ImageFilter.MaxFilter(size=k))
    else:
        return img.filter(ImageFilter.MinFilter(size=k))



def random_sharpen(img: Image.Image, radius_min: float = 1.0, radius_max: float = 2.0, pct_min: int = 50, pct_max: int = 150) -> Image.Image:
    """Random unsharp mask (very mild by default)."""
    r = random.uniform(float(radius_min), float(radius_max))
    pct = random.randint(int(pct_min), int(pct_max))
    return img.filter(ImageFilter.UnsharpMask(radius=r, percent=pct, threshold=3))


def add_gaussian_noise(x: torch.Tensor, std: float = 0.02) -> torch.Tensor:
    """Add gaussian noise to image tensor in [0,1] before normalization."""
    s = float(std)
    if s <= 0:
        return x
    if not torch.is_floating_point(x):
        x = x.float()
    noise = torch.randn_like(x) * s
    return torch.clamp(x + noise, 0.0, 1.0)

def _white_pad_value(normalize: str) -> Tuple[float, float, float]:
    """Tensor-space value for a pure white pixel after normalization."""
    norm = (normalize or "").lower()
    if norm in ("none", "no", "false", "0"):
        return (1.0, 1.0, 1.0)
    if norm == "half":
        return (1.0, 1.0, 1.0)
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    return tuple((1.0 - m) / s for m, s in zip(mean, std))


class LineCTCTransformDynamic:
    """CTC transform with dynamic width (cap + per-sample padding)."""

    def __init__(
        self,
        target_height: int = 128,
        max_width: int = 2048,
        *,
        pad_to_multiple: int = 32,
        right_pad_min: int = 8,
        trim_whitespace: bool = True,
        trim_margin: int = 2,
        normalize: str = "half",
        estimate_valid_w: bool = True,
        bg_thresh: int = 250,
        # augmentation knobs (same as create_line_transform)
        aug_affine_p: float = 0.15,
        aug_wstretch_p: float = 0.0,
        aug_wstretch_min: float = 0.7,
        aug_wstretch_max: float = 1.3,
        aug_degrees: float = 2.0,
        aug_translate: float = 0.01,
        aug_shear: float = 2.0,
        aug_stroke_p: float = 0.25,
        aug_stroke_kmin: int = 3,
        aug_stroke_kmax: int = 5,
        aug_sharpen_p: float = 0.15,
        aug_invert_p: float = 0.0,
        aug_noise_p: float = 0.15,
        aug_noise_std: float = 0.03,
    ) -> None:
        self.normalize = normalize
        self.estimate_valid_w = bool(estimate_valid_w)
        self.bg_thresh = int(bg_thresh)

        self.trim_whitespace = bool(trim_whitespace)
        self.trim_margin = int(trim_margin)

        self.resize_pad = LineResizePadDynamic(
            target_height=target_height,
            max_width=max_width,
            pad_to_multiple=pad_to_multiple,
            right_pad_min=right_pad_min,
        )

        self.aug_affine_p = float(aug_affine_p)
        self.aug_wstretch_p = float(aug_wstretch_p)
        self.aug_wstretch_min = float(aug_wstretch_min)
        self.aug_wstretch_max = float(aug_wstretch_max)
        self.aug_degrees = float(aug_degrees)
        self.aug_translate = float(aug_translate)
        self.aug_shear = float(aug_shear)

        self.aug_stroke_p = float(aug_stroke_p)
        self.aug_stroke_kmin = int(aug_stroke_kmin)
        self.aug_stroke_kmax = int(aug_stroke_kmax)
        self.aug_sharpen_p = float(aug_sharpen_p)
        self.aug_invert_p = float(aug_invert_p)

        self.aug_noise_p = float(aug_noise_p)
        self.aug_noise_std = float(aug_noise_std)

        self._to_tensor = transforms.ToTensor()
        self._norm = _get_normalize(normalize)
        self._pad_value = _white_pad_value(normalize)

    def __call__(self, img: Image.Image) -> Tuple[torch.Tensor, Dict[str, Any]]:
        if img.mode != "RGB":
            img = img.convert("RGB")

        if self.trim_whitespace:
            img = trim_horizontal_whitespace(img, bg_thresh=self.bg_thresh, margin=self.trim_margin)

        # geometric aug (pre-resize)
        if self.aug_wstretch_p > 0 and random.random() < self.aug_wstretch_p:
            img = random_width_stretch(img, self.aug_wstretch_min, self.aug_wstretch_max)

        if self.aug_affine_p > 0 and random.random() < self.aug_affine_p:
            img = random_affine(img, degrees=self.aug_degrees, translate=self.aug_translate, shear=self.aug_shear)

        # resize + pad (dynamic width)
        img = self.resize_pad(img)

        # appearance aug (post-resize)
        if self.aug_stroke_p > 0 and random.random() < self.aug_stroke_p:
            k = random.randint(self.aug_stroke_kmin, self.aug_stroke_kmax)
            img = random_stroke(img, k=k)
        if self.aug_sharpen_p > 0 and random.random() < self.aug_sharpen_p:
            img = random_sharpen(img)
        if self.aug_invert_p > 0 and random.random() < self.aug_invert_p:
            img = ImageOps.invert(img)

        # valid_w from ink AFTER everything
        if self.estimate_valid_w:
            valid_w = estimate_valid_width_from_ink(img, bg_thresh=self.bg_thresh)
        else:
            valid_w = img.width

        x = self._to_tensor(img)

        if self.aug_noise_p > 0 and random.random() < self.aug_noise_p:
            x = add_gaussian_noise(x, std=self.aug_noise_std)

        if self._norm is not None:
            x = self._norm(x)

        meta_u = {
            "valid_w": int(valid_w),
            "target_w": int(img.width),
            "pad_value": tuple(float(v) for v in self._pad_value),
        }
        return x, meta_u


def create_line_transform_dynamic(
    target_height: int = 128,
    max_width: int = 2048,
    *,
    pad_to_multiple: int = 32,
    right_pad_min: int = 8,
    trim_whitespace: bool = True,
    trim_margin: int = 2,
    normalize: str = "half",
    estimate_valid_w: bool = True,
    bg_thresh: int = 250,
    # augmentation
    aug_affine_p: float = 0.15,
    aug_wstretch_p: float = 0.0,
    aug_wstretch_min: float = 0.7,
    aug_wstretch_max: float = 1.3,
    aug_degrees: float = 2.0,
    aug_translate: float = 0.01,
    aug_shear: float = 2.0,
    aug_stroke_p: float = 0.25,
    aug_stroke_kmin: int = 3,
    aug_stroke_kmax: int = 5,
    aug_sharpen_p: float = 0.15,
    aug_invert_p: float = 0.0,
    aug_noise_p: float = 0.15,
    aug_noise_std: float = 0.03,
) -> LineCTCTransformDynamic:
    return LineCTCTransformDynamic(
        target_height=target_height,
        max_width=max_width,
        pad_to_multiple=pad_to_multiple,
        right_pad_min=right_pad_min,
        trim_whitespace=trim_whitespace,
        trim_margin=trim_margin,
        normalize=normalize,
        estimate_valid_w=estimate_valid_w,
        bg_thresh=bg_thresh,
        aug_affine_p=aug_affine_p,
        aug_wstretch_p=aug_wstretch_p,
        aug_wstretch_min=aug_wstretch_min,
        aug_wstretch_max=aug_wstretch_max,
        aug_degrees=aug_degrees,
        aug_translate=aug_translate,
        aug_shear=aug_shear,
        aug_stroke_p=aug_stroke_p,
        aug_stroke_kmin=aug_stroke_kmin,
        aug_stroke_kmax=aug_stroke_kmax,
        aug_sharpen_p=aug_sharpen_p,
        aug_invert_p=aug_invert_p,
        aug_noise_p=aug_noise_p,
        aug_noise_std=aug_noise_std,
    )


def _pad_right_tensor(
    x: torch.Tensor, target_w: int, pad_value: Optional[Sequence[float]] = None
) -> torch.Tensor:
    c, h, w = x.shape
    if w >= target_w:
        return x
    delta = int(target_w - w)
    if pad_value is None:
        pad = x[:, :, -1:].expand(c, h, delta).clone()
        return torch.cat([x, pad], dim=2)
    if isinstance(pad_value, (list, tuple)) and len(pad_value) == c:
        pad = x.new_empty((c, h, delta))
        for i in range(c):
            pad[i].fill_(float(pad_value[i]))
    else:
        pad = x.new_full((c, h, delta), float(pad_value[0] if isinstance(pad_value, (list, tuple)) else pad_value))
    return torch.cat([x, pad], dim=2)


def ctc_collate_pad_to_max(
    batch: Sequence[Tuple[torch.Tensor, torch.Tensor, str, Dict[str, Any]]]
):
    """CTC collate that pads variable-width images to the max width within the batch."""
    if len(batch) == 0:
        raise ValueError("Empty batch")

    images, targets, texts, metas = zip(*batch)

    widths = [int(im.shape[-1]) for im in images]
    max_w = int(max(widths))

    padded_images: List[torch.Tensor] = []
    valid_ws: List[int] = []
    for im, meta in zip(images, metas):
        pv = meta.get("pad_value", None)
        padded_images.append(_pad_right_tensor(im, max_w, pv))
        valid_ws.append(int(meta.get("valid_w", max_w)))

    images_t = torch.stack(padded_images, dim=0)

    target_lengths = torch.tensor([t.numel() for t in targets], dtype=torch.long)
    targets_concat = torch.cat(targets, dim=0) if len(targets) > 1 else targets[0]

    extra = {
        "valid_w": torch.tensor(valid_ws, dtype=torch.long),
        # IMPORTANT: target_w must match the width that produced this batch's logit time-dim
        "target_w": torch.full((len(valid_ws),), max_w, dtype=torch.long),
    }
    return images_t, targets_concat, target_lengths, list(texts), extra
