import os
import glob
import torch
from torch.utils.data import Dataset
from torchvision.transforms import v2
import numpy as np
import random
from PIL import Image
import h5py
import hdf5plugin
from tqdm import tqdm
import cv2

dataset_map = {
    "phrase_cut": "train_data/VGPhraseCut_v0",
    "refcoco": "/public_dataset/RefCOCO_h5/refcoco_unc/",
    "refcoco+": "/public_dataset/RefCOCO_h5/refcoco+_unc/",
    "refcocog": "/public_dataset/RefCOCO_h5/refcocog_umd/",
}


def read_data(
        root: str, dataset_type: str = ""
):
    """Reads image and mask data from the specified directory structure.

    Args:
        root: Root directory path containing the dataset files.
        dataset_type: Optional subdirectory name (e.g., 'train', 'val') to
            read specific dataset split. Defaults to empty string.

    Returns:
        List of dictionaries containing image paths and corresponding mask
        information. Each dictionary has 'image' (path to image file) and
        'mask' (path to mask file or h5 dataset) keys.
    """
    if dataset_type:
        root = os.path.join(root, dataset_type)
    image_dir = os.path.join(root, "images")
    mask_dir = None
    for mask_dir_name in ["perfect_masks", "masks"]:
        mask_dir = os.path.join(root, mask_dir_name)
        if os.path.exists(mask_dir):
            break
    images = glob.glob(os.path.join(image_dir, "*"))
    data = []

    if len(images) > 0 and not os.path.exists(mask_dir):
        mask_h5_path = os.path.join(root, "masks.h5")
        if not os.path.exists(mask_h5_path):
            return []
        mask_h5 = h5py.File(mask_h5_path, 'r')
        for img_path in images:
            img_name = os.path.basename(img_path)
            name = img_name.rsplit(".")[0]
            if name not in mask_h5:
                continue
            item = {
                "image": img_path,
                "mask": mask_h5[name]
            }
            data.append(item)
        return data

    for img_path in images:
        img_name = os.path.basename(img_path)
        ext = img_name.rsplit(".")[1]
        mask_path = os.path.join(mask_dir, img_name.replace(f".{ext}", ".pth"))
        if not os.path.isfile(mask_path):
            continue
        mask_dict = torch.load(mask_path, weights_only=False)
        if len(mask_dict) == 0:
            continue
        item = {
            "image": img_path,
            "mask": mask_path
        }
        data.append(item)
    return data


def read_data_h5(root: str, dataset_type: str = ""):
    """Reads image metadata including dimensions from an H5 dataset file.

    Args:
        root: Root directory path where the H5 file is located.
        dataset_type: Optional subdirectory name (e.g., 'train', 'val') to
            read specific dataset split. Defaults to empty string.

    Returns:
        Tuple containing two elements:
            - List of tuples with (image_id, width, height) for each image
            - Path to the H5 file if exists, None otherwise
    """
    if dataset_type:
        root = os.path.join(root, dataset_type)
    h5_path = os.path.join(root, "image_masks.h5")
    if not os.path.exists(h5_path):
        return [], None
    data = []
    with h5py.File(h5_path, 'r') as data_h5:
        for image_id in tqdm(data_h5, desc=f"h5 file:{h5_path}"):
            data.append((image_id, int(data_h5[image_id].attrs["width"]), int(data_h5[image_id].attrs["height"])))
    return data, h5_path



