import os
from PIL import Image
import random
import json
import torch
import glob
from torch.utils.data import Dataset
from lavis.processors.blip_processors import BlipCaptionProcessor
import torchvision.transforms as transforms

class CLEVRMaskDataset(Dataset):
    def __init__(self, data_path, transform, split='train', max_words=40, prompt="", scale=1.0):
        super().__init__()
        self.data_path = data_path
        self.default_image_path = os.path.join(self.data_path, 'images')
        self.nsc_image_path = os.path.join(self.data_path, 'nsc_images')
        self.sc_image_path = os.path.join(self.data_path, 'sc_images')
        self.sc_mask_path = os.path.join(self.data_path, 'sc_images_with_mask')
        self.nsc_mask_path = os.path.join(self.data_path, 'nsc_images_with_mask')
        self.split = split
        self.transform = transform
        self.max_words = max_words
        self.prompt = prompt
        self.scale = scale
        self.text_process = BlipCaptionProcessor(prompt=prompt, max_words=max_words)
        
        with open(os.path.join(self.data_path, "splits.json"), 'r') as fp:
            total_image_ids = json.load(fp)
            self.image_ids = total_image_ids[split]

        with open(os.path.join(self.data_path, "change_captions.json"), 'r') as fp:
            self.change_captions = json.load(fp)

        with open(os.path.join(self.data_path, "no_change_captions.json"), 'r') as fp:
            self.no_change_captions = json.load(fp)
        
        self.mask_transform = transforms.Compose([
            transforms.ToTensor(), 
        ])

    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, index):
        image_id = self.image_ids[index]
        image_name = "CLEVR_default_%06d.png" % int(image_id)
        sc_caption = random.choice(self.change_captions[image_name])
        nsc_caption = random.choice(self.no_change_captions[image_name])

        bef_image_path = os.path.join(self.default_image_path, image_name)
        aft_image_path = os.path.join(self.sc_image_path, image_name.replace('default', 'semantic'))
        no_image_path = os.path.join(self.nsc_image_path, image_name.replace('default', 'nonsemantic'))

        bef_image_data = self.get_image_data(bef_image_path)
        aft_image_data = self.get_image_data(aft_image_path)
        no_image_data = self.get_image_data(no_image_path)

        sc_default_mask_name = image_name.replace('.png', '_mask.png') 
        sc_image_mask_name = image_name.replace('default', 'semantic').replace('.png', '_mask.png')  
        nsc_default_mask_name = image_name.replace('.png', '_mask.png') 
        nsc_image_mask_name = image_name.replace('default', 'nonsemantic').replace('.png', '_mask.png')
        
        sc_default_mask = self.get_mask(os.path.join(self.sc_mask_path, sc_default_mask_name))
        sc_image_mask = self.get_mask(os.path.join(self.sc_mask_path, sc_image_mask_name))
        nsc_default_mask = self.get_mask(os.path.join(self.nsc_mask_path, nsc_default_mask_name))
        nsc_image_mask = self.get_mask(os.path.join(self.nsc_mask_path, nsc_image_mask_name))
        
        out = {}

        if self.split == 'train':
            out['sc_caption'] = self.text_process(sc_caption)
            out['nsc_caption'] = self.text_process(nsc_caption)
            if random.random() < 0.5:
                out['bef_image'] = bef_image_data
                out['aft_image'] = aft_image_data
                out['nsc_image'] = no_image_data
                out['sc_bef_mask'] = sc_default_mask
                out['sc_aft_mask'] = sc_image_mask
                out['nsc_bef_mask'] = nsc_default_mask
                out['nsc_aft_mask'] = nsc_image_mask
            else:
                out['bef_image'] = no_image_data
                out['aft_image'] = aft_image_data
                out['nsc_image'] = bef_image_data
                out['sc_bef_mask'] = nsc_image_mask
                out['sc_aft_mask'] = sc_image_mask
                out['nsc_bef_mask'] = nsc_image_mask
                out['nsc_aft_mask'] = nsc_default_mask
        else:
            out['bef_image'] = bef_image_data
            out['aft_image'] = aft_image_data
            out['nsc_image'] = no_image_data
            out['sc_bef_mask'] = sc_default_mask
            out['sc_aft_mask'] = sc_image_mask
            out['nsc_bef_mask'] = nsc_default_mask
            out['nsc_aft_mask'] = nsc_image_mask
            out['img_id'] = "%06d.png" % int(image_id)
            out['bef_path'] = bef_image_path
            out['aft_path'] = aft_image_path
            out['nsc_path'] = no_image_path
        return out
    
    def get_image_data(self, image_path):
        with open(image_path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img
    
    def get_mask(self, mask_path):
        try:
            if not os.path.exists(mask_path):
                raise FileNotFoundError
            with open(mask_path, 'rb') as f:
                img = Image.open(f)
                img = img.convert('L') 
        except FileNotFoundError:
            img = Image.new('L', (224, 224), color=0)
        if self.mask_transform is not None:
            img = self.mask_transform(img)
        return img

    
class SpotMaskDataset(Dataset):
    def __init__(self, image_path, anno_path, mask_path, transform=None, prompt="", split='train') -> None:
        super().__init__()
        self.image_path = image_path
        self.mask_path = mask_path
        self.split = split
        self.transform = transform
        self.text_preprocess = BlipCaptionProcessor(prompt=prompt, max_words=40)
        captions_file_path = os.path.join(anno_path,'filter_{}.json'.format(split))
        with open(captions_file_path,'r') as f:
            self.captions = json.load(f)

        self.texts = []
        for cap in self.captions:
            for sentence in cap['sentences']:
                self.texts.append(self.text_preprocess(sentence))

        self.mask_transform = transforms.Compose([
            transforms.ToTensor(), 
        ])
           
    def __len__(self):
        return len(self.captions)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = caption['img_id']
        text = random.choice(caption['sentences'])
        text = self.text_preprocess(text)
        bef_img = self.get_image(img_id)
        aft_img = self.get_image(img_id+'_2')

        bef_mask = self.get_mask(img_id)
        aft_mask = self.get_mask(img_id+'_2')
        
        output = {}
        if self.split == 'train':
            output['bef_img'] = bef_img
            output['aft_img'] = aft_img
            output['caption'] = text
            output['bef_mask'] = bef_mask
            output['aft_mask'] = aft_mask
        else:
            output['bef_path'] = os.path.join(self.image_path, '%s.png' % img_id)
            output['aft_path'] = os.path.join(self.image_path, '%s.png' % (img_id+"_2"))
            output['bef_img'] = bef_img
            output['aft_img'] = aft_img
            output['img_id'] = "%s.png" % img_id
            output['bef_mask'] = bef_mask
            output['aft_mask'] = aft_mask
        return output
    
    def get_image(self, img_id):
        img_path = os.path.join(self.image_path, '%s.png' % img_id)
        with open(img_path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img
    
    def get_mask(self, img_id):
        img_path = os.path.join(self.mask_path, '%s_mask.png' % img_id)
        with open(img_path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('L')
        img = self.mask_transform(img)
        return img
    

class LEVIRMaskDataset(Dataset):
    def __init__(self, data_path, transform=None, split='train', prompt="") -> None:
        super().__init__()
        self.data_path = data_path
        self.split = split
        self.transform = transform
        self.text_preprocess = BlipCaptionProcessor(prompt=prompt, max_words=40)
        with open("gt.json",'r') as f:
            self.captions = json.load(f)
        self.bef_image_path = os.path.join(data_path,"images", split, 'A')
        self.aft_image_path = os.path.join(data_path,"images", split, 'B')
        self.mask_path = os.path.join(data_path, "images_with_mask")
        self.image_names = []
        if os.path.exists(self.bef_image_path):
            for ext in ['*.png', '*.jpg', '*.jpeg']:
                self.image_names.extend([os.path.basename(f) for f in glob.glob(os.path.join(self.bef_image_path, ext))])
        self.image_names.sort()
        filtered = [name for name in self.image_names if name in self.captions]
        self.image_names  = filtered
        self.mask_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
           
    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, index):
        img_name = self.image_names[index]
        caption = self.captions[img_name]
        text = random.choice(caption)
        text = self.text_preprocess(text)
        img_id = os.path.splitext(img_name)[0]
        bef_img = self.get_image(img_name, 'before')
        aft_img = self.get_image(img_name, 'after')
        bef_mask = self.get_mask(img_id + "_A")
        aft_mask = self.get_mask(img_id + "_B")
        
        output = {}
        if self.split == 'train':
            output['bef_img'] = bef_img
            output['aft_img'] = aft_img
            output['caption'] = text
            output['bef_mask'] = bef_mask
            output['aft_mask'] = aft_mask
        else:
            output['bef_path'] = os.path.join(self.bef_image_path, img_name)
            output['aft_path'] = os.path.join(self.aft_image_path, img_name)
            output['bef_img'] = bef_img
            output['aft_img'] = aft_img
            output['img_id'] = img_name
            output['bef_mask'] = bef_mask
            output['aft_mask'] = aft_mask
        return output
    
    def get_image(self, img_name, img_type):
        if img_type == 'before':
            img_path = os.path.join(self.bef_image_path, img_name)
        else:
            img_path = os.path.join(self.aft_image_path, img_name)
        with open(img_path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img
    
    def get_mask(self, img_id):
        img_path = os.path.join(self.mask_path, '%s_mask.png' % img_id)
        try:
            if not os.path.exists(img_path):
                raise FileNotFoundError
            with open(img_path, 'rb') as f:
                img = Image.open(f).convert('L')
        except FileNotFoundError:
            img = Image.new('L', (256, 256), color=0)
        img = self.mask_transform(img)
        return img
    

class IERMaskDataset(Dataset):  
    def __init__(self, data_path, transform, split="split", max_words=40, prompt="") -> None:
        super().__init__()
        self.data_path = data_path
        self.image_path = os.path.join(data_path, "images")
        self.mask_path = os.path.join(data_path, "images_with_mask")
        self.split = split
        self.transform = transform
        self.max_words = max_words
        self.prompt = prompt
        self.text_process = BlipCaptionProcessor(prompt=prompt, max_words=max_words)
        self.word_list = None

        with open(os.path.join(self.data_path, "%s.json"%self.split), "r") as f:
            self.captions = json.load(f)
        
        self.mask_transform = transforms.Compose([
            transforms.Resize((224, 224)),  
            transforms.ToTensor(), 
        ])

    def __len__(self,):
        return len(self.captions)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        text = random.choice(caption['sents'])
        text = self.text_process(text)
        bef_img = self.get_image(caption['img0'])
        aft_img = self.get_image(caption['img1'])

        bef_mask = self.get_mask(os.path.splitext(caption['img0'])[0])
        aft_mask = self.get_mask(os.path.splitext(caption['img1'])[0])
        img_id = caption['uid']
        out = {}
        if self.split == 'train':
            out['bef_img'] = bef_img
            out['aft_img'] = aft_img
            out['bef_mask'] = bef_mask
            out['aft_mask'] = aft_mask
            out['caption'] = text
        else:
            out['bef_img'] = bef_img
            out['aft_img'] = aft_img
            out['bef_mask'] = bef_mask
            out['aft_mask'] = aft_mask
            out['img_id'] = img_id
            out['bef_path'] = os.path.join(self.image_path, caption['img0'])
            out['aft_path'] = os.path.join(self.image_path, caption['img1'])
        return out
    
    def get_image(self, image_name):
        img_path = os.path.join(self.image_path, image_name)
        with open(img_path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

    def get_mask(self, img_id):
        img_path = os.path.join(self.mask_path, '%s_mask.png' % img_id)
        try:
            if not os.path.exists(img_path):
                raise FileNotFoundError
            with open(img_path, 'rb') as f:
                img = Image.open(f).convert('L')
        except FileNotFoundError:
            img = Image.new('L', (224, 224), color=0)
        img = self.mask_transform(img)
        return img
    
