import re
import torch
import random
from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import AutoTokenizer
import openai


openai.api_key = ""
openai.api_base = "https://api.openai-proxy.com/v1"


class StoryDataset(Dataset):
    def __init__(self, root, dataset_name, tokenizer):
        self.root = root
        self.dataset_name = dataset_name
        self.tokenizer = tokenizer
        self.image_dir = os.path.join(self.root, 'image_inpainted_finally_checked')
        self.mask_dir = os.path.join(self.root, 'mask')
        self.text_dir = os.path.join(self.root, 'text_caption')
        folders = sorted(os.listdir(self.image_dir)) # 00001
        self.image_folders = [os.path.join(self.image_dir, folder) for folder in folders]
        self.mask_folders = [os.path.join(self.mask_dir, folder) for folder in folders]
        self.text_folders = [os.path.join(self.text_dir, folder) for folder in folders]

        self.image_list = []
        self.mask_list = []
        self.text_list = []
        for video in self.image_folders: # video: image_folder, /dataset/image/00001
            images = sorted(os.listdir(video))
            for image in images:
                self.image_list.append(os.path.join(video, image))
        for video in self.mask_folders: # video: mask_folder, /dataset/mask/00001
            masks = sorted(os.listdir(video))
            for mask in masks:
                self.mask_list.append(os.path.join(video, mask))

        for video in self.text_folders: # video: image_folder, /dataset/image/00001
            texts = sorted(os.listdir(video))
            for text in texts:
                self.text_list.append(os.path.join(video, text))
        
        cnt = int(len(self.image_list) * 0.9)
        if self.dataset_name == 'train':
            self.image_list = self.image_list[:cnt]
            self.mask_list = self.mask_list[:cnt]
            self.text_list = self.text_list[:cnt]
        elif self.dataset_name == 'test':
            self.image_list = self.image_list[cnt:]
            self.mask_list = self.mask_list[cnt:]
            self.text_list = self.text_list[cnt:]
        
        self.pdf_image_dir = os.path.join(self.root, 'StoryBook_finally_checked', 'image_inpainted')
        self.pdf_mask_dir = os.path.join(self.root, 'StoryBook_finally_checked', 'mask')
        self.pdf_text_dir = os.path.join(self.root, 'StoryBook_finally_checked', 'text_caption')
        pdf_folders = sorted(os.listdir(self.pdf_image_dir)) # 00001
        self.pdf_image_folders = [os.path.join(self.pdf_image_dir, folder) for folder in pdf_folders]
        self.pdf_mask_folders = [os.path.join(self.pdf_mask_dir, folder) for folder in pdf_folders]
        self.pdf_text_folders = [os.path.join(self.pdf_text_dir, folder) for folder in pdf_folders]
        self.pdf_image_list = []
        self.pdf_mask_list = []
        self.pdf_text_list = []
        
        fns = lambda s: sum(((s,int(n))for s, n in re.findall('(\D+)(\d+)','a%s0'%s)),()) 
        for video in self.pdf_image_folders: # video: image_folder, /dataset/image/00001
            # images = sorted(os.listdir(video))
            images = sorted(os.listdir(video), key=fns)
            for image in images:
                self.pdf_image_list.append(os.path.join(video, image))
        
        for video in self.pdf_mask_folders: # video: mask_folder, /dataset/mask/00001
            # masks = sorted(os.listdir(video))
            masks = sorted(os.listdir(video), key=fns)
            for mask in masks:
                self.pdf_mask_list.append(os.path.join(video, mask))

        for video in self.pdf_text_folders: # video: image_folder, /dataset/image/00001
            # texts = sorted(os.listdir(video))
            texts = sorted(os.listdir(video), key=fns)
            for text in texts:
                self.pdf_text_list.append(os.path.join(video, text))
                
        pdf_cnt = int(len(self.pdf_image_list) * 0.9)
        if self.dataset_name == 'train':
            self.pdf_image_list = self.pdf_image_list[:pdf_cnt]
            self.pdf_mask_list = self.pdf_mask_list[:pdf_cnt]
            self.pdf_text_list = self.pdf_text_list[:pdf_cnt]
        elif self.dataset_name == 'test':
            self.pdf_image_list = self.pdf_image_list[pdf_cnt:]
            self.pdf_mask_list = self.pdf_mask_list[pdf_cnt:]
            self.pdf_text_list = self.pdf_text_list[pdf_cnt:]
        
        self.image_list = self.image_list + self.pdf_image_list
        self.mask_list = self.mask_list + self.pdf_mask_list
        self.text_list = self.text_list + self.pdf_text_list
        
        
    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        image = self.image_list[index]
        mask = self.mask_list[index]
        image_path = image
        mask_path = mask
        text = self.text_list[index]
        image =  Image.open(image).convert('RGB')
        mask = Image.open(mask).convert('RGB')
        
        image = image.resize((512, 512))
        mask = mask.resize((512, 512))

        image = transforms.ToTensor()(image)
        mask = transforms.ToTensor()(mask)
        image = torch.from_numpy(np.ascontiguousarray(image)).float()
        mask = torch.from_numpy(np.ascontiguousarray(mask)).float()        
        # image_mean = torch.mean(image)
        # image = image_mean * mask + image * (1 - mask)

        with open(text, "r") as f:
            prompt = f.read()

        # normalize
        image = image * 2. - 1.

        prompt_ids = self.tokenizer(prompt, truncation=True, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids
        prompt_ids = prompt_ids.squeeze(0)
        return {"image_path": image_path, "image": image, "mask_path": mask_path, "mask": mask, "prompt_ids": prompt_ids, "prompt": prompt}


class StorySeqDataset(Dataset):
    def __init__(self, root, dataset_name, tokenizer):
        self.root = root
        self.dataset_name = dataset_name
        self.tokenizer = tokenizer
        self.image_dir = os.path.join(self.root, 'image_inpainted_finally_checked')
        self.mask_dir = os.path.join(self.root, 'mask')
        self.text_dir = os.path.join(self.root, 'text_caption')
        folders = sorted(os.listdir(self.image_dir)) # 00001
        self.image_folders = [os.path.join(self.image_dir, folder) for folder in folders]
        self.mask_folders = [os.path.join(self.mask_dir, folder) for folder in folders]
        self.text_folders = [os.path.join(self.text_dir, folder) for folder in folders]

        self.image_list = []
        self.mask_list = []
        self.text_list = []
        
        fns = lambda s: sum(((s,int(n))for s, n in re.findall('(\D+)(\d+)','a%s0'%s)),()) 
        
        for video in self.image_folders: # video: image_folder, /dataset/image/00001
            # images = sorted(os.listdir(video))
            images = sorted(os.listdir(video), key=fns)
            if len(images) <= 1:
                print(video)
                continue
            else:    
                for i in range(len(images) - 1):
                    self.image_list.append([os.path.join(video, images[i]), os.path.join(video, images[i+1])])

        for video in self.mask_folders: # video: mask_folder, /dataset/mask/00001
            # masks = sorted(os.listdir(video))
            masks = sorted(os.listdir(video), key=fns)
            if len(masks) <= 1:
                continue
            else:
                for i in range(len(masks) - 1):
                    self.mask_list.append([os.path.join(video, masks[i]), os.path.join(video, masks[i+1])])

        for video in self.text_folders: # video: image_folder, /dataset/image/00001
            # texts = sorted(os.listdir(video))
            texts = sorted(os.listdir(video), key=fns)
            if len(texts) <= 1:
                continue
            else:
                for i in range(len(texts) - 1):
                    self.text_list.append([os.path.join(video, texts[i]), os.path.join(video, texts[i+1])])

        cnt = int(len(self.image_list) * 0.9)
        if self.dataset_name == 'train':
            self.image_list = self.image_list[:cnt]
            self.mask_list = self.mask_list[:cnt]
            self.text_list = self.text_list[:cnt]
        elif self.dataset_name == 'test':
            self.image_list = self.image_list[cnt:]
            self.mask_list = self.mask_list[cnt:]
            self.text_list = self.text_list[cnt:]
        
        # self.image_list = self.image_list[:500]
        # self.mask_list = self.mask_list[:500]
        # self.text_list = self.text_list[:500]
        
        
        self.pdf_image_dir = os.path.join(self.root, 'StoryBook_finally_checked', 'image_inpainted')
        self.pdf_mask_dir = os.path.join(self.root, 'StoryBook_finally_checked', 'mask')
        self.pdf_text_dir = os.path.join(self.root, 'StoryBook_finally_checked', 'text_caption')
        pdf_folders = sorted(os.listdir(self.pdf_image_dir)) # 00001
        self.pdf_image_folders = [os.path.join(self.pdf_image_dir, folder) for folder in pdf_folders]
        self.pdf_mask_folders = [os.path.join(self.pdf_mask_dir, folder) for folder in pdf_folders]
        self.pdf_text_folders = [os.path.join(self.pdf_text_dir, folder) for folder in pdf_folders]
        self.pdf_image_list = []
        self.pdf_mask_list = []
        self.pdf_text_list = []
        
        
        for video in self.pdf_image_folders: # video: image_folder, /dataset/image/00001
            # images = sorted(os.listdir(video))
            images = sorted(os.listdir(video), key=fns)
            if len(images) <= 1:
                print(video)
                continue
            else:    
                for i in range(len(images) - 1):
                    self.pdf_image_list.append([os.path.join(video, images[i]), os.path.join(video, images[i+1])])
        
        for video in self.pdf_mask_folders: # video: mask_folder, /dataset/mask/00001
            # masks = sorted(os.listdir(video))
            masks = sorted(os.listdir(video), key=fns)
            if len(masks) <= 1:
                continue
            else:
                for i in range(len(masks) - 1):
                    self.pdf_mask_list.append([os.path.join(video, masks[i]), os.path.join(video, masks[i+1])])

        for video in self.pdf_text_folders: # video: image_folder, /dataset/image/00001
            # texts = sorted(os.listdir(video))
            texts = sorted(os.listdir(video), key=fns)
            if len(texts) <= 1:
                continue
            else:
                for i in range(len(texts) - 1):
                    self.pdf_text_list.append([os.path.join(video, texts[i]), os.path.join(video, texts[i+1])])
        pdf_cnt = int(len(self.pdf_image_list) * 0.9)
        
        if self.dataset_name == 'train':
            self.pdf_image_list = self.pdf_image_list[:pdf_cnt]
            self.pdf_mask_list = self.pdf_mask_list[:pdf_cnt]
            self.pdf_text_list = self.pdf_text_list[:pdf_cnt]
        elif self.dataset_name == 'test':
            self.pdf_image_list = self.pdf_image_list[pdf_cnt:]
            self.pdf_mask_list = self.pdf_mask_list[pdf_cnt:]
            self.pdf_text_list = self.pdf_text_list[pdf_cnt:]
        
        # self.pdf_image_list = self.pdf_image_list[:500]
        # self.pdf_mask_list = self.pdf_mask_list[:500]
        # self.pdf_text_list = self.pdf_text_list[:500]
        # print(len(self.pdf_image_list))
        # print(len(self.mask_list))
        # print(len(self.text_list))

        self.hf_image_dir = os.path.join(self.root, 'Samples', 'image')
        self.hf_text_dir = os.path.join(self.root, 'Samples', 'text')
        hf_folders = sorted(os.listdir(self.hf_image_dir)) # 00001
        self.hf_image_folders = [os.path.join(self.hf_image_dir, folder) for folder in hf_folders]
        self.hf_text_folders = [os.path.join(self.hf_text_dir, folder) for folder in hf_folders]
        self.hf_image_folders = self.hf_image_folders[:100]
        self.hf_text_folders = self.hf_text_folders[:100]
        self.hf_image_list = []
        self.hf_mask_list = []
        self.hf_text_list = []
        
        for video in self.hf_image_folders: # video: image_folder, /dataset/image/00001
            images = sorted(os.listdir(video), key=fns)
            if len(images) <= 1:
                print(video)
                continue
            else:    
                for i in range(len(images) - 1):
                    self.hf_image_list.append([os.path.join(video, images[i]), os.path.join(video, images[i+1])])
        
        for video in self.hf_image_folders: # video: image_folder, /dataset/image/00001
            images = sorted(os.listdir(video), key=fns)
            if len(images) <= 1:
                continue
            else:    
                for i in range(len(images) - 1):
                    self.hf_mask_list.append(["./Data/mask/000000/000000_mask_0-3-57-36.jpg", "./Data/mask/000000/000000_mask_0-3-57-36.jpg"])

        for video in self.hf_text_folders: # video: image_folder, /dataset/image/00001
            texts = sorted(os.listdir(video), key=fns)
            if len(texts) <= 1:
                continue
            else:
                for i in range(len(texts) - 1):
                    self.hf_text_list.append([os.path.join(video, texts[i]), os.path.join(video, texts[i+1])])
        
        self.hf_dir = os.path.join(self.root, 'NewSamples')
        self.hf_folders = sorted(os.listdir(self.hf_dir), key=fns)
        for video in self.hf_folders:
            self.hf_image_list.append([os.path.join(self.hf_dir, video, "prev.png"),os.path.join(self.hf_dir, video, "current.png")])
            self.hf_mask_list.append(["./Data/mask/000000/000000_mask_0-3-57-36.jpg", "./Data/mask/000000/000000_mask_0-3-57-36.jpg"])
            self.hf_text_list.append([os.path.join(self.hf_dir, video, "prev.txt"), os.path.join(self.hf_dir, video, "current.txt")])     
        
        hf_cnt =len(self.hf_image_list)-1
        
        
        if self.dataset_name == 'train':
            self.hf_image_list = self.hf_image_list[:hf_cnt]
            self.hf_mask_list = self.hf_mask_list[:hf_cnt]
            self.hf_text_list = self.hf_text_list[:hf_cnt]
        elif self.dataset_name == 'test':
            self.hf_image_list = self.hf_image_list[hf_cnt:]
            self.hf_mask_list = self.hf_mask_list[hf_cnt:]
            self.hf_text_list = self.hf_text_list[hf_cnt:]
        
        self.image_list = self.image_list + self.pdf_image_list + self.hf_image_list
        self.mask_list = self.mask_list + self.pdf_mask_list + self.hf_mask_list
        self.text_list = self.text_list + self.pdf_text_list + self.hf_text_list
        

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        
        ref_image = self.image_list[index][0]
        image = self.image_list[index][1]
        mask = self.mask_list[index][1]
        
        # ref_image_path = ref_image
        # image_path = image
        # mask_path = mask
        
        ref_text = self.text_list[index][0]
        text = self.text_list[index][1]
        
        ref_image = Image.open(ref_image).convert('RGB')
        image =  Image.open(image).convert('RGB')
        mask = Image.open(mask).convert('RGB')
        
        # ref_image = ref_image.resize((512, 512))
        ref_image = ref_image.resize((224, 224))
        image = image.resize((512, 512))
        mask = mask.resize((512, 512))

        ref_image = transforms.ToTensor()(ref_image)
        image = transforms.ToTensor()(image)
        mask = transforms.ToTensor()(mask)
        ref_image = torch.from_numpy(np.ascontiguousarray(ref_image)).float()
        image = torch.from_numpy(np.ascontiguousarray(image)).float()
        mask = torch.from_numpy(np.ascontiguousarray(mask)).float()
        
        # image_mean = torch.mean(image)
        # image = image_mean * mask + image * (1 - mask)

        with open(ref_text, "r") as f:
            ref_prompt = f.read()
        with open(text, "r") as f:
            prompt = f.read()

        # p = random.uniform(0, 1)
        # if p < 0.5:
        #     ref_prompt = ""

        
        p = random.uniform(0, 1)
        if p < 0.3:
            # split prompt into list of words
            words = prompt.split()  
            num_words = len(words)
            p_drop = random.uniform(0, 1)
            p_drop = p_drop * 0.1 + 0.2
            num_words_to_remove = min(int(num_words * p_drop), max(num_words - 1,0))

            # randomly choose which words to remove
            words_to_remove = set(random.sample(range(num_words), num_words_to_remove))

            # create new sentence without the chosen words
            prompt = " ".join([word for i, word in enumerate(words) if i not in words_to_remove])

        # # split sentence into list of words
        # words = ref_prompt.split()  
        # num_words = len(words)
        # p_drop = random.uniform(0, 1)
        # p_drop = p_drop * 0.4 + 0.5
        # num_words_to_remove = min(int(num_words * p_drop), max(num_words - 1,0))

        # # randomly choose which words to remove
        # words_to_remove = set(random.sample(range(num_words), num_words_to_remove))

        # # create new sentence without the chosen words
        # ref_prompt = " ".join([word for i, word in enumerate(words) if i not in words_to_remove])
            
        # elif 0.2 <= p < 0.6:
        #     try:
        #         response = openai.ChatCompletion.create(
        #             model="gpt-3.5-turbo",
        #             messages=[
        #                     {"role": "user", "content": "I want you to act as an English re-writer. I will speak to you in English and you will answer in another version of my text in English. Keep the meaning same, but make the words different. I want you to only reply the sentence and nothing else. Do not write any kind of explanations and notes. Do not write what you have done to the sentence. The sentence is: \" " + prompt +"\"."}
        #                 ]
        #         )
        #         result = ''
        #         for choice in response.choices:
        #             result += choice.message.content
        #         result = re.sub(r'\(Note:.*?\)', '', result)
        #         result = re.sub(r'Correction:.*?', '', result)
        #         result = re.sub(r'Corrected:.*?', '', result)
        #         result = re.sub(r'The following is the rewritten sentence: .*?', '', result)
        #         result = result.replace("\n","")
        #         result = result.replace("\r","")
        #         result = result.strip('\"')
                
        #         prompt = result
        #     except:
        #         print("OpenAI Error. Keep the prompt same.")
        
        # normalize
        ref_image = ref_image * 2. - 1.
        image = image * 2. - 1.

        ref_prompt_ids = self.tokenizer(ref_prompt, truncation=True, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids
        ref_prompt_ids = ref_prompt_ids.squeeze(0)
        prompt_ids = self.tokenizer(prompt, truncation=True, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids
        prompt_ids = prompt_ids.squeeze(0)

        return {"ref_image": ref_image, "image": image, "mask": mask, "ref_prompt_ids": ref_prompt_ids, "prompt_ids": prompt_ids, "ref_prompt": ref_prompt, "prompt": prompt}

class Story4FrameDataset(Dataset):
    def __init__(self, root, dataset_name, tokenizer):
        self.root = root
        self.dataset_name = dataset_name
        self.tokenizer = tokenizer
        self.image_dir = os.path.join(self.root, 'image_inpainted_finally_checked')
        self.mask_dir = os.path.join(self.root, 'mask')
        self.text_dir = os.path.join(self.root, 'text_caption')
        folders = sorted(os.listdir(self.image_dir)) # 00001
        self.image_folders = [os.path.join(self.image_dir, folder) for folder in folders]
        self.mask_folders = [os.path.join(self.mask_dir, folder) for folder in folders]
        self.text_folders = [os.path.join(self.text_dir, folder) for folder in folders]

        self.image_list = []
        self.mask_list = []
        self.text_list = []
        
        fns = lambda s: sum(((s,int(n))for s, n in re.findall('(\D+)(\d+)','a%s0'%s)),()) 
        
        for video in self.image_folders: # video: image_folder, /dataset/image/00001
            images = sorted(os.listdir(video), key=fns)
            if len(images) <= 3:
                print(video)
                continue
            else:    
                for i in range(len(images) - 3):
                    self.image_list.append([os.path.join(video, images[i]), os.path.join(video, images[i+1]), os.path.join(video, images[i+2]), os.path.join(video, images[i+3])])

        for video in self.mask_folders: # video: mask_folder, /dataset/mask/00001
            masks = sorted(os.listdir(video), key=fns)
            if len(masks) <= 3:
                continue
            else:
                for i in range(len(masks) - 3):
                    self.mask_list.append([os.path.join(video, masks[i]), os.path.join(video, masks[i+1]), os.path.join(video, masks[i+2]), os.path.join(video, masks[i+3])])

        for video in self.text_folders: # video: image_folder, /dataset/image/00001
            texts = sorted(os.listdir(video), key=fns)
            if len(texts) <= 3:
                continue
            else:
                for i in range(len(texts) - 3):
                    self.text_list.append([os.path.join(video, texts[i]), os.path.join(video, texts[i+1]), os.path.join(video, texts[i+2]), os.path.join(video, texts[i+3])])

        cnt = int(len(self.image_list) * 0.9)
        if self.dataset_name == 'train':
            self.image_list = self.image_list[:cnt]
            self.mask_list = self.mask_list[:cnt]
            self.text_list = self.text_list[:cnt]
        elif self.dataset_name == 'test':
            self.image_list = self.image_list[cnt:]
            self.mask_list = self.mask_list[cnt:]
            self.text_list = self.text_list[cnt:]
        
        
        self.pdf_image_dir = os.path.join(self.root, 'StoryBook_finally_checked', 'image_inpainted')
        self.pdf_mask_dir = os.path.join(self.root, 'StoryBook_finally_checked', 'mask')
        self.pdf_text_dir = os.path.join(self.root, 'StoryBook_finally_checked', 'text_caption')
        pdf_folders = sorted(os.listdir(self.pdf_image_dir)) # 00001
        self.pdf_image_folders = [os.path.join(self.pdf_image_dir, folder) for folder in pdf_folders]
        self.pdf_mask_folders = [os.path.join(self.pdf_mask_dir, folder) for folder in pdf_folders]
        self.pdf_text_folders = [os.path.join(self.pdf_text_dir, folder) for folder in pdf_folders]
        self.pdf_image_list = []
        self.pdf_mask_list = []
        self.pdf_text_list = []
        
        
        for video in self.pdf_image_folders: # video: image_folder, /dataset/image/00001
            images = sorted(os.listdir(video), key=fns)
            if len(images) <= 3:
                print(video)
                continue
            else:    
                for i in range(len(images) - 3):
                    self.pdf_image_list.append([os.path.join(video, images[i]), os.path.join(video, images[i+1]), os.path.join(video, images[i+2]), os.path.join(video, images[i+3])])
        
        for video in self.pdf_mask_folders: # video: mask_folder, /dataset/mask/00001
            masks = sorted(os.listdir(video), key=fns)
            if len(masks) <= 3:
                continue
            else:
                for i in range(len(masks) - 3):
                    self.pdf_mask_list.append([os.path.join(video, masks[i]), os.path.join(video, masks[i+1]), os.path.join(video, masks[i+2]), os.path.join(video, masks[i+3])])

        for video in self.pdf_text_folders: # video: image_folder, /dataset/image/00001
            texts = sorted(os.listdir(video), key=fns)
            if len(texts) <= 3:
                continue
            else:
                for i in range(len(texts) - 3):
                    self.pdf_text_list.append([os.path.join(video, texts[i]), os.path.join(video, texts[i+1]), os.path.join(video, texts[i+2]), os.path.join(video, texts[i+3])])
        pdf_cnt = int(len(self.pdf_image_list) * 0.9)
        
        if self.dataset_name == 'train':
            self.pdf_image_list = self.pdf_image_list[:pdf_cnt]
            self.pdf_mask_list = self.pdf_mask_list[:pdf_cnt]
            self.pdf_text_list = self.pdf_text_list[:pdf_cnt]
        elif self.dataset_name == 'test':
            self.pdf_image_list = self.pdf_image_list[pdf_cnt:]
            self.pdf_mask_list = self.pdf_mask_list[pdf_cnt:]
            self.pdf_text_list = self.pdf_text_list[pdf_cnt:]
        
        self.image_list = self.image_list + self.pdf_image_list
        self.mask_list = self.mask_list + self.pdf_mask_list
        self.text_list = self.text_list + self.pdf_text_list
        

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        
        ref_image_ids = self.image_list[index][0:3]
        image = self.image_list[index][3]
        mask = self.mask_list[index][3]
        
        ref_texts = self.text_list[index][0:3]
        text = self.text_list[index][3]
        
        ref_images_0 = []
        for id in ref_image_ids:
            ref_images_0.append(Image.open(id).convert('RGB'))
        image =  Image.open(image).convert('RGB')
        mask = Image.open(mask).convert('RGB')
        
        ref_images_1 = []
        for ref_image in ref_images_0:
            ref_images_1.append(ref_image.resize((224, 224)))
        image = image.resize((512, 512))
        mask = mask.resize((512, 512))
        
        ref_images_2 = []
        for ref_image in ref_images_1:
            ref_images_2.append(np.ascontiguousarray(transforms.ToTensor()(ref_image))) 
        image = transforms.ToTensor()(image)
        mask = transforms.ToTensor()(mask)

        ref_images = torch.from_numpy(np.ascontiguousarray(ref_images_2)).float()
        image = torch.from_numpy(np.ascontiguousarray(image)).float()
        mask = torch.from_numpy(np.ascontiguousarray(mask)).float()
        
        # image_mean = torch.mean(image)
        # image = image_mean * mask + image * (1 - mask)

        ref_prompts = []
        for ref_text in ref_texts:
            with open(ref_text, "r") as f:
                ref_prompts.append(f.read())
        with open(text, "r") as f:
            prompt = f.read()
        
        p = random.uniform(0, 1)
        if p < 0.3:
            # split prompt into list of words
            words = prompt.split()  
            num_words = len(words)
            p_drop = random.uniform(0, 1)
            p_drop = p_drop * 0.1 + 0.2
            num_words_to_remove = min(int(num_words * p_drop), max(num_words - 1,0))

            # randomly choose which words to remove
            words_to_remove = set(random.sample(range(num_words), num_words_to_remove))

            # create new sentence without the chosen words
            prompt = " ".join([word for i, word in enumerate(words) if i not in words_to_remove])
        
        # normalize
        for ref_image in ref_images:
            ref_image = ref_image * 2. - 1.
        image = image * 2. - 1.

        ref_prompt_ids = []
        for ref_prompt in ref_prompts:
            ref_prompt_ids.append(self.tokenizer(ref_prompt, truncation=True, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.squeeze(0))
        ref_prompt_ids = torch.stack(ref_prompt_ids)
        prompt_ids = self.tokenizer(prompt, truncation=True, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids
        prompt_ids = prompt_ids.squeeze(0)

        return {"ref_image": ref_images, "image": image, "mask": mask, "ref_prompt_ids": ref_prompt_ids, "prompt_ids": prompt_ids, "ref_prompt": ref_prompts, "prompt": prompt}



if __name__ == '__main__':
    pretrained_model_path =  "./ckpt/stable-diffusion-v1-5/"
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer", use_fast=False)

    train_dataset = StorySeqDataset(root="./Dataset/", dataset_name='train', tokenizer=tokenizer)
    
    print(train_dataset.__len__())

    train_data = DataLoader(train_dataset, batch_size=1, num_workers=1, shuffle=True)
    # B C H W
    # for i, data in enumerate(train_data):
    #     print(i)
    #     print(data["prompt"])
    #     print(data["ref_prompt"])
    #     print(data["ref_prompt_ids"].shape)
    #     print(data["prompt_ids"].shape)
    #     print(data["ref_image"].shape)
    #     break
