

import h5py
import numpy as np
from PIL import Image
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from slot_attention.const import MOD_FILE_PATHS

from slot_attention.model.model_utils import clevr_rescale
from torchvision import transforms
from torch.utils.data import Dataset


class MOClevrDataset(Dataset):
    def __init__(self, data_root):
        super().__init__()
        assert 'clevr' in data_root, 'data_root must be a clevr dataset'
        file = h5py.File(data_root, 'r')
        print(file.keys())
        if 'train' in file.keys():
            data = file['train']
        else:
            data = file['test']

        print(data.keys())
        self.imgs = np.array(data['imgs'][:])
        self.masks = np.array(data['masks'][:]).squeeze()
        # replace values > 0 with 1
        self.masks[self.masks > 0] = 1
        # self.factors = data['factors'] if 'factors' in data.keys() else None
        print(f'imgs: {self.imgs.shape}')
        print(f'masks: {self.masks.shape}')

        self.to_tensor = transforms.ToTensor()
        self.resize = transforms.Resize([128,128], antialias=True)
        self.resize_mask = transforms.Resize([128,128], interpolation=Image.NEAREST)
        self.rescale = transforms.Lambda(clevr_rescale)
        self.center_crop = transforms.CenterCrop(192)


    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        img = self.imgs[index]
        mask = self.masks[index]

        # convert to tensor
        img = self.to_tensor(img)
        mask = torch.from_numpy(mask).float()

        # rescale between -1 and 1
        img = self.rescale(img)

        # resize to 128x128
        img = self.resize(img)
        mask = self.resize_mask(mask)

        return img, mask


class MODClevrModule(pl.LightningDataModule):
    def __init__(
        self,
        data_root: str,
        train_batch_size: int,
        val_batch_size: int,
        num_workers: int,
        train_val_perc: 0.9
    ):
        super().__init__()
        self.data_root = data_root
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers

        full_dataset = MOClevrDataset(
            data_root=self.data_root,
        )

        train_size = int(train_val_perc * len(full_dataset))
        val_size = len(full_dataset) - train_size
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, val_size]
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )
        
if __name__ == '__main__':
    
    path = MOD_FILE_PATHS['clevr6']['train']
    print(path)

    ds = MOClevrDataset(path)
    print(len(ds))
    # img, mask = ds[0]
    for i, (img, mask) in enumerate(ds):
        print(f'img: {img.size()}')
        print(f'mask: {mask.size()}')
        if i > 5:
            break