import glob
import json
import os
from typing import Callable, Optional

import torch
from PIL import Image, ImageFile
from torch.utils.data import Dataset
from torchvision import tv_tensors
from torchvision.transforms import v2
ImageFile.LOAD_TRUNCATED_IMAGES = True


class ExpertDataset(Dataset):
    """
    The images which have the same caption are grouped together in the same
    directory. The directory name is the hash of the caption.
    ```
    dirname = hashlib.sha1(caption.encode()).hexdigest()
    ```

    Each image is companioned a caption file (TXT) and an optional metadata
    file (JSON). The filenames are the same as the image file.

    Args:
        root: The root directory of the dataset.
        transform: The transformation function for the image. If one of the
            `random_flip` and `random_crop` is not None, the `transform` will
            be ommited.
        random_flip: Whether to apply random horizontal flip.
        random_crop: Whether to apply random resized crop. If False, the center
            crop will be applied.
        resolution: The resolution of the image to be cropped.
    """
    def __init__(
        self,
        root: str,
        load_image: Optional[bool] = True,
    ):
        self.root = root
        self.load_image = load_image

        self.paths = []
        self.paths.extend(glob.glob(os.path.join(root, "**/*.png")))
        self.paths.extend(glob.glob(os.path.join(root, "**/*.jpg")))
        self.paths = sorted(self.paths)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        item = dict()

        # Attach the path of the image.
        image_path = self.paths[idx]
        item["path"] = os.path.relpath(image_path, self.root)

        # Load the prompt.
        caption_path = os.path.join(os.path.dirname(image_path), "caption.txt")
        with open(caption_path) as f:
            prompt = f.read()
        item["prompt"] = prompt

        # Load the metadata if it exists.
        metadata_path = os.path.join(os.path.basename(image_path) + ".json")
        if os.path.exists(metadata_path):
            metadata = json.load(open(metadata_path))
            item['metadata'] = metadata

        # Load the image if load_image is enabled
        if self.load_image:
            image = Image.open(image_path)
            # For Aestheticv2, some images are corrupted. We will replace them
            # with a blank image to avoid exceptions.
            if image.size[0] < 10 or image.size[1] < 10:
                image = Image.new("RGB", (512, 512))
            item["image"] = image

        return item


class TrainingDataset(ExpertDataset):
    """
    Support random horizontal flip and random crop.
    """
    def __init__(
        self,
        root: str,
        resolution: Optional[int] = None,
        random_flip: Optional[bool] = False,
        random_crop: Optional[bool] = False,
    ):
        super().__init__(root)
        self.random_flip = random_flip
        self.random_crop = random_crop
        self.resolution = resolution
        self.transform = v2.Compose([
            v2.RGB(),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Resize(resolution),
            v2.RandomHorizontalFlip() if random_flip else v2.Lambda(lambda x: x),
            v2.Normalize([0.5], [0.5]),
        ])

    def __getitem__(self, idx):
        item = super().__getitem__(idx)
        image = item['image']
        original_h = image.height
        original_w = image.width
        image = self.transform(image)
        _, resized_h, resized_w = image.shape

        # Apply random crop or center crop.
        if self.random_crop is False:
            y1 = max(0, int(round((resized_h - self.resolution) / 2.0)))
            x1 = max(0, int(round((resized_w - self.resolution) / 2.0)))
            h, w = self.resolution, self.resolution
        elif self.random_crop is True:
            y1, x1, h, w = v2.RandomCrop.get_params(image, output_size=(self.resolution, self.resolution))
        else:
            y1, x1, h, w = 0, 0, resized_h, resized_w
        item['image'] = v2.functional.crop(image, y1, x1, h, w)
        # item["original_sizes"] = (original_h, original_w)
        # item["crop_top_lefts"] = (y1, x1)
        item["add_time_ids"] = torch.tensor([original_h, original_w, y1, x1, self.resolution, self.resolution])

        return item


class ScoreDataset(ExpertDataset):
    def __init__(
        self,
        root: str,
        transform: Callable[[Image.Image], tv_tensors.Image] = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Resize(512),
            v2.CenterCrop(512),
            v2.RGB(),
        ]),
    ):
        super().__init__(root)
        self.transform = transform

    def __getitem__(self, idx):
        item = super().__getitem__(idx)

        if self.transform is not None:
            item['image'] = self.transform(item['image'])
        assert item['image'].shape[0] == 3, f"Image {os.path.join(self.root, item['path'])} is not RGB."

        return item


class PromptDataset(ExpertDataset):
    def __init__(
        self,
        root: str,
    ):
        super().__init__(root, load_image=False)
