import os
from typing import Iterator, Sequence
import torch
import numpy as np
import torchvision
from einops import rearrange
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, IterableDataset
import json
import glob
import PIL
from PIL import Image


class DummyDataset(IterableDataset):
    def __init__(self, shape):
        self.data = torch.randn(shape).clip(-1, 1)  # Simulate image data

    def __iter__(self) -> Iterator:
        while True:
            yield {"x": self.data}


class DummyDataModule:
    def __init__(
        self,
        batch_size,
        size,
        num_workers=4,
        channel_last=False,
        val_batch_size=None,
        val_num_workers=None,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.channel_last = channel_last
        if not isinstance(size, Sequence):
            size = (size, size, 3) if self.channel_last else (3, size, size)
        else:
            size = tuple(size)
        self.train_dataset = DummyDataset(shape=(self.batch_size, *size))
        self.val_batch_size = val_batch_size if val_batch_size is not None else batch_size
        self.val_dataset = DummyDataset(shape=(self.val_batch_size, *size))
        self.val_num_workers = val_num_workers if val_num_workers is not None else num_workers

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=None,
            num_workers=self.num_workers,
            pin_memory=False,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=None,
            num_workers=self.val_num_workers,
            pin_memory=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=None,
            num_workers=self.val_num_workers,
            pin_memory=False,
        )
