
import random

import torch.utils.data as data
from PIL import Image
import os
import torch
# from tqdm import tqdm
class ImageSet(data.Dataset):
    def __init__(self, folder , transform=None, keep_in_mem=True, caption=None):
        self.path = folder
        self.transform = transform
        self.caption_path = None
        self.images = []
        self.captions = []
        self.keep_in_mem = keep_in_mem

        if not isinstance(folder, list):
            self.image_files = [file for file in os.listdir(folder) if file.endswith((".png",".jpg"))]
            self.image_files.sort()
        else:
            self.images = folder

        if not isinstance(caption, list):
            if caption not in [None, "", "None"]:
                self.caption_path = caption
                self.caption_files = [os.path.join(caption, file.replace(".png", ".txt").replace(".jpg", ".txt")) for file in self.image_files]
                self.caption_files.sort()
        else:
            self.caption_path = True
            self.captions = caption
        # get all the image files png/jpg


        if keep_in_mem:
            if len(self.images) == 0:
                for file in self.image_files:
                    img = self.load_image(os.path.join(self.path, file))
                    self.images.append(img)
            if len(self.captions) == 0:
                if self.caption_path is not None:
                    self.captions = []
                    for file in self.caption_files:
                        caption = self.load_caption(file)
                        self.captions.append(caption)
        else:
            self.images = None

    def limit_num(self, n):
        raise NotImplementedError
        assert n <= len(self), f"n should be less than the length of the dataset {len(self)}"
        self.image_files = self.image_files[:n]
        self.caption_files = self.caption_files[:n]
        if self.keep_in_mem:
            self.images = self.images[:n]
            self.captions = self.captions[:n]
        print(f"Dataset limited to {n}")

    def __len__(self):
        if len(self.images) != 0:
            return len(self.images)
        else:
            return len(self.image_files)

    def load_image(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f).convert('RGB')
        return img

    def load_caption(self, path):
        with open(path, 'r') as f:
            caption = f.readlines()
        caption = [line.strip() for line in caption if len(line.strip()) > 0]
        return caption

    def __getitem__(self, index):
        if len(self.images) != 0:
            img = self.images[index]
        else:
            img = self.load_image(os.path.join(self.path, self.image_files[index]))

        # if self.transform is not None:
        #     img = self.transform(img)

        if self.caption_path is not None or len(self.captions) != 0:
            if len(self.captions) != 0:
                caption = self.captions[index]
            else:
                caption = self.load_caption(self.caption_files[index])
            ret= {"image": img, "caption": caption, "id": index}
        else:
            ret= {"image": img, "id": index}
        if self.transform is not None:
            ret = self.transform(ret)
        return ret

    def subsample(self, n: int = 10):
        if n is None or n == -1:
            return self
        ori_len = len(self)
        assert n <= ori_len
        # equal interval subsample
        ids = self.image_files[::ori_len // n][:n]
        self.image_files = ids
        if self.keep_in_mem:
            self.images = self.images[::ori_len // n][:n]
        print(f"Dataset subsampled from {ori_len} to {len(self)}")
        return self

    def with_transform(self, transform):
        self.transform = transform
        return self
    @staticmethod
    def collate_fn(examples):
        images = [example["image"] for example in examples]
        ids = [example["id"] for example in examples]
        if "caption" in examples[0]:
            captions = [random.choice(example["caption"]) for example in examples]
            return {"images": images, "captions": captions, "id": ids}
        else:
            return {"images": images, "id": ids}


class ImagePair(ImageSet):
    def __init__(self, folder1, folder2, transform=None, keep_in_mem=True):
        self.path1 = folder1
        self.path2 = folder2
        self.transform = transform
        # get all the image files png/jpg
        self.image_files = [file for file in os.listdir(folder1) if file.endswith(".png") or file.endswith(".jpg")]
        self.image_files.sort()
        self.keep_in_mem = keep_in_mem
        if keep_in_mem:
            self.images = []
            for file in self.image_files:
                img1 = self.load_image(os.path.join(self.path1, file))
                img2 = self.load_image(os.path.join(self.path2, file))
                self.images.append((img1, img2))
        else:
            self.images = None

    def __getitem__(self, index):
        if self.keep_in_mem:
            img1, img2 = self.images[index]
        else:
            img1 = self.load_image(os.path.join(self.path1, self.image_files[index]))
            img2 = self.load_image(os.path.join(self.path2, self.image_files[index]))

        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return {"image1": img1, "image2": img2, "id": index}



    @staticmethod
    def collate_fn(examples):
        images1 = [example["image1"] for example in examples]
        images2 = [example["image2"] for example in examples]
        # images1 = torch.stack(images1)
        # images2 = torch.stack(images2)
        ids = [example["id"] for example in examples]
        return {"image1": images1, "image2": images2, "id": ids}

    def push_to_huggingface(self, hug_folder):
        from datasets import Dataset
        from datasets import Image as HugImage
        photo_path = [os.path.join(self.path1, file) for file in self.image_files]
        sketch_path = [os.path.join(self.path2, file) for file in self.image_files]
        dataset = Dataset.from_dict({"photo": photo_path, "sketch": sketch_path, "file_name": self.image_files})
        dataset = dataset.cast_column("photo", HugImage())
        dataset = dataset.cast_column("sketch", HugImage())
        dataset.push_to_hub(hug_folder, private=True)

class ImageClass(ImageSet):
    def __init__(self, folders: list, transform=None, keep_in_mem=True):
        self.paths = folders
        self.transform = transform
        # get all the image files png/jpg
        self.image_files = []
        self.keep_in_mem = keep_in_mem
        for i, folder in enumerate(folders):
            self.image_files+=[(os.path.join(folder, file), i) for file in os.listdir(folder) if file.endswith(".png") or file.endswith(".jpg")]
        if keep_in_mem:
            self.images = []
            print("Loading images to memory")
            for file in self.image_files:
                img = self.load_image(file[0])
                self.images.append((img, file[1]))
            print("Loading images to memory done")
        else:
            self.images = None

    def __getitem__(self, index):
        if self.keep_in_mem:
            img, label = self.images[index]
        else:
            img_path, label = self.image_files[index]
            img = self.load_image(img_path)

        if self.transform is not None:
            img = self.transform(img)
        return {"image": img, "label": label, "id": index}

    @staticmethod
    def collate_fn(examples):
        images = [example["image"] for example in examples]
        labels = [example["label"] for example in examples]
        ids = [example["id"] for example in examples]
        return {"images": images, "labels":labels, "id": ids}
