"""
adapted from FreeStyleRet: https://github.com/CuriseJia/FreeStyleRet
"""
import os
import json
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

class T2ITestDataset(Dataset):
    def __init__(self, root_path, json_path, image_transform, image_size=224):
        self.root_path = root_path
        self.dataset = json.load(open(json_path,'r'))
        self.image_transform = image_transform
        self.image_size = image_size


    def __len__(self):
        return len(self.dataset)
    
    
    def __getitem__(self, index):
        caption_path = os.path.join(self.root_path, 'text/'+self.dataset[index]['caption'])
        image_path = os.path.join(self.root_path, 'images/'+self.dataset[index]['image'])
        
        
        f = open(caption_path, 'r')
        caption = f.readline().replace('\n', '')

        image = Image.open(image_path).convert("RGB")
        image_size = self.image_size
        image = image.resize((image_size,image_size))
        pair_image = self.image_transform.preprocess(image, return_tensors='pt')['pixel_values']

        pair_image = pair_image.squeeze()

        return [caption, pair_image, index]
    
class I2ITestDataset(Dataset):
    def __init__(self, style, root_path, json_path, image_transform, image_size=224):
        self.style = style
        self.root_path = root_path
        self.dataset = json.load(open(json_path,'r'))
        self.image_transform = image_transform
        self.image_size = image_size
    

    def __len__(self):
        return len(self.dataset)
    
    
    def __getitem__(self, index):
        ori_path = os.path.join(self.root_path, 'images/'+self.dataset[index]['image'])
        pair_path = os.path.join(self.root_path, '{}/'.format(self.style)+self.dataset[index]['image'])
        
        image_size = self.image_size
        
        ori_image = Image.open(ori_path).convert("RGB")
        pair_image = Image.open(pair_path).convert("RGB")

        ori_image = ori_image.resize((image_size,image_size))
        pair_image = pair_image.resize((image_size,image_size))

        ori_image = self.image_transform.preprocess(ori_image, return_tensors='pt')['pixel_values']
        pair_image = self.image_transform.preprocess(pair_image, return_tensors='pt')['pixel_values']

        ori_image = ori_image.squeeze()
        pair_image = pair_image.squeeze()

        return [ori_image, pair_image, index]
    
class StyleImageTextDataset(Dataset):
    ### contains captions and multi-domain images
    def __init__(self, root_path, json_path, image_transform = None, image_size = 224):
        self.root_path = root_path
        self.dataset = json.load(open(json_path,'r'))
        self.image_transform = image_transform
        self.image_size = image_size

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        ori_path = os.path.join(self.root_path, 'images/'+self.dataset[index]['image'])

        ### multi style
        style1_pair_path = os.path.join(self.root_path, 'sketch/'+self.dataset[index]['image'])
        style1_pair_dir = 'sketch/'

        style2_pair_path = os.path.join(self.root_path, 'art/'+self.dataset[index]['image'])
        style2_pair_dir='art/'

        style3_pair_path = os.path.join(self.root_path, 'mosaic/'+self.dataset[index]['image'])
        style3_pair_dir='mosaic/'

        ### caption
        caption_path = os.path.join(self.root_path, 'text/'+self.dataset[index]['caption'])
        f = open(caption_path, 'r')
        caption = f.readline().replace('\n', '')

        ### neg sample
        neg = np.random.randint(1, len(self.dataset))
        while neg == index:
            neg = np.random.randint(1, len(self.dataset))
        negative_path = os.path.join(self.root_path, 'images/'+self.dataset[neg]['image'])
        style1_neg_path = os.path.join(self.root_path, style1_pair_dir+self.dataset[neg]['image'])
        style2_neg_path = os.path.join(self.root_path, style2_pair_dir+self.dataset[neg]['image'])
        style3_neg_path = os.path.join(self.root_path, style3_pair_dir+self.dataset[neg]['image'])

        ### get image
        image_size = self.image_size
        ori_image = Image.open(ori_path).convert("RGB").resize((image_size,image_size))
        style1_pair_image = Image.open(style1_pair_path).convert("RGB").resize((image_size,image_size))
        style2_pair_image = Image.open(style2_pair_path).convert("RGB").resize((image_size,image_size))
        style3_pair_image = Image.open(style3_pair_path).convert("RGB").resize((image_size,image_size))
        negative_image = Image.open(negative_path).convert("RGB").resize((image_size,image_size))
        style1_neg_image = Image.open(style1_neg_path).convert("RGB").resize((image_size,image_size))
        style2_neg_image = Image.open(style2_neg_path).convert("RGB").resize((image_size,image_size))
        style3_neg_image = Image.open(style3_neg_path).convert("RGB").resize((image_size,image_size))

        image_processor = self.image_transform
        ori_image = image_processor.preprocess(ori_image, return_tensors='pt')['pixel_values']
        style1_pair_image = image_processor.preprocess(style1_pair_image, return_tensors='pt')['pixel_values']
        style2_pair_image = image_processor.preprocess(style2_pair_image, return_tensors='pt')['pixel_values']
        style3_pair_image = image_processor.preprocess(style3_pair_image, return_tensors='pt')['pixel_values']

        negative_image = image_processor.preprocess(negative_image, return_tensors='pt')['pixel_values']
        style1_neg_image = image_processor.preprocess(style1_neg_image, return_tensors='pt')['pixel_values']
        style2_neg_image = image_processor.preprocess(style2_neg_image, return_tensors='pt')['pixel_values']
        style3_neg_image = image_processor.preprocess(style3_neg_image, return_tensors='pt')['pixel_values']

        ori_image = ori_image.squeeze()
        style1_pair_image = style1_pair_image.squeeze()
        style2_pair_image = style2_pair_image.squeeze()
        style3_pair_image = style3_pair_image.squeeze()
        
        negative_image = negative_image.squeeze()
        style1_neg_image = style1_neg_image.squeeze()
        style2_neg_image = style2_neg_image.squeeze()
        style3_neg_image = style3_neg_image.squeeze()

        return {
            "ori_image":ori_image,
            "style1_pair_image":style1_pair_image,
            "style2_pair_image":style2_pair_image,
            "style3_pair_image":style3_pair_image,
            "neg_image":negative_image,
            "caption": caption,
            "style1_neg_image": style1_neg_image,
            "style2_neg_image": style2_neg_image,
            "style3_neg_image": style3_neg_image,
        }
