import argparse
import contextlib
import gc
import logging
import math
import os
import random
import shutil
import ast

import accelerate
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
import json
import pickle
from transformers import AutoTokenizer, PretrainedConfig
import pandas as pd

class RealV40CustomDataset(torch.utils.data.Dataset):

    def __init__(self,
                tokenizer,
                image_root_path,
                condition_root,
                txt_json_root,
                resolution=512,
                drop_txt_prob = 0.02,
                drop_all_prob = 0.02,
                ):

        self.local_tasks = 'segmentation'
        self.tokenizer = tokenizer
        self.resolution = resolution

        self.transform = transforms.Compose([
            transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

        self.conditioning_image_transforms = transforms.Compose([
            transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.resolution),
            transforms.ToTensor(),
        ])

        self.img_path = image_root_path
        # self.img_name_list = os.listdir(self.img_path)

        self.txt_dict_global = json.load(open(txt_json_root['global'], 'r'))
        self.txt_dict_global = {os.path.basename(key):value for key, value in self.txt_dict_global.items()}
        self.txt_dict_shape = json.load(open(txt_json_root['shape'], 'r'))
        self.txt_dict_shape = {os.path.basename(key):value for key, value in self.txt_dict_shape.items()}

        self.condition_root = condition_root
        img_set = set(os.listdir(self.img_path)).intersection(set(os.listdir(self.condition_root)))
        img_set = img_set.intersection(set(self.txt_dict_global.keys()))
        img_set = img_set.intersection(set(self.txt_dict_shape.keys()))

        img_set = img_set - set(error_file)

        self.img_name_list = list(img_set)

        self.drop_txt_prob = drop_txt_prob
        self.drop_all_prob = drop_all_prob

    def __getitem__(self, index):
        if 1:
            img_name = self.img_name_list[index]
            
            # read image
            img = Image.open(os.path.join(self.img_path, img_name)).convert("RGB")
            img = self.transform(img)

            condition_image = Image.open(os.path.join(self.condition_root, img_name.replace('.jpg', '.png'))).convert("RGB")
            condition_image = self.conditioning_image_transforms(condition_image)
            
            text_global = self.txt_dict_global[img_name]
            drop_all_local_prob = 0.0
            # randomly drop
            rand_num = random.random()

            text_shape = self.txt_dict_shape[img_name]
            if rand_num < self.drop_all_prob:
                text_shape = ''
                drop_all_local_prob = 1.0
            elif rand_num < self.drop_all_prob + self.drop_txt_prob:
                text_shape = ''

            # drop conditions
            if drop_all_local_prob == 1.0:
                condition_image = torch.zeros_like(condition_image)
            
            # c_text = 'A face conditioned on segmentation'
            # text = text + " " + c_text.strip()
            # '/group_meiyan/meiyan_data/CelebA-HQ/CelebA-HQ-img/7738.jpg'

            return {
                "pixel_values": img,
                "img_name":img_name,
                "conditioning_pixel_values" : [condition_image],
                "input_ids_global": self.tokenize_captions([text_global]),
                "input_ids_shape": self.tokenize_captions([text_shape]),
            }
        # except:
        #     return None
        
    def __len__(self):
        return len(self.img_name_list)

    def tokenize_captions(self, text, is_train=True):
        captions = []
        for caption in text:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{text}` should contain either strings or lists of strings."
                )
        inputs = self.tokenizer(
            captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

class RealV40CustomDataset_v4(torch.utils.data.Dataset):

    def __init__(self,
                tokenizer,
                TP_tokenizer,
                image_root_path,
                txt_root,
                resolution=512,
                drop_txt_prob = 0.02,
                drop_all_prob = 0.02,
                ):

        self.tokenizer = tokenizer
        self.TP_tokenizer = TP_tokenizer
        self.resolution = resolution

        self.transform = transforms.Compose([
            transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

        self.img_path = image_root_path
        # self.img_name_list = os.listdir(self.img_path)

        self.txt_dict_global = json.load(open(txt_root['global'], 'r'))
        self.txt_dict_global = {os.path.basename(key):value for key, value in self.txt_dict_global.items()}
        
        self.txt_dict_shape = pd.read_csv(txt_root['fine_grained']).set_index('img_path')['caption'].to_dict()
        self.txt_dict_shape = {os.path.basename(key):ast.literal_eval(value) for key, value in self.txt_dict_shape.items()}
        [value.pop('lianxing') for key, value in self.txt_dict_shape.items()] ## lianxing暂时去掉

        img_set = set(os.listdir(image_root_path))
        img_set = img_set.intersection(set(self.txt_dict_global.keys()))
        img_set = img_set.intersection(set(self.txt_dict_shape.keys()))


        img_set = img_set - set(error_file)

        self.img_name_list = list(img_set)

        self.drop_txt_prob = drop_txt_prob
        self.drop_all_prob = drop_all_prob

    def __getitem__(self, index):
        if 1:
            img_name = self.img_name_list[index]
            
            # read image
            img = Image.open(os.path.join(self.img_path, img_name)).convert("RGB")
            img = self.transform(img)

            text_global = self.txt_dict_global[img_name]
            drop_all_local_prob = 0.0
            
            text_list_finegrained = []
            caption_dict = self.txt_dict_shape[img_name]

            for key in caption_dict:
                caption = caption_dict[key].strip(" .")
                if caption != "":
                    text_list_finegrained.append(caption)
            # if len(text_list_finegrained) > 1 and random.random() < 0.2:
            #     text_list_finegrained.pop(random.randint(0,len(text_list_finegrained) - 1))
            text_shape = random.choice(text_list_finegrained)
            # text_shape = '. '.join(text_list_finegrained)
            # if text_shape != '':
            #     text_shape +='.'

            # randomly drop
            rand_num = random.random()

            if rand_num < self.drop_txt_prob:
                text_shape = ''
            elif rand_num < self.drop_txt_prob + self.drop_all_prob:
                text_shape, text_global= '', ''
            
            return {
                "pixel_values": img,
                "img_name":img_name,
                "input_ids_global": self.tokenize_captions([text_global]),
                "input_ids_shape": self.tokenize_finegrained_captions([text_shape])[0],
            }
        # except:
        #     return None
        
    def __len__(self):
        return len(self.img_name_list)

    def tokenize_captions(self, text, is_train=True):
        captions = []
        for caption in text:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{text}` should contain either strings or lists of strings."
                )
        inputs = self.tokenizer(
            captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    def tokenize_finegrained_captions(self, text):
        inputs = self.TP_tokenizer(text)
        return inputs

prompt_list = {
    "fuzhi":
        {
            "freckles": [
                "freckled skin",
                "a person with freckled skin",
                "freckles on the face",
                "obvious freckles on the face",
                ],
            "fair": [
                "smooth skin",
                "a person with smooth skin"
                ],
            "wrinkles": [
                "wrinkled skin",
                "serve wrinkles on the face",
                "a person with wrinkled face",
                "wrinkles on the face",
            ],
        },
    "hair color":
        [
            "The color of the hair is {}",
            "{} hair.",
            "the hair is in {}.",
            "A person with {} color hair.",
        ],
    "hair curliness":
        [
            "hair that is {}.",
            "hair which is {}.",
            "a {} hair is obvious.",
            "A person with {} hair."
        ],
    "eyeglasses":
        [
            "a person with {}",
            "wearing {}",
            "the person is accessorized with {}."
        ],
    "beards":
        [
            "a person with {}",
            "wearing {}",
            "the person is accessorized with {}."
        ],
    "hat":
        [
            "a person with {}",
            "wearing {}",
            "the person is accessorized with {}."
        ],
    "necklace":
        [
            "a person with {}",
            "wearing {}",
            "the person is accessorized with {}."
        ],
    "makeup":
        {
            "makeup":[
                "a person with makeup",
                "wearing makeup",
                "the human has a makeup look"
            ],
            "no makeup": [
                "The bare face is the obvious.",
                "the person has a bare face",
                "a person with bare face."
            ],
        },
}


prompt_list2 = {
    "fuzhi":
        {
            "freckles": [
                "freckled skin",
                "freckled face",
                ],
            "fair": [
                "smooth skin",
                "smooth face"
                ],
            "wrinkles": [
                "wrinkled skin",
                "wrinkles face",
            ],
        },
    "hair color":
        [
            "{} hair.",
            "hair in {}.",
        ],
    "hair curliness":
        [
            "{} hair",
            "hair is {}.",
        ],
    "eyeglasses":
        [
            "a person with {}",
            "wearing {}",
        ],
    "beards":
        [
            "a person with {}",
            "wearing {}",
        ],
    "hat":
        [
            "a person with {}",
            "wearing {}",
        ],
    "necklace":
        [
            "a person with {}",
            "wearing {}",
        ],
    "makeup":
        {
            "makeup":[
                "wearing makeup",
            ],
            "no makeup": [
                "natural face.",
            ],
            "red lips": [
                "red lips",
                "the lips is red",
            ]
        },
    "style":
        [
            "{}",
            "{} image",
            "{} human image",
        ],
    "expression":
        [
            "a person is {}",
            "with {} facial expression",
            "{}",
        ],
    "scene object":
        [
            "{}",
            "a person {}",
        ],
    "cloth":
        [
            "wearing {}",
            "{}",
            "a person with {}",
        ]
}

class RealV40CustomDataset_v7(torch.utils.data.Dataset):

    def __init__(self,
                tokenizer,
                TP_tokenizer,
                data_csv_root,
                img_key,
                caption_key,
                resolution=512,
                drop_txt_prob = 0.02,
                drop_all_prob = 0.02,
                ):

        self.tokenizer = tokenizer
        self.TP_tokenizer = TP_tokenizer
        self.resolution = resolution

        self.transform = transforms.Compose([
            transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

        # self.img_path = image_root_path
        # self.img_name_list = os.listdir(self.img_path)

        # self.txt_dict_global = json.load(open(txt_root['global'], 'r'))
        # self.txt_dict_global = {os.path.basename(key):value for key, value in self.txt_dict_global.items()}
        
        # self.txt_dict_shape = pd.read_csv(txt_root['fine_grained']).set_index('img_path')['caption'].to_dict()
        # self.txt_dict_shape = {os.path.basename(key):ast.literal_eval(value) for key, value in self.txt_dict_shape.items()}
        # [value.pop('lianxing') for key, value in self.txt_dict_shape.items()] ## lianxing暂时去掉

        # img_set = set(os.listdir(image_root_path))
        # img_set = img_set.intersection(set(self.txt_dict_global.keys()))
        # img_set = img_set.intersection(set(self.txt_dict_shape.keys()))

        df = pd.read_csv(data_csv_root)
        df[caption_key] = df[caption_key].apply(ast.literal_eval)
        for i in range(len(df)):
            data = df.iloc[i]
            new_caption = {k: v for k, v in data[caption_key].items() if v != '' and ('not ' not in v or k == 'global_prompt')}
            df.at[i, caption_key] = new_caption
        df = df[df[caption_key].apply(lambda d: len(d) > 1)]
        # df = df[df[caption_key].apply(lambda x: 'cloth' in x and x.get('cloth', '') != '')]
        self.images = df[img_key].tolist()
        self.captions = df[caption_key].tolist()
        self.prompt_list = prompt_list2

        self.drop_txt_prob = drop_txt_prob
        self.drop_all_prob = drop_all_prob

        self.fill_dimension = ["hair color", "hair curliness", 'eyeglasses', 'beards', 'hat', 'necklace', 'style', 'expression', 'scene object', 'cloth']
        self.unfill_dimension = ["makeup", "fuzhi"]

        # self.check_data()

    def __getitem__(self, index):
        if 1:
            img_name = os.path.basename(self.images[index])
            img = Image.open(str(self.images[index].strip().replace("crop_img", "img"))).convert("RGB")

            # read image
            img = self.transform(img)

            text_global = self.captions[index]['global_prompt']
            drop_all_loclal_prob = 0.0
            
            caption_dict = self.captions[index].copy()
            new_caption = caption_dict.pop("global_prompt")
            text_list_finegrained = []
            if len(caption_dict)> 1:
                attr_num = random.choices([1,2], [0.8, 0.2])[0]
                attr_list = random.sample(caption_dict.keys(), attr_num)
            else:
                attr_list = [random.choice(list(caption_dict.keys()))]
            for attr in attr_list:
                label = caption_dict[attr]
                if attr == 'cloth':
                    text_global = text_global.replace(label, "")
                    text_global = text_global.replace(', ,', ",")
                if attr in self.fill_dimension:
                    caption = random.choice(self.prompt_list[attr])
                    caption = caption.format(label)
                    text_list_finegrained.append(caption.strip(" ."))
                elif attr in self.unfill_dimension:
                    caption = random.choice(self.prompt_list[attr][label])
                    text_list_finegrained.append(caption.strip(" ."))

            random.shuffle(text_list_finegrained)
            text_shape = ', '.join(text_list_finegrained).strip(' ,.') + "."
            text_global = text_global.strip(' ,.') + "."

            # randomly drop
            rand_num = random.random()

            if rand_num < self.drop_txt_prob:
                text_shape = ''
            elif rand_num < self.drop_txt_prob + self.drop_all_prob:
                text_shape, text_global= '', ''
            
            return {
                "pixel_values": img,
                "img_name":img_name,
                "input_ids_global": self.tokenize_captions([text_global]),
                "input_ids_shape": self.TP_tokenize_captions([text_shape]),
            }
        # except:
        #     return None
    def check_data(self):
        for index in tqdm(range(len(self.images))):
            # img_name = os.path.basename(self.images[index])
            # img = Image.open(str(self.images[index].strip().replace("crop_img", "img"))).convert("RGB")
            text_global = self.captions[index]['global_prompt']

    def __len__(self):
        return len(self.images)

    def tokenize_captions(self, text, is_train=True):
        captions = []
        for caption in text:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{text}` should contain either strings or lists of strings."
                )
        inputs = self.tokenizer(
            captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    def TP_tokenize_captions(self, text, is_train=True):
        captions = []
        for caption in text:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{text}` should contain either strings or lists of strings."
                )
        inputs = self.TP_tokenizer(
            captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

class RealV40CustomDataset_v7_stage2(torch.utils.data.Dataset):

    def __init__(self,
                tokenizer,
                TP_tokenizer,
                clip_processor,
                data_csv_root,
                img_key,
                caption_key,
                resolution=512,
                drop_txt_prob = 0.02,
                drop_all_prob = 0.02,
                ):

        self.tokenizer = tokenizer
        self.TP_tokenizer = TP_tokenizer
        self.clip_processor = clip_processor
        self.resolution = resolution

        self.transform = transforms.Compose([
            transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        
        df = pd.read_csv(data_csv_root)
        # df = df.tail(len(df) - 107626)
        df[caption_key] = df[caption_key].apply(ast.literal_eval)
        df[caption_key] = df[caption_key].apply(lambda x: {k: v for k, v in x.items() if v != '' or k == 'global_prompt'})
        # for i in range(len(df)):
        #     data = df.iloc[i]
        #     new_caption = {k: v for k, v in data[caption_key].items() if v != ''}
        #     df.at[i, caption_key] = new_caption
        df = df[df[caption_key].apply(lambda d: len(d) > 1)]
        self.images = df[img_key].tolist()
        self.captions = df[caption_key].tolist()
        self.prompt_list = prompt_list2

        self.drop_txt_prob = drop_txt_prob
        self.drop_all_prob = drop_all_prob

        self.fill_dimension = ["hair color", "hair curliness", 'eyeglasses', 'beards', 'hat', 'necklace', 'style', 'expression', 'scene object', 'cloth']
        self.unfill_dimension = ["makeup", "fuzhi"]

    def __getitem__(self, index):
        if 1:
            img_name = os.path.basename(self.images[index])
            img = Image.open(str(self.images[index].strip().replace("crop_img", "img"))).convert("RGB")
            # read image
            clip_img = self.clip_processor(images=img, return_tensors="pt").pixel_values
            img = self.transform(img)
            
            text_global = self.captions[index]['global_prompt']
            drop_all_loclal_prob = 0.0
            
            caption_dict = self.captions[index].copy()
            new_caption = caption_dict.pop("global_prompt")
            text_list_finegrained = []
            if len(caption_dict)> 1:
                attr_num = random.choices([1,2], [0.8, 0.2])[0]
                attr_list = random.sample(caption_dict.keys(), attr_num)
            else:
                attr_list = [random.choice(list(caption_dict.keys()))]
            for attr in attr_list:
                label = caption_dict[attr]
                if attr in self.fill_dimension:
                    caption = random.choice(self.prompt_list[attr])
                    caption = caption.format(label)
                    text_list_finegrained.append(caption.strip(" ."))
                elif attr in self.unfill_dimension:
                    caption = random.choice(self.prompt_list[attr][label])
                    text_list_finegrained.append(caption.strip(" ."))

            random.shuffle(text_list_finegrained)
            text_shape = '. '.join(text_list_finegrained) + "."
            # randomly drop
            rand_num = random.random()

            if rand_num < self.drop_txt_prob:
                text_shape = ''
            elif rand_num < self.drop_txt_prob + self.drop_all_prob:
                text_shape, text_global= '', ''
            
            return {
                "pixel_values": img,
                "img_name":img_name,
                "input_ids_global": self.tokenize_captions([text_global]),
                "input_ids_shape": self.TP_tokenize_captions([text_shape]),
                "clip_img":clip_img,
            }
        # except:
        #     return None
        
    def __len__(self):
        return len(self.images)

    def tokenize_captions(self, text, is_train=True):
        captions = []
        for caption in text:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{text}` should contain either strings or lists of strings."
                )
        inputs = self.tokenizer(
            captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    def TP_tokenize_captions(self, text, is_train=True):
        captions = []
        for caption in text:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{text}` should contain either strings or lists of strings."
                )
        inputs = self.TP_tokenizer(
            captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

class RealV40CustomDataset_v7_stage2_drop_head(torch.utils.data.Dataset):

    def __init__(self,
                tokenizer,
                TP_tokenizer,
                clip_processor,
                data_csv_root,
                img_key,
                caption_key,
                resolution=512,
                drop_txt_prob = 0.02,
                drop_all_prob = 0.02,
                ):

        self.tokenizer = tokenizer
        self.TP_tokenizer = TP_tokenizer
        self.clip_processor = clip_processor
        self.resolution = resolution

        self.transform = transforms.Compose([
            transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

        df = pd.read_csv(data_csv_root)
        df = df.tail(len(df) - 32000)
        df[caption_key] = df[caption_key].apply(ast.literal_eval)
        # df = df[df[caption_key].apply(lambda x: "white" in x.get('hair color', '') )]
        df[caption_key] = df[caption_key].apply(lambda x: {k: v for k, v in x.items() if v != '' and ('not ' not in v or k == 'global_prompt') })
        # df = df[df[caption_key].apply(lambda x: x.get('scene object', '') != '' )]
        df = df[df[caption_key].apply(lambda d: len(d) > 1)]
        self.images = df[img_key].tolist()
        self.captions = df[caption_key].tolist()
        self.prompt_list = prompt_list2

        self.drop_txt_prob = drop_txt_prob
        self.drop_all_prob = drop_all_prob

        self.fill_dimension = ["hair color", "hair curliness", 'eyeglasses', 'beards', 'hat', 'necklace', 'style', 'expression', 'scene object', 'cloth']
        self.unfill_dimension = ["makeup", "fuzhi"]

    def __getitem__(self, index):
        if 1:
            img_name = os.path.basename(self.images[index])
            img = Image.open(str(self.images[index].strip().replace("crop_img", "img"))).convert("RGB")
            # read image
            clip_img = self.clip_processor(images=img, return_tensors="pt").pixel_values
            img = self.transform(img)
            
            text_global = self.captions[index]['global_prompt']
            drop_all_loclal_prob = 0.0
            
            caption_dict = self.captions[index].copy()
            new_caption = caption_dict.pop("global_prompt")
            text_list_finegrained = []
            if len(caption_dict)> 1:
                attr_num = random.choices([1,2], [0.8, 0.2])[0]
                attr_list = random.sample(caption_dict.keys(), attr_num)
            else:
                attr_list = [random.choice(list(caption_dict.keys()))]
            for attr in attr_list:
                label = caption_dict[attr]
                if attr == 'cloth':
                    text_global = text_global.replace(label + ", ", "")
                if attr in self.fill_dimension:
                    caption = random.choice(self.prompt_list[attr])
                    caption = caption.format(label)
                    text_list_finegrained.append(caption.strip(" ."))
                elif attr in self.unfill_dimension:
                    caption = random.choice(self.prompt_list[attr][label])
                    text_list_finegrained.append(caption.strip(" ."))

            random.shuffle(text_list_finegrained)
            text_shape = '. '.join(text_list_finegrained).strip(' ,.') + "."

            # randomly drop
            rand_num = random.random()

            if rand_num < self.drop_txt_prob:
                text_shape = ''
            elif rand_num < self.drop_txt_prob + self.drop_all_prob:
                text_shape, text_global= '', ''
            
            return {
                "pixel_values": img,
                "img_name":img_name,
                "input_ids_global": self.tokenize_captions([text_global]),
                "input_ids_global_indice": self.tokenize_captions(text_global.split(", ")),
                "input_ids_shape": self.TP_tokenize_captions([text_shape]),
                "clip_img":clip_img,
            }
        # except:
        #     return None
        
    def __len__(self):
        return len(self.images)

    def tokenize_captions(self, text, is_train=True):
        captions = []
        for caption in text:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{text}` should contain either strings or lists of strings."
                )
        inputs = self.tokenizer(
            captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    def TP_tokenize_captions(self, text, is_train=True):
        captions = []
        for caption in text:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{text}` should contain either strings or lists of strings."
                )
        inputs = self.TP_tokenizer(
            captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids