from ldm.util import instantiate_from_config
import torch 

from glob import glob 
from pathlib import Path
import os 
import json 
from PIL import Image 
import torchvision.transforms as transforms 


class Ai2D_Dataset():
    def __init__(self, ROOT, image_size = (512, 512), test = False, customized = None):

        self.ROOT = ROOT 

        self.images = os.listdir(os.path.join(ROOT, "images"))
        self.image_size = image_size
        
        if not test:
            ## ai2d json file 
            print("load training set...")
            with open(os.path.join(ROOT, "ai2d", "ai2d_gligen_formated_data.json"), "r") as read_file:
                self.ai2d_json = json.load(read_file)
                self.ai2d_json = [dict_ for dict_ in self.ai2d_json]   
            
        if customized is not None:
            print("load testing set...")
            with open(customized, "r") as read_file:
                self.ai2d_json = json.load(read_file)                
        
        self.datasets = self.ai2d_json
        self.total_length = len(self.datasets)
        
        self.transform = transforms.Compose([transforms.PILToTensor()]) 


    def total_images(self):
        count = self.total_length
        return count


    def __getitem__(self, idx):  
        
        json_data = self.datasets[idx]
        
        dict_ = {}
        dict_['caption'] = json_data['caption']
        if "caption2" in json_data:
            dict_['caption2'] = json_data['caption2']
        
        dict_['id'] = json_data['data_id']
        image_path = json_data['image_path'] if "image_path" in json_data else ""
        
        if image_path != '':
            image_path = os.path.join(self.ROOT, "images", image_path.split("/")[-1])
            pil_image = Image.open(image_path).resize(self.image_size)
            transformed_image = self.transform(pil_image)  
            transformed_image = transformed_image / 255
            transformed_image = transformed_image * 2 - 1
            dict_['image'] = transformed_image
        else:
            dict_['image'] = '' 
            
        
        ##### entity images #####
        n_entity_images = len(json_data['entity_images'])
        dict_['entity_image_tokens_positive'] = ['' for i in range(30)]
        if n_entity_images > 0:
            for i in range(min(30, n_entity_images)):
                dict_['entity_image_tokens_positive'][i] = json_data['entity_images'][i]['tokens_positive']
        dict_['entity_image_boxes'] = torch.zeros((30, 4)) 
        dict_['entity_image_masks'] = torch.zeros(30)
        if n_entity_images > 0:
            dict_['entity_image_boxes'][:n_entity_images] = torch.tensor([json_data['entity_images'][i]['bbox'] for i in range(n_entity_images)])[:30]
            dict_['entity_image_masks'][:n_entity_images] = 1.0
        
        
        ##### entity texts #####
        n_entity_texts = len(json_data['entity_texts']) if "entity_texts" in json_data else 0
        dict_['entity_text_boxes'] = torch.zeros((30, 4)) 
        dict_['entity_text_masks'] = torch.zeros(30)
        if n_entity_texts > 0:
            dict_['entity_text_boxes'][:n_entity_texts] = torch.tensor([json_data['entity_texts'][i]['bbox'] for i in range(n_entity_texts)])[:30]
            dict_['entity_text_masks'][:n_entity_texts] = 1.0
            

        ##### entity arrows #####
        n_entity_arrows = len(json_data['entity_arrows']) if "entity_arrows" in json_data else 0
        dict_['entity_arrow_tokens_positive'] = ['' for i in range(30)]
        if n_entity_arrows > 0:
            for i in range(min(30, n_entity_arrows)):
                dict_['entity_arrow_tokens_positive'][i] = json_data['entity_arrows'][i]['tokens_positive']
        dict_['entity_arrow_boxes'] = torch.zeros((30, 4)) 
        dict_['entity_arrow_masks'] = torch.zeros(30)
        if n_entity_arrows > 0:
            dict_['entity_arrow_boxes'][:n_entity_arrows] = torch.tensor([json_data['entity_arrows'][i]['bbox'] for i in range(n_entity_arrows)])[:30]
            dict_['entity_arrow_masks'][:n_entity_arrows] = 1.0
        
            
        return dict_




    def __len__(self):
        return self.total_length
            