def read_data_h5_mask(root: str, dataset_type: str = "", use_refine=False):
    """Reads image metadata from an H5 file containing dataset information.

    Args:
        root: Root directory path where the H5 file is located.
        dataset_type: Optional subdirectory name (e.g., 'train', 'val') to
            read specific dataset split. Defaults to empty string.

    Returns:
        Tuple containing two elements:
            - List of tuples with (image_id, width, height) for each image
            - Path to the H5 file if exists, None otherwise
    """
    if dataset_type:
        root = os.path.join(root, dataset_type)
    if use_refine:
        h5_path = os.path.join(root, "image_masks_refine.h5")
    else:
        h5_path = os.path.join(root, "image_masks.h5")

    if not os.path.exists(h5_path):
        return [], None
    data = []
    with h5py.File(h5_path, 'r') as data_h5:
        for image_id in tqdm(data_h5, desc=f"h5 file:{h5_path}"):
            example_group = data_h5[f"{image_id}"]
            if use_refine and "mask_refine" in example_group:
                mask_group = example_group["mask_refine"]
            elif "mask" in example_group:
                mask_group = example_group["mask"]
            else:
                ValueError(f"{image_id} has not mask or refine mask")
            for caption in mask_group:
                if isinstance(data_h5[f"{image_id}/mask/{caption}"], h5py.Dataset):
                    data.append((image_id, caption, int(data_h5[image_id].attrs["width"]),
                                 int(data_h5[image_id].attrs["height"])))
    return data, h5_path


class NormalImageOperator:
    """Applies normalized image transformations including resizing and tensor conversion.

    Args:
        size: Target size (height, width) for image resizing.
        shuffle: If True, enables shuffle-specific transformations. Defaults to False.
    """

    def __init__(self, size, shuffle=False):

        self.size = size
        self.shuffle = shuffle
        self.transform_resize = v2.Compose(
            [
                v2.Resize(self.size),
                # v2.CenterCrop(self.size),
            ]
        )
        if shuffle:
            self.transform = v2.Compose(
                [
                    # v2.RandomHorizontalFlip(),
                    v2.ToTensor(),
                ]
            )
        else:
            self.transform = v2.Compose(
                [
                    v2.ToTensor(),
                ]
            )

    def __call__(self, data, target_size=None):
        """Applies composed transformations to input image/mask data.

        Args:
            data: Dictionary containing image and mask data to transform.
            target_size: Optional (width, height) tuple specifying exact output size.
                If None, uses the predefined size from initialization.

        Returns:
            Dictionary with transformed image and mask tensors.
        """
        data_new = {}
        key_arr, value_arr = list(zip(*data.items()))
        if target_size is None:
            value_arr = self.transform_resize(value_arr)
        else:
            w, h = value_arr[0].size
            tw, th = target_size
            ratio = max(tw / w, th / h)
            nw, nh = int(w * ratio + 0.5), int(h * ratio + 0.5)
            value_arr = v2.Resize((nh, nw))(value_arr)
            value_arr = v2.CenterCrop((th, tw))(value_arr)
        value_arr = self.transform(value_arr)
        for key, value in zip(key_arr, value_arr):
            data_new[key] = value
        return data_new

