import os

import numpy as np
from PIL import Image
from torch.utils.data import Dataset


class DiffusionMNIST(Dataset):
    def __init__(
        self, image_directory_paths, transform=None, image_directory_cutoff_index=None
    ):
        self.image_directory_paths = image_directory_paths
        # infinity if None, else cutoff index
        self.image_directory_cutoff_index = (
            np.inf
            if image_directory_cutoff_index is None
            else image_directory_cutoff_index
        )
        self.transform = transform
        # initialize
        self.image_directory_index = -1
        self.data = np.array([])
        # load first dataset
        self.load_next_dataset()

    def __len__(self):
        return self.data.shape[0]

    def load_next_dataset(self):
        # get path for next dataset
        self.image_directory_index += 1
        if (
            (self.image_directory_index >= len(self.image_directory_paths))
            or (self.image_directory_index >= self.image_directory_cutoff_index)
        ):
            print(
                f"{self.__class__.__name__}: "
                f"no more image directories to load or cutoff index reached, "
                f"looping back to first directory"
            )
            self.image_directory_index = 0
        next_image_directory_path = self.image_directory_paths[
            self.image_directory_index
        ]
        # load next dataset
        path = os.path.join(next_image_directory_path, "npy")
        next_image_sub_directories = os.listdir(path)
        data_files = list()
        for sub_directory in next_image_sub_directories:
            # folder structure: ".../<from_seed>-<to_seed>/<from_seed>-<to_seed>.npy"
            if "-" not in sub_directory:
                continue
            data_file_path = os.path.join(path, sub_directory, f"{sub_directory}.npy")
            if os.path.isfile(data_file_path):
                data_files.append(np.load(data_file_path))
            else:
                raise RuntimeError(f"no data found in path '{data_file_path}'")
        self.data = np.concatenate(data_files)

    def __getitem__(self, idx):
        img = self.data[idx, :, :, 0]
        img = Image.fromarray(img, mode="L")

        if self.transform is not None:
            img = self.transform(img)

        return img


class DiffusionCIFAR(Dataset):
    def __init__(
        self,
        image_directory_paths,
        transform=None,
        size_cap=None,
        image_directory_cutoff_index=None,
    ):
        self.image_directory_paths = image_directory_paths
        # infinity if None, else cutoff index
        self.image_directory_cutoff_index = (
            np.inf
            if image_directory_cutoff_index is None
            else image_directory_cutoff_index
        )
        self.transform = transform
        self.size_cap = size_cap
        # initialize
        self.image_directory_index = -1
        self.data = np.array([])
        # load first dataset
        self.load_next_dataset()

    def __len__(self):
        return self.data.shape[0]

    def load_next_dataset(self):
        # get path for next dataset
        self.image_directory_index += 1
        if (
            (self.image_directory_index >= len(self.image_directory_paths))
            or (self.image_directory_index >= self.image_directory_cutoff_index)
        ):
            print(
                f"{self.__class__.__name__}: "
                f"no more image directories to load or cutoff index reached, "
                f"looping back to first directory"
            )
            self.image_directory_index = 0
        next_image_directory_path = self.image_directory_paths[
            self.image_directory_index
        ]
        # load next dataset
        path = os.path.join(next_image_directory_path, "npy")
        next_image_sub_directories = os.listdir(path)
        data_files = list()
        for sub_directory in next_image_sub_directories:
            # folder structure: ".../<from_seed>-<to_seed>/<from_seed>-<to_seed>.npy"
            if "-" not in sub_directory:
                continue
            data_file_path = os.path.join(path, sub_directory, f"{sub_directory}.npy")
            if os.path.isfile(data_file_path):
                data_files.append(np.load(data_file_path))
            else:
                raise RuntimeError(f"no data found in path '{data_file_path}'")
        self.data = np.concatenate(data_files)
        # cap dataset size
        if self.size_cap is not None and self.data.shape[0] > self.size_cap:
            self.data = self.data[: self.size_cap]

    def __getitem__(self, idx):
        img = self.data[idx]
        img = Image.fromarray(img, mode="RGB")

        if self.transform is not None:
            img = self.transform(img)

        return img
