"""Load BSDS."""

import torch
from PIL import Image
import numpy as np

import os
import math

from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive

from .operators import Operator


class BSDS(VisionDataset):
    """A dataset classicaly abused for image denoising.

    See https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

    url = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300-images.tgz"
    filename = "BSDS300-images.tgz"
    tgz_md5 = '5c3e983f27eb33a1d62ff55ceac61e93'

    def __init__(self, root, split='train', transform=None, target_transform=None, download=False, patch_size=17, grayscale=False,
                 noise_level=0.25, clip_to_realistic=True, sigma=5, downsampling=4):

        super().__init__(root, transform=transform, target_transform=target_transform)
        self.split = split  # training set or test set
        self.patch_size = patch_size
        self.grayscale = grayscale
        self.noise_level, self.clip_to_realistic, self.sigma = noise_level, clip_to_realistic, sigma
        self.downsampling = int(downsampling)  # Needs to be an integer

        if self.patch_size % self.downsampling != 0:
            raise ValueError('For technical reasons, assure that patch_size is divible by downsampling rate.')

        if download:
            self.download()

        self.path = os.path.join(self.root, 'BSDS300')

        if not self._check_existence():
            raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it')

        # Get files and dataset folders
        if split == 'train':
            dict_file = os.path.join(self.path, 'iids_train.txt')
            with open(dict_file, 'r') as file:
                self.files = file.read().splitlines()
            self.folder = 'train'
            self.format = '.jpg'
        elif split == 'test':
            dict_file = os.path.join(self.path, 'iids_test.txt')
            with open(dict_file, 'r') as file:
                self.files = file.read().splitlines()
            self.folder = 'test'
            self.format = '.jpg'
        elif split == 'test68':
            dict_file = os.path.join('inverse_problems', 'foe_test.txt')
            with open(dict_file, 'r') as file:
                self.files = file.read().splitlines()
            self.folder = 'test'
            self.format = ''
        else:
            raise ValueError()

        self.operator = self._create_operator()
        self.measurements, self.data = self._create_patches_in_memory()

    def _create_patches_in_memory(self):
        data = []
        for file in self.files:
            img_path = os.path.join(self.path, 'images', self.folder, file + self.format)
            img = Image.open(img_path)
            if self.grayscale:
                img = img.convert('L')

            # augmentations
            if self.transform is not None:
                img = self.transform(img)

            patches = img.unfold(1, self.patch_size, self.patch_size).unfold(2, self.patch_size, self.patch_size)
            data.append(patches.reshape(img.shape[0], -1, self.patch_size, self.patch_size).permute(1, 0, 2, 3))

        data = torch.cat(data)
        if self.clip_to_realistic:
            data.clamp_(0, 1)
        measurements = []
        for slice in data:
            measurements.append(self.operator(slice[None, ...])[0].clone())  # clone in case of op == Id
        return measurements, data

    def _create_operator(self):
        channels = 1 if self.grayscale else 3

        if self.sigma > 0 or self.downsampling > 1:
            kernel, kernel_size = self._gaussian_kernel(channels)
            operator = torch.nn.Conv2d(channels, channels, kernel_size, stride=self.downsampling,
                                       padding=[k // 2 for k in kernel_size], groups=channels, bias=False,
                                       padding_mode='replicate')
            delattr(operator, 'weight')
            operator.register_buffer('weight', kernel)

            conv_transposed = torch.nn.Conv2d(channels, channels, kernel_size, stride=1,
                                              padding=[k // 2 for k in kernel_size], groups=channels, bias=False,
                                              padding_mode='replicate')

            delattr(conv_transposed, 'weight')
            conv_transposed.register_buffer('weight', kernel.transpose(2, 3))
            if self.downsampling > 1:
                unstrider = Unstride(shape=[channels, self.patch_size, self.patch_size], stride=self.downsampling)
                adj_operator = torch.nn.Sequential(unstrider, conv_transposed)
            else:
                adj_operator = conv_transposed
            op = Operator(operator=operator, adjoint_operator=adj_operator, dimension=2, channels=channels)

        else:
            operator, adj_operator = torch.nn.Identity(), torch.nn.Identity()
            op = Operator(operator=operator, adjoint_operator=adj_operator, dimension=2, channels=channels)
            op.norm_val = 1

        return op

    def _gaussian_kernel(self, channels):
        """https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10"""
        kernel_size = [math.ceil(self.sigma) * 2 + 1] * 2  # Choose a radius of 2*sigma
        sigma = [self.sigma] * 2
        meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])

        kernel = 1
        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
            mean = (size - 1) / 2
            kernel *= torch.exp(-((mgrid - mean) / (std + 1e-8)) ** 2 / 2)  # unnormalized kernel

        # Make sure sum of values in gaussian kernel equals 1.
        kernel = kernel / torch.sum(kernel)

        # Reshape to depthwise convolutional weight
        kernel = kernel.view(1, 1, *kernel.size())
        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
        return kernel, kernel_size

    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        measurements_patch, data_patch = self.measurements[index].clone(), self.data[index].clone()
        # Add noise
        if self.noise_level is not None:
            measurements_patch += self.noise_level * torch.randn_like(measurements_patch)
        # Clip to [0,1]
        if self.clip_to_realistic:
            measurements_patch.clamp_(0, 1)
        return measurements_patch, data_patch

    def __len__(self) -> int:
        return len(self.data)

    def _check_existence(self):
        return os.path.isdir(os.path.join(self.root, 'BSDS300'))

    def download(self) -> None:
        if self._check_existence():
            print('Files already downloaded')
            return
        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)

    def extra_repr(self) -> str:
        return f"Split: {self.split}"


class Unstride(torch.nn.Module):

    def __init__(self, shape, stride):
        super().__init__()
        self.stride = stride
        self.shape = shape  # Dont use shape for now

    def forward(self, inputs):
        new_shape = list(inputs.shape)
        new_shape[-1] *= self.stride
        new_shape[-2] *= self.stride
        y_hat = inputs.new_zeros(new_shape)
        y_hat[..., ::self.stride, ::self.stride] = inputs
        return y_hat


def augment_img(img, mode=0):
    '''Kai Zhang (github: https://github.com/cszn)
    '''
    if mode == 0:
        return img
    elif mode == 1:
        return np.flipud(np.rot90(img))
    elif mode == 2:
        return np.flipud(img)
    elif mode == 3:
        return np.rot90(img, k=3)
    elif mode == 4:
        return np.flipud(np.rot90(img, k=2))
    elif mode == 5:
        return np.rot90(img)
    elif mode == 6:
        return np.rot90(img, k=2)
    elif mode == 7:
        return np.flipud(np.rot90(img, k=3))


if __name__ == "__main__":
    import torchvision.transforms as transforms

    dataset = BSDS('~/data', split='test68', download=True, patch_size=250,
                   transform=transforms.ToTensor(), grayscale=False,
                   noise_level=0.00, sigma=4, downsampling=4)
    x, y = dataset[0]
    Ax = dataset.operator(x[None, ...])
    ATy = dataset.adj_operator(y[None, ...])

    print((Ax * y).sum() - (ATy * x).sum())
