import json
import os
import random
import pandas as pd
from torch.utils.data import Dataset
import torch
from PIL import Image
import numpy as np
from PIL import ImageFile
from PIL.Image import blend as blend
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
import ast
from data.utils import pre_caption
import os,glob

class pretrain_dataset(Dataset):
    def __init__(self, ann_file, transform): 
        self.img_root = "COCO2014" # server
        self.ann_pretrain = None
        for f in ann_file:
            ann_temp = pd.read_csv(f, sep='\t', header=None)
            if self.ann_pretrain is None:
                self.ann_pretrain = ann_temp
            else:
                self.ann_pretrain = pd.concat([self.ann_pretrain, ann_temp], ignore_index=True, sort=False)
        
        self.annotation = self.ann_pretrain[:1000]
        self.max_words = 30
        self.transform = transform
        
        # print(self.classes_dict)
    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):    
        ann = self.annotation.iloc[index]
        ann[3] = str(ann[3]) # no object: nan
        # print(ann[3])
        if len(ann[3]) > 5:
            image_caption = pre_caption(ann[3], self.max_words*2)
        else:
            image_caption = pre_caption(ann[0], self.max_words*2)
        caption = pre_caption(ann[0], 30)
        return image_caption, caption
