from types import SimpleNamespace

import polars as pl
import torch
from PIL import Image


class SimpleDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        cfg: SimpleNamespace,
        data: pl.DataFrame | tuple[pl.DataFrame, dict[str, Image.Image]],
    ):
        self.cfg = cfg
        if isinstance(data, tuple):
            self.data, self.imgs = data
        else:
            self.data = data

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, Image.Image, str]:
        if self.cfg.load_memory:
            img = self.imgs[self.data[int(idx), "dicom_id"]]
        else:
            img_path = self.data[int(idx), "img_path"]
            img = Image.open(img_path)
        img = img.resize(self.cfg.img_size)
        dicom_id = self.data[int(idx), "dicom_id"]
        return self.data[int(idx), "target"], img, dicom_id
