import h5py
import torch
import os

from torch.utils.data import Dataset
from torchvision.transforms.functional import resize, InterpolationMode
from torchvision import transforms

from chip.utils.fourier import fft_2D, ifft_2D
from chip.utils import add_defects
import torch.nn.functional as F
from chip.utils.utils import create_gaussian_filter
from chip.datasets.base_dataset import BaseImageDataset, PairedTransform


class h5_wrapper():
    def __init__(self, path):
        if os.path.isdir(path):
            folder = [filename for filename in os.listdir(path) if
                      filename.endswith('.h5') or filename.endswith('.hdf5') or filename.endswith('.mat')]
        else:
            folder = [os.path.basename(path)]
            path = os.path.dirname(path)

        self.path = path
        self.folder = folder
        sizes = []
        idx_to_file_list = []
        global_index_to_local_list = []
        for i, filename in enumerate(folder):
            hr_data = h5py.File(os.path.join(path, filename), 'r')
            for dataset_name in ['images', 'tomogram_delta', 'data']:
                if dataset_name in hr_data:
                    break
            idx_to_file_list.append(i * torch.ones(len(hr_data.get(dataset_name))).int())
            global_index_to_local_list.append(torch.arange(len(hr_data.get(dataset_name))))

        self.idx_to_file = torch.cat(idx_to_file_list)
        self.global_index_to_local = torch.cat(global_index_to_local_list)

    def __getitem__(self, idx):
        file_id = self.idx_to_file[idx]
        local_index = self.global_index_to_local[idx]
        filename = self.folder[file_id]
        hr_data = h5py.File(os.path.join(self.path, filename), 'r')
        for dataset_name in ['images', 'tomogram_delta', 'data']:
            if dataset_name in hr_data:
                break
        return hr_data.get(dataset_name)[local_index]

    def __len__(self):
        return len(self.idx_to_file)
    @property
    def shape(self):
        return (len(self), *self.__getitem__([0]).shape)

class TomogramDataset(BaseImageDataset):
    def __init__(self, path, lr_forward_function, lr_path=None,
                 rescale=None, clip_range=None, normalize_range=False, rotation_angle=None, num_defects=None,
                 contrast=None, train_transform=False, crop=None, gray_background=False, to_gray=False,
                 to_synthetic=False):
        self.hr_tomogram = h5_wrapper(path)
        self.lr_tomogram = None
        super().__init__(path, lr_forward_function, lr_path,
                 rescale, clip_range, normalize_range, rotation_angle, num_defects,
                 contrast, train_transform, crop, gray_background, to_gray,
                 to_synthetic)

        if lr_path:
            self.lr_tomogram = h5_wrapper(lr_path)

        self.lr_forward_function = lr_forward_function

    def __getitem__(self, idx):
        hr_image = torch.Tensor(self.hr_tomogram[idx])
        if len(hr_image.shape) == 2:
            hr_image = hr_image.unsqueeze(0)

        if self.lr_tomogram:
            lr_image = torch.Tensor(self.lr_tomogram[idx])
            hr_image, lr_image = PairedTransform(self.transform)(hr_image, lr_image)
        else:
            hr_image = self.transform(hr_image)
            lr_image = self.lr_forward_function(hr_image)

        return lr_image, hr_image, int(idx)

    def __len__(self):
        return len(self.hr_tomogram)


if __name__ == '__main__':
    import lovely_tensors as lt
    import os
    from chip.utils.utils import create_circle_filter, create_gaussian_filter
    from chip.models.forward_models import fourier_filtering

    lt.monkey_patch()

    DATA_PATH = '/mydata/chip/shared/data' if torch.cuda.is_available() else 'data'

    frequency_cut_out_radius = 30
    circle_filter = create_circle_filter(frequency_cut_out_radius, 512)
    gaussian_filter = create_gaussian_filter(sigma=15, size=512)

    current_filter = circle_filter

    kwargs = {
        'path': os.path.join(DATA_PATH, 'p17299/tomogram_delta.mat'),
        'lr_forward_function': lambda x: fourier_filtering(x, current_filter),
        'gray_background': False,
        'train_transform': False,
        'to_gray': False,
        'rotation_angle': 30,
        'rescale': 512
    }

    trainSet = TomogramDataset(**kwargs)
    print(trainSet[0])
