#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
root/
  train.ln
  val.ln
  test.ln
  lines/
    (A)  train_0.jpeg, train_0.txt, valid_0.jpeg, valid_0.txt, ...
    (B) lines/img/*.jpeg +  lines/img/*.txt
"""

from __future__ import annotations

import os
import pickle
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

from .common import (
    CharsetMapper,
    DistributedShardSampler,
    create_line_transform,
    ctc_collate,
    normalize_transcription,
)

try:
    from .common import create_line_transform_dynamic, ctc_collate_pad_to_max  # type: ignore
    _HAS_DYNAMIC = True
except Exception:  # pragma: no cover
    create_line_transform_dynamic = None  # type: ignore
    ctc_collate_pad_to_max = None  # type: ignore
    _HAS_DYNAMIC = False


_IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")


def _is_dist() -> bool:
    return torch.distributed.is_available() and torch.distributed.is_initialized()


def _read_list(path: str) -> List[str]:
    if not os.path.isfile(path):
        raise FileNotFoundError(f"READ2016 split file not found: {path}")
    out: List[str] = []
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            s = line.strip().replace("\r", "")
            if s:
                out.append(s)
    return out


def _pick_img_dir(lines_root: str) -> str:
    """Prefer lines/img if it exists and has images; otherwise lines/."""
    cand = os.path.join(lines_root, "img")
    if os.path.isdir(cand):
        for fn in os.listdir(cand):
            if os.path.splitext(fn)[1].lower() in _IMG_EXTS:
                return cand
    return lines_root


def _build_stem_index(img_dir: str) -> Dict[str, str]:
    """Map stem(lower) -> absolute image path, for robust resolution."""
    m: Dict[str, str] = {}
    if not os.path.isdir(img_dir):
        return m
    for fn in os.listdir(img_dir):
        ext = os.path.splitext(fn)[1].lower()
        if ext in _IMG_EXTS:
            stem = os.path.splitext(fn)[0].lower()
            m[stem] = os.path.join(img_dir, fn)
    return m


def _resolve_image(token: str, *, lines_root: str, img_dir: str, stem_index: Dict[str, str]) -> str:
    tok = token.strip().replace("\\", "/").replace("\r", "")
    tok = tok.lstrip("./")
    base = os.path.basename(tok)
    base_stem, base_ext = os.path.splitext(base)

    candidates: List[str] = []

    # 1) 原样当相对路径（相对 lines_root 或 img_dir）
    for b in (lines_root, img_dir):
        candidates.append(os.path.join(b, tok))
        candidates.append(os.path.join(b, base))

    # 2) 若没有扩展名，补常见扩展名
    if base_ext == "":
        for ext in _IMG_EXTS:
            for b in (img_dir, lines_root):
                candidates.append(os.path.join(b, base_stem + ext))

    for p in candidates:
        if os.path.isfile(p):
            return p

    # 3) stem index fallback
    key = base_stem.lower()
    if key in stem_index and os.path.isfile(stem_index[key]):
        return stem_index[key]

    raise FileNotFoundError(
        f"READ2016 line image not found for token='{token}'. "
        f"Tried under lines_root={lines_root} and img_dir={img_dir}."
    )


def _load_labels_pkl(lines_root: str) -> Dict[str, str]:
    """
    Load lines/labels.pkl if present.
    Expected format (from your formatter): {'ground_truth': {split: {img_name: {'text': ...}}}, ...}
    Returns map: stem(lower) -> text
    """
    pkl_path = os.path.join(lines_root, "labels.pkl")
    if not os.path.isfile(pkl_path):
        return {}
    try:
        with open(pkl_path, "rb") as f:
            obj = pickle.load(f)
    except Exception:
        return {}

    gt = obj.get("ground_truth", {}) if isinstance(obj, dict) else {}
    out: Dict[str, str] = {}
    if isinstance(gt, dict):
        for split_dict in gt.values():
            if not isinstance(split_dict, dict):
                continue
            for k, v in split_dict.items():
                if not isinstance(k, str) or not isinstance(v, dict):
                    continue
                txt = v.get("text", "")
                stem = os.path.splitext(os.path.basename(k))[0].lower()
                if isinstance(txt, str):
                    out[stem] = txt
    return out


def _normalize_txt(
    txt: str,
    *,
    text_norm: bool,
    text_norm_form: str,
    text_collapse_ws: bool,
    text_drop_chars: str,
    allow_empty: bool,
    max_label_len: int,
) -> Optional[str]:
    t = txt or ""
    if text_norm:
        t = normalize_transcription(
            t,
            form=text_norm_form,
            collapse_whitespace=text_collapse_ws,
            strip=True,
            drop_chars=text_drop_chars,
        )
    else:
        t = t.strip()

    if t == "" and allow_empty:
        t = " "
    if t == "" and not allow_empty:
        return None
    if max_label_len > 0 and len(t) > max_label_len:
        return None
    return t


def _read_text_for_image(
    img_path: str,
    *,
    lines_root: str,
    text_source: str,
    pkl_map: Dict[str, str],
    text_norm: bool,
    text_norm_form: str,
    text_collapse_ws: bool,
    text_drop_chars: str,
    allow_empty: bool,
    max_label_len: int,
) -> Optional[str]:
    stem = os.path.splitext(os.path.basename(img_path))[0]
    stem_l = stem.lower()

    txt_paths = [
        os.path.join(os.path.dirname(img_path), stem + ".txt"),
        os.path.join(lines_root, stem + ".txt"),
    ]

    raw: Optional[str] = None
    if text_source in ("auto", "txt"):
        for tp in txt_paths:
            if os.path.isfile(tp):
                raw = open(tp, "r", encoding="utf-8", errors="ignore").read()
                break

    if raw is None and text_source in ("auto", "pkl"):
        raw = pkl_map.get(stem_l, None)

    if raw is None:
        return None

    return _normalize_txt(
        raw,
        text_norm=text_norm,
        text_norm_form=text_norm_form,
        text_collapse_ws=text_collapse_ws,
        text_drop_chars=text_drop_chars,
        allow_empty=allow_empty,
        max_label_len=max_label_len,
    )


@dataclass
class READLineSample:
    img_path: str
    text: str
    sid: str


class READ2016LineDataset(Dataset):
    def __init__(
        self,
        samples: List[READLineSample],
        *,
        charset: CharsetMapper,
        transform=None,
        oov: str = "error",
    ):
        self.samples = samples
        self.cm = charset
        self.transform = transform
        self.oov = str(oov)

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        s = self.samples[idx]
        img = Image.open(s.img_path).convert("RGB")
        meta: Dict[str, Any] = {"id": s.sid, "raw_w": int(img.size[0]), "raw_h": int(img.size[1])}

        if self.transform is not None:
            out = self.transform(img)
            if isinstance(out, (tuple, list)) and len(out) == 2:
                img_t, meta_u = out
                if isinstance(meta_u, dict):
                    meta.update(meta_u)
                img = img_t
            else:
                img = out

        if "target_w" not in meta:
            try:
                meta["target_w"] = int(img.shape[-1])  # type: ignore[attr-defined]
            except Exception:
                meta["target_w"] = 1
        if "valid_w" not in meta:
            meta["valid_w"] = int(meta["target_w"])

        target = self.cm.encode(s.text, oov=self.oov)
        return img, target, s.text, meta


def _cap_width_by_ctc(args, req_w: int, logger=None) -> int:
    try:
        max_seq_len = int(getattr(args, "max_seq_len", 0) or 0)
        downsample_w = int(getattr(args, "downsample_w", 0) or 0)
        if max_seq_len > 0 and downsample_w > 0:
            cap = max_seq_len * downsample_w
            if cap > 0 and req_w > cap:
                if logger:
                    logger.warning(f"[READ2016] img_max_width={req_w} capped to {cap} (max_seq_len*downsample_w)")
                return int(cap)
    except Exception:
        pass
    return int(req_w)


def _resolve_split_file(root: str, name: str) -> str:
    cand = os.path.join(root, f"{name}.ln")
    if os.path.isfile(cand):
        return cand
    aliases = {
        "val": ["valid", "validation"],
        "valid": ["val", "validation"],
        "validation": ["val", "valid"],
    }
    for alt in aliases.get(name, []):
        cand2 = os.path.join(root, f"{alt}.ln")
        if os.path.isfile(cand2):
            return cand2
    if os.path.isfile(name):
        return name
    raise FileNotFoundError(f"READ2016 split file not found for split='{name}' under root={root}")


def build_read2016_dataloaders(args, logger=None):
    default_root = "/workspace/data/OCR/READ2016"
    root = getattr(args, "data_path", "") or default_root

    lines_root = os.path.join(root, "lines")
    if not os.path.isdir(lines_root):
        raise FileNotFoundError(f"READ2016 lines_root not found: {lines_root}")

    split_train = str(getattr(args, "train_split", "train"))
    split_val = str(getattr(args, "val_split", "val"))
    split_test = str(getattr(args, "test_split", "test"))

    train_file = _resolve_split_file(root, split_train)
    val_file = _resolve_split_file(root, split_val)
    test_file = os.path.join(root, f"{split_test}.ln")

    img_dir = _pick_img_dir(lines_root)
    stem_index = _build_stem_index(img_dir)

    text_source = str(getattr(args, "read2016_text_source", "auto")).lower().strip()
    if text_source not in ("auto", "txt", "pkl"):
        text_source = "auto"
    pkl_map = _load_labels_pkl(lines_root) if text_source in ("auto", "pkl") else {}

    allow_empty = bool(getattr(args, "allow_empty", False))
    text_norm = bool(getattr(args, "text_norm", True))
    text_norm_form = getattr(args, "text_norm_form", "NFKC")
    text_collapse_ws = bool(getattr(args, "text_collapse_ws", True))
    text_drop_chars = getattr(args, "text_drop_chars", "¬")
    max_label_len = int(getattr(args, "max_label_len", 0))
    oov_policy = getattr(args, "oov", "error")

    def build_samples(list_path: str) -> List[READLineSample]:
        toks = _read_list(list_path)
        samples: List[READLineSample] = []
        for tok in toks:
            ip = _resolve_image(tok, lines_root=lines_root, img_dir=img_dir, stem_index=stem_index)
            txt = _read_text_for_image(
                ip,
                lines_root=lines_root,
                text_source=text_source,
                pkl_map=pkl_map,
                text_norm=text_norm,
                text_norm_form=text_norm_form,
                text_collapse_ws=text_collapse_ws,
                text_drop_chars=text_drop_chars,
                allow_empty=allow_empty,
                max_label_len=max_label_len,
            )
            if txt is None:
                continue
            sid = os.path.splitext(os.path.basename(ip))[0]
            samples.append(READLineSample(img_path=ip, text=txt, sid=sid))
        return samples

    train_samples = build_samples(train_file)
    val_samples = build_samples(val_file)
    if os.path.isfile(test_file):
        _ = build_samples(test_file)

    charset_from = str(getattr(args, "charset_from", "trainval"))
    charset_texts = [s.text for s in train_samples]
    if charset_from in ("trainval", "all"):
        charset_texts.extend([s.text for s in val_samples])
    cm = CharsetMapper.from_texts(charset_texts, sort=True, add_unk=bool(getattr(args, "add_unk", False)))

    target_h = int(getattr(args, "img_height", 64))
    req_w = int(getattr(args, "img_max_width", 1024))
    max_w = _cap_width_by_ctc(args, req_w, logger)

    dynamic_width = bool(getattr(args, "dynamic_width", False)) and _HAS_DYNAMIC

    if dynamic_width:
        transform_train = create_line_transform_dynamic(
            target_height=target_h,
            max_width=max_w,
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=float(getattr(args, "aug_affine_p", 0.12)),
            aug_wstretch_p=float(getattr(args, "aug_wstretch_p", 0.25)),
            aug_wstretch_min=float(getattr(args, "aug_wstretch_min", 0.7)),
            aug_wstretch_max=float(getattr(args, "aug_wstretch_max", 1.3)),
            aug_degrees=float(getattr(args, "aug_degrees", 2.0)),
            aug_translate=float(getattr(args, "aug_translate", 0.01)),
            aug_shear=float(getattr(args, "aug_shear", 2.0)),
            aug_stroke_p=float(getattr(args, "aug_stroke_p", 0.25)),
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=float(getattr(args, "aug_sharpen_p", 0.15)),
            aug_invert_p=float(getattr(args, "aug_invert_p", 0.0)),
            aug_noise_p=float(getattr(args, "aug_noise_p", 0.15)),
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )
        transform_eval = create_line_transform_dynamic(
            target_height=target_h,
            max_width=max_w,
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=0.0,
            aug_wstretch_p=0.0,
            aug_wstretch_min=float(getattr(args, "aug_wstretch_min", 0.7)),
            aug_wstretch_max=float(getattr(args, "aug_wstretch_max", 1.3)),
            aug_degrees=0.0,
            aug_translate=0.0,
            aug_shear=0.0,
            aug_stroke_p=0.0,
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=0.0,
            aug_invert_p=0.0,
            aug_noise_p=0.0,
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )
        collate_fn = ctc_collate_pad_to_max
    else:
        transform_train = create_line_transform(
            target_height=target_h,
            target_width=max_w,
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=float(getattr(args, "aug_affine_p", 0.12)),
            aug_wstretch_p=float(getattr(args, "aug_wstretch_p", 0.25)),
            aug_wstretch_min=float(getattr(args, "aug_wstretch_min", 0.7)),
            aug_wstretch_max=float(getattr(args, "aug_wstretch_max", 1.3)),
            aug_degrees=float(getattr(args, "aug_degrees", 2.0)),
            aug_translate=float(getattr(args, "aug_translate", 0.01)),
            aug_shear=float(getattr(args, "aug_shear", 2.0)),
            aug_stroke_p=float(getattr(args, "aug_stroke_p", 0.25)),
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=float(getattr(args, "aug_sharpen_p", 0.15)),
            aug_invert_p=float(getattr(args, "aug_invert_p", 0.0)),
            aug_noise_p=float(getattr(args, "aug_noise_p", 0.15)),
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )
        transform_eval = create_line_transform(
            target_height=target_h,
            target_width=max_w,
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=0.0,
            aug_wstretch_p=0.0,
            aug_wstretch_min=float(getattr(args, "aug_wstretch_min", 0.7)),
            aug_wstretch_max=float(getattr(args, "aug_wstretch_max", 1.3)),
            aug_degrees=0.0,
            aug_translate=0.0,
            aug_shear=0.0,
            aug_stroke_p=0.0,
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=0.0,
            aug_invert_p=0.0,
            aug_noise_p=0.0,
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )
        collate_fn = ctc_collate

    ds_train = READ2016LineDataset(train_samples, charset=cm, transform=transform_train, oov=oov_policy)
    ds_val = READ2016LineDataset(val_samples, charset=cm, transform=transform_eval, oov=oov_policy)

    if logger:
        logger.info(f"READ2016 root       = {root}")
        logger.info(f"  lines_root        = {lines_root}")
        logger.info(f"  img_dir           = {img_dir}")
        logger.info(f"  split_files       = train:{train_file} val:{val_file} test:{test_file}")
        logger.info(f"  picked            = train:{len(ds_train)} val:{len(ds_val)}")
        logger.info(f"  charset size      = {len(cm.charset)}, CTC classes = {cm.num_classes}")
        logger.info(
            f"  text_norm={text_norm} form={text_norm_form} collapse_ws={text_collapse_ws} drop_chars={repr(text_drop_chars)}"
        )
        logger.info(f"  dynamic_width={dynamic_width} img_height={target_h} img_max_width={max_w} (requested={req_w})")

    batch_size = int(getattr(args, "batch_size", 8))
    val_batch_size = int(getattr(args, "val_batch_size", batch_size))
    workers = int(getattr(args, "workers", 4))
    pin_mem = bool(getattr(args, "pin_mem", True))

    if _is_dist():
        train_sampler = DistributedSampler(ds_train, shuffle=True, drop_last=True)
        val_sampler = DistributedShardSampler(ds_val, shuffle=False)
        shuffle_train = False
    else:
        train_sampler = None
        val_sampler = None
        shuffle_train = True

    train_loader = DataLoader(
        ds_train,
        batch_size=batch_size,
        shuffle=shuffle_train,
        sampler=train_sampler,
        num_workers=workers,
        pin_memory=pin_mem,
        drop_last=True,
        collate_fn=collate_fn,
        persistent_workers=(workers > 0),
    )

    val_loader = DataLoader(
        ds_val,
        batch_size=val_batch_size,
        shuffle=False,
        sampler=val_sampler,
        num_workers=workers,
        pin_memory=pin_mem,
        drop_last=False,
        collate_fn=collate_fn,
        persistent_workers=(workers > 0),
    )

    blank_id = 0
    return train_loader, val_loader, cm, blank_id, cm.num_classes


build_dataloaders = build_read2016_dataloaders
