from typing import Optional
import functools

from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

from .data_sample import DataSample


# def wrap_data_loader(dataloader: DataLoader) -> DataLoader[DataSample]:
#     @functools.wraps(dataloader.__iter__)
#     def wrapped_iter(self):
#         for batch in iter(self):
#             print("returning batch")
#             yield DataSample(x=batch[0], y=batch[1])
#     dataloader.__iter__ = wrapped_iter


# def wrap_data_module(
#     datamodule: pl.LightningDataModule
# ) -> pl.LightningDataModule:
#     print("wrapping data module")

#     @functools.wraps(datamodule.train_dataloader)
#     def wrapped_train_loader(self):
#         return wrap_data_loader(self.train_dataloader())

#     @functools.wraps(datamodule.val_dataloader)
#     def wrapped_val_loader(self):
#         return wrap_data_loader(self.val_dataloader())

#     @functools.wraps(datamodule.test_dataloader)
#     def wrapped_test_loader(self):
#         return wrap_data_loader(self.test_dataloader())

#     datamodule.train_dataloader = wrapped_train_loader
#     datamodule.val_dataloader = wrapped_val_loader
#     datamodule.test_dataloader = wrapped_test_loader
#     return datamodule


class DatasetWrapper(Dataset):

    def __init__(self, wrapped_data: Dataset) -> None:
        self.wrapped_data = wrapped_data

    def __len__(self) -> int:
        return len(self.wrapped_data)

    def __getitem__(self, idx: int) -> DataSample:
        batch = self.wrapped_data[idx]
        return DataSample(
            input=batch[0],
            target=batch[1],
        )
