# dataset_builders.py（更新版）
from __future__ import annotations

import glob
import hashlib
import json
import os.path as osp
import random
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate

from .sa1b import SA1BDataset
from .samed2d import SAMed2DTestingDataset, SAMed2DTrainingDataset


# ----------------- Collate：保持与你现有 pipeline 完全一致 -----------------
def _collate_generic(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    if len(batch) == 0:
        return {}
    passthrough = {"ori_label", "original_size", "label_path", "name", "id", "annot"}
    keys = batch[0].keys()
    out = {}
    for k in keys:
        vs = [x.get(k) for x in batch]
        if any(v is None for v in vs):
            out[k] = None
            continue
        if k in passthrough:
            out[k] = vs
            continue
        out[k] = default_collate(vs)
    return out


def _make_loader(
    ds,
    batch_size: int,
    is_train: bool,
    num_workers: int = 4,
    pin_memory: Optional[bool] = None,
    prefetch_factor: Optional[int] = None,
    drop_last: Optional[bool] = None,
):
    if ds is None:
        return None
    if pin_memory is None:
        pin_memory = torch.cuda.is_available()
    return DataLoader(
        dataset=ds,
        batch_size=batch_size,
        shuffle=is_train,
        drop_last=(True if is_train else False) if drop_last is None else bool(drop_last),
        num_workers=num_workers,
        persistent_workers=(num_workers > 0),
        pin_memory=pin_memory,
        prefetch_factor=(prefetch_factor if (num_workers > 0 and prefetch_factor is not None) else None),
        worker_init_fn=_worker_init_fn,
        collate_fn=_collate_generic,
    )


def _worker_init_fn(worker_id: int):
    seed = torch.initial_seed() % 2**32
    np.random.seed(seed)
    random.seed(seed)
    try:
        import cv2

        cv2.setNumThreads(0)
    except Exception:
        pass


# ----------------- 哈希划分工具：稳定、可复现 -----------------
def _sha1_int(s: str, salt: int) -> int:
    h = hashlib.sha1()
    h.update(f"{salt}::{s}".encode())
    return int(h.hexdigest(), 16)


def _pick_by_hash(keys: Sequence[str], k: int, salt: int) -> List[str]:
    keys = list(keys)
    if k > len(keys):
        raise ValueError(f"Requested {k} samples but only {len(keys)} available.")
    keys_sorted = sorted(keys, key=lambda x: _sha1_int(x, salt))
    return keys_sorted[:k]


# ----------------- SA-1B anchors（stem 交集） -----------------
def _sa1b_stems(root: str) -> List[str]:
    jpgs = {osp.splitext(osp.basename(p))[0] for p in glob.glob(osp.join(root, "*.jpg"))}
    jsons = {osp.splitext(osp.basename(p))[0] for p in glob.glob(osp.join(root, "*.json"))}
    stems = sorted(jpgs & jsons)
    if not stems:
        raise RuntimeError(f"No valid jpg/json pairs in {root}")
    return stems


# ----------------- SAMed2D 锚点加载 -----------------
def _load_json(path: str) -> dict:
    if not osp.exists(path):
        raise FileNotFoundError(path)
    with open(path, "r") as f:
        return json.load(f)


def _samed2d_image_anchors(root: str, mode: str) -> List[str]:
    # image2label_{mode}.json : keys are image relative paths
    j = _load_json(osp.join(root, f"image2label_{mode}.json"))
    return sorted(list(j.keys()))


def _samed2d_val_mask_anchors_filtered_by_images(root: str, mode: str, allowed_images: set[str]) -> List[str]:
    # label2image_{mode}.json : key=mask relpath, val=image relpath
    j = _load_json(osp.join(root, f"label2image_{mode}.json"))
    masks = [m for m, im in j.items() if im in allowed_images]
    return sorted(masks)


# =============================================================================
# Public API（Accelerate 托管采样：这里不再塞 DistributedSampler）
# =============================================================================
def build_dataloaders(
    *,
    datasets_root: Dict[str, Path | str],
    batch_size: int,
    train_size: int = 10_000,
    grad_size: int = 10_000,
    val_size: int = 50,
    num_workers: int = 4,
    prefetch_factor: Optional[int] = None,
    pin_memory: Optional[bool] = None,
    subset_seed: int = 0,
) -> Dict[str, Optional[DataLoader]]:
    # --------- 强约束：pre/down 均必须提供；数量约束提前报错 ---------
    root_pre_train = datasets_root.get("train_pre")
    root_down_train = datasets_root.get("train_down") or datasets_root.get("train")
    root_pre_val = datasets_root.get("val_pre")

    if not all([root_pre_train, root_pre_val, root_down_train]):
        raise ValueError("'train_pre', 'val_pre', and 'train_down' roots are all required.")

    root_down_val = datasets_root.get("val_down") or datasets_root.get("val", root_down_train)
    root_pre_train = str(root_pre_train)
    root_pre_val = str(root_pre_val)
    root_down_train = str(root_down_train)
    root_down_val = str(root_down_val)
    if grad_size > train_size:
        raise ValueError(f"grad_size({grad_size}) must be <= train_size({train_size}).")

    # ===================== SA-1B (兼容双模式) =====================
    if root_pre_train == root_pre_val:
        print("SA-1B paths are the same. Splitting from a common pool...")
        sa1b_all = _sa1b_stems(root_pre_train)
        if train_size + val_size > len(sa1b_all):
            raise ValueError(f"SA-1B train_size({train_size}) + val_size({val_size}) exceeds total available samples ({len(sa1b_all)}).")
        train_pre_stems = _pick_by_hash(sa1b_all, train_size, subset_seed + 11)
        remain_set = set(sa1b_all) - set(train_pre_stems)
        remain = sorted(list(remain_set))
        val_pre_stems = _pick_by_hash(remain, val_size, subset_seed + 12)
    else:
        print("SA-1B paths are different. Sampling from separate pools...")
        train_candidates = _sa1b_stems(root_pre_train)
        val_candidates = _sa1b_stems(root_pre_val)
        if train_size > len(train_candidates):
            raise ValueError(f"SA-1B train_size={train_size} exceeds available={len(train_candidates)} in train_pre dir.")
        if val_size > len(val_candidates):
            raise ValueError(f"SA-1B val_size={val_size} exceeds available={len(val_candidates)} in val_pre dir.")
        train_pre_stems = _pick_by_hash(train_candidates, train_size, subset_seed + 11)
        val_pre_stems = _pick_by_hash(val_candidates, val_size, subset_seed + 12)
        if not set(train_candidates).isdisjoint(set(val_candidates)):
            print("Warning: train_pre and val_pre directories contain overlapping samples.")

    grad_pre_stems = _pick_by_hash(train_pre_stems, grad_size, subset_seed + 13) if grad_size > 0 else []
    assert set(train_pre_stems).isdisjoint(set(val_pre_stems))
    assert set(grad_pre_stems).issubset(set(train_pre_stems))

    ds_grad_pre = SA1BDataset(root_pre_train, image_size=1024, allowed_stems=grad_pre_stems)
    ds_train_pre = SA1BDataset(root_pre_train, image_size=1024, allowed_stems=train_pre_stems)
    ds_val_pre = SA1BDataset(root_pre_val, image_size=1024, allowed_stems=val_pre_stems, is_validation=True)

    # ===================== SAMed2D (简化为单一的全局文件+动态划分模式) =====================
    print("SAMed2D: Loading from a global pool and splitting dynamically...")

    # 1. 从全局文件(约定名为 image2label_train.json)加载所有图片
    down_images_all = _samed2d_image_anchors(root_down_train, "train")
    if train_size + val_size > len(down_images_all):
        raise ValueError(f"SAMed2D train_size({train_size}) + val_size({val_size}) exceeds total available samples ({len(down_images_all)}).")

    # 2. 动态哈希划分，得到 train_images 和 val_images 两个图片列表
    train_images = _pick_by_hash(down_images_all, train_size, subset_seed + 21)
    remain_img_set = set(down_images_all) - set(train_images)
    remain_img = sorted(list(remain_img_set))
    val_images = _pick_by_hash(remain_img, val_size * 20, subset_seed + 22)  ###########

    # 3. 创建 grad 子集
    grad_images = _pick_by_hash(train_images, grad_size, subset_seed + 23) if grad_size > 0 else []
    assert set(train_images).isdisjoint(set(val_images))
    assert set(grad_images).issubset(set(train_images))

    # 4. 创建数据集实例，直接传入图片列表作为 allowed_stems
    ds_grad_down = SAMed2DTrainingDataset(
        root_down_train,
        image_size=256,
        mode="all",
        point_num=1,
        mask_num=5,
        requires_name=True,
        allowed_stems=grad_images,
    )
    ds_train_down = SAMed2DTrainingDataset(
        root_down_train,
        image_size=256,
        mode="all",
        point_num=1,
        mask_num=5,
        requires_name=True,
        allowed_stems=train_images,
    )

    ds_val_down = SAMed2DTestingDataset(
        root_down_val,
        image_size=256,
        mode="all",
        requires_name=True,
        point_num=1,
        return_original_mask=True,
        allowed_stems=val_images,  # 传入图片列表，Dataset内部自己处理
    )

    # ----------------- 构建 DataLoader -----------------
    loaders: Dict[str, Optional[DataLoader]] = {
        "grad_pre": _make_loader(
            ds_grad_pre,
            is_train=True,
            batch_size=batch_size,
            num_workers=num_workers,
            prefetch_factor=prefetch_factor,
            pin_memory=pin_memory,
        ),
        "train_pre": _make_loader(
            ds_train_pre,
            is_train=True,
            batch_size=batch_size,
            num_workers=num_workers,
            prefetch_factor=prefetch_factor,
            pin_memory=pin_memory,
        ),
        "val_pre": _make_loader(
            ds_val_pre,
            is_train=False,
            batch_size=1,
            num_workers=min(2, num_workers),
            pin_memory=pin_memory,
        ),
        "grad_down": _make_loader(
            ds_grad_down,
            is_train=True,
            batch_size=batch_size,
            num_workers=num_workers,
            prefetch_factor=prefetch_factor,
            pin_memory=pin_memory,
        ),
        "train_down": _make_loader(
            ds_train_down,
            is_train=True,
            batch_size=batch_size,
            num_workers=num_workers,
            prefetch_factor=prefetch_factor,
            pin_memory=pin_memory,
        ),
        "val_down": _make_loader(
            ds_val_down,
            is_train=False,
            batch_size=1,
            num_workers=min(2, num_workers),
            pin_memory=pin_memory,
        ),
    }
    return loaders
