import torch
import os
import numpy as np

from chip.utils import add_defects
from chip.utils.sinogram import Sinogram
from torch.utils.data import Dataset
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur, resize, rotate, InterpolationMode
import h5py
import torch
from torchvision import transforms




class SuperresolutionDS(torch.utils.data.Dataset):
    def __init__(self, files, data_path="", kernel_size=21, sigma=50, num_defects=None, rotate_angle=None,
                 generator=None, interpolation_mode='bilinear'):
        super(SuperresolutionDS, self).__init__()
        self.files = files
        self.data_path = data_path
        self.generator = generator
        self.low_res_filter = lambda image: torch_gaussian_blur(image.unsqueeze(0), kernel_size=kernel_size,
                                                                sigma=sigma).squeeze(0)
        self.interpolation_mode = {'bilinear': InterpolationMode.BILINEAR,
                                   'nearest': InterpolationMode.NEAREST}[interpolation_mode]
        self.num_defects = num_defects

        self.get_rotate_angle = None
        if rotate_angle == 'random':
            self.get_rotate_angle = lambda: torch.randint(0, 180, generator=self.generator).item()
        elif isinstance(rotate_angle, (float, int)):
            self.get_rotate_angle = lambda: rotate_angle

    def __getitem__(self, idx):
        img = self.files[idx]
        img_path = os.path.join(f'{self.data_path}/high_res_90', os.path.splitext(img)[0] + ".npy")
        image = np.load(img_path)
        tensor_img = torch.FloatTensor(image)

        with torch.no_grad():
            rotation_angle = None
            if self.get_rotate_angle is not None:
                rotation_angle = self.get_rotate_angle()
                tensor_img = rotate(
                    tensor_img.unsqueeze(0),
                    angle=-rotation_angle,
                    interpolation=self.interpolation_mode
                ).squeeze()

            if self.num_defects:
                tensor_img = add_defects(target=tensor_img, num_defects=self.num_defects)

            low_res_image = self.low_res_filter(tensor_img)

        return low_res_image, tensor_img, dict(file=self.files[idx], rotation_angle=rotation_angle)

    def __len__(self):
        return len(self.files)


class TomogramsMAT(torch.utils.data.Dataset):
    def __init__(self, mat_file, kernel_size=21, sigma=50, num_defects=None, crop_xoffset=None, crop_yoffset=None,
                 crop_width=None, normalize_range=False, rescale_size=None, rotate_angle=None, with_transform=True):
        super(TomogramsMAT, self).__init__()
        self.files = mat_file.get('tomogram_delta')
        self.low_res_filter = lambda image: torch_gaussian_blur(image.unsqueeze(0), kernel_size=kernel_size,
                                                                sigma=sigma).squeeze(0)
        self.num_defects = num_defects
        self.crop_range = None
        self.normalize_range = normalize_range

        if None not in (crop_xoffset, crop_yoffset, crop_width):
            self.crop_range = (crop_xoffset, crop_xoffset + crop_width, crop_yoffset, crop_yoffset + crop_width)

        if isinstance(rescale_size, int):
            self.rescale_size = (rescale_size, rescale_size)
        self.rescale_size = rescale_size
        self.rotate_angle = rotate_angle

        self.with_transform = with_transform
        self.transform = PairedTransform(
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToPILImage(),
                transforms.RandomAffine(degrees=90, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=0,
                                        interpolation=InterpolationMode.BILINEAR),
                transforms.ColorJitter(brightness=0.2, contrast=0.2),
                transforms.RandomResizedCrop(size=(1024, 1024), scale=(0.8, 1.0), ratio=(0.75, 1.33)),
                transforms.ToTensor()
            ])
        )

    def __getitem__(self, idx):
        if self.crop_range is None:
            tensor_img = torch.tensor(self.files[idx])
        else:
            tensor_img = torch.tensor(
                self.files[idx, self.crop_range[0]:self.crop_range[1], self.crop_range[2]:self.crop_range[3]])

        with torch.no_grad():
            if self.rescale_size is not None:
                tensor_img = resize(tensor_img.unsqueeze(0), size=self.rescale_size, antialias=True).squeeze()

            if self.rotate_angle is not None:
                tensor_img = rotate(
                    tensor_img.unsqueeze(0),
                    angle=self.rotate_angle,
                    interpolation=InterpolationMode.BILINEAR
                ).squeeze()

            if self.normalize_range:
                tensor_img -= torch.min(tensor_img)
                tensor_img /= torch.max(tensor_img)

            if self.num_defects:
                tensor_img = add_defects(target=tensor_img, num_defects=self.num_defects)
            low_res_image = self.low_res_filter(tensor_img)

        if self.with_transform:
            low_res_image, tensor_img = self.transform(low_res_image, tensor_img)
        else:
            low_res_image, tensor_img = low_res_image.unsqueeze(0), tensor_img.unsqueeze(0)

        return low_res_image, tensor_img, str(idx)

    def __len__(self):
        return len(self.files)


if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()
    import h5py
    import hdf5plugin

    h5filepath = "../../data/p17299/tomogram_delta.mat"

    with h5py.File(h5filepath, "r") as mat_file:
        ds = TomogramsMAT(mat_file)
        print(ds[0])
