import einops
import numpy as np
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import h5py

from slot_attention.const import MOD_FILE_PATHS, OBJECTSROOM_ROOT_PATH
# from multi_object_datasets_torch import ObjectsRoom  #install via: pip install git+https://github.com/JohannesTheo/multi_object_datasets_torch.git@main`
from slot_attention.data.external_datasets import ObjectsRoom
from slot_attention.utils import objectsroom_image_transform, objectsroom_mask_transform


# https://github.com/JohannesTheo/multi_object_datasets_torch

class ObjectsRoomDataset(Dataset):

    def __init__(self, split, ttv) -> None:
        super().__init__()
        
        
        transforms = {
            'mask': objectsroom_mask_transform,
            'image': objectsroom_image_transform,            
        }
        
        self.ds = ObjectsRoom(
            root=OBJECTSROOM_ROOT_PATH,
            split=split,     # str:  select train, test or val
            ttv=ttv, # list: the size of [train, test, val]
            transforms=transforms,     # dict: mapping feature names to transforms
            download=True,     # bool: download the dataset
            convert=True
        )
        
    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index):
        idx_dict = self.ds[index]
        img = idx_dict['image']
        mask = idx_dict['mask'].squeeze()
        return img, mask

class ObjectsRoomDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_batch_size: int,
        val_batch_size: int,
        num_workers: int,
        train_val_perc: 0.95,
    ):
        super().__init__()
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers
        
        full_dataset = ObjectsRoomDataset(
            split='Train',     # str:  select train, test or val
            ttv=[50000, 5000, 5000], # list: the size of [train, test, val]
        )

        train_size = int(train_val_perc * len(full_dataset))
        val_size = len(full_dataset) - train_size
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, val_size]
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )



if __name__ == "__main__":
    
    # dm = ObjectsRoomDataModule(64, 64, 4, train_val_perc=0.9)  


    ds = ObjectsRoomDataset(
            split='Train',     # str:  select train, test or val
            ttv=[50000, 5000, 5000], # list: the size of [train, test, val]
        )
    print(len(ds))
    # img, mask = ds[0]
    for i, (img, mask) in enumerate(ds):
        print(f'img: {img.size()}')
        print(f'mask: {mask.size()}')
        if i > 5:
            break
    
