# dataset/samed2d.py
from __future__ import annotations

import json
import os
import random
from typing import Any, Dict, List, Optional, Sequence, Tuple

import cv2
import numpy as np
import torch
from albumentations import Compose, PadIfNeeded, Resize
from albumentations.pytorch import ToTensorV2
from skimage.measure import label as sk_label
from skimage.measure import regionprops
from torch.utils.data import Dataset


# -----------------------------------------------------------------------------
# Transforms & prompt helpers (aligned with your v2 logic)
# -----------------------------------------------------------------------------
def train_transforms(img_size: int, ori_h: int, ori_w: int):
    """
    If both sides are smaller than `img_size`, pad to a square.
    Otherwise, resize directly to a square with NEAREST interpolation.
    Then convert to tensors via ToTensorV2.
    """
    tfm = []
    if ori_h < img_size and ori_w < img_size:
        tfm.append(
            PadIfNeeded(
                min_height=img_size,
                min_width=img_size,
                border_mode=cv2.BORDER_CONSTANT,
                fill=0,  # New API in Albumentations 2.x
                fill_mask=0,  # New API in Albumentations 2.x
            )
        )
    else:
        tfm.append(Resize(int(img_size), int(img_size), interpolation=cv2.INTER_NEAREST))
    tfm.append(ToTensorV2(p=1.0))
    return Compose(tfm, p=1.0)


def get_boxes_from_mask(
    mask: torch.Tensor | np.ndarray,
    box_num: int = 1,
    std: float = 0.1,
    max_pixel: int = 5,
) -> torch.Tensor:
    """
    Extract bounding boxes from a binary mask via connected components.
    - For training: default adds noise (jitter) to boxes.
    - For testing: set max_pixel=0 to disable noise.
    Returns a Tensor of shape (K, 4) as (x0, y0, x1, y1), dtype float32.
    Includes a fallback full-image box if the mask has no foreground.
    """
    if isinstance(mask, torch.Tensor):
        mask = mask.detach().cpu().numpy()

    labeled = sk_label(mask)
    regions = regionprops(labeled)

    if len(regions) == 0:
        # Fallback: whole-image box if no connected components
        h, w = mask.shape[:2]
        boxes = [(0, 0, h, w)]  # (y0, x0, y1, x1)
    else:
        # Take up to `box_num` largest regions
        regions = sorted(regions, key=lambda r: r.area, reverse=True)[:box_num]
        boxes = [tuple(r.bbox) for r in regions]  # (y0, x0, y1, x1)
        if len(boxes) < box_num:
            boxes += [boxes[i % len(boxes)] for i in range(box_num - len(boxes))]

    out: List[Tuple[int, int, int, int]] = []
    for y0, x0, y1, x1 in boxes:
        w, h = abs(x1 - x0), abs(y1 - y0)
        noise_std = min(w, h) * std
        jitter_cap = min(max_pixel, int(noise_std * 5))
        if jitter_cap > 0:
            nx = int(np.random.randint(-jitter_cap, jitter_cap))
            ny = int(np.random.randint(-jitter_cap, jitter_cap))
        else:
            nx = ny = 0
        # Convert to (x0, y0, x1, y1) with jitter
        out.append((x0 + nx, y0 + ny, x1 + nx, y1 + ny))
    return torch.as_tensor(out, dtype=torch.float32)


def init_point_sampling(mask: torch.Tensor | np.ndarray, get_point: int = 1):
    """
    Randomly sample points from foreground/background.
    - If get_point == 1: sample one point (prefer foreground if present).
    - Else: sample half foreground and half background (with replacement),
      then shuffle.
    Returns:
      coords: Tensor (P, 2) float32 in (x, y) order
      labels: Tensor (P,) int64 where 1=FG, 0=BG
    """
    if isinstance(mask, torch.Tensor):
        mask = mask.detach().cpu().numpy()

    fg = np.argwhere(mask == 1)[:, ::-1]  # (x, y)
    bg = np.argwhere(mask == 0)[:, ::-1]

    if get_point == 1:
        if len(fg) > 0:
            coord = fg[np.random.randint(len(fg))].astype(np.float32)
            label = 1
        else:
            coord = bg[np.random.randint(len(bg))].astype(np.float32)
            label = 0
        return torch.from_numpy(coord).unsqueeze(0), torch.tensor([label], dtype=torch.long)

    num_fg = get_point // 2
    num_bg = get_point - num_fg

    fg_idx = np.random.choice(len(fg), num_fg, replace=True) if len(fg) > 0 else np.array([], dtype=int)
    bg_idx = np.random.choice(len(bg), num_bg, replace=True) if len(bg) > 0 else np.array([], dtype=int)

    coords = np.concatenate([fg[fg_idx], bg[bg_idx]], axis=0).astype(np.float32)
    labels = np.concatenate([np.ones(num_fg), np.zeros(num_bg)], axis=0).astype(np.int64)

    perm = np.random.permutation(get_point)
    coords = coords[perm]
    labels = labels[perm]
    return torch.from_numpy(coords), torch.from_numpy(labels)


