"""
Provides the MNIST and KMNIST image datasets.
"""

import os
import torch
from torch import Tensor
from torchvision import datasets, transforms
from bayes_dip.utils import get_original_cwd


from chip.datasets.tomogram_dataset import TomogramDataset
from chip.models.forward_models import fourier_filtering
from chip.utils.utils import create_circle_filter, create_gaussian_filter

class ChipImageDataset(torch.utils.data.Dataset):
    """
    Torch dataset wrapper for the (K)MNIST images.
    """
    def __init__(self, path, im_size, rotation_angle=30):
        """
        Parameters
        ----------
        dataset_type : callable
            Either ``torchvision.datasets.MNIST`` or ``torchvision.datasets.KMNIST``.
        path : str
            Root path for storing the dataset, either absolute or relative to the original current
            working directory.
        train : bool
            Whether to use the training images (or the test images otherwise).
        """

        # size = 128
        # size = 64  # smaller dataset for faster training

        low_res_filter = create_circle_filter(radius=30, size=im_size)
        low_res_filter = create_gaussian_filter(sigma=15, size=im_size)

        # this is our forward model
        # the quality is determined by the filter (radius, sigma) 
        lr_forward_function = lambda x : fourier_filtering(x, low_res_filter)

        # load a dataset
        # self.dataset.transforms.append(lambda x : torch.flip(x, dims=[2]))
        self.dataset = TomogramDataset(path=path, lr_forward_function=lr_forward_function, rotation_angle=rotation_angle, gray_background=False)

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Tensor:
        return self.dataset[idx][1]  # only the hr image, not the label


# def get_mnist_testset(path='.'):
#     """
#     Return the MNIST image test dataset.

#     Parameters
#     ----------
#     path : str, optional
#         Root path for storing the dataset, either absolute or relative to the original current
#         working directory. The default is ``'.'``, i.e. the original current working directory.
#     """
#     return MNISTImageDataset(datasets.MNIST, path, train=False)


# def get_mnist_trainset(path='.'):
#     """
#     Return the MNIST image training dataset.

#     Parameters
#     ----------
#     path : str, optional
#         Root path for storing the dataset, either absolute or relative to the original current
#         working directory. The default is ``'.'``, i.e. the original current working directory.
#     """
#     return MNISTImageDataset(datasets.MNIST, path, train=True)


# def get_kmnist_testset(path='.'):
#     """
#     Return the KMNIST image test dataset.

#     Parameters
#     ----------
#     path : str, optional
#         Root path for storing the dataset, either absolute or relative to the original current
#         working directory. The default is ``'.'``, i.e. the original current working directory.
#     """
#     return MNISTImageDataset(datasets.KMNIST, path, train=False)


# def get_kmnist_trainset(path='.'):
#     """
#     Return the KMNIST image training dataset.

#     Parameters
#     ----------
#     path : str, optional
#         Root path for storing the dataset, either absolute or relative to the original current
#         working directory. The default is ``'.'``, i.e. the original current working directory.
#     """
#     return MNISTImageDataset(datasets.KMNIST, path, train=True)
