import einops
import numpy as np
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import h5py

from slot_attention.const import MOD_FILE_PATHS

# data from https://zenodo.org/record/4895643
# tfrecords is not well parallelizable: https://gitlab.com/generally-intelligent/simone/-/blob/main/data.py


class MOSpritesDataset(Dataset):

    def __init__(self, data_root):
        super().__init__()
        assert 'sprites' in data_root, 'data_root must contain sprites'
        file = h5py.File(data_root, 'r')
        print(file.keys())
        if 'train' in file.keys():
            data = file['train']
        else:
            data = file['test']
        
        print(data.keys())
        self.imgs = np.array(data['imgs'][:])
        self.masks = np.array(data['masks'][:]).squeeze()
        # replace values > 0 with 1
        self.masks[self.masks > 0] = 1
        # self.factors = data['factors'] if 'factors' in data.keys() else None
        print(f'imgs: {self.imgs.shape}')
        print(f'masks: {self.masks.shape}')

        # normalize into [-1,1]
        self.imgs = (self.imgs / 255.0 ) * 2.0 - 1.0

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        img = torch.from_numpy(self.imgs[index]).float()
        mask = torch.from_numpy(self.masks[index]).float()
        # change channels last to channels first
        img = einops.rearrange(img, 'h w c -> c h w')

        return img, mask

class MOSpritesDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_root: str,
        train_batch_size: int,
        val_batch_size: int,
        num_workers: int,
        train_val_perc: 0.95,
    ):
        super().__init__()
        self.data_root = data_root
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers
        
        full_dataset = MOSpritesDataset(
            data_root=self.data_root,
        )

        train_size = int(train_val_perc * len(full_dataset))
        val_size = len(full_dataset) - train_size
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, val_size]
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )



if __name__ == "__main__":
    
    path = MOD_FILE_PATHS['multi_dsprites']['train']
    print(path)
    dm = MOSpritesDataModule(path, 64, 64, 4, train_val_perc=0.9)  

    ds = MOSpritesDataset(path)
    print(len(ds))
    # img, mask = ds[0]
    for i, (img, mask) in enumerate(ds):
        print(f'img: {img.size()}')
        print(f'mask: {mask.size()}')
        if i > 5:
            break
    
