import os
from datasets import load_from_disk
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class CustomDataset(Dataset):
    def __init__(self, latents, imgs, text, \
                 gt_patch_real, gt_patch_imag, watermarking_mask):
        self.latents = latents
        self.imgs = imgs
        self.text = list(text)

        self.gt_patch_real = gt_patch_real
        self.gt_patch_imag = gt_patch_imag
        self.watermarking_mask = watermarking_mask
        # define transformation for making images the same size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((512, 512))
        ])

    
    def __len__(self):
        return len(self.latents)
    
    def __getitem__(self, idx):
        latents = torch.squeeze(torch.tensor(self.latents[idx]))
        img = self.transform(self.imgs[idx])
        text = self.text[idx]
        
        gt_patch_real = torch.tensor(self.gt_patch_real[idx])
        gt_patch_imag = torch.tensor(self.gt_patch_imag[idx])
        gt_patch = torch.complex(real=gt_patch_real, imag=gt_patch_imag)
        
        watermarking_mask = torch.tensor(self.watermarking_mask[idx]).squeeze()

        return latents, img, text, gt_patch, watermarking_mask

def get_dataset_base(dataset_path, dataset_name, is_train):
    dataset_path = os.path.join(dataset_path, dataset_name)
    ds_full = load_from_disk(dataset_path)
    
    ds = ds_full['train'] if is_train else ds_full['test']
    print(f'Loaded {dataset_name} of size {len(ds)} images')
    return ds

def get_dataset(args, is_train):
    dataset_path = args.dataset_path
    dataset_path = os.path.join(dataset_path, args.dataset)
    ds_full = load_from_disk(dataset_path)
    
    ds = ds_full['train'] if is_train else ds_full['test']
    print(f'Loaded {args.dataset} of size {len(ds)} images')

    dataset_obj = None
    if args.dataset == "diffusiondb":
        dataset_obj = CustomDataset(ds['latents'], ds['image'], ds['prompt'], \
                                    ds['gt_patch_real'], ds['gt_patch_imag'], ds['watermarking_mask'])
    elif args.dataset == "coco":
        raise NotImplementedError()
    elif args.dataset == "wikiart":
        raise NotImplementedError()
    return dataset_obj

#-----------------------------------------------------------------------------------------#
def test_get_dataset():
    import argparse
    parser = argparse.ArgumentParser(description='diffusion watermark')
    parser.add_argument ('--dataset', default='diffusiondb', choices=['coco', 'diffusiondb', 'wikiart'])
    parser.add_argument ('--dataset_path', default='/localhome/data/datasets/watermarking')
    args = parser.parse_args()

    ds = get_dataset(args, is_train=True)
    ds = get_dataset(args, is_train=False)

if __name__ == '__main__':
    test_get_dataset()
