import h5py
import torch

from torch.utils.data import Dataset
from torchvision.transforms.functional import resize, InterpolationMode
from torchvision import transforms
from chip.utils.fourier import fft_2D, ifft_2D
import torch.nn.functional as F
import numpy as np

class PairedTransform:
    """Applies the same transform to a pair of images."""

    def __init__(self, transform):
        self.transform = transform

    def __call__(self, img1, img2, seed=None):
        if seed is None:
            seed = torch.randint(0, 2 ** 32, ())
        torch.manual_seed(seed)
        img1 = self.transform(img1)
        torch.manual_seed(seed)
        img2 = self.transform(img2)
        return img1, img2


def contrast_transform(target, scale=10):
    cont_target = torch.sigmoid((target - 0.5) * scale)
    cont_target -= torch.min(cont_target)
    cont_target /= torch.max(cont_target)
    return cont_target


def add_gray_background(img, mask=None):
    if mask is None:
        w = img.shape[1]
        cp = torch.meshgrid(torch.arange(w, device=img.device), torch.arange(w, device=img.device))
        mask = (cp[0] - w / 2) ** 2 + (cp[1] - w / 2) ** 2 <= (w / 2) ** 2
    return img + 0.5 * mask * (1. - img)


def transform_to_gray(image):
    alpha = 0.25
    gaussian_filter = create_gaussian_filter(size=512, sigma=50)
    gaussian_filter /= torch.max(gaussian_filter)
    gray = alpha + (1 - alpha) * (image + 0.1)
    gray /= torch.max(gray)
    fft_gray = fft_2D(gray, ortho=True)
    fft_gray += 0.1 * torch.randn_like(fft_gray)
    gray = ifft_2D(fft_gray * gaussian_filter, ortho=True).real
    return gray


def to_synthetic(target):
    size = target.shape[-2:]
    target = F.interpolate(target.unsqueeze(0).unsqueeze(0), size=size, mode='bilinear', align_corners=True)[0, 0]
    target /= torch.max(target)
    cont_target = torch.sigmoid((target - 0.5) * 20)
    cont_target -= torch.min(cont_target)
    cont_target /= torch.max(cont_target)
    return cont_target


# class ProjectionDataset(Dataset):
#     def __init__(self, file):
#         super().__init__()
#         data = h5py.File(file, 'r')
#         self.projections = data.get('sinogram')
#         self.theta = torch.Tensor(np.array(data.get('theta'))).flatten()
#         self._sorted_order = torch.argsort(self.theta)
#
#     def __getitem__(self, idx):
#         idx = self._sorted_order[idx]
#         hr_projection = torch.Tensor(self.projections[idx])
#         return hr_projection, idx
#
#     def __len__(self):
#         return self.projections.shape[0]



