#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations

import io
import os
from typing import Any, Dict, List, Tuple

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from PIL import Image

try:
    import pandas as pd
except Exception:
    pd = None

from .common import (
    CharsetMapper,
    compute_valid_width,
    create_line_transform,
    create_line_transform_dynamic,
    ctc_collate,
    ctc_collate_pad_to_max,
    DistributedShardSampler,
    normalize_transcription,
    merge_meta_keep_max_valid_w,
)


def _is_dist() -> bool:
    return torch.distributed.is_available() and torch.distributed.is_initialized()


def _decode_parquet_image(obj: Any) -> Image.Image:
    if obj is None:
        raise ValueError("image field is None")

    if isinstance(obj, Image.Image):
        return obj.convert("RGB")

    if isinstance(obj, (bytes, bytearray, memoryview)):
        return Image.open(io.BytesIO(bytes(obj))).convert("RGB")

    if isinstance(obj, str):
        if not os.path.exists(obj):
            raise FileNotFoundError(f"image path not found: {obj}")
        return Image.open(obj).convert("RGB")

    if isinstance(obj, dict):
        # broaden compatibility
        for k in ("bytes", "data", "image", "raw", "path", "file", "filename"):
            if k in obj and obj[k] is not None:
                return _decode_parquet_image(obj[k])

    if hasattr(obj, "as_py"):
        return _decode_parquet_image(obj.as_py())

    raise TypeError(f"Unsupported image field type: {type(obj)}")


def _resolve_iam_parquet(data_dir: str, split: str) -> str:
    """
    Resolve IAM-line parquet path from split name.

    Supported (case-insensitive):
      - train
      - val / valid / validation / dev
      - test

    NOTE:
      We intentionally do NOT silently fall back from 'test' -> 'validation',
      because that produces "fake test" evaluations.
    """
    s = (split or "").strip().lower()
    alias = {
        "train": "train",
        "training": "train",
        "tr": "train",
        "val": "validation",
        "valid": "validation",
        "validation": "validation",
        "dev": "validation",
        "test": "test",
        "te": "test",
    }
    s = alias.get(s, s)

    cand = os.path.join(data_dir, f"{s}.parquet")
    if os.path.exists(cand):
        return cand

    # legacy compatibility for val naming
    if s == "validation":
        legacy = os.path.join(data_dir, "validation.parquet")
        if os.path.exists(legacy):
            return legacy

    raise FileNotFoundError(f"[IAM-line] parquet for split='{split}' not found. Tried: {cand}")


class IAMLineParquetDataset(Dataset):
    def __init__(
        self,
        parquet_path: str,
        *,
        charset: CharsetMapper,
        target_hw: Tuple[int, int],
        transform=None,
        oov: str = "error",
        # text norm
        text_norm: bool = True,
        text_norm_form: str = "NFKC",
        text_collapse_ws: bool = True,
        text_drop_chars: str = "¬",
        max_label_len: int = 0,
    ):
        if pd is None:
            raise ImportError("pandas/pyarrow are required for IAM-line parquet. pip install pandas pyarrow")

        self.parquet_path = parquet_path
        df = pd.read_parquet(parquet_path)

        img_col = "image" if "image" in df.columns else ("img" if "img" in df.columns else None)
        txt_col = "text" if "text" in df.columns else ("transcription" if "transcription" in df.columns else None)
        if img_col is None or txt_col is None:
            raise KeyError(f"Parquet missing required columns. columns={list(df.columns)}")

        self.images = df[img_col].tolist()
        raw_texts: List[str] = [str(x) for x in df[txt_col].tolist()]

        self.cm = charset
        self.transform = transform
        self.target_hw = (int(target_hw[0]), int(target_hw[1]))
        self.oov = oov

        self.text_norm = bool(text_norm)
        self.text_norm_form = str(text_norm_form or "")
        self.text_collapse_ws = bool(text_collapse_ws)
        self.text_drop_chars = str(text_drop_chars or "")
        self.max_label_len = int(max_label_len)

        # pre-normalize & filter
        self.texts: List[str] = []
        self.keep: List[int] = []
        for i, t in enumerate(raw_texts):
            s = normalize_transcription(
                t,
                form=self.text_norm_form if self.text_norm else "",
                collapse_whitespace=self.text_collapse_ws,
                strip=True,
                drop_chars=self.text_drop_chars,
            ) if self.text_norm else str(t)

            if len(s) == 0:
                # keep a single space to satisfy CTC if user wants; otherwise drop
                s = " "

            if self.max_label_len > 0 and len(s) > self.max_label_len:
                continue

            self.texts.append(s)
            self.keep.append(i)

    def __len__(self):
        return len(self.keep)

    def __getitem__(self, idx: int):
        real = self.keep[idx]
        img = _decode_parquet_image(self.images[real])
        text = self.texts[idx]

        th, tw = self.target_hw
        rw, rh = img.size
        cheap_vw = compute_valid_width(rw, rh, th, tw)
        # If cheap_vw was clipped by tw (very long lines), allow ink-based estimator to shrink it later.
        # This avoids "always full-length" input_lengths on wide lines.
        try:
            scale = th / float(rh)
            new_w = max(1, int(round(rw * scale)))
        except Exception:
            new_w = int(tw)
        base_vw = 1 if (new_w > tw and int(cheap_vw) >= int(tw)) else int(cheap_vw)
        meta: Dict[str, Any] = {"valid_w": int(base_vw), "target_w": int(tw), "raw_w": int(rw), "raw_h": int(rh)}

        if self.transform is not None:
            out = self.transform(img)
            if isinstance(out, (tuple, list)) and len(out) == 2:
                img_t, meta_u = out
                if isinstance(meta_u, dict):
                    merge_meta_keep_max_valid_w(meta, meta_u)
                img = img_t
            else:
                img = out

        target = self.cm.encode(text, oov=self.oov)
        return img, target, text, meta


