from torch.utils.data import Dataset
from PIL import Image
import os
import json
import random
import torch
import numpy as np
from einops import rearrange
from xtuner.registry import BUILDER
from src.datasets.utils import crop2square
from glob import glob

from datasets import load_from_disk, interleave_datasets, concatenate_datasets
from itertools import permutations

class Text2ImageDatasetCustom(Dataset):
    def __init__(self,
                 data_path,
                 aux_data_path,
                 image_size,
                 unconditional=0.1,
                 tokenizer=None,
                 prompt_template=None,
                 max_length=1024,
                 crop_image=True,
                 shuffle_permutations=False,
                 num_permutations=1,
                 aux_data_ratio=1.0
                 ):
        super().__init__()
        self.data_path = data_path
        self.aux_data_path = aux_data_path

        self.unconditional = unconditional

        self.image_size = image_size

        self.tokenizer = BUILDER.build(tokenizer)
        self.prompt_template = prompt_template
        self.max_length = max_length
        self.crop_image = crop_image

        def map_1(batch):
            permutations_list = list(permutations((batch["synthetic_color"][0], batch["synthetic_pattern"][0], batch["synthetic_position"][0], batch["synthetic_shape"][0])))
            
            prompts = [" ".join(permutation) for permutation in permutations_list]
            
            if shuffle_permutations:
                random.shuffle(prompts)
            
            prompts = prompts[:num_permutations]

            return {
                "prompt": prompts,
                "image": batch["image"] * len(prompts),
                "color": batch["color"] * len(prompts),
                "pattern": batch["pattern"] * len(prompts),
                "position": batch["position"] * len(prompts),
                "shape": batch["shape"] * len(prompts),
            }
        
        def map_2(data_sample):
            pixel_values = self._process_image(data_sample["image"])["pixel_values"]
            input_ids = self._process_text(data_sample["prompt"])["input_ids"]
            return {
                "pixel_values": pixel_values,
                "input_ids": input_ids,   
                "type": "text2image",             
            }

        self.data_list = load_from_disk(data_path)
        self.data_list = self.data_list.map(map_1, batched=True, batch_size=1, remove_columns=["synthetic_color", "synthetic_pattern", "synthetic_position", "synthetic_shape"])

        self.aux_data_list = load_from_disk(aux_data_path)
        self.aux_data_list = self.aux_data_list.remove_columns([col for col in self.aux_data_list.column_names if col not in ["image", "prompt"]])

        if len(self.data_list) < len(self.aux_data_list):
            self.aux_data_list = self.aux_data_list.select(range(int(len(self.data_list) * aux_data_ratio)))

        if aux_data_ratio == 1:
            self.data_list = self.data_list.shuffle(seed=123456789)
            self.aux_data_list = self.aux_data_list.shuffle(seed=123456789)
            self.data_list = interleave_datasets([self.data_list, self.aux_data_list])
        else:
            self.data_list = concatenate_datasets([self.data_list, self.aux_data_list])
            self.data_list = self.data_list.shuffle(seed=123456789)
        
        self.data_list = self.data_list.map(map_2, batched=False, keep_in_memory=True, load_from_cache_file=False)
        self.data_list.set_format(type='torch', columns=["pixel_values", "input_ids"], output_all_columns=True)

    def __len__(self):
        return len(self.data_list)

    def _read_image(self, image_file):
        image = Image.open(os.path.join(self.local_folder, image_file))
        assert image.width > 8 and image.height > 8, f"Image: {image.size}"
        assert image.width / image.height > 0.1, f"Image: {image.size}"
        assert image.width / image.height < 10, f"Image: {image.size}"
        return image

    def _process_text(self, text):
        if random.uniform(0, 1) < self.unconditional:
            prompt = "Generate an image."
        else:
            prompt = f"Generate an image: {text.strip()}"
        prompt = self.prompt_template['INSTRUCTION'].format(input=prompt)
        input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt')[0]

        return dict(input_ids=input_ids[:self.max_length])

    def _process_image(self, image):
        data = dict()

        if self.crop_image:
            image = crop2square(image)
        else:
            target_size = max(image.size)
            image = image.resize(size=(target_size, target_size))

        image = image.resize(size=(self.image_size, self.image_size))
        pixel_values = torch.from_numpy(np.array(image)).float()
        pixel_values = pixel_values / 255
        pixel_values = 2 * pixel_values - 1
        pixel_values = rearrange(pixel_values, 'h w c -> c h w')

        data.update(pixel_values=pixel_values)

        return data

    def _retry(self):
        return self.__getitem__(random.choice(range(self.__len__())))

    def __getitem__(self, idx):
        return self.data_list[idx]