class BucketDataset(Dataset):
    """Implements a dataset with bucket-based image size grouping for efficient batching.

    This dataset groups images by aspect ratio into predefined buckets to minimize
    padding during batch processing, improving training efficiency.

    Args:
        datasets: List of dataset configuration dictionaries specifying dataset type
            and name for data loading.
        max_batchsize: Maximum number of samples per batch. Defaults to 2.
        shuffle: If True, shuffles the data during batch construction. Defaults to False.
        dataset_type: Specifies dataset split to load (e.g., 'train', 'val').
            Defaults to "train".
        strict_filter: If True, applies strict filtering on mask candidates.
            Defaults to False.
        except_sentence: If True, excludes 'sentence' key from mask descriptions.
            Defaults to True.
        unfold: If True, unfolds all mask entries for a single image. Defaults to False.
        only_word: If True, uses only word-level prompts. Defaults to True.
        max_num_data: Maximum number of data samples to load. Defaults to 8000000.
    """

    def __init__(self, datasets, max_batchsize=2, shuffle=False, dataset_type="train", strict_filter=False,
                 except_sentence=True, unfold=False, only_word=True, max_num_data=8000000):

        base_size = 512
        patch_size = 32
        buckets = [(base_size, base_size)]
        w = [(base_size + patch_size * r, base_size) for r in range(1, 20)]
        h = [(base_size, base_size + patch_size * r) for r in range(1, 20)]
        buckets.extend(w)
        buckets.extend(h)
        buckets = sorted(list(set(buckets)))

        self.data = []
        self.h5_file = None
        self.h5_path = None
        if isinstance(datasets, str):
            datasets = [datasets]
        for item in datasets:
            if isinstance(item, str) or item["type"] == "image":
                data_path = dataset_map[item["name"]]
                self.data.extend(read_data(data_path, dataset_type))
            elif item["type"] == "h5":
                assert self.h5_path is None
                data_path = dataset_map[item["name"]]
                h5_data, self.h5_path = read_data_h5(data_path, dataset_type)
                self.data.extend(h5_data)
            else:
                raise ValueError(f"dataset_type {item['type']} not supported")
        assert buckets is not None
        self.buckets = buckets
        self.asp_buckets = np.array([w / h for w, h in buckets])
        self.bid2row = {}
        self.item2bid = []
        self.max_batchsize = max_batchsize
        self.shuffle = shuffle
        self.batch_indices = []
        self.set_bucket()
        self.build_batch_indices()
        self.lambda_func = NormalImageOperator(base_size, shuffle)
        self.dataset_type = dataset_type
        self.epoch = 1
        self.strict_filter = strict_filter
        self.except_sentence = except_sentence
        self.unfold = unfold
        self.only_word = only_word
        self.max_num_data = max_num_data

    def process(self, idx):
        """Processes a single data item to generate model inputs.

        Args:
            idx: Index of the data item to process.

        Returns:
            List of dictionaries containing model inputs, including 'image' (tensor),
            'mask' (tensor), and 'prompt' (str) for each sample.
        """
        data: dict = self.data[idx]
        bid = self.item2bid[idx]
        w, h = self.buckets[bid]
        if isinstance(data, dict):
            image = Image.open(data["image"]).convert("RGB")
            mask_dict = torch.load(data["mask"], weights_only=False)
        elif isinstance(data, tuple):
            image_id = data[0]
            if self.h5_file is None:
                self.h5_file = h5py.File(self.h5_path, 'r')
            image_group = self.h5_file[image_id]
            mode = image_group["image/image"].attrs.get("mode", "BGR")
            image = np.array(image_group["image/image"][:])
            if mode == "BGR":
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(image.copy())
            mask_dict = {}
            for k, mask in image_group["mask"].items():
                try:
                    mask = np.array(mask[:]).copy()
                except TypeError:
                    # print(f"image_id:{image_id}, mask:{k}"[])
                    continue
                if not np.any(mask > 128):
                    continue
                mask_dict[k] = mask
        else:
            raise ValueError(f"item type {type(data)} not supported")

        if self.dataset_type == "train":
            mask_cls, mask = random.choice(list(mask_dict.items()))
            mask_cls = [mask_cls]
            mask = [mask]
        elif self.unfold:
            mask_cls, mask = [], []
            for k, v in mask_dict.items():
                mask_cls.append(k)
                mask.append(v)
        else:
            mask_cls, mask = list(mask_dict.items())[0]
            mask_cls = [mask_cls]
            mask = [mask]

        output_list = []
        for mask_cls_per, mask_per in zip(mask_cls, mask):

            if isinstance(mask_per, dict):
                mask_description = mask_per
                mask_per = mask_description.pop("mask")
                if not self.only_word:
                    if self.except_sentence and "sentence" in mask_description:
                        mask_description.pop("sentence")
                    if self.dataset_type == "train":
                        mask_type = random.choice(list(mask_description.keys()))
                    else:
                        prompts = list(mask_description.keys())
                        pick_idx = idx % len(prompts)
                        mask_type = prompts[pick_idx]
                    if self.strict_filter:
                        mask_candidate = mask_description[mask_type]
                        if mask_cls_per in mask_candidate:
                            mask_cls_per = mask_candidate
                    else:
                        mask_cls_per = mask_description[mask_type]
                # print(f"[prompt]: {mask_cls}, {list(mask_description.items())}")
            assert np.any(mask_per > 128)
            item = {"image": image.copy(), "mask": Image.fromarray(mask_per)}
            data_images = self.lambda_func(item, (w, h))
            output = {"prompt": mask_cls_per}
            output.update(data_images)
            output_list.append(output)
        return output_list

    def set_bucket(self):
        """Assigns each data item to an appropriate bucket based on image aspect ratio."""
        self.item2bid = []
        for idx in range(len(self.data)):
            item = self.data[idx]
            if isinstance(item, dict):
                image = Image.open(item["image"])
                width, height = image.size
            elif isinstance(item, tuple):
                width, height = item[1], item[2]
            else:
                raise ValueError(f"item type {type(item)} not supported")
            ratio = width / height
            bid = np.argmin(np.abs(self.asp_buckets - ratio))
            self.item2bid.append(bid)
            if bid not in self.bid2row:
                self.bid2row[bid] = []
            self.bid2row[bid].append(idx)

    def build_batch_indices(self):
        """Constructs batch indices by grouping bucket items into batches of max_batchsize."""
        print("build_batch...")
        if len(self.batch_indices) > 0 and not self.shuffle:
            return
        elif len(self.batch_indices) == 1:
            return
        self.batch_indices = []
        for bid, bucket in self.bid2row.items():
            if self.shuffle:
                random.shuffle(bucket)
            max_batchsize = self.max_batchsize
            if self.asp_buckets[bid] > 3 or self.asp_buckets[bid] < 0.33:
                max_batchsize = int(self.max_batchsize / 2 + 0.5)
            for start_idx in range(0, len(bucket), max_batchsize):
                end_idx = min(start_idx + max_batchsize, len(bucket))
                batch = bucket[start_idx:end_idx]
                self.batch_indices.append(batch)
        if self.shuffle:
            random.shuffle(self.batch_indices)

    def __getitem__(self, bid):
        """Retrieves a batch of data from the specified bucket index.

        Args:
            bid: Index of the bucket to retrieve the batch from.

        Returns:
            List of processed data dictionaries forming a batch.
        """
        if self.dataset_type == "train" and self.max_num_data < len(self.batch_indices):
            bid = random.randint(0, len(self.batch_indices) - 1)
        assert self.batch_indices
        data = []
        for idx in self.batch_indices[bid]:
            data.extend(self.process(idx))
        return data

    def __len__(self):
        """Returns the number of batches available in the dataset."""
        return min(len(self.batch_indices), self.max_num_data)

    def __del__(self):
        """Cleans up resources by closing the H5 file when the dataset is destroyed."""
        print(f"Closing file in process {os.getpid()}")
        if self.h5_file is not None:
            self.h5_file.close()

class MaskBucketDataset(BucketDataset):
    """Extends BucketDataset to support mask-specific datasets with optional refined masks.

    This subclass specializes in handling datasets with mask annotations, including
    support for refined mask versions when available. It maintains bucket-based
    batching for efficient processing while focusing on mask-related data loading.

    Args:
        datasets: List of dataset configuration dictionaries specifying H5 dataset type.
        max_batchsize: Maximum number of samples per batch. Defaults to 2.
        shuffle: If True, shuffles the data during batch construction. Defaults to False.
        dataset_type: Specifies dataset split to load (e.g., 'train', 'val').
            Defaults to "train".
        strict_filter: If True, applies strict filtering on mask candidates.
            Defaults to False.
        except_sentence: If True, excludes 'sentence' key from mask descriptions.
            Defaults to True.
        unfold: If True, unfolds all mask entries for a single image. Defaults to False.
        only_word: If True, uses only word-level prompts. Defaults to True.
        use_refine: If True, loads refined masks instead of original masks. Defaults to False.
        max_num_data: Maximum number of data samples to load. Defaults to 8000000.
    """

    def __init__(self,
                 datasets, max_batchsize=2, shuffle=False, dataset_type="train",
                 strict_filter=False, except_sentence=True, unfold=False, only_word=True,
                 use_refine=False,
                 max_num_data=8000000
                 ):

        base_size = 512
        patch_size = 32
        buckets = [(base_size, base_size)]
        w = [(base_size + patch_size * r, base_size) for r in range(1, 60)]
        h = [(base_size, base_size + patch_size * r) for r in range(1, 40)]
        buckets.extend(w)
        buckets.extend(h)
        buckets = sorted(list(set(buckets)))

        self.data = []
        self.h5_file = None
        self.h5_path = None
        if isinstance(datasets, str):
            datasets = [datasets]
        for item in datasets:
            if isinstance(item, str) or item["type"] == "image":
                continue
            elif item["type"] == "h5":
                assert self.h5_path is None
                data_path = dataset_map[item["name"]]
                h5_data, self.h5_path = read_data_h5_mask(data_path, dataset_type, use_refine)
                self.data.extend(h5_data)
            else:
                raise ValueError(f"dataset_type {item['type']} not supported")
        assert buckets is not None
        self.buckets = buckets
        self.asp_buckets = np.array([w / h for w, h in buckets])
        self.bid2row = {}
        self.item2bid = []
        self.max_batchsize = max_batchsize
        self.shuffle = shuffle
        self.batch_indices = []
        self.set_bucket()
        self.build_batch_indices()
        self.lambda_func = NormalImageOperator(base_size, shuffle)
        self.dataset_type = dataset_type
        self.epoch = 1
        self.strict_filter = strict_filter
        self.except_sentence = except_sentence
        self.unfold = unfold
        self.only_word = only_word
        self.use_refine = use_refine
        self.max_num_data = max_num_data

    def process(self, idx):
        """Processes a mask-specific data item to generate model inputs with optional refinements.

        Args:
            idx: Index of the data item to process.

        Returns:
            List of dictionaries containing model inputs, including 'image' (tensor),
            'mask' (tensor), optional 'mask_refine' (tensor), and 'prompt' (str).
        """
        data: dict = self.data[idx]
        bid = self.item2bid[idx]
        w, h = self.buckets[bid]
        if isinstance(data, dict):
            raise NotImplementedError
        elif isinstance(data, tuple):
            image_id, mask_cls = data[:2]
            if self.h5_file is None:
                self.h5_file = h5py.File(self.h5_path, 'r')
            image_group = self.h5_file[image_id]
            mode = image_group["image/image"].attrs.get("mode", "BGR")
            image = np.array(image_group["image/image"][:])
            if mode == "BGR":
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(image.copy())
            mask_refine = None
            if self.use_refine:
                mask_refine = image_group[f"mask_refine/{mask_cls}"]
                mask_refine = np.array(mask_refine[:]).copy()
                if np.all(mask_refine < 128):
                    mask_refine = None

            mask = image_group[f"mask/{mask_cls}"]
            mask = np.array(mask[:]).copy()
            if mask_refine is None:
                mask_refine = mask.copy()
        else:
            raise ValueError(f"item type {type(data)} not supported")

        mask_cls = [mask_cls]
        mask = [mask]
        mask_refine = [mask_refine]

        output_list = []
        for mask_cls_per, mask_per, mask_refine_per in zip(mask_cls, mask, mask_refine):
            assert np.any(mask_per > 128), f"{image_id}/{mask_cls_per}, mask_per有问题: {mask_per.mean()}"
            item = {"image": image.copy(), "mask": Image.fromarray(mask_per)}
            if mask_refine_per is not None:
                item["mask_refine"] = Image.fromarray(mask_refine_per)
            data_images = self.lambda_func(item, (w, h))
            output = {"prompt": mask_cls_per}
            output.update(data_images)
            output_list.append(output)
        return output_list

    def set_bucket(self):
        """Assigns mask dataset items to buckets based on image aspect ratio from tuple data."""
        self.item2bid = []
        for idx in range(len(self.data)):
            item = self.data[idx]
            if isinstance(item, dict):
                image = Image.open(item["image"])
                width, height = image.size
            elif isinstance(item, tuple):
                width, height = item[-2], item[-1]
            else:
                raise ValueError(f"item type {type(item)} not supported")
            ratio = width / height
            bid = np.argmin(np.abs(self.asp_buckets - ratio))
            self.item2bid.append(bid)
            if bid not in self.bid2row:
                self.bid2row[bid] = []
            self.bid2row[bid].append(idx)


