import h5py
import torch

from torch.utils.data import Dataset
from torchvision.transforms.functional import resize, InterpolationMode
from torchvision import transforms

from chip.utils.fourier import fft_2D, ifft_2D
from chip.utils import add_defects
import torch.nn.functional as F
from PIL import Image
import math
import os
import numpy as np
from chip.utils.utils import create_gaussian_filter
from chip.datasets.base_dataset import BaseImageDataset

Image.MAX_IMAGE_PIXELS = None

class tiff_wrapper():
    def __init__(self, path, im_size=512):
        if os.path.isdir(path):
            folder = [filename for filename in os.listdir(path) if
                      filename.endswith('.tiff') or filename.endswith('.tif')]
        else:
            folder = [os.path.basename(path)]
            path = os.path.dirname(path)

        self.path = path
        self.folder = folder
        self.im_size = im_size
        sizes = []
        idx_to_file_list = []
        global_index_to_local_list = []
        coordinates_list = []

        for i, filename in enumerate(folder):
            image = Image.open(os.path.join(path, filename))
            w, h = image.size
            w_ = (w - im_size) / (math.ceil(w / im_size) - 1)
            h_ = (h - im_size) / (math.ceil(h / im_size) - 1)
            num_images = math.ceil(w / im_size) * math.ceil(h / im_size)
            row = torch.arange(num_images) // math.ceil(w / im_size)
            col = torch.arange(num_images) % math.ceil(w / im_size)
            x = col * w_
            y = row * h_

            coordinates_list.append(torch.stack([x, y], -1))
            idx_to_file_list.append(i * torch.ones(num_images).int())
            global_index_to_local_list.append(torch.arange(num_images))

        self.idx_to_file = torch.cat(idx_to_file_list)
        self.global_index_to_local = torch.cat(global_index_to_local_list)
        self.coordinates = torch.cat(coordinates_list)

    def __getitem__(self, idx):
        file_id = self.idx_to_file[idx]
        local_index = self.global_index_to_local[idx]
        coors = self.coordinates[idx].int().numpy()

        filename = self.folder[file_id]
        image = Image.open(os.path.join(self.path, filename))
        cropped_image = image.crop((coors[0], coors[1], coors[0] + self.im_size, coors[1] + self.im_size))
        return cropped_image

    def __len__(self):
        return len(self.idx_to_file)

class TIFFDataset(BaseImageDataset):

    def __init__(self, path, im_size, lr_forward_function,
                 rescale=None, clip_range=None, normalize_range=False, rotation_angle=None, num_defects=None,
                 contrast=None, train_transform=False, crop=None, gray_background=False, to_gray=False,
                 to_synthetic=False):
        super().__init__(path, lr_forward_function=lr_forward_function,
                 rescale=rescale, clip_range=clip_range, normalize_range=normalize_range, rotation_angle=rotation_angle, num_defects=num_defects,
                 contrast=contrast, train_transform=train_transform, crop=crop, gray_background=gray_background, to_gray=to_gray,
                 to_synthetic=to_synthetic)
        self.im_size = im_size
        self.images = tiff_wrapper(path, im_size)

        self.lr_forward_function = lr_forward_function

    def __getitem__(self, idx):
        hr_image = torch.tensor(np.array(self.images[idx])).float()

        if len(hr_image.shape) == 2:
            hr_image = hr_image.unsqueeze(0)

        hr_image = self.transform(hr_image)
        lr_image = self.lr_forward_function(hr_image)

        return lr_image, hr_image, int(idx)
    def __len__(self):
        return len(self.images)


if __name__ == '__main__':
    import lovely_tensors as lt
    import os
    from chip.utils.utils import create_circle_filter, create_gaussian_filter
    from chip.models.forward_models import fourier_filtering

    lt.monkey_patch()

    DATA_PATH = '/mydata/chip/shared/data' if torch.cuda.is_available() else 'data'

    frequency_cut_out_radius = 30
    circle_filter = create_circle_filter(frequency_cut_out_radius, 512)
    gaussian_filter = create_gaussian_filter(sigma=15, size=512)

    current_filter = circle_filter

    kwargs = {
        'path': 'data/DATASET_G7_170um_10nm_rect',
        'im_size':512,
        'lr_forward_function': lambda x: fourier_filtering(x, current_filter),
        'gray_background': False,
        'train_transform': True,
        'to_gray': False,
        'rotation_angle': 30,
        'rescale': 512,
        'crop': (200, 200, 128),
    }

    trainSet = TIFFDataset(**kwargs)
    print(trainSet[0])
