from re import M
import einops
import numpy as np
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import h5py

from slot_attention.const import MOD_FILE_PATHS

# data from https://zenodo.org/record/4895643
# tfrecords is not well parallelizable: https://gitlab.com/generally-intelligent/simone/-/blob/main/data.py


class MOTetroDataset(Dataset):

    def __init__(self, data_root, mode='train'):
        super().__init__()
        assert 'tetro' in data_root, 'data_root must contain tetro'
        file = h5py.File(data_root, 'r')
        print(file.keys())
        assert mode in file.keys(), f'{mode} not in {file.keys()}, must be "train" or "test"'
        data = file[mode]
        
        print(data.keys())
        self.imgs = np.array(data['imgs'][:])
        self.masks = np.array(data['masks'][:]).squeeze()
        # replace values > 0 with 1
        self.masks[self.masks > 0] = 1
        # self.factors = data['factors'] if 'factors' in data.keys() else None
        print(f'imgs: {self.imgs.shape}')
        print(f'masks: {self.masks.shape}')

        # normalize into [-1,1]
        self.imgs = (self.imgs / 255.0 ) * 2.0 - 1.0

        # # display whole numpy array
        # np.set_printoptions(threshold=np.inf)

        # print(self.masks[0,0,:,:])
        # print(self.masks[0,1,:,:])
        # print(self.masks[0,2,:,:])
        # print(self.masks[0,3,:,:])

        # print('breakpoint')

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        img = torch.from_numpy(self.imgs[index]).float()
        mask = torch.from_numpy(self.masks[index]).float()
        # change channels last to channels first
        img = einops.rearrange(img, 'h w c -> c h w')

        return img, mask

class MOTetroDataModule(pl.LightningDataModule):
    def __init__(
        self,
        params,
        train_batch_size: int,
        val_batch_size: int,
        num_workers: int,
        train_val_perc: 0.95,
    ):
        super().__init__()
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers
        
        full_dataset = MOTetroDataset(
            data_root=MOD_FILE_PATHS['tetrominoes']['train'],
        )

        if params.is_test:
            self.train_dataset = full_dataset
            self.val_dataset = MOTetroDataset(
                data_root=MOD_FILE_PATHS['tetrominoes']['test'],
                mode="test",
            )
        else:
            train_size = int(train_val_perc * len(full_dataset))
            val_size = len(full_dataset) - train_size
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(
                full_dataset, [train_size, val_size]
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )



if __name__ == "__main__":
    
    path = MOD_FILE_PATHS['tetrominoes']['train']
    print(path)
    dm = MOTetroDataModule(path, 64, 64, 4, train_val_perc=0.9)  


    ds = MOTetroDataset(path)
    print(len(ds))
    # img, mask = ds[0]
    for i, (img, mask) in enumerate(ds):
        print(f'img: {img.size()}')
        print(f'mask: {mask.size()}')
        if i > 5:
            break
    