# -----------------------------------------------------------------------------
# Datasets (names include the dataset prefix for readability)
# -----------------------------------------------------------------------------
class SAMed2DTestingDataset(Dataset):
    """
    SA-Med2D testing/validation dataset.

    新增:
      - allowed_stems: Optional[Sequence[str]]，若提供，仅保留 image_relpath 在白名单中的样本。
                       (这里的 "stems" 指图像相对路径，与本数据集的键一致)
    """

    def __init__(
        self,
        root_directory: str,
        image_size: int = 256,
        mode: str = "test",
        requires_name: bool = True,
        point_num: int = 1,
        return_original_mask: bool = True,
        prompt_json_path: str | None = None,
        disable_cv2_multithread: bool = False,
        allowed_stems: Optional[Sequence[str]] = None,  # ← 新增
    ) -> None:
        super().__init__()
        self.root_directory = root_directory
        self.image_size = int(image_size)
        self.requires_name = bool(requires_name)
        self.point_num = int(point_num)
        self.return_original_mask = bool(return_original_mask)

        if disable_cv2_multithread:
            try:
                cv2.setNumThreads(0)
            except Exception:
                pass

        self.pixel_mean = [123.675, 116.28, 103.53]
        self.pixel_std = [58.395, 57.12, 57.375]

        mapping_path = os.path.join(root_directory, f"label2image_{mode}.json")
        with open(mapping_path, "r") as f:
            mapping = json.load(f)  # dict: mask_rel -> image_rel

        # ---- 依据 allowed_stems 进行图像级白名单过滤（保持 JSON 原有顺序）----
        if allowed_stems is not None:
            allow = set(map(str, allowed_stems))
            items = [(m_rel, i_rel) for m_rel, i_rel in mapping.items() if i_rel in allow]
            if len(items) == 0:
                raise ValueError("[SAMed2DTestingDataset] No samples left after applying allowed_stems.")
        else:
            items = list(mapping.items())

        # 与原实现保持一致的内部结构（顺序不变）
        self.image_relpaths = [i for _, i in items]
        self.mask_relpaths = [m for m, _ in items]

        self.prompt_json = {} if prompt_json_path is None else json.load(open(prompt_json_path, "r"))

    def __len__(self) -> int:
        return len(self.mask_relpaths)

    def __getitem__(self, index: int) -> Dict[str, Any]:
        out: Dict[str, Any] = {}

        # --- image ---
        image_path = os.path.join(self.root_directory, self.image_relpaths[index])
        image_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
        if image_bgr is None:
            raise FileNotFoundError(f"Failed to read image: {image_path}")
        image = (image_bgr.astype(np.float32) - np.array(self.pixel_mean, dtype=np.float32)) / np.array(self.pixel_std, dtype=np.float32)

        # --- mask ---
        mask_rel = self.mask_relpaths[index]
        mask_path = os.path.join(self.root_directory, mask_rel)
        mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask_np is None:
            raise FileNotFoundError(f"Failed to read mask: {mask_path}")
        if mask_np.max() == 255:
            mask_np = mask_np / 255
        # Strict binary check for testing/validation
        assert np.array_equal(mask_np, mask_np.astype(bool)), f"Mask must be binary: {mask_rel}"

        ori_h, ori_w = mask_np.shape
        ori_mask_tensor = torch.tensor(mask_np, dtype=torch.float32).unsqueeze(0)  # [1, H, W]

        # --- transforms ---
        tfm = train_transforms(self.image_size, ori_h, ori_w)
        aug = tfm(image=image, mask=mask_np)
        input_image = aug["image"].to(torch.float32)  # [C, H', W'] float32
        mask_t = aug["mask"].to(torch.int64)  # [H', W'] int64

        # --- prompts ---
        if not self.prompt_json:
            boxes = get_boxes_from_mask(mask_t, max_pixel=0)  # no noise for testing
            point_coords, point_labels = init_point_sampling(mask_t, self.point_num)
        else:
            key = os.path.basename(mask_path)  # you mentioned data is not at root; keep basename as key
            prm = self.prompt_json[key]
            boxes = torch.as_tensor(prm["boxes"], dtype=torch.float32)
            point_coords = torch.as_tensor(prm["point_coords"], dtype=torch.float32)
            point_labels = torch.as_tensor(prm["point_labels"], dtype=torch.int64)

        # Align shapes with v1
        mask_t = mask_t.unsqueeze(0).unsqueeze(0)  # [1, 1, H', W']
        if point_coords.dim() == 2:
            point_coords = point_coords.unsqueeze(0)  # [1, P, 2]
            point_labels = point_labels.unsqueeze(0)  # [1, P]
        if boxes.dim() == 1:
            boxes = boxes.unsqueeze(0)  # [1, 4]

        out.update(
            dict(
                input_image=input_image,
                label=mask_t,
                point_coords=point_coords,
                point_labels=point_labels,
                boxes=boxes,
                original_size=(ori_h, ori_w),
                label_path=os.path.dirname(mask_path),
            )
        )
        if self.return_original_mask:
            out["ori_label"] = ori_mask_tensor
        if self.requires_name:
            out["name"] = os.path.basename(mask_rel)
        return out


