from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
import einops
import numpy as np
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import h5py
from slot_attention.data.mod_clevr6 import MOClevrDataset

from slot_attention.const import MOD_FILE_PATHS

# data from https://zenodo.org/record/4895643
# tfrecords is not well parallelizable: https://gitlab.com/generally-intelligent/simone/-/blob/main/data.py


class MODataset(Dataset):

    def __init__(self, data_root, transforms=None):
        super().__init__()
        raise NotImplementedError 'MODataset not properly implemented; use e.g. MOClevrDataset instead'
        self.transforms = transforms
        file = h5py.File(data_root, 'r')
        print(file.keys())
        if 'train' in file.keys():
            data = file['train']
        else:
            data = file['test']
        
        print(data.keys())
        self.imgs = np.array(data['imgs'][:])
        self.masks = np.array(data['masks'][:]).squeeze()
        # replace values > 0 with 1
        self.masks[self.masks > 0] = 1
        # self.factors = data['factors'] if 'factors' in data.keys() else None
        print(f'imgs: {self.imgs.shape}')
        print(f'masks: {self.masks.shape}')

        # display whole numpy array
        # np.set_printoptions(threshold=np.inf)

        # print(self.masks[0,0,:,:,0])
        # print(self.masks[0,1,:,:,0])
        # print(self.masks[0,2,:,:,0])
        # print(self.masks[0,3,:,:,0])

        # print('breakpoint')

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        img = torch.from_numpy(self.imgs[index]).float()
        mask = torch.from_numpy(self.masks[index]).float()
        # change channels last to channels first
        img = einops.rearrange(img, 'h w c -> c h w')

        # if self.factors is not None:
        #     return img, mask, torch.from_numpy(self.factors[index]).float()
        img = self.transforms(img)
        return img, mask

class MODataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_root: str,
        train_batch_size: int,
        val_batch_size: int,
        num_workers: int,
        train_val_perc: 0.95,
        transforms: Optional[Callable],
    ):
        super().__init__()
        self.data_root = data_root
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers
        
        full_dataset = MODataset(
            data_root=self.data_root,
            transforms=transforms,
        )

        train_size = int(train_val_perc * len(full_dataset))
        val_size = len(full_dataset) - train_size
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, val_size]
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )



if __name__ == "__main__":
    
    path = MOD_FILE_PATHS['tetrominoes']['train']
    print(path)
    dm = MODataModule(path, 64, 64, 4, train_val_perc=0.9)  


    ds = MODataset(path)
    print(len(ds))
    # img, mask = ds[0]
    for i, (img, mask) in enumerate(ds):
        print(f'img: {img.size()}')
        print(f'mask: {mask.size()}')
        if i > 5:
            break

    path = MOD_FILE_PATHS['clevr6']['train']
    print(path)

    ds = MOClevrDataset(path)
    print(len(ds))
    # img, mask = ds[0]
    for i, (img, mask) in enumerate(ds):
        print(f'img: {img.size()}')
        print(f'mask: {mask.size()}')
        if i > 5:
            break
    
