import os
import glob
import os
import glob
import json
import torch
import random
import numpy as np
from PIL import Image, ImageFile
from torchvision import transforms
import torchvision
from torch.utils.data import Dataset

ImageFile.LOAD_TRUNCATED_IMAGES = True

class clevrtex(Dataset):
    def __init__(self, root, phase, img_size,
                 img_glob='CLEVRTEX_full_??????.png',
                 max_num_obj=10,
                 keys_to_log=None):
        assert phase in ['train', 'valid', 'test']
        
        self.root = root
        self.img_size = img_size
        self.phase = phase
        self.max_num_obj = max_num_obj
        self.keys_to_log = keys_to_log

        if phase == 'train':
            self.total_dirs = sorted(glob.glob(root+'/*'))[:48]
        elif phase == 'valid':
            self.total_dirs = sorted(glob.glob(root+'/*'))[48:49]
        elif phase == 'test':
            self.total_dirs = sorted(glob.glob(root+'/*'))[49:]
        else:
            pass

        # chunk into episodes
        self.episodes = []
        for dir in self.total_dirs:
            image_paths = sorted(glob.glob(os.path.join(dir, img_glob)))
            self.episodes += image_paths

        if randomcrop and phase == 'train':
            self.randomcrop = transforms.RandomResizedCrop((img_size,img_size),scale=(0.5,1.0),ratio=(1.0,1.0))
        else:
            self.randomcrop = transforms.Compose([
                transforms.Resize(img_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR)
                ,transforms.CenterCrop((img_size,img_size))])
            
        self.transform = transforms.ToTensor()

        self.transform_img = transforms.Compose([transforms.ToTensor()])
        
        
    def __len__(self):
        return len(self.episodes)

    def __getitem__(self, idx):
        image = Image.open(self.episodes[idx]).convert("RGB")
        image = self.randomcrop(image)
        image = self.transform_img(image)
        segment = Image.open(self.mask_episodes[idx]).convert("P")
        segment = self.randomcrop(segment)
        segment = (self.transform(segment)*255).long().squeeze(0)

        return {'image': image, 'segment': segment}