def build_iam_line_dataloaders(args, logger=None):
    """
    IAM-line dataloaders.

    Key fix:
      - eval split is controlled by args.val_split (e.g., 'validation' or 'test')
      - charset_from='trainval' will still use validation texts for charset,
        even when eval split is 'test' (prevents charset leakage from test)
    """
    data_dir = getattr(args, "data_path", "") or "/mnt/sdb/datasets/OCR/IAM-line/data"

    train_split = getattr(args, "train_split", "train")
    eval_split = getattr(args, "val_split", "validation")  # in test script we override to "test"

    train_path = _resolve_iam_parquet(data_dir, str(train_split))
    eval_path = _resolve_iam_parquet(data_dir, str(eval_split))

    # Always keep a canonical validation path for charset building when needed
    val_path_for_charset = _resolve_iam_parquet(data_dir, "validation")

    target_h = int(getattr(args, "img_height", 128))

    # --- Width policy ---
    # IAM-line (already height-normalized to 128) has very wide lines (median ~1.7k-1.9k px, p95 ~2.8k-3.0k).
    # Using a tiny fixed width (e.g., 512) will squeeze many samples into "crowded" images.
    dynamic_width = bool(getattr(args, "dynamic_width", True))
    pad_to_multiple = int(getattr(args, "pad_to_multiple", 32))
    right_pad_min = int(getattr(args, "right_pad_min", 8))
    trim_whitespace = bool(getattr(args, "trim_whitespace", True))
    trim_margin = int(getattr(args, "trim_margin", 2))

    max_w = int(getattr(args, "img_max_width", 2048))
    # Safety: do not exceed what the model's max_seq_len can support (if provided).
    max_seq_len = int(getattr(args, "max_seq_len", 0))
    downsample_w = int(getattr(args, "downsample_w", getattr(args, "model_downsample_w", 4)))
    if max_seq_len > 0 and downsample_w > 0:
        cap = int(max_seq_len) * int(downsample_w)
        if cap > 0:
            max_w = min(max_w, cap)

    target_w = int(max_w)

    # ---- transforms ----
    if dynamic_width:
        transform_train = create_line_transform_dynamic(
            target_height=target_h,
            max_width=target_w,
            pad_to_multiple=pad_to_multiple,
            right_pad_min=right_pad_min,
            trim_whitespace=trim_whitespace,
            trim_margin=trim_margin,
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=float(getattr(args, "aug_affine_p", 0.15)),
            aug_wstretch_p=float(getattr(args, "aug_wstretch_p", 0.0)),
            aug_wstretch_min=float(getattr(args, "aug_wstretch_min", 0.7)),
            aug_wstretch_max=float(getattr(args, "aug_wstretch_max", 1.3)),
            aug_degrees=float(getattr(args, "aug_degrees", 2.0)),
            aug_translate=float(getattr(args, "aug_translate", 0.01)),
            aug_shear=float(getattr(args, "aug_shear", 2.0)),
            aug_stroke_p=float(getattr(args, "aug_stroke_p", 0.25)),
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=float(getattr(args, "aug_sharpen_p", 0.15)),
            aug_invert_p=float(getattr(args, "aug_invert_p", 0.0)),
            aug_noise_p=float(getattr(args, "aug_noise_p", 0.15)),
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )

        # validation / test: NO random augmentation (avoid eval noise)
        transform_eval = create_line_transform_dynamic(
            target_height=target_h,
            max_width=target_w,
            pad_to_multiple=pad_to_multiple,
            right_pad_min=right_pad_min,
            trim_whitespace=trim_whitespace,
            trim_margin=trim_margin,
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=0.0,
            aug_wstretch_p=0.0,
            aug_wstretch_min=float(getattr(args, "aug_wstretch_min", 0.7)),
            aug_wstretch_max=float(getattr(args, "aug_wstretch_max", 1.3)),
            aug_degrees=0.0,
            aug_translate=0.0,
            aug_shear=0.0,
            aug_stroke_p=0.0,
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=0.0,
            aug_invert_p=0.0,
            aug_noise_p=0.0,
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )
        collate_fn = ctc_collate_pad_to_max
    else:
        # Legacy fixed-width pipeline
        transform_train = create_line_transform(
            target_height=target_h,
            target_width=target_w,
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=float(getattr(args, "aug_affine_p", 0.15)),
            aug_wstretch_p=float(getattr(args, "aug_wstretch_p", 0.0)),
            aug_wstretch_min=float(getattr(args, "aug_wstretch_min", 0.7)),
            aug_wstretch_max=float(getattr(args, "aug_wstretch_max", 1.3)),
            aug_degrees=float(getattr(args, "aug_degrees", 2.0)),
            aug_translate=float(getattr(args, "aug_translate", 0.01)),
            aug_shear=float(getattr(args, "aug_shear", 2.0)),
            aug_stroke_p=float(getattr(args, "aug_stroke_p", 0.25)),
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=float(getattr(args, "aug_sharpen_p", 0.15)),
            aug_invert_p=float(getattr(args, "aug_invert_p", 0.0)),
            aug_noise_p=float(getattr(args, "aug_noise_p", 0.15)),
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )

        transform_eval = create_line_transform(
            target_height=target_h,
            target_width=target_w,
            normalize=getattr(args, "normalize", "half"),
            estimate_valid_w=bool(getattr(args, "estimate_valid_w", True)),
            bg_thresh=int(getattr(args, "bg_thresh", 250)),
            aug_affine_p=0.0,
            aug_degrees=0.0,
            aug_translate=0.0,
            aug_shear=0.0,
            aug_stroke_p=0.0,
            aug_stroke_kmin=int(getattr(args, "aug_stroke_kmin", 3)),
            aug_stroke_kmax=int(getattr(args, "aug_stroke_kmax", 5)),
            aug_sharpen_p=0.0,
            aug_invert_p=0.0,
            aug_noise_p=0.0,
            aug_noise_std=float(getattr(args, "aug_noise_std", 0.03)),
        )
        collate_fn = ctc_collate

    if pd is None:
        raise ImportError("pandas/pyarrow are required for IAM-line parquet. pip install pandas pyarrow")

    # -------- charset --------
    df_train = pd.read_parquet(train_path)
    txt_col = "text" if "text" in df_train.columns else ("transcription" if "transcription" in df_train.columns else None)
    if txt_col is None:
        raise KeyError(f"train parquet missing text column. columns={list(df_train.columns)}")

    text_norm = bool(getattr(args, "text_norm", True))
    text_norm_form = getattr(args, "text_norm_form", "NFKC")
    text_collapse_ws = bool(getattr(args, "text_collapse_ws", True))
    text_drop_chars = getattr(args, "text_drop_chars", "¬")

    raw_train_texts = [str(x) for x in df_train[txt_col].tolist()]
    train_texts = [
        normalize_transcription(
            t,
            form=text_norm_form if text_norm else "",
            collapse_whitespace=text_collapse_ws,
            strip=True,
            drop_chars=text_drop_chars,
        ) if text_norm else t
        for t in raw_train_texts
    ]
    train_texts = [t if t else " " for t in train_texts]

    charset_from = str(getattr(args, "charset_from", "trainval"))
    charset_texts = list(train_texts)

    # For trainval: always use canonical validation split, not eval split (prevents leakage)
    if charset_from in ("trainval", "all"):
        df_val = pd.read_parquet(val_path_for_charset)
        txt_col_v = "text" if "text" in df_val.columns else ("transcription" if "transcription" in df_val.columns else None)
        if txt_col_v is None:
            raise KeyError(f"validation parquet missing text column. columns={list(df_val.columns)}")
        raw_val_texts = [str(x) for x in df_val[txt_col_v].tolist()]
        val_texts = [
            normalize_transcription(
                t,
                form=text_norm_form if text_norm else "",
                collapse_whitespace=text_collapse_ws,
                strip=True,
                drop_chars=text_drop_chars,
            ) if text_norm else t
            for t in raw_val_texts
        ]
        val_texts = [t if t else " " for t in val_texts]
        charset_texts.extend(val_texts)

    # For all: include test texts if present (explicitly requested)
    if charset_from == "all":
        test_path = _resolve_iam_parquet(data_dir, "test") if os.path.exists(os.path.join(data_dir, "test.parquet")) else None
        if test_path and os.path.exists(test_path):
            df_test = pd.read_parquet(test_path)
            txt_col_t = "text" if "text" in df_test.columns else ("transcription" if "transcription" in df_test.columns else None)
            if txt_col_t is not None:
                raw_test_texts = [str(x) for x in df_test[txt_col_t].tolist()]
                test_texts = [
                    normalize_transcription(
                        t,
                        form=text_norm_form if text_norm else "",
                        collapse_whitespace=text_collapse_ws,
                        strip=True,
                        drop_chars=text_drop_chars,
                    ) if text_norm else t
                    for t in raw_test_texts
                ]
                test_texts = [t if t else " " for t in test_texts]
                charset_texts.extend(test_texts)

    cm = CharsetMapper.from_texts(charset_texts, sort=True, add_unk=bool(getattr(args, "add_unk", False)))
    oov_policy = getattr(args, "oov", "error")
    max_label_len = int(getattr(args, "max_label_len", 0))

    ds_train = IAMLineParquetDataset(
        train_path,
        charset=cm,
        target_hw=(target_h, target_w),
        transform=transform_train,
        oov=oov_policy,
        text_norm=text_norm,
        text_norm_form=text_norm_form,
        text_collapse_ws=text_collapse_ws,
        text_drop_chars=text_drop_chars,
        max_label_len=max_label_len,
    )
    ds_eval = IAMLineParquetDataset(
        eval_path,
        charset=cm,
        target_hw=(target_h, target_w),
        transform=transform_eval,
        oov=oov_policy,
        text_norm=text_norm,
        text_norm_form=text_norm_form,
        text_collapse_ws=text_collapse_ws,
        text_drop_chars=text_drop_chars,
        max_label_len=max_label_len,
    )

    if logger:
        logger.info(f"IAM-line data_dir = {data_dir}")
        logger.info(f"  train_split   = {train_split} -> {train_path}")
        logger.info(f"  eval_split    = {eval_split} -> {eval_path}")
        logger.info(f"  charset_from  = {charset_from} (val_for_charset={val_path_for_charset})")
        logger.info(f"  charset size  = {len(cm.charset)}, CTC classes = {cm.num_classes}")
        logger.info(
            f"  text_norm={text_norm} form={text_norm_form} collapse_ws={text_collapse_ws} drop_chars={repr(text_drop_chars)}"
        )
        if max_label_len > 0:
            logger.info(f"  max_label_len={max_label_len}")

    if _is_dist():
        train_sampler = DistributedSampler(ds_train, shuffle=True, drop_last=True)
        eval_sampler = DistributedShardSampler(ds_eval, shuffle=False)
        shuffle_train = False
    else:
        train_sampler = None
        eval_sampler = None
        shuffle_train = True

    train_loader = DataLoader(
        ds_train,
        batch_size=int(getattr(args, "batch_size", 8)),
        shuffle=shuffle_train,
        sampler=train_sampler,
        num_workers=int(getattr(args, "workers", 4)),
        pin_memory=bool(getattr(args, "pin_mem", True)),
        persistent_workers=bool(getattr(args, "persistent_workers", True)) if int(getattr(args, "workers", 4)) > 0 else False,
        collate_fn=collate_fn,
        drop_last=True,
    )

    eval_loader = DataLoader(
        ds_eval,
        batch_size=int(getattr(args, "val_batch_size", getattr(args, "batch_size", 8))),
        shuffle=False,
        sampler=eval_sampler,
        num_workers=int(getattr(args, "workers", 4)),
        pin_memory=bool(getattr(args, "pin_mem", True)),
        persistent_workers=bool(getattr(args, "persistent_workers", True)) if int(getattr(args, "workers", 4)) > 0 else False,
        collate_fn=collate_fn,
        drop_last=False,
    )

    return train_loader, eval_loader, cm, cm.blank_id, cm.num_classes
