from torch.utils.data import Dataset
from torchvision import transforms
import json
import os
from PIL import Image
import random
from torchvision.transforms.functional import crop
import torch
class LoraDataset(Dataset):
    def __init__(
        self,
        root_dir, 
        file_path,
        size=1024,
        center_crop=False,
        random_flip=False,
        num_samples = -1,
        repeat: int = 1,
    ):
        self.size = size
        self.center_crop = center_crop
        self.root_dir = root_dir
        # file, caption
        self.files = []
        
        file = json.load(open(file_path, 'r'))
        for item in file:
            self.files.append((os.path.join(root_dir, item['file_name']), item['prompt']))

        self.files = self.files[:num_samples] if num_samples > 0 else self.files
        self.repeat = repeat

        self.center_crop = center_crop
        self.resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
        self.crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
        self.flip = transforms.RandomHorizontalFlip(p=1.0)
        self.random_flip = random_flip
        self.transforms = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return len(self.files) * self.repeat

    def __getitem__(self, index):
        index = index % len(self.files)
        file_name, prompt = self.files[index]
        image = Image.open(file_name).convert("RGB")
        # image aug
        original_size = [image.height, image.width]
        image = self.resize(image)
        if self.center_crop:
            y1 = max(0, int(round((image.height - self.size) / 2.0)))
            x1 = max(0, int(round((image.width - self.size) / 2.0)))
            image = self.crop(image)
        else:
            y1, x1, h, w = self.crop.get_params(image, (self.size, self.size))
            image = crop(image, y1, x1, h, w)
        if self.random_flip and random.random() < 0.5:
            # flip
            x1 = image.width - x1
            image = self.flip(image)
        crop_top_left = [y1, x1]
        image = self.transforms(image)
        return {
            "file_name": file_name,
            "original_size": torch.tensor(original_size),
            "crop_top_left": torch.tensor(crop_top_left),
            "prompt": prompt,
            "image": image,
        }
