# sa1b_dataset.py
"""
SA-1B dataset (simple, deterministic, no-dist).

- Pair *.jpg and *.json by filename stem (intersection + natural sort).
- Optional `allowed_stems` to restrict to a precomputed split (no overlap guaranteed externally).
- Preprocess: RGB -> ResizeLongestSide -> normalize -> pad-right/bottom to a square.
"""

from __future__ import annotations

import glob
import os.path as osp
import re
from typing import Any, Dict, List, Optional, Sequence, Tuple

import cv2
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import Dataset

from .transforms import ResizeLongestSide  # external, unchanged


def _natural_sort_key(text: str) -> List[Any]:
    """Return chunks of digits/non-digits for 'natural' ordering."""
    return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", text)]


class SA1BDataset(Dataset):
    """Single-image + JSON-prompt loader for a SA-1B-style folder."""

    def __init__(
        self,
        root_directory: str,
        image_size: int = 1024,
        pixel_mean: Sequence[float] = (123.675, 116.28, 103.53),
        pixel_std: Sequence[float] = (58.395, 57.12, 57.375),
        is_validation: bool = False,
        allowed_stems: Optional[Sequence[str]] = None,  # whitelist for precomputed split
        force_imread_color: bool = True,  # force 3-channel read
        disable_cv2_multithread: bool = False,  # optionally set OpenCV threads to 0
    ) -> None:
        super().__init__()
        self.root_directory = root_directory
        self.image_size = int(image_size)
        self.is_validation = bool(is_validation)

        self.pixel_mean_tensor = torch.tensor(pixel_mean, dtype=torch.float32).view(3, 1, 1)
        self.pixel_std_tensor = torch.tensor(pixel_std, dtype=torch.float32).view(3, 1, 1)
        self.resize_transform = ResizeLongestSide(self.image_size)
        self.force_imread_color = bool(force_imread_color)

        if disable_cv2_multithread:
            try:
                cv2.setNumThreads(0)
            except Exception:
                pass

        # ---- Robust jpg/json pairing by filename stem ----
        image_file_paths = sorted(glob.glob(osp.join(root_directory, "*.jpg")))
        json_file_paths = sorted(glob.glob(osp.join(root_directory, "*.json")))

        image_path_by_stem = {osp.splitext(osp.basename(p))[0]: p for p in image_file_paths}
        json_path_by_stem = {osp.splitext(osp.basename(p))[0]: p for p in json_file_paths}

        image_stems = set(image_path_by_stem.keys())
        json_stems = set(json_path_by_stem.keys())
        common_stems = image_stems & json_stems

        if allowed_stems is not None:
            allowed_set = set(map(str, allowed_stems))
            common_stems &= allowed_set

        if not common_stems:
            raise RuntimeError(f"No valid jpg/json pairs found in: {root_directory}")

        # Stable natural ordering
        ordered_stems = sorted(common_stems, key=_natural_sort_key)
        self.paired_items: List[Tuple[str, str]] = [(image_path_by_stem[s], json_path_by_stem[s]) for s in ordered_stems]

        # Non-fatal diagnostics for mismatches (helpful in curation)
        missing_images = sorted(json_stems - image_stems, key=_natural_sort_key)
        missing_jsons = sorted(image_stems - json_stems, key=_natural_sort_key)
        if missing_images:
            print(f"[SA1BDataset] {len(missing_images)} prompts without image, e.g. {missing_images[:5]}")
        if missing_jsons:
            print(f"[SA1BDataset] {len(missing_jsons)} images without prompt, e.g. {missing_jsons[:5]}")

    # ----------------------------- Dataset API -----------------------------

    def __len__(self) -> int:  # type: ignore[override]
        return len(self.paired_items)

    def __getitem__(self, index: int) -> Dict[str, Any]:  # type: ignore[override]
        image_path, annotation_path = self.paired_items[index]

        # --- Load image (BGR->RGB), enforce 3 channels if requested ---
        imread_flag = cv2.IMREAD_COLOR if self.force_imread_color else cv2.IMREAD_UNCHANGED
        image_bgr = cv2.imread(image_path, imread_flag)
        if image_bgr is None:
            raise FileNotFoundError(f"Failed to read image: {image_path}")
        if image_bgr.ndim != 3 or image_bgr.shape[2] != 3:
            raise ValueError(f"Expect 3-channel image, got {image_bgr.shape} @ {image_path}")
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

        # --- Preprocess (unchanged pipeline) ---
        resized_uint8 = self.resize_transform.apply_image(image_rgb)  # H' W' C (uint8)
        chw_float = torch.as_tensor(resized_uint8).permute(2, 0, 1).contiguous().to(dtype=torch.float32)
        input_image = self._normalize_and_pad(chw_float)

        original_image_size: Tuple[int, int] = tuple(image_rgb.shape[:2])  # (H, W)
        resized_image_size: Tuple[int, int] = tuple(chw_float.shape[-2:])  # (H', W')

        return {
            "sample_id": image_path,  # absolute image path as ID
            "input_image": input_image,  # Tensor 3×S×S
            "input_image_size": np.array(resized_image_size),  # (H', W') before padding
            "original_image_size": np.array(original_image_size),  # (H, W)
            "annotation_path": annotation_path,  # json path (lazy load upstream)
            "is_validation": self.is_validation,  # passthrough flag
        }

    # ----------------------------- Helpers -----------------------------

    def _normalize_and_pad(self, rgb_chw_float: torch.Tensor) -> torch.Tensor:
        """Normalize (RGB) and zero-pad to square image_size (pad right/bottom)."""
        normalized = (rgb_chw_float - self.pixel_mean_tensor) / self.pixel_std_tensor
        height, width = normalized.shape[-2:]
        pad_h = self.image_size - height
        pad_w = self.image_size - width
        if pad_h < 0 or pad_w < 0:
            raise ValueError(f"Input larger than target size: got {(height, width)} > {self.image_size}. ResizeLongestSide may have failed.")
        return F.pad(normalized, (0, pad_w, 0, pad_h))
