from functools import lru_cache
import os
import torch
from tqdm import tqdm
import numpy as np
import yaml
from transformers import BertTokenizer, BertModel
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import Dataset
import argparse
import re
import json
from PIL import Image

def pre_caption(caption,max_words=50):
    caption = re.sub(
        r"([.!\"()*#:;~])",       
        ' ',
        caption.lower(),
    )
    caption = re.sub(
        r"\s{2,}",
        ' ',
        caption,
    )
    caption = caption.rstrip('\n') 
    caption = caption.strip(' ')

    #truncate caption
    caption_words = caption.split(' ')
    if len(caption_words)>max_words:
        caption = ' '.join(caption_words[:max_words])
            
    return caption


class flickr30k_train(Dataset):
    def __init__(self, transform, data_root, max_words=30, prompt=''):        
        '''
        data_root (string): Root directory of data (e.g. flickr30k/)
        '''        
        filename = 'flickr30k_train.json'
        
        self.annotation = json.load(open(os.path.join(data_root,filename),'r'))
        self.transform = transform
        self.image_root = data_root
        self.max_words = max_words      
        self.prompt = prompt
        
        self.img_ids = {}
        n = 0
        for ann in self.annotation:
            img_id = ann['image_id']
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1    
        
    def __len__(self):
        return len(self.annotation)
    
    @lru_cache(maxsize=100)
    def read_image(self, image_path):      
        image = Image.open(image_path).convert('RGB')   
        if self.transform is not None:
            image = self.transform(image)
        return image
    
    def __getitem__(self, index):    
        
        ann = self.annotation[index]
        
        image_path = os.path.join(self.image_root,ann['image'])
        image = self.read_image(image_path)      
        
        caption = self.prompt+pre_caption(ann['caption'], self.max_words) 

        return image, caption, self.img_ids[ann['image_id']]
        
    def get_all_captions(self):
        captions = []
        for ann in self.annotation:
            caption = self.prompt + pre_caption(ann['caption'], self.max_words)
            captions.append(caption)
        return captions

    
    
class flickr30k_retrieval_eval(Dataset):
    def __init__(self, transform, data_root, split, max_words=30):  
        '''
        data_root (string): Root directory of data (e.g. flickr30k/)
        split (string): val or test
        '''
        filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
        
        self.annotation = json.load(open(os.path.join(data_root,filenames[split]),'r'))
        self.transform  = transform
        self.image_root = data_root
        self.max_words = max_words   
        
        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}
        
        txt_id = 0
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann['image'])
            if img_id not in self.img2txt.keys():
                self.img2txt[img_id] = []
            for i, caption in enumerate(ann['caption']):
                self.text.append(pre_caption(caption,max_words))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1
                                    
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        image_path = os.path.join(self.image_root, self.annotation[index]['image'])       
        image = Image.open(image_path).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        return image, index