class SAMed2DTrainingDataset(Dataset):
    """
    SA-Med2D training dataset.

    新增:
      - allowed_stems: Optional[Sequence[str]]，若提供，仅保留 image_relpath 在白名单中的图像；
                       该图像下的所有实例(mask)一并保留（图像级过滤，防止同图跨域泄漏）。
    """

    def __init__(
        self,
        root_directory: str,
        image_size: int = 256,
        mode: str = "train",
        requires_name: bool = True,
        point_num: int = 1,
        mask_num: int = 5,
        disable_cv2_multithread: bool = False,
        allowed_stems: Optional[Sequence[str]] = None,  # ← 新增
    ) -> None:
        super().__init__()
        self.root_directory = root_directory
        self.image_size = int(image_size)
        self.requires_name = bool(requires_name)
        self.point_num = int(point_num)
        self.mask_num = int(mask_num)

        if disable_cv2_multithread:
            try:
                cv2.setNumThreads(0)
            except Exception:
                pass

        self.pixel_mean = [123.675, 116.28, 103.53]
        self.pixel_std = [58.395, 57.12, 57.375]

        mapping_path = os.path.join(root_directory, f"image2label_{mode}.json")
        with open(mapping_path, "r") as f:
            mapping = json.load(f)  # dict: image_rel -> [mask_rel, ...]

        # ---- 依据 allowed_stems 进行图像级白名单过滤（保持 JSON 原有顺序）----
        if allowed_stems is not None:
            allow = set(map(str, allowed_stems))
            keys = [k for k in mapping.keys() if k in allow]
            if len(keys) == 0:
                raise ValueError("[SAMed2DTrainingDataset] No images left after applying allowed_stems.")
        else:
            keys = list(mapping.keys())

        # 与原实现保持一致的内部结构（顺序不变）
        self.image_relpaths = keys
        self.mask_lists_per_image = [mapping[k] for k in keys]

    def __len__(self) -> int:
        return len(self.image_relpaths)

    def __getitem__(self, index: int) -> Dict[str, Any]:
        out: Dict[str, Any] = {}

        # --- image ---
        image_path = os.path.join(self.root_directory, self.image_relpaths[index])
        image_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR)
        if image_bgr is None:
            raise FileNotFoundError(f"Failed to read image: {image_path}")
        image = (image_bgr.astype(np.float32) - np.array(self.pixel_mean, dtype=np.float32)) / np.array(self.pixel_std, dtype=np.float32)
        h, w, _ = image.shape

        tfm = train_transforms(self.image_size, h, w)

        # --- choose mask_num instances (with replacement) ---
        chosen_masks = random.choices(self.mask_lists_per_image[index], k=self.mask_num)

        masks_t, boxes_t, pts_c, pts_l = [], [], [], []
        last_img_t: torch.Tensor | None = None

        for m_rel in chosen_masks:
            mask_path = os.path.join(self.root_directory, m_rel)
            m_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if m_np is None:
                raise FileNotFoundError(f"Failed to read mask: {mask_path}")
            if m_np.max() == 255:
                m_np = m_np / 255

            aug = tfm(image=image, mask=m_np)
            img_t = aug["image"].float()  # [C, H', W']
            m_t = aug["mask"].to(torch.int64)  # [H', W']

            b = get_boxes_from_mask(m_t)  # noisy boxes for training
            pc, pl = init_point_sampling(m_t, self.point_num)

            masks_t.append(m_t)
            boxes_t.append(b)
            pts_c.append(pc)
            pts_l.append(pl)
            last_img_t = img_t

        label = torch.stack(masks_t, dim=0).unsqueeze(1)  # [M, 1, H', W']
        boxes = torch.stack(boxes_t, dim=0)  # [M, 4]
        point_coords = torch.stack(pts_c, dim=0)  # [M, P, 2]
        point_labels = torch.stack(pts_l, dim=0)  # [M, P]

        out.update(
            dict(
                input_image=last_img_t.float(),  # [C, H', W']
                label=label,
                boxes=boxes,
                point_coords=point_coords,
                point_labels=point_labels,
            )
        )
        if self.requires_name:
            out["name"] = os.path.basename(self.image_relpaths[index])
        return out