class ProjectionDataset(Dataset):

    def __init__(self, path, complex_projections:bool=False, lr_forward_function=None, lr_path=None,
                 rescale=None, clip_range=False, normalize_range=False, rotation_angle=None,
                 contrast=None, train_transform=False, crop=None):
        super().__init__()

        # with h5py.File(hr_data, 'r') as hr_data:
        hr_data = h5py.File(path, 'r')

        self.complex_projections = complex_projections

        if complex_projections:
            self.hr_real = hr_data.get('complex_projections_r')
            self.hr_imag = hr_data.get('complex_projections_i')
            self.theta = torch.Tensor(np.array(hr_data.get('rotation_angle'))).flatten()
            self.laminography_angle = torch.tensor(hr_data['laminography_angle'][0][0])
        else:
            self.hr_real = hr_data.get('sinogram')
            self.theta = torch.Tensor(np.array(hr_data.get('theta'))).flatten()

        self.theta, self._sorted_order = torch.sort(self.theta)
        self.theta -= self.theta.min()
        self.lr_tomogram = None

        if lr_path:
            # with h5py.File(lr_data, 'r') as lr_data:
            lr_data = h5py.File(lr_path, 'r')
            if complex_projections:
                self.lr_real = lr_data.get('complex_projections_r')
                self.lr_imag = lr_data.get('complex_projections_i')
            else:
                self.lr_real = lr_data.get('sinogram')

        self.lr_forward_function = lr_forward_function

        self.transforms = []
        if rotation_angle: self.add_rotation(angle=rotation_angle)
        if crop is not None: self.add_crop(*crop)
        if rescale: self.add_scale(width=rescale)
        if normalize_range: self.add_normalize_range()
        if clip_range: self.add_clip_range()
        if contrast: self.add_contrast(scale=contrast)
        if train_transform: self.add_train_transform()

    def __getitem__(self, idx):
        idx = self._sorted_order[idx]
        hr_image = torch.Tensor(self.hr_real[idx])
        if self.complex_projections:
            imag = torch.Tensor(self.hr_imag[idx])
            # we use the phase as the image to be reconstruted
            hr_image = torch.angle(torch.complex(hr_image, imag))


        if len(hr_image.shape) == 2:
            hr_image = hr_image.unsqueeze(0)

        if self.lr_tomogram:
            lr_image = torch.Tensor(self.lr_tomogram[idx])
            hr_image, lr_image = PairedTransform(self.transform)(hr_image, lr_image)
        else:
            hr_image = self.transform(hr_image)
            lr_image = self.lr_forward_function(hr_image)

        return lr_image, hr_image, int(idx)

    @property
    def transform(self):
        return transforms.Compose(self.transforms)

    def add_rotation(self, angle=30, interpolation='bilinear'):
        interpolation_mode = {
            'bilinear': InterpolationMode.BILINEAR,
            'nearest': InterpolationMode.NEAREST
        }[interpolation]

        self.transforms.append(
            transforms.RandomAffine((angle, angle), (0, 0), (1., 1.), interpolation=InterpolationMode.BILINEAR))

    def add_gray_background(self):
        self.transforms.append(add_gray_background)

    def add_contrast(self, scale=10):
        def contrast(image):
            return contrast_transform(image, scale=scale)

        self.transforms.append(contrast)

    def add_scale(self, width=512):
        def scale(image):
            height = round(image.shape[-1] * width / image.shape[-2])
            return resize(image.unsqueeze(0), size=(width, height), antialias=True).squeeze()

        self.transforms.append(scale)

    def add_normalize_range(self):
        def normalize_range(image):
            image /= torch.max(image)
            return image

        self.transforms.append(normalize_range)

    def add_clip_range(self):
        def clip_range(image):
            return torch.clip(image, 0, 1)

        self.transforms.append(clip_range)

    def add_crop(self, xoffset, yoffset, width):
        def crop(image):
            return image.squeeze()[xoffset:xoffset + width, yoffset:yoffset + width]

        self.transforms.append(crop)

    def add_train_transform(self):
        self.transforms.append(
            transforms.RandomAffine((-180, 180), (0, 0), (0.6, 1.), interpolation=InterpolationMode.BILINEAR))

    def __len__(self):
        return len(self.hr_real)


if __name__ == '__main__':
    import torch
    from chip.datasets.projection_dataset import ProjectionDataset
    from chip.utils.utils import create_circle_filter, create_gaussian_filter
    from chip.models.forward_models import fourier_filtering

    import lovely_tensors as lt

    lt.monkey_patch()

    DATA_PATH = '/mydata/chip/shared/data' if torch.cuda.is_available() else 'data'

    frequency_cut_out_radius = 30
    circle_filter = create_circle_filter(frequency_cut_out_radius, 512)
    gaussian_filter = create_gaussian_filter(sigma=15, size=512)
    current_filter = gaussian_filter

    kwargs = {
        'path': 'data/PyXL_projections_chip_aligned.h5',
        'complex_projections': True,
        'lr_forward_function': lambda x: fourier_filtering(x, current_filter),
        'train_transform': False,
        'rescale': 512
    }

    ds = ProjectionDataset(**kwargs)
    print(ds[0])

