#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""ocr_datasets.lam (LAM Lines) — optimized for dynamic width + strict CTC.

This dataset file supports **two split/label loading modes**:

1) custom (DEFAULT)
   - Uses list files at dataset root (recommended for your current workflow):
       <root>/train.ln, <root>/val.ln, <root>/test.ln
   - Each line is an image token, e.g.:
       img/002_02_00.jpg
       002_02_00.jpg
       /abs/path/to/002_02_00.jpg
   - Images are resolved under:
       <root>/lines/img/
       <root>/lines/
   - Labels are read from sidecar files next to images:
       <root>/lines/img/002_02_00.txt
     (If not found and transcriptions.json exists, it can fall back to JSON.)

2) legacy
   - Uses original LAM structure:
       <root>/lines/split/<protocol>/{train,val,test}.{txt|json}
       <root>/lines/transcriptions.json

How to switch:
  - args.lam_split_mode = 'custom' | 'legacy' | 'auto'
    * custom: prefer root train.ln/val.ln/test.ln (default)
    * legacy: use lines/split/<protocol>/... (original behavior)
    * auto  : try custom first; if missing, fall back to legacy
  - args.lam_text_source = 'auto' | 'txt' | 'json'
    * auto: prefer sidecar .txt, then fall back to transcriptions.json
    * txt : only sidecar .txt (missing => keep as single space)
    * json: only transcriptions.json text (ignores sidecar .txt)

Notes:
  - This module keeps the dynamic-width pipeline and CTC safety caps.
