import os
from PIL import Image

import numpy as np
import jax
import jax.numpy as jnp
import torch
from torch.utils.data import Dataset

import torchvision.transforms as transforms


class ImageDataset(Dataset):
    r"""
    A generic image dataloader.
    """
    def __init__(self, root_dir_target, transform=transforms.ToTensor(), grayscale=False):
        self.root_dir_target = root_dir_target
        self.transform = [transform] if not isinstance(transform, list) else transform
        if grayscale:
            self.transform.append(transforms.Grayscale())
        self.transform = transforms.Compose(self.transform)

        self.image_names = [name.split('.')[0] for name in os.listdir(root_dir_target) if name.endswith('.png')]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]

        target_image_path = os.path.join(self.root_dir_target, f"{image_name}.png")

        target_image = Image.open(target_image_path).convert('RGB')

        if self.transform:
            seed = torch.randint(0, 2 ** 32, size=(1,)).item()
            torch.manual_seed(seed)
            target_image = self.transform(target_image)

        # return target_image[..., :32, :32]  # For debugging
        return target_image


class CustomDataset(Dataset):
    r"""
    An image dataloader loading pairs of input / target images.
    """
    def __init__(self, root_dir_input, root_dir_target, transform=transforms.ToTensor(), grayscale=False):
        self.root_dir_input = root_dir_input
        self.root_dir_target = root_dir_target
        self.transform = [transform] if not isinstance(transform, list) else transform
        if grayscale:
            self.transform.append(transforms.Grayscale())
        self.transform = transforms.Compose(self.transform)

        self.image_names = [name.split('.')[0] for name in os.listdir(root_dir_input) if name.endswith('.jpg')]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]

        input_image_path = os.path.join(self.root_dir_input, f"{image_name}.jpg")
        target_image_path = os.path.join(self.root_dir_target, f"{image_name}.png")

        input_image = Image.open(input_image_path).convert('RGB')
        target_image = Image.open(target_image_path).convert('RGB')

        if self.transform:
            seed = torch.randint(0, 2 ** 32, size=(1,)).item()
            torch.manual_seed(seed)
            input_image = self.transform(input_image)
            torch.manual_seed(seed)
            target_image = self.transform(target_image)

        return input_image, target_image


class ToyfastMRI(Dataset):
    def __init__(self, root_dir_target='pth/to/dataset/fastMRI/',
                 transform=None, train=True):
        self.root_dir_target = root_dir_target
        self.transform = transform
        if train:
            self.image_names = [name.split('.')[0] for name in os.listdir(root_dir_target) if 'slice_train' in name]
        else:
            self.image_names = [name.split('.')[0] for name in os.listdir(root_dir_target) if 'slice_test' in name]

    def create_vertical_lines_mask(self, image_shape=(320, 320), acceleration_factor=4, prng=None):
        if acceleration_factor == 4:
            central_lines_percent = 0.08
            num_lines_center = int(central_lines_percent * image_shape[-1])
            side_lines_percent = 0.25 - central_lines_percent
            num_lines_side = int(side_lines_percent * image_shape[-1])
        if acceleration_factor == 8:
            central_lines_percent = 0.04
            num_lines_center = int(central_lines_percent * image_shape[-1])
            side_lines_percent = 0.125 - central_lines_percent
            num_lines_side = int(side_lines_percent * image_shape[-1])
        mask = jnp.zeros(image_shape, dtype=jnp.float32)
        center_line_indices = jnp.linspace(image_shape[0] // 2 - num_lines_center // 2,
                                           image_shape[0] // 2 + num_lines_center // 2 + 1, dtype=jnp.int32)
        mask = mask.at[:, center_line_indices].set(1.0)
        random_line_indices = jax.random.choice(prng, image_shape[0], shape=(num_lines_side // 2,), replace=False)
        mask = mask.at[:, random_line_indices].set(1.0)
        return mask

    def forward_op(self, data, mask, norm="ortho"):
        data = jnp.fft.fftn(
            data, axes=(-2, -1), norm=norm
        )
        data = jnp.fft.fftshift(data, axes=[-2, -1])
        out = data * mask
        out = jnp.fft.ifftshift(out, axes=[-2, -1])
        return out

    def backward_op(self, data, mask, norm="ortho"):
        data = jnp.fft.ifftshift(data, axes=[-2, -1])
        data = mask * data
        data = jnp.fft.fftshift(data, axes=[-2, -1])
        out = np.fft.ifftn(
            data, axes=(-2, -1), norm=norm
        )
        return jnp.real(out)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]

        target_image_path = os.path.join(self.root_dir_target, f"{image_name}.npy")

        im_slice = np.load(target_image_path)
        jnp_im_slice = jnp.array(im_slice[np.newaxis, np.newaxis, ...])

        prng = jax.random.PRNGKey(0)  # Beware: we cannot have several masks in out dataset, since architecture depends on mask
        mask = self.create_vertical_lines_mask(prng=prng)

        measurements = self.forward_op(jnp_im_slice, mask)
        backproj = self.backward_op(measurements, mask)

        return np.array(jnp_im_slice[0]), np.array(measurements[0]), np.array(mask), np.array(backproj[0])
        # return np.array(jnp_im_slice), np.array(measurements), np.array(mask), np.array(backproj)



def get_train_test_datasets(grayscale=False, fastMRI=False, set3C=False):

    if not fastMRI and not set3C:
        root_dir_input_train = 'pth/to/BSDS500/0/'
        root_dir_target_train = 'pth/to/BSDS500/TV/'

        root_dir_input_test = 'pth/to/BSDS500/1/'
        root_dir_target_test = 'pth/to/BSDS500/TV_val/'

        # Define transformations
        transform_list = [
            transforms.RandomAffine(180),
            transforms.RandomResizedCrop(size=(32, 32), scale=(0.05, 0.15)),  # These numbers are chosen to have a meaningful image; the BSD dataset has shape 321 and 481, so we shall expect 0.1/0.05 factors to be ok.
            transforms.ColorJitter(),
            transforms.ToTensor(),  # Convert images to tensors
        ]

        # Create the dataset
        dataset_train = CustomDataset(root_dir_input_train, root_dir_target_train, transform=transform_list,
                                      grayscale=grayscale)
        dataset_test = CustomDataset(root_dir_input_test, root_dir_target_test, transform=transform_list,
                                     grayscale=grayscale)
    elif set3C:
        root_dir_input_test = '/pth/to/Set3C/'
        root_dir_target_test = '/pth/to/Set3C/'

        # Define transformations
        transform_list = [
            transforms.ToTensor(),  # Convert images to tensors
        ]

        # Create the dataset
        dataset_train = None
        dataset_test = CustomDataset(root_dir_input_test, root_dir_target_test, transform=transform_list,
                                     grayscale=grayscale)
    else:
        dataset_train = ToyfastMRI(train=True)
        dataset_test = ToyfastMRI(train=False)

    return dataset_train, dataset_test
