# -*- coding: utf-8 -*-
# Copyright (C) 2020 Machine Learning Group of the University of Oldenburg.
# Licensed under the Academic Free License version 3.0

import numpy as np
import torch as to
from numpy import ndarray
from torch import Tensor
from scipy.stats import multivariate_normal
from skimage.util.shape import view_as_windows
from typing import List, Callable, Tuple, Any


def mean_merger(values: Tensor, position_indices: Tensor = None) -> Tensor:
    """Merge data estimates using mean.

    :param values: estimates from overlapping patches for given pixel in reconstructed image
    :param position_indices: estimate based on top-left patch has index 1
    """
    return to.mean(values)


def max_merger(values: Tensor, position_indices: Tensor = None) -> Tensor:
    """Merge data estimates by taking the maximum value.

    :param values: see `mean_merger` docs
    :param position_indices: see `mean_merger` docs
    """
    return to.max(values)


def min_merger(values: Tensor, position_indices: Tensor = None) -> Tensor:
    """Merge data estimates by taking the minimum value.

    :param values: see `mean_merger` docs
    :param position_indices: see `mean_merger` docs
    """
    return to.min(values)


def median_merger(values: Tensor, position_indices: Tensor = None) -> Tensor:
    """Merge data estimates by taking the median.

    :param values: see `mean_merger` docs
    :param position_indices: see `mean_merger` docs
    """
    return to.median(values)


def variance_merger(values: Tensor, position_indices: Tensor = None) -> Tensor:
    """Merge data estimates by taking the variance.

    :param values: see `mean_merger` docs
    :param position_indices: see `mean_merger` docs
    """
    return to.var(values)


def gaussian2d(
    no_bins_dim1: int,
    no_bins_dim2: int,
    lim_dim1: List[float] = [-1.0, 1.0],
    lim_dim2: List[float] = [-1.0, 1.0],
    mu: float = 0.0,
    sigma: float = 1.0,
    precision: to.dtype = to.float64,
    device: to.device = to.device("cpu"),
) -> Tuple[ndarray, ndarray, ndarray]:
    """Returns the pdf of a two-dimensional multivariate Gaussian distribution.

    :param no_bins_dim1: grid size in first dimension
    :param no_bins_dim2: grid size in second dimension
    :param lim_dim1: domain limits of pdf in first direction
    :param lim_dim2: domain limits of pdf in second direction
    :param mu: mean of pdf
    :param sigma: variance of pdf
    :param device: torch.device of output Tensor, defaults to to.device('cpu').
    """
    step_dim1 = np.diff(lim_dim1) / no_bins_dim1
    step_dim2 = np.diff(lim_dim2) / no_bins_dim2

    grd = np.empty((no_bins_dim1, no_bins_dim2, 2))
    grd[:, :, 0], grd[:, :, 1] = np.mgrid[
        lim_dim1[0] : lim_dim1[1] : step_dim1, lim_dim2[0] : lim_dim2[1] : step_dim2  # type: ignore
    ]  # type: ignore
    pdf = (
        to.from_numpy(multivariate_normal.pdf(grd, mu * np.ones(2), sigma * np.eye(2)))
        .type(precision)
        .to(device)
    )
    grd_dim1 = to.from_numpy(grd[:, :, 0]).type(precision).to(device)
    grd_dim2 = to.from_numpy(grd[:, :, 2]).type(precision).to(device)
    return pdf, grd_dim1, grd_dim2