"""

from __future__ import annotations

import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple

import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

from .common import (
    CharsetMapper,
    DistributedShardSampler,
    create_line_transform,
    ctc_collate,
    normalize_transcription,
)

# optional (only exists in optimized common)
try:
    from .common import create_line_transform_dynamic, ctc_collate_pad_to_max  # type: ignore

    _HAS_DYNAMIC = True
except Exception:  # pragma: no cover
    create_line_transform_dynamic = None  # type: ignore
    ctc_collate_pad_to_max = None  # type: ignore
    _HAS_DYNAMIC = False


def _is_dist() -> bool:
    return torch.distributed.is_available() and torch.distributed.is_initialized()


@dataclass
class LAMSample:
    sid: str
    rel_img: str
    text: str


def _safe_get(d: Dict[str, Any], keys: Tuple[str, ...]):
    for k in keys:
        if k in d and d[k] is not None:
            return d[k]
    return None


def _norm_key(x: Any) -> str:
    if x is None:
        return ""
    s = str(x).strip()
    if not s:
        return ""
    s = s.replace("\\", "/")
    return os.path.splitext(os.path.basename(s))[0].lower()


def _parse_dataset_spec(spec: str) -> Tuple[str, str, str]:
    """Returns: (dataset_name, protocol, split_name)

    dataset can be:
      lam
      lam:basic
      lam:leave_decade_out/leave_decade_4_out

    If split_name not provided -> use args.train_split/val_split later.
    """
    name = (spec or "lam").strip()
    if ":" not in name:
        return name, "basic", ""
    base, rest = name.split(":", 1)
    rest = rest.strip()
    if "/" in rest:
        protocol, split_name = rest.split("/", 1)
        return base, protocol.strip(), split_name.strip()
    return base, rest, ""


def _resolve_split_file(split_dir: str, split_name: str, *, extra_dirs: Tuple[str, ...] = ()) -> str:
    """Resolve split file path.

    Supports:
      - absolute path (returned as-is)
      - relative names with extensions: *.ln / *.txt / *.json
      - relative names without extension (tries .ln -> .txt -> .json)

    `extra_dirs` are searched after `split_dir`.
    """
    if not split_name:
        raise ValueError("split_name is empty")

    if os.path.isabs(split_name) and os.path.exists(split_name):
        return split_name

    search_dirs: List[str] = [split_dir]
    for d in extra_dirs:
        if d:
            search_dirs.append(d)

    exts = (".ln", ".txt", ".json")

    def _cands(d: str, name: str) -> List[str]:
        low = name.lower()
        if low.endswith(exts):
            return [os.path.join(d, name)]
        return [os.path.join(d, name + e) for e in exts]

    candidates: List[str] = []
    for d in search_dirs:
        candidates += _cands(d, split_name)

    # common aliases
    if split_name in ("val", "validation"):
        alias = "validation" if split_name == "val" else "val"
        for d in search_dirs:
            candidates += _cands(d, alias)

    for p in candidates:
        if os.path.exists(p):
            return p

    raise FileNotFoundError(
        f"LAM split file not found: {split_name}. searched={search_dirs} (extensions tried: {exts})"
    )


def _load_split_ids(path: str) -> List[str]:
    """Support:
    - .ln / .txt: one token per line
    - .json list ["id1", ...]
    - .json dict with keys: ids/items/samples/data/lines OR train/val/validation/test
    """
    if not os.path.exists(path):
        raise FileNotFoundError(path)

    low = path.lower()
    if low.endswith((".txt", ".ln")):
        out: List[str] = []
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                s = line.strip().replace("\r", "")
                if s:
                    out.append(s)
        return out

    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        obj = json.load(f)

    if isinstance(obj, list):
        return [str(x).strip() for x in obj if str(x).strip()]

    if isinstance(obj, dict):
        ids = _safe_get(obj, ("ids", "items", "samples", "data", "lines"))
        if ids is None:
            for k in ("train", "val", "validation", "test"):
                if k in obj:
                    ids = obj[k]
                    break
        if isinstance(ids, list):
            return [str(x).strip() for x in ids if str(x).strip()]

    return []


def _parse_transcriptions(trans_json: str) -> List[LAMSample]:
    with open(trans_json, "r", encoding="utf-8", errors="ignore") as f:
        obj = json.load(f)

    items = obj
    if isinstance(obj, dict):
        items = _safe_get(obj, ("items", "samples", "data", "lines")) or []
    if not isinstance(items, list):
        return []

    out: List[LAMSample] = []
    for it in items:
        if not isinstance(it, dict):
            continue
        sid = _safe_get(it, ("id", "sid", "name", "uid", "key"))
        img = _safe_get(it, ("img", "image", "path", "file", "filename", "file_name"))
        txt = _safe_get(it, ("text", "transcription", "label", "gt", "ground_truth", "unicode"))

        sid = str(sid).strip() if sid is not None else ""
        img = str(img).strip() if img is not None else ""
        txt = str(txt).strip() if txt is not None else ""

        if (not sid) and img:
            sid = os.path.splitext(os.path.basename(img))[0]
        if not sid or not img:
            continue

        out.append(LAMSample(sid=sid, rel_img=img, text=txt))
    return out


def _cap_width_by_ctc(args, req_w: int, logger=None) -> int:
    """Ensure max_width <= max_seq_len * downsample_w when args provide them."""
    try:
        max_seq_len = int(getattr(args, "max_seq_len", 0) or 0)
        downsample_w = int(getattr(args, "downsample_w", 0) or 0)
        if max_seq_len > 0 and downsample_w > 0:
            cap = max_seq_len * downsample_w
            if cap > 0 and req_w > cap:
                if logger:
                    logger.warning(f"[LAM] img_max_width={req_w} capped to {cap} (max_seq_len*downsample_w)")
                return int(cap)
    except Exception:
        pass
    return int(req_w)


def _read_sidecar_text(img_path: str) -> str:
    p = os.path.splitext(img_path)[0] + ".txt"
    if os.path.exists(p):
        with open(p, "r", encoding="utf-8", errors="ignore") as f:
            return f.read()
    return ""


def _resolve_img_from_token(token: str, *, lines_root: str, img_dir: str) -> str:
    """Resolve an image path from a token in .ln/.txt split lists.

    Tries (in order):
      1) token as absolute path
      2) <lines_root>/<token>
      3) <img_dir>/<token>
      4) <img_dir>/<basename(token)>
      5) if token has no extension, try token + {jpg,jpeg,png} in img_dir
    """
    t = (token or "").strip().replace("\\", "/")
    if not t:
        raise ValueError("empty token")

    if os.path.isabs(t) and os.path.exists(t):
        return t

    # normalize leading ./
    if t.startswith("./"):
        t = t[2:]

    cand1 = os.path.join(lines_root, t)
    if os.path.exists(cand1):
        return cand1

    cand2 = os.path.join(img_dir, t)
    if os.path.exists(cand2):
        return cand2

    base = os.path.basename(t)
    cand3 = os.path.join(img_dir, base)
    if os.path.exists(cand3):
        return cand3

    stem, ext = os.path.splitext(base)
    if not ext:
        for e in (".jpg", ".jpeg", ".png"):
            p = os.path.join(img_dir, stem + e)
            if os.path.exists(p):
                return p

    raise FileNotFoundError(f"LAM image not found for token='{token}' under {lines_root}")


class LAMLineDataset(Dataset):
    def __init__(
        self,
        samples: List[Tuple[str, str, str]],
        *,
        charset: CharsetMapper,
        transform=None,
        oov: str = "error",
        max_label_len: int = 0,
    ):
        self.cm = charset
        self.transform = transform
        self.oov = str(oov)

        ml = int(max_label_len)
        if ml > 0:
            self.samples = [(p, t, sid) for (p, t, sid) in samples if len(t) <= ml]
        else:
            self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx: int):
        img_path, text, sid = self.samples[idx]
        img = Image.open(img_path).convert("RGB")

        meta: Dict[str, Any] = {"id": sid, "raw_w": int(img.size[0]), "raw_h": int(img.size[1])}

        if self.transform is not None:
            out = self.transform(img)
            if isinstance(out, (tuple, list)) and len(out) == 2:
                img_t, meta_u = out
                if isinstance(meta_u, dict):
                    meta.update(meta_u)  # prefer transform's valid_w/target_w
                img = img_t
            else:
                img = out

        # ensure required keys for collate
        if "target_w" not in meta:
            try:
                meta["target_w"] = int(img.shape[-1])  # type: ignore[attr-defined]
            except Exception:
                meta["target_w"] = 1
        if "valid_w" not in meta:
            meta["valid_w"] = int(meta["target_w"])

        text_use = text if text else " "
        target = self.cm.encode(text_use, oov=self.oov)
        return img, target, text_use, meta


def build_lam_line_dataloaders(args, logger=None):
    default_root = "/mnt/sdb/datasets/OCR/LAM"
    root = getattr(args, "data_path", "") or default_root

    dataset_spec = getattr(args, "dataset", "lam")
    _name, protocol, embedded_split = _parse_dataset_spec(dataset_spec)

    lines_root = os.path.join(root, "lines")
    img_dir = os.path.join(lines_root, "img")
    legacy_split_dir = os.path.join(lines_root, "split", protocol)
    trans_json = os.path.join(lines_root, "transcriptions.json")

    # user-facing options
    split_mode = str(getattr(args, "lam_split_mode", "custom") or "custom").strip().lower()
    if split_mode in ("old", "original"):
        split_mode = "legacy"

    text_source = str(getattr(args, "lam_text_source", "auto") or "auto").strip().lower()
    if text_source not in ("auto", "txt", "json"):
        text_source = "auto"

    train_split = embedded_split or getattr(args, "train_split", "train")
    val_split = getattr(args, "val_split", "val")

    # text normalization
    text_norm = bool(getattr(args, "text_norm", True))
    text_norm_form = getattr(args, "text_norm_form", "NFKC")
    text_collapse_ws = bool(getattr(args, "text_collapse_ws", True))
    text_drop_chars = getattr(args, "text_drop_chars", "¬")
    max_label_len = int(getattr(args, "max_label_len", 0))

    if not os.path.isdir(lines_root):
        raise FileNotFoundError(f"LAM lines directory not found: {lines_root}")
    if not os.path.isdir(img_dir):
        raise FileNotFoundError(f"LAM img directory not found: {img_dir}")

    # Decide split files
    def _resolve_splits_for_mode(mode: str):
        if mode == "legacy":
            train_file = _resolve_split_file(legacy_split_dir, str(train_split))
            val_file = _resolve_split_file(legacy_split_dir, str(val_split))
            return train_file, val_file, "legacy"

        # custom: root/<split>.ln|txt|json
        train_file = _resolve_split_file(root, str(train_split))
        val_file = _resolve_split_file(root, str(val_split))
        return train_file, val_file, "custom"

    if split_mode == "auto":
        try:
            train_file, val_file, eff_mode = _resolve_splits_for_mode("custom")
        except Exception:
            train_file, val_file, eff_mode = _resolve_splits_for_mode("legacy")
    elif split_mode == "legacy":
        train_file, val_file, eff_mode = _resolve_splits_for_mode("legacy")
    else:
        train_file, val_file, eff_mode = _resolve_splits_for_mode("custom")

    # Build transcriptions lookup (only if needed)
    want_json = (text_source in ("auto", "json")) and os.path.exists(trans_json)
    key_map: Dict[str, LAMSample] = {}
    all_samples: List[LAMSample] = []
    if want_json or eff_mode == "legacy":
        if not os.path.exists(trans_json):
            raise FileNotFoundError(f"LAM transcriptions.json not found: {trans_json}")
        all_samples = _parse_transcriptions(trans_json)
        for s in all_samples:
            for k in {_norm_key(s.sid), _norm_key(s.rel_img)}:
                if k and k not in key_map:
                    key_map[k] = s

    # Helpers for legacy: resolve rel_img stored in JSON
    def _resolve_img_path_from_json(rel_img: str) -> str:
        rel = rel_img.replace("\\", "/")
        cand1 = os.path.join(lines_root, rel)
        if os.path.exists(cand1):
            return cand1
        cand2 = os.path.join(img_dir, rel)
        if os.path.exists(cand2):
            return cand2
        cand3 = os.path.join(img_dir, os.path.basename(rel))
        if os.path.exists(cand3):
            return cand3
        raise FileNotFoundError(f"LAM image not found for rel_img='{rel_img}' under {lines_root}")

    # pick samples for train/val
    def _normalize_text(txt: str) -> str:
        s = txt or ""
        if text_norm:
            s = normalize_transcription(
                s,
                form=text_norm_form,
                collapse_whitespace=text_collapse_ws,
                strip=True,
                drop_chars=text_drop_chars,
            )
        else:
            s = s.strip()
        if not s:
            s = " "
        return s

    def pick_custom(tokens: List[str]) -> Tuple[List[Tuple[str, str, str]], int]:
        picked: List[Tuple[str, str, str]] = []
        miss = 0
        for tok in tokens:
            try:
                img_path = _resolve_img_from_token(tok, lines_root=lines_root, img_dir=img_dir)
            except Exception:
                miss += 1
                continue

            sid = os.path.splitext(os.path.basename(img_path))[0]

            txt = ""
            if text_source in ("auto", "txt"):
                txt = _read_sidecar_text(img_path)

            if (not txt) and want_json and text_source in ("auto", "json"):
                s = key_map.get(_norm_key(tok)) or key_map.get(_norm_key(sid))
                if s is not None:
                    txt = s.text or ""

            txt = _normalize_text(txt)
            if max_label_len > 0 and len(txt) > max_label_len:
                continue

            picked.append((img_path, txt, sid))
        return picked, miss

    def pick_legacy(ids: List[str]) -> Tuple[List[Tuple[str, str, str]], int]:
        picked: List[Tuple[str, str, str]] = []
        miss = 0
        for tok in ids:
            s = key_map.get(_norm_key(tok))
            if s is None:
                miss += 1
                continue

            img_path = _resolve_img_path_from_json(s.rel_img)
            sid = s.sid

            if text_source == "json":
                txt = s.text or ""
            else:
                # auto/txt: prefer sidecar if exists
                txt = _read_sidecar_text(img_path)
                if (not txt) and text_source == "auto":
                    txt = s.text or ""

            txt = _normalize_text(txt)
            if max_label_len > 0 and len(txt) > max_label_len:
                continue

            picked.append((img_path, txt, sid))
        return picked, miss

    train_ids = _load_split_ids(train_file)
    val_ids = _load_split_ids(val_file)

    if eff_mode == "legacy":
        train_samples, train_miss = pick_legacy(train_ids)
        val_samples, val_miss = pick_legacy(val_ids)
    else:
        train_samples, train_miss = pick_custom(train_ids)
        val_samples, val_miss = pick_custom(val_ids)

    # charset
    train_texts = [t for _p, t, _sid in train_samples]

    charset_from = str(getattr(args, "charset_from", "trainval"))
    charset_texts = list(train_texts)
    if charset_from in ("trainval", "all"):
        charset_texts.extend([t for _p, t, _id in val_samples])

    cm = CharsetMapper.from_texts(charset_texts, sort=True, add_unk=bool(getattr(args, "add_unk", False)))
    oov_policy = getattr(args, "oov", "error")

    # transforms / collate
    target_h = int(getattr(args, "img_height", 128))
    req_w = int(getattr(args, "img_max_width", 1024))
    max_w = _cap_width_by_ctc(args, req_w, logger)

    dynamic_width = bool(getattr(args, "dynamic_width", False)) and _HAS_DYNAMIC

    if dynamic_width:
        transform_train = create_line_transform_dynamic(
            target_height=target_h,
            max_width=max_w,
            pad_to_multiple=int(getattr(args, "pad_to_multiple", 32) or 0),
            right_pad_min=int(getattr(args, "right_pad_min", 8) or 0),
            trim_whitespace=bool(getattr(args, "trim_whitespace", False)),
            trim_margin=int(getattr(args, "trim_margin", 2)),
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=float(getattr(args, "aug_affine_p", 0.15)),
            aug_wstretch_p=float(getattr(args, "aug_wstretch_p", 0.0)),
            aug_wstretch_min=float(getattr(args, "aug_wstretch_min", 0.7)),
            aug_wstretch_max=float(getattr(args, "aug_wstretch_max", 1.3)),
            aug_degrees=float(getattr(args, "aug_degrees", 2.0)),
            aug_translate=float(getattr(args, "aug_translate", 0.01)),
            aug_shear=float(getattr(args, "aug_shear", 2.0)),
            aug_stroke_p=float(getattr(args, "aug_stroke_p", 0.25)),
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=float(getattr(args, "aug_sharpen_p", 0.15)),
            aug_invert_p=float(getattr(args, "aug_invert_p", 0.0)),
            aug_noise_p=float(getattr(args, "aug_noise_p", 0.15)),
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )
        transform_eval = create_line_transform_dynamic(
            target_height=target_h,
            max_width=max_w,
            pad_to_multiple=int(getattr(args, "pad_to_multiple", 32) or 0),
            right_pad_min=int(getattr(args, "right_pad_min", 8) or 0),
            trim_whitespace=bool(getattr(args, "trim_whitespace", False)),
            trim_margin=int(getattr(args, "trim_margin", 2)),
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=0.0,
            aug_wstretch_p=0.0,
            aug_wstretch_min=float(getattr(args, "aug_wstretch_min", 0.7)),
            aug_wstretch_max=float(getattr(args, "aug_wstretch_max", 1.3)),
            aug_degrees=0.0,
            aug_translate=0.0,
            aug_shear=0.0,
            aug_stroke_p=0.0,
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=0.0,
            aug_invert_p=0.0,
            aug_noise_p=0.0,
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )
        collate_fn = ctc_collate_pad_to_max
    else:
        transform_train = create_line_transform(
            target_height=target_h,
            target_width=max_w,
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=float(getattr(args, "aug_affine_p", 0.15)),
            aug_wstretch_p=float(getattr(args, "aug_wstretch_p", 0.0)),
            aug_wstretch_min=float(getattr(args, "aug_wstretch_min", 0.7)),
            aug_wstretch_max=float(getattr(args, "aug_wstretch_max", 1.3)),
            aug_degrees=float(getattr(args, "aug_degrees", 2.0)),
            aug_translate=float(getattr(args, "aug_translate", 0.01)),
            aug_shear=float(getattr(args, "aug_shear", 2.0)),
            aug_stroke_p=float(getattr(args, "aug_stroke_p", 0.25)),
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=float(getattr(args, "aug_sharpen_p", 0.15)),
            aug_invert_p=float(getattr(args, "aug_invert_p", 0.0)),
            aug_noise_p=float(getattr(args, "aug_noise_p", 0.15)),
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )
        transform_eval = create_line_transform(
            target_height=target_h,
            target_width=max_w,
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=0.0,
            aug_wstretch_p=0.0,
            aug_wstretch_min=float(getattr(args, "aug_wstretch_min", 0.7)),
            aug_wstretch_max=float(getattr(args, "aug_wstretch_max", 1.3)),
            aug_degrees=0.0,
            aug_translate=0.0,
            aug_shear=0.0,
            aug_stroke_p=0.0,
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=0.0,
            aug_invert_p=0.0,
            aug_noise_p=0.0,
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )
        collate_fn = ctc_collate

    ds_train = LAMLineDataset(
        train_samples, charset=cm, transform=transform_train, oov=oov_policy, max_label_len=max_label_len
    )
    ds_val = LAMLineDataset(val_samples, charset=cm, transform=transform_eval, oov=oov_policy, max_label_len=max_label_len)

    if logger:
        logger.info(f"LAM root        = {root}")
        logger.info(f"  protocol      = {protocol}")
        logger.info(f"  split_mode    = {eff_mode} (requested={split_mode})")
        logger.info(f"  text_source   = {text_source}")
        logger.info(f"  lines_root    = {lines_root}")
        logger.info(f"  img_dir       = {img_dir}")
        logger.info(f"  train_file    = {train_file}")
        logger.info(f"  val_file      = {val_file}")
        if os.path.exists(trans_json):
            logger.info(f"  trans_json    = {trans_json}")
        if all_samples:
            logger.info(f"LAM parsed {len(all_samples)} samples from transcriptions.json")
        if train_miss > 0:
            logger.warning(f"[LAM] {train_miss} tokens not resolved (train list size={len(train_ids)})")
        if val_miss > 0:
            logger.warning(f"[LAM] {val_miss} tokens not resolved (val list size={len(val_ids)})")
        logger.info(f"LAM picked train={len(ds_train)} val={len(ds_val)} dynamic_width={dynamic_width}")
        logger.info(f"  charset size  = {len(cm.charset)}, CTC classes = {cm.num_classes}")
        logger.info(
            f"  text_norm={text_norm} form={text_norm_form} collapse_ws={text_collapse_ws} drop_chars={repr(text_drop_chars)}"
        )
        if max_label_len > 0:
            logger.info(f"  max_label_len={max_label_len}")
        logger.info(f"  img_height={target_h} img_max_width={max_w} (requested={req_w})")

    if _is_dist():
        train_sampler = DistributedSampler(ds_train, shuffle=True, drop_last=True)
        val_sampler = DistributedShardSampler(ds_val, shuffle=False)
        shuffle_train = False
    else:
        train_sampler = None
        val_sampler = None
        shuffle_train = True

    train_loader = DataLoader(
        ds_train,
        batch_size=int(getattr(args, "batch_size", 8)),
        shuffle=shuffle_train,
        sampler=train_sampler,
        num_workers=int(getattr(args, "workers", 4)),
        pin_memory=bool(getattr(args, "pin_mem", True)),
        persistent_workers=bool(getattr(args, "persistent_workers", True))
        if int(getattr(args, "workers", 4)) > 0
        else False,
        collate_fn=collate_fn,
        drop_last=True,
    )

    val_loader = DataLoader(
        ds_val,
        batch_size=int(getattr(args, "val_batch_size", getattr(args, "batch_size", 8))),
        shuffle=False,
        sampler=val_sampler,
        num_workers=int(getattr(args, "workers", 4)),
        pin_memory=bool(getattr(args, "pin_mem", True)),
        persistent_workers=bool(getattr(args, "persistent_workers", True))
        if int(getattr(args, "workers", 4)) > 0
        else False,
        collate_fn=collate_fn,
        drop_last=False,
    )

    return train_loader, val_loader, cm, cm.blank_id, cm.num_classes