class TestSegDataset(Dataset):
    """Implements a dataset for segmentation testing with simplified batching logic.

    This dataset is optimized for inference/testing scenarios, providing straightforward
    data loading and processing without bucket-based batching, focusing on stability
    and reproducibility.

    Args:
        datasets: List of dataset configuration dictionaries specifying dataset type
            and name for test data loading.
        dataset_type: Specifies dataset split to load (e.g., 'test', 'val').
            Defaults to "train".
        unfold: If True, unfolds all mask entries for a single image. Defaults to False.
    """

    def __init__(self, datasets, dataset_type="train", unfold=False):

        base_size = 512

        self.data = []
        self.h5_file = None
        self.h5_path = None
        if isinstance(datasets, str):
            datasets = [datasets]
        for item in datasets:
            if isinstance(item, str) or item["type"] == "image":
                self.data.extend(read_data(dataset_map[item["name"]], dataset_type))
            elif item["type"] == "h5":
                assert self.h5_path is None
                data_path = dataset_map[item["name"]]
                h5_data, self.h5_path = read_data_h5(data_path, dataset_type)
                self.data.extend(h5_data)
            else:
                raise ValueError(f"dataset_type {item['type']} not supported")
        self.lambda_func = NormalImageOperator(base_size, False)
        self.unfold = unfold

    def process(self, idx):
        """Processes a single test data item to generate model inputs.

        Args:
            idx: Index of the test data item to process.

        Returns:
            List of dictionaries containing test inputs, including 'image' (tensor),
            'mask' (tensor), and 'prompt' (str) for each sample.
        """
        data: dict = self.data[idx]
        if isinstance(data, dict):
            image = Image.open(data["image"]).convert("RGB")
            mask_dict = torch.load(data["mask"], weights_only=False)
        elif isinstance(data, tuple):
            image_id = data[0]
            if self.h5_file is None:
                self.h5_file = h5py.File(self.h5_path, 'r')
            image_group = self.h5_file[image_id]
            mode = image_group["image/image"].attrs.get("mode", "BGR")
            image = np.array(image_group["image/image"][:])
            if mode == "BGR":
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(image.copy())
            mask_dict = {}
            for k, mask in image_group["mask"].items():
                try:
                    mask = np.array(mask[:]).copy()
                except TypeError:
                    continue
                if not np.any(mask > 128):
                    continue
                mask_dict[k] = mask
        else:
            raise ValueError(f"item type {type(data)} not supported")

        if self.unfold:
            mask_cls, mask = [], []
            for k, v in mask_dict.items():
                mask_cls.append(k)
                mask.append(v)
        else:
            mask_cls, mask = list(mask_dict.items())[0]
            mask_cls = [mask_cls]
            mask = [mask]

        output_list = []
        for mask_cls_per, mask_per in zip(mask_cls, mask):

            if isinstance(mask_per, dict):
                mask_description = mask_per
                mask_per = mask_description.pop("mask")
                if not self.only_word:
                    if self.except_sentence and "sentence" in mask_description:
                        mask_description.pop("sentence")
                    if self.dataset_type == "train":
                        mask_type = random.choice(list(mask_description.keys()))
                    else:
                        prompts = list(mask_description.keys())
                        pick_idx = idx % len(prompts)
                        mask_type = prompts[pick_idx]
                    if self.strict_filter:
                        mask_candidate = mask_description[mask_type]
                        if mask_cls_per in mask_candidate:
                            mask_cls_per = mask_candidate
                    else:
                        mask_cls_per = mask_description[mask_type]
            assert np.any(mask_per > 128)
            item = {"image": image.copy(), "mask": Image.fromarray(mask_per)}
            data_images = self.lambda_func(item)
            output = {"prompt": mask_cls_per}
            output.update(data_images)
            output_list.append(output)
        return output_list

    def __getitem__(self, idx):
        """Retrieves a processed test data item by index.

        Args:
            idx: Index of the test item to retrieve.

        Returns:
            List of processed test inputs for the specified index.
        """
        data = self.process(idx)
        return data

    def __len__(self):
        """Returns the total number of test data items."""
        return len(self.data)
