from torch.utils.data import Dataset
import torch
from PIL import Image
from PIL.ImageOps import exif_transpose
from pathlib import Path
from torchvision import transforms
from datasets import load_dataset, IterableDatasetDict
import requests
from io import BytesIO
import json

def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
    if tokenizer_max_length is not None:
        max_length = tokenizer_max_length
    else:
        max_length = tokenizer.model_max_length

    text_inputs = tokenizer(
        prompt,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    )

    return text_inputs

def collate_fn(examples, with_prior_preservation=False):
    has_attention_mask = "instance_attention_mask" in examples[0]

    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]

    if has_attention_mask:
        attention_mask = [example["instance_attention_mask"] for example in examples]

    # Concat class and instance examples for prior preservation.
    # We do this to avoid doing two forward passes.
    if with_prior_preservation:
        input_ids += [example["class_prompt_ids"] for example in examples]
        pixel_values += [example["class_images"] for example in examples]

        if has_attention_mask:
            attention_mask += [example["class_attention_mask"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = torch.cat(input_ids, dim=0)

    batch = {
        "input_ids": input_ids,
        "pixel_values": pixel_values,
    }

    if has_attention_mask:
        attention_mask = torch.cat(attention_mask, dim=0)
        batch["attention_mask"] = attention_mask

    return batch

class BatDreamBoothDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        class_num=None,
        bat_data_root=None,
        bat_ratio=None,
        bat_data_root_size=1000,
        size=512,
        center_crop=True, #for consistency 
        encoder_hidden_states=None,
        class_prompt_encoder_hidden_states=None,
        tokenizer_max_length=None,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer
        self.encoder_hidden_states = encoder_hidden_states
        self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
        self.tokenizer_max_length = tokenizer_max_length

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt = instance_prompt
        self._length = self.num_instance_images
        self.good_bat = False

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(self.class_data_root.iterdir())
            if class_num is not None:
                self.num_class_images = min(len(self.class_images_path), class_num)
            else:
                self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        #add bat data
        if bat_data_root is not None and isinstance(bat_data_root, str) and not bat_data_root.endswith(".json"):
            self.bat_ratio = bat_ratio
            self.bat_data_root = Path(bat_data_root)
            self.bat_data_root.mkdir(parents=True, exist_ok=True)
            self.bat_images_pool = list(self.bat_data_root.iterdir())
            self.bat_images_pool = sorted(self.bat_images_pool, key=lambda x: x.name)
            if self.bat_ratio is not None:
                self.num_bat_images = round(self.bat_ratio * self.num_instance_images) * 2
            else:
                self.num_bat_images = 2
            self.bat_images_path = self.bat_images_pool[:self.num_bat_images]
            self._length = max(self.num_class_images, self.num_instance_images, self.num_bat_images)
        #using huggingface datasets
        elif bat_data_root is not None and isinstance(bat_data_root, IterableDatasetDict):
            self.bat_ratio = bat_ratio
            self.bat_data_root = bat_data_root
            self.bat_data_root_size = bat_data_root_size
            self.bat_images_pool = []
            for data in self.bat_data_root["train"].shuffle(seed=42):
                self.bat_images_pool.append(data)
                if len(self.bat_images_pool) == self.bat_data_root_size:
                    break
            if self.bat_ratio is not None:
                self.num_bat_images = round(self.bat_ratio * (self.num_instance_images))
            else:
                self.num_bat_images = 1
            self._length = max(self.num_class_images, self.num_instance_images, self.num_bat_images)
        #using json metadata
        elif bat_data_root is not None and isinstance(bat_data_root, str) and bat_data_root.endswith(".json"):
            self.bat_ratio = bat_ratio
            self.bat_data_root = bat_data_root
            self.bat_data_path = Path(self.bat_data_root)
            self.bat_data_dict = json.load(open(self.bat_data_path))
            self.bat_data_root_size = bat_data_root_size
            self.bat_images_pool = []
            limit = 0
            for data in self.bat_data_dict.keys():
                self.bat_images_pool.append(self.bat_data_dict[data]["filename"])
                self.bat_images_pool.append(self.bat_data_dict[data]["text"])
                limit += 1
                if limit == self.bat_data_root_size:
                    break
            if self.bat_ratio is not None:
                self.num_bat_images = round(self.bat_ratio * (self.num_instance_images)) * 2
            else:  
                self.num_bat_images = 2
            self.bat_images_path = self.bat_images_pool[:self.num_bat_images]
            self._length = max(self.num_class_images, self.num_instance_images, self.num_bat_images)
        else:
            self.bat_data_root = None

        #create image lists
        self.instance_images = [exif_transpose(Image.open(image_path)) for image_path in self.instance_images_path]
        if self.class_data_root:
            self.class_images = [exif_transpose(Image.open(image_path)) for image_path in self.class_images_path]
        #for normal datasets
        if self.bat_data_root and isinstance(self.bat_data_root, Path):
            self.bat_images = []
            self.bat_prompts = []
            for idx, image_path in enumerate(self.bat_images_path):
                if idx % 2 == 0:
                    self.bat_images.append(exif_transpose(Image.open(image_path)))
            for idx, prompt_path in enumerate(self.bat_images_path):
                if idx % 2 == 1:
                    self.bat_prompts.append(open(prompt_path).read().rstrip())
        #for huggingface datasets
        elif self.bat_data_root and isinstance(self.bat_data_root, IterableDatasetDict):
            self.bat_images = []
            self.bat_prompts = []
            for data in self.bat_images_pool:
                file = requests.get(data["URL"])
                file = BytesIO(file.content)
                try:
                    file = Image.open(file)
                except:
                    continue
                image = exif_transpose(file)
                self.bat_images.append(image)
                self.bat_prompts.append(data["TEXT"])
                if len(self.bat_images) == self.num_bat_images:
                    break
        #for json metadata
        elif self.bat_data_root and isinstance(self.bat_data_root, str) and self.bat_data_root.endswith(".json"):
            self.bat_images = []
            self.bat_prompts = []
            for idx, image_path in enumerate(self.bat_images_path):
                if idx % 2 == 0:
                    self.bat_images.append(exif_transpose(Image.open("/laion_dataset_img/"+image_path)))
            for idx, prompt in enumerate(self.bat_images_path):
                if idx % 2 == 1:
                    self.bat_prompts.append(prompt)


        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}

        #for normal datasets
        if self.bat_data_root is not None and isinstance(self.bat_data_root, Path):
            if index % (self.num_instance_images + (self.num_bat_images // 2)) < self.num_instance_images:
                instance_image = self.instance_images[index % self.num_instance_images]

                #store instance image
                if not instance_image.mode == "RGB":
                    instance_image = instance_image.convert("RGB")
                example["instance_images"] = self.image_transforms(instance_image)

                if self.encoder_hidden_states is not None:
                    example["instance_prompt_ids"] = self.encoder_hidden_states
                else:
                    #set the instance text embedding
                    text_inputs = tokenize_prompt(
                        self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
                    )
                    example["instance_prompt_ids"] = text_inputs.input_ids
                    example["instance_attention_mask"] = text_inputs.attention_mask
                    
                # print(self.instance_images_path[index % self.num_instance_images])

        if self.bat_data_root is not None and isinstance(self.bat_data_root, Path) and \
        index % (self.num_instance_images + (self.num_bat_images // 2)) >= self.num_instance_images:
            bat_image = self.bat_images[((index % (self.num_instance_images + (self.num_bat_images // 2))) - self.num_instance_images)]
            
            #store bat image
            #store instance image
            if not bat_image.mode == "RGB":
                bat_image = bat_image.convert("RGB")
            example["instance_images"] = self.image_transforms(bat_image)

            bat_prompt = self.bat_prompts[((index % (self.num_instance_images + (self.num_bat_images // 2))) - self.num_instance_images)]

            if self.encoder_hidden_states is not None:
                example["instance_prompt_ids"] = self.encoder_hidden_states
            else:
                #set the instance text embedding
                text_inputs = tokenize_prompt(
                    self.tokenizer, bat_prompt, tokenizer_max_length=self.tokenizer_max_length
                )
                example["instance_prompt_ids"] = text_inputs.input_ids
                example["instance_attention_mask"] = text_inputs.attention_mask
            bat_image_path = self.bat_images_path[((index % (self.num_instance_images + (self.num_bat_images // 2))) - self.num_instance_images) * 2]

            # print(bat_image_path)
            # print(bat_prompt)
        
        #for hugginface datasets
        if self.bat_data_root is not None and isinstance(self.bat_data_root, IterableDatasetDict):
            if index % (self.num_instance_images + self.num_bat_images) < self.num_instance_images:
                instance_image = self.instance_images[index % self.num_instance_images]

                #store instance image
                if not instance_image.mode == "RGB":
                    instance_image = instance_image.convert("RGB")
                example["instance_images"] = self.image_transforms(instance_image)

                if self.encoder_hidden_states is not None:
                    example["instance_prompt_ids"] = self.encoder_hidden_states
                else:
                    #set the instance text embedding
                    text_inputs = tokenize_prompt(
                        self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
                    )
                    example["instance_prompt_ids"] = text_inputs.input_ids
                    example["instance_attention_mask"] = text_inputs.attention_mask

                # print(index)
                # print(self.instance_images_path[index % self.num_instance_images])
                # sleep(1)
            

        if self.bat_data_root is not None and isinstance(self.bat_data_root, IterableDatasetDict) and index % (self.num_instance_images + self.num_bat_images) >= self.num_instance_images:
            bat_image = self.bat_images[((index % (self.num_instance_images + (self.num_bat_images))) - self.num_instance_images)]
            #store bat image
            #store instance image
            if not bat_image.mode == "RGB":
                bat_image = bat_image.convert("RGB")
            example["instance_images"] = self.image_transforms(bat_image)

            bat_prompt = self.bat_prompts[((index % (self.num_instance_images + (self.num_bat_images))) - self.num_instance_images)]

            if self.encoder_hidden_states is not None:
                example["instance_prompt_ids"] = self.encoder_hidden_states
            else:
                #set the instance text embedding
                text_inputs = tokenize_prompt(
                    self.tokenizer, bat_prompt, tokenizer_max_length=self.tokenizer_max_length
                )
                example["instance_prompt_ids"] = text_inputs.input_ids
                example["instance_attention_mask"] = text_inputs.attention_mask

            # print(index)
            # print(bat_image)
            # print(bat_prompt)
            # sleep(1)

        #for json metadata
        if self.bat_data_root is not None and isinstance(self.bat_data_root, str) and self.bat_data_root.endswith(".json"):
            if index % (self.num_instance_images + (self.num_bat_images // 2)) < self.num_instance_images:
                instance_image = self.instance_images[index % self.num_instance_images]

                #store instance image
                if not instance_image.mode == "RGB":
                    instance_image = instance_image.convert("RGB")
                example["instance_images"] = self.image_transforms(instance_image)

                if self.encoder_hidden_states is not None:
                    example["instance_prompt_ids"] = self.encoder_hidden_states
                else:
                    #set the instance text embedding
                    text_inputs = tokenize_prompt(
                        self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
                    )
                    example["instance_prompt_ids"] = text_inputs.input_ids
                    example["instance_attention_mask"] = text_inputs.attention_mask

                # print(index)
                # print(self.instance_images_path[index % self.num_instance_images])


        if self.bat_data_root is not None and isinstance(self.bat_data_root, str) and self.bat_data_root.endswith(".json") and  \
        index % (self.num_instance_images + (self.num_bat_images // 2)) >= self.num_instance_images:
            bat_image = self.bat_images[((index % (self.num_instance_images + (self.num_bat_images // 2))) - self.num_instance_images)]
            
            #store bat image
            #store instance image
            if not bat_image.mode == "RGB":
                bat_image = bat_image.convert("RGB")
            example["instance_images"] = self.image_transforms(bat_image)

            bat_prompt = self.bat_prompts[((index % (self.num_instance_images + (self.num_bat_images // 2))) - self.num_instance_images)]

            if self.encoder_hidden_states is not None:
                example["instance_prompt_ids"] = self.encoder_hidden_states
            else:
                #set the instance text embedding
                text_inputs = tokenize_prompt(
                    self.tokenizer, bat_prompt, tokenizer_max_length=self.tokenizer_max_length
                )
                example["instance_prompt_ids"] = text_inputs.input_ids
                example["instance_attention_mask"] = text_inputs.attention_mask
        
            bat_image_path = self.bat_images_path[((index % (self.num_instance_images + (self.num_bat_images // 2))) - self.num_instance_images) * 2]

            # print(bat_image_path)
            # print(bat_prompt)

        #bat or not, they share the same class images
        if self.class_data_root:
            class_image = self.class_images[index % self.num_class_images]

            #store class image
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)

            if self.class_prompt_encoder_hidden_states is not None:
                example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
            else:
                #set the class text embedding
                class_text_inputs = tokenize_prompt(
                    self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
                )
                example["class_prompt_ids"] = class_text_inputs.input_ids
                example["class_attention_mask"] = class_text_inputs.attention_mask

            return example
            
        return example

    def give_new_bat(self, was_bad: bool) -> str or list:
        if was_bad:
            self.good_bat = False
            #for normal datasets
            if isinstance(self.bat_data_root, Path):
                try:
                    self.bat_images_pool = self.bat_images_pool[self.num_bat_images:]
                    self.bat_images_path = self.bat_images_pool[:self.num_bat_images]
                    self.bat_images = []
                    self.bat_prompts = []
                    for idx, image_path in enumerate(self.bat_images_path):
                        if idx % 2 == 0:
                            self.bat_images.append(exif_transpose(Image.open(image_path)))
                    for idx, prompt_path in enumerate(self.bat_images_path):
                        if idx % 2 == 1:
                            self.bat_prompts.append(open(prompt_path).read().rstrip())
                    
                    return self.bat_prompts
                except:
                    return "No more bat images"
            #for huggingface datasets
            elif isinstance(self.bat_data_root, IterableDatasetDict):
                try:
                    self.bat_images_pool = self.bat_images_pool[self.num_bat_images:]
                    self.bat_images = []
                    self.bat_prompts = []
                    for data in self.bat_images_pool:
                        file = requests.get(data["URL"])
                        file = BytesIO(file.content)
                        try:
                            file = Image.open(file)
                        except:
                            continue
                        image = exif_transpose(file)
                        self.bat_images.append(image)
                        self.bat_prompts.append(data["TEXT"])
                        if len(self.bat_images) == self.num_bat_images:
                            break
                        return self.bat_prompts
                except:
                    return "No more bat images"
            #for json metadata
            elif isinstance(self.bat_data_root, str) and self.bat_data_root.endswith(".json"):
                try:
                    self.bat_images_pool = self.bat_images_pool[self.num_bat_images:]
                    self.bat_images_path = self.bat_images_pool[:self.num_bat_images]
                    self.bat_images = []
                    self.bat_prompts = []
                    for idx, image_path in enumerate(self.bat_images_path):
                        if idx % 2 == 0:
                            self.bat_images.append(exif_transpose(Image.open("/laion_dataset_img/" + image_path)))
                    for idx, prompt in enumerate(self.bat_images_path):
                        if idx % 2 == 1:
                            self.bat_prompts.append(prompt)
                    return self.bat_prompts
                except:
                    self.bat_images_pool = self.bat_images_pool[self.num_bat_images:]
                    self.bat_images_path = self.bat_images_pool[:self.num_bat_images]
                    self.bat_images = []
                    self.bat_prompts = []
                    for idx, image_path in enumerate(self.bat_images_path):
                        if idx % 2 == 0:
                            self.bat_images.append(exif_transpose(Image.open("/laion_dataset_img/" + image_path)))
                    for idx, prompt in enumerate(self.bat_images_path):
                        if idx % 2 == 1:
                            self.bat_prompts.append(prompt)
                    return self.bat_prompts
        elif was_bad == False:
            self.good_bat = True