class OverlappingPatches:
    def __init__(
        self,
        image: Tensor,
        patch_height: int,
        patch_width: int,
        patch_shift: int,
    ):
        """Back and forth transformation for image segmentation into overlapping patches.
        Makes use of `skimage.util.view_as_windows`.

        :param image: Tensor to be cut into patches and reconstructed.
                      Must be 2-dim. tensor (height x width).
        :param patch_height: Will be passed as `window_shape[0]` to `skimage.util.view_as_windows`.
        :param patch_width: Will be passed as `window_shape[1]` to `skimage.util.view_as_windows`.
        :param patch_shift: Will be passed as `step` to `skimage.util.view_as_windows`.
        """
        assert image.dim() == 2, "image tensor must be two-dimensional (width x height)"
        self.verbose = False  # dis-/enable prints (for debugging purposes)
        self._patch_height = patch_height
        self._patch_width = patch_width
        self._patch_shift = patch_shift
        device, precision = image.device, image.dtype
        image_np = image.to(device="cpu").numpy()

        # infer some parameters
        image_not_incomplete = np.logical_not(np.isnan(image_np).any())
        image_height, image_width = image_np.shape[0], image_np.shape[1]
        no_pixels_in_patch = patch_height * patch_width
        no_patches_vert = int(
            np.ceil(float(image_height - patch_height) / patch_shift) + 1
        )  # no patches in vertical direction
        no_patches_horz = int(
            np.ceil(float(image_width - patch_width) / patch_shift) + 1
        )  # no patches in horizontal direction
        no_patches = no_patches_vert * no_patches_horz
        no_patches_vert_shift_1 = int(
            np.ceil(float(image_height - patch_height)) + 1
        )  # no patches in vertical dir for step=1
        no_patches_horz_shift_1 = int(
            np.ceil(float(image_width - patch_width)) + 1
        )  # no patches in horizontal dir for step=1
        no_patches_shift_1 = (
            no_patches_vert_shift_1 * no_patches_horz_shift_1
        )  # no patches for step=1

        ninds_ = np.arange(no_patches_shift_1).reshape(
            no_patches_vert_shift_1, no_patches_horz_shift_1
        )  # lower right patch locations in image for step = 1
        hinds = (
            np.unique(
                np.append(
                    np.arange(1, no_patches_vert_shift_1, patch_shift),
                    [no_patches_vert_shift_1],
                )
            )
            - 1
        ).flatten()  # indices of relevant patches for step=patch_shift in vertical direction
        winds = (
            np.unique(
                np.append(
                    np.arange(1, no_patches_horz_shift_1, patch_shift),
                    [no_patches_horz_shift_1],
                )
            )
            - 1
        ).flatten()  # indices of relevant patches for step=patch_shift in horizontal direction
        ninds = (ninds_[hinds, :][:, winds]).flatten()  # indices of relevant patches
        assert len(ninds) == no_patches
        dinds = np.arange(no_pixels_in_patch).reshape(
            patch_height, patch_width
        )  # spatial order of pixel indices, is (patch_height, patch_width),
        # indexed l->r and then top->bottom
        to_be_synthesized = (
            np.isnan(image_np) if np.isnan(image_np).any() else np.ones_like(image_np, dtype=bool)
        )  # indicates which pixels of the input image are to be reconstructed
        ind_rows_to_synthesize, ind_cols_to_synthesize = np.where(
            to_be_synthesized
        )  # index tuples of missing values, is (total # missing vals)
        no_pixels_to_synthesize = ind_rows_to_synthesize.size  # no missing values

        # cut patches
        print("Cutting image into patches...", end="", flush=True)
        patches_np = view_as_windows(
            image_np, window_shape=[patch_height, patch_width], step=1
        )  # moves sliding window left->right and then top->bottom
        # is (image_height-patch_height+1, image_width-patch_width+1, patch_height, patch_width)
        patches_np = (
            patches_np.reshape(patch_height, patch_width, no_patches_shift_1)
            .reshape(no_patches_shift_1, no_pixels_in_patch)
            .T
        )  # is (no_pixels_in_patch,no_patches_shift_1)
        patches_np = patches_np[
            :, ninds
        ]  # remove patches not satisfying `patch_shift` is (no_pixels_in_patch,no_patches)
        patches_np_not_isnan = np.logical_not(np.isnan(patches_np))
        print("Done", flush=True)

        # compute indices required to merge patches back to image
        print("Initialize back-transformation...", end="", flush=True)
        all_inds_relevant_patches = [0] * no_pixels_to_synthesize
        all_inds_relevant_values_in_patch = [0] * no_pixels_to_synthesize
        restorable = (
            np.zeros(no_pixels_to_synthesize, dtype=np.bool) if np.isnan(patches_np).any() else None
        )
        for p in range(no_pixels_to_synthesize):

            # location of missing value in original image
            r, c = ind_rows_to_synthesize[p], ind_cols_to_synthesize[p]

            # location of relevant patches for patch_shift = 1(rows and columns of ninds_)
            r_ = (
                np.arange(
                    max(r - patch_height + 2, 1),
                    min(r + 1, no_patches_vert_shift_1) + 1,
                )
                - 1
            )
            c_ = np.arange(max(c - patch_width + 2, 1), min(c + 1, no_patches_horz_shift_1) + 1) - 1

            ns_ = ninds_[r_, :][
                :, c_
            ].flatten()  # is no relevant patches for given pixel in original
            ds_ = np.sort(dinds[r - r_, :][:, c - c_].flatten())[::-1]

            # only use patches compatible with given patch_shift
            if patch_shift > 1:
                nsinds = np.isin(ns_, ninds)
                inds_relevant_patches, inds_relevant_values_in_patch = (
                    ns_[nsinds],
                    ds_[nsinds],
                )
            else:
                inds_relevant_patches, inds_relevant_values_in_patch = ns_, ds_

            # indices considering remaining patches
            if patch_shift > 1:
                inds_relevant_patches = np.where(np.isin(ninds, inds_relevant_patches))[0]

            if image_not_incomplete:
                all_inds_relevant_patches[p] = inds_relevant_patches
                all_inds_relevant_values_in_patch[p] = inds_relevant_values_in_patch
            else:
                relevant_patches = patches_np_not_isnan[:, inds_relevant_patches]
                ind_nonempty_patches = relevant_patches.any(axis=0)
                if ind_nonempty_patches.any():
                    restorable[p] = True
                    # ind_nonempty_patches = relevant_patches.any(axis=0)
                    all_inds_relevant_patches[p] = inds_relevant_patches[ind_nonempty_patches]
                    all_inds_relevant_values_in_patch[p] = inds_relevant_values_in_patch[
                        ind_nonempty_patches
                    ]
        print("Done", flush=True)

        self._image, self._patches = image, to.from_numpy(patches_np).type(precision).to(device)
        (
            self._ind_rows_to_synthesize,
            self._ind_cols_to_synthesize,
            self._no_pixels_to_synthesize,
            self._restorable,
        ) = (
            to.from_numpy(ind_rows_to_synthesize).type(to.int64).to(device),
            to.from_numpy(ind_cols_to_synthesize).type(to.int64).to(device),
            no_pixels_to_synthesize,
            to.from_numpy(restorable).type(to.bool).to(device) if restorable is not None else None,
        )
        self._all_inds_relevant_patches = [
            to.from_numpy(x.copy() if isinstance(x, np.ndarray) else np.ndarray(x))
            .type(to.int64)
            .to(device)
            for x in all_inds_relevant_patches
        ]
        self._all_inds_relevant_values_in_patch = [
            to.from_numpy(x.copy() if isinstance(x, np.ndarray) else np.ndarray(x))
            .type(to.int64)
            .to(device)
            for x in all_inds_relevant_values_in_patch
        ]

    def get_image_shape(self) -> Tuple[Any, ...]:
        return tuple(self._image.shape)

    def get_number_of_patches(self, discard_empty: bool = True) -> int:
        not_isnan: to.Tensor = to.logical_not(to.isnan(self._patches))
        return (
            to.sum(not_isnan.any(dim=0)).item() if discard_empty else self._patches.shape[1].item()
        )

    def get_patch_height_width_shift(self) -> Tuple[int, int, int]:
        return self._patch_height, self._patch_width, self._patch_shift

    def get(self, discard_empty: bool = True) -> Tensor:
        """Returns tensor containing patchified image.

        :param discard_empty: If true, patches in which all values are missing, are not returned.
        """
        if to.logical_not(to.isnan(self._patches).any()):
            return self._patches
        else:
            if discard_empty:
                not_isnan: to.Tensor = to.logical_not(to.isnan(self._patches))
                return self._patches[:, not_isnan.any(dim=0)]
            else:
                return self._patches

    def set(self, new_patches: Tensor, discarded_empty: bool = True):
        """Updates values of `self._patches` to `new_patches`.

        :param new_patches: Values in `self._patches` will be replaces with values in
                            `new_patches`.
        :param discarded_empty: Needs to be set corresondingly to how `get_patches` has been called
                                (compare docs of `get_patches`).
        """
        if to.logical_not(to.isnan(self._patches).any()):
            assert (
                new_patches.shape == self._patches.shape
            ), "shape of new and internal patches does not match"
            self._patches[:, :] = new_patches
        else:
            if discarded_empty:
                not_isnan: to.Tensor = to.logical_not(to.isnan(self._patches))
                assert (
                    new_patches.shape == self._patches[:, not_isnan.any(axis=0)].shape
                ), "shape of new and non-empty internal patches does not match"
                self._patches[:, not_isnan.any(axis=0)] = new_patches
            else:
                assert (
                    new_patches.shape == self._patches.shape
                ), "shape of new and internal patches does not match"
                self._patches[:, :] = new_patches

    def merge(self, merge_method: Callable = mean_merger) -> Tensor:
        """Merge patches to obtain new image.

        :param merge_method: Function defining how pixel estimates from different patches are to be
                             merged, defaults to unweighted averaging.
        """
        new_image = self._image.clone()
        for p in range(self._no_pixels_to_synthesize):
            if self._restorable is not None and not self._restorable[p]:
                continue

            r, c = self._ind_rows_to_synthesize[p], self._ind_cols_to_synthesize[p]
            inds_relevant_patches = self._all_inds_relevant_patches[p]
            inds_relevant_values_in_patch = self._all_inds_relevant_values_in_patch[p]

            restored = self._patches[inds_relevant_values_in_patch, inds_relevant_patches]

            if self.verbose:
                print(f"Processing image pixel at ({r},{c}) \n=============================")
                print(f"Estimates from all patches \n  {restored}\n")

            estimate = merge_method(restored, inds_relevant_values_in_patch)

            if self.verbose:
                print(f"Merged estimate is \n  {estimate}%s\n")

            new_image[r, c] = estimate

        return new_image

    def set_and_merge(
        self,
        new_patches: Tensor,
        discarded_empty: bool = True,
        merge_method: Callable = mean_merger,
    ) -> Tensor:
        """Sequentially calls `set` and `merge`.

        :param new_patches: see docs of `set`
        :param discarded_empty: see docs of `set`
        :param merge_method: see docs of `merge`
        """
        self.set(new_patches, discarded_empty)
        return self.merge(merge_method)


