from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from utils.configuration import Configuration
from torch.utils.data import Dataset
from torch.utils.data import DistributedSampler
import h5py
import numpy as np
import torch
from data.datasets.hdf5_lightning_uncertainty import HDF5_Dataset, ChainedHDF5_Dataset
from data.sampler.object_sampler import DistributedObjectSampler

class LociUncertaintyPretrainerDataModule(LightningDataModule):
    def __init__(self, cfg: Configuration):
        super().__init__()

        self.cfg = cfg
        self.trainset = ChainedHDF5_Dataset(
            [HDF5_Dataset(
                d.path, 
                cfg.model.crop_size, 
                "train" if d.split else None, 
                seed=cfg.seed, 
                max_num_mask_per_image=cfg.model.num_objects
            ) for d in cfg.data.train],
            [d.weight for d in cfg.data.train],
        )
        self.valset = ChainedHDF5_Dataset(
            [HDF5_Dataset(
                d.path, 
                cfg.model.crop_size, 
                "val" if d.split else None,
                max_num_mask_per_image=cfg.model.num_objects
            ) for d in cfg.data.val],
            [d.weight for d in cfg.data.val],
        )
        self.testset = ChainedHDF5_Dataset(
            [HDF5_Dataset(
                d.path, 
                cfg.model.crop_size, 
                "test" if d.split else None,
                max_num_mask_per_image=cfg.model.num_objects
            ) for d in cfg.data.test],
            [d.weight for d in cfg.data.test],
        )

        self.batch_size = self.cfg.model.batch_size

    def train_dataloader(self):
        sampler = DistributedObjectSampler(self.trainset, shuffle=True, seed=self.cfg.seed)

        return DataLoader(
            self.trainset,
            pin_memory=True,
            num_workers=self.cfg.num_workers,
            batch_size=self.batch_size,
            sampler=sampler,
            drop_last=True,
            prefetch_factor=self.cfg.prefetch_factor,
            persistent_workers=True,
        )

    def val_dataloader(self):
        sampler = DistributedObjectSampler(self.valset, shuffle=True)

        return DataLoader(
            self.valset, 
            pin_memory=True, 
            num_workers=self.cfg.num_workers, 
            batch_size=self.batch_size,
            sampler=sampler,
            drop_last=True, 
            prefetch_factor=self.cfg.prefetch_factor, 
            persistent_workers=True
        )

    def test_dataloader(self):
        sampler = DistributedObjectSampler(self.testset, shuffle=False)

        return DataLoader(
            self.testset, 
            pin_memory=True, 
            num_workers=self.cfg.num_workers, 
            batch_size=self.batch_size,
            sampler=sampler,
            drop_last=True, 
            prefetch_factor=self.cfg.prefetch_factor, 
            persistent_workers=True
        )