class MultiDimOverlappingPatches:
    def __init__(
        self,
        image: Tensor,
        patch_height: int,
        patch_width: int,
        patch_shift: int,
    ):
        """Sequentially apply transformations implemented by `OverlappingPatches` for multi-channel
        data (e.g. RGB images).

        :param image: Tensor to be cut into patches and reconstructed.
                      Must be 3-dim. tensor (height x width x no_channels).
        :param patch_height: see `OverlappingPatches.__init__` docs
        :param patch_width: see `OverlappingPatches.__init__` docs
        :param patch_shift: see `OverlappingPatches.__init__` docs
        """
        assert (
            image.dim() == 3
        ), "image tensor must be three-dimensional (width x height x no_channels)"
        self.no_channels = image.shape[2]
        self.OVPs = [
            OverlappingPatches(
                image[:, :, ch],
                patch_height,
                patch_width,
                patch_shift,
            )
            for ch in range(self.no_channels)
        ]

    def get_image_shape(self) -> Tuple[Any, ...]:
        return self.OVPs[0].get_image_shape() + (self.no_channels,)

    def get_number_of_patches(self, discard_empty: bool = True) -> int:
        return self.OVPs[0].get_number_of_patches()

    def get_patch_height_width_shift(self) -> Tuple[int, int, int]:
        return self.OVPs[0].get_patch_height_width_shift()

    def get(self, discard_empty: bool = True, concatenate: bool = True) -> Tensor:
        """Runs `OverlappingPatches.get` sequentially for each channel"""
        p = [self.OVPs[ch].get(discard_empty) for ch in range(self.no_channels)]
        return to.cat(p, dim=0) if concatenate else to.stack(p, dim=-1)

    def set(self, new_patches: Tensor, discarded_empty: bool = True, concatenated: bool = True):
        """Runs `OverlappingPatches.set` sequentially for each channel"""
        patches = self.get(discarded_empty)
        assert new_patches.device == patches.device, "device of new_patches mismatched"
        assert new_patches.dtype == patches.dtype, "dtype of new_patches mismatched"
        if concatenated:
            assert (
                new_patches.dim() == 2
            ), "patches tensor must be two-dimensional\
            (patch_height*patch_width*no_channels x no_patches)"
            px_per_ch = new_patches.shape[0] // self.no_channels
            for ch in range(self.no_channels):
                inds_px_ch = to.arange(px_per_ch) + ch * px_per_ch
                self.OVPs[ch].set(new_patches[inds_px_ch, :], discarded_empty)
        else:
            assert (
                new_patches.dim() == 3
            ), "patches tensor must be three-dimensional (patch_height*patch_width x no_patches\
                x no_channels)"
            for ch in range(self.no_channels):
                self.OVPs[ch].set(new_patches[:, :, ch], discarded_empty)

    def merge(self, merge_method: Callable = mean_merger) -> Tensor:
        """Runs `OverlappingPatches.merge` sequentially for each channel"""
        return to.stack(
            [self.OVPs[ch].merge(merge_method) for ch in range(self.no_channels)], dim=-1
        )

    def set_and_merge(
        self,
        new_patches: Tensor,
        discarded_empty: bool = True,
        concatenated: bool = True,
        merge_method: Callable = mean_merger,
    ) -> Tensor:
        """Runs `OverlappingPatches.set_and_merge` sequentially for each channel"""
        patches = self.get(discarded_empty)
        assert new_patches.device == patches.device, "device of new_patches mismatched"
        assert new_patches.dtype == patches.dtype, "dtype of new_patches mismatched"
        if concatenated:
            assert (
                new_patches.dim() == 2
            ), "patches tensor must be two-dimensional\
            (patch_height*patch_width*no_channels x no_patches)"
            px_per_ch = new_patches.shape[0] // self.no_channels
            merged = [
                self.OVPs[ch].set_and_merge(
                    new_patches[to.arange(px_per_ch) + ch * px_per_ch, :],
                    discarded_empty,
                    merge_method,
                )
                for ch in range(self.no_channels)
            ]
        else:
            assert (
                new_patches.dim() == 3
            ), "patches tensor must be three-dimensional (patch_height*patch_width x no_patches\
                x no_channels)"
            merged = [
                self.OVPs[ch].set_and_merge(new_patches[:, :, ch], discarded_empty, merge_method)
                for ch in range(self.no_channels)
            ]
        return to.stack(merged, dim=-1)
