# -*- coding: utf-8 -*-
# Copyright (C) 2020 Machine Learning Group of the University of Oldenburg.
# Licensed under the Academic Free License version 3.0

import os
import numpy as np
import torch as to
import matplotlib.pyplot as plt
from typing import Tuple, Union, Optional
from sklearn.utils import as_float_array
from sklearn.preprocessing import MinMaxScaler


def _rescale(data: np.ndarray, feature_range: Tuple[float, float]):
    """Rescale values of data to fill feature_range"""
    data_shape = data.shape
    data_float = as_float_array(data, copy=True)
    scaler = MinMaxScaler(feature_range=feature_range, copy=True)
    scaler.set_params()

    data_float_flattened = data_float.reshape(-1, 1)
    scaler.fit(data_float_flattened)
    data_float_flattened_scaled = scaler.transform(data_float_flattened)
    return data_float_flattened_scaled.reshape(data_shape)


def _get_to_show(subplots_yx: Tuple, no_subplots: int):
    no_show = np.min([subplots_yx[0] * subplots_yx[1], no_subplots])
    if no_show < no_subplots:
        print(f"Displaying {no_show} of {no_subplots} subplots.")
        suf = "_sel"
    else:
        suf = ""
    return no_show, suf


def _get_height_width(height_width: Union[None, Tuple[int, int]], no_pixel: int):
    if height_width is None:
        height = int(np.sqrt(no_pixel))
        if height == np.double(np.sqrt(no_pixel)):
            width = height
        else:
            raise ValueError("height * width not a perfect square number.")
    else:
        assert isinstance(height_width, tuple)  # to make mypy happy
        height, width = height_width
    return int(height), int(width)


def _get_observed(data: np.ndarray):
    return data[np.logical_not(np.isnan(data))]


def _get_is_binary(
    values: np.ndarray, observed: np.ndarray, global_clim: bool, sym_clim: bool, cmap: str
):
    is_binary = values.dtype == "uint8"
    if not is_binary:
        val_min = np.nanmin(observed)
        val_max = np.nanmax(observed)

        if sym_clim:
            val_max = np.nanmax([np.abs(val_min), val_max])
            val_min = -val_max
    else:
        val_min, val_max = 0, 1
        global_clim = True
        cmap = "gray"
    return is_binary, val_min, val_max, global_clim, cmap


def _get_clim_suff(clim: Union[None, Tuple[float, float]], global_clim: bool, sym_clim: bool):
    suff = ""
    if clim is None:
        if global_clim:
            suff += "_global"
        else:
            suff += "_local"
        if sym_clim:
            suff += "_sym"
        else:
            suff += "_unsym"
    else:
        assert isinstance(clim, tuple)  # to make mypy happy
    return suff


def _get_figure(
    no_show: int,
    figsize: Optional[Tuple[float, float]],
    subplots_yx: Tuple[int, int],
    height: int,
    width: int,
    pxsize: float,
    facecolor: str,
):
    figsize_ = (
        (pxsize * width * subplots_yx[1], pxsize * height * subplots_yx[0])
        if figsize is None
        else figsize
    )
    if no_show == 1:
        fig = plt.figure(figsize=figsize_, facecolor=facecolor, edgecolor="none")
        axarr = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
        axarr.set_axis_off()
        fig.add_axes(axarr)
    elif no_show > 1:
        fig, axarr = plt.subplots(
            subplots_yx[0], subplots_yx[1], figsize=figsize_, facecolor=facecolor, edgecolor="none"
        )
        axarr = axarr.ravel()
    return fig, axarr


def _get_to_plot(
    values: np.ndarray,
    subplot_ind: int,
    height: int,
    width: int,
    global_clim: bool,
    val_min: float,
    val_max: float,
    sym_clim: bool,
    clim: Union[None, Tuple[float, float]],
):
    if np.ndim(values) == 2:
        ToPlot = values[:, subplot_ind].reshape(height, width)
    elif np.ndim(values) == 3:
        ToPlot = values[:, subplot_ind, :].reshape(height, width, 3)

    if global_clim:
        my_val_min = val_min
        my_val_max = val_max
    else:
        my_val_min = np.nanmin(ToPlot)
        my_val_max = np.nanmax(ToPlot)
        if sym_clim:
            my_val_max = np.nanmax([np.abs(my_val_min), my_val_max])
            my_val_min = -my_val_max
    if clim is None:
        my_clim = (my_val_min, my_val_max)
    else:
        assert isinstance(clim, tuple)  # to make mypy happy
        my_clim = clim
    if np.ndim(values) == 3:
        ToPlot_ = np.minimum(np.maximum(ToPlot, my_clim[0]), my_clim[1])
        my_clim = (0.0, 1.0)
        ToPlot = _rescale(ToPlot_, my_clim)
    return ToPlot, my_clim


def _values_as_np(data: Union[np.ndarray, to.Tensor]):
    if type(data) == to.Tensor:
        return data.detach().cpu().numpy()
    else:
        return data


class matrix_columns_as_image_subplots:
    def __init__(
        self,
        cdata: Union[np.ndarray, to.Tensor],
        subplots_yx: Tuple[int, int],
        height_width: Tuple[int, int] = None,
        figsize: Tuple[float, float] = None,
        clim: Tuple[float, float] = None,
        cmap: str = "jet",
        sym_clim: bool = True,
        global_clim: bool = True,
        figure_name: str = None,
        output_directory: str = None,
        output_file_type: str = ".png",
        interpolation: str = "nearest",
        facecolor: str = "white",
        pxsize: float = 0.05,
    ):
        """
        Create subplot figure based on 2- or 3-dim tensor.
        The last dimension of cdata will be reshaped as height_width[0] x height_width[1] image
        and printed in a new subplot.

        :param cdata: Input tensor. Can be 2 or 3-dim.
        :param subplots_yx: Number of subplots in y- and x-direction.
        :param height_width: Vertical and horizontal dimensions of each subplot
        :param figsize: Figsize argument passed to plt.figure
        :param clim: Color limits
        :param cmap: Color map
        :param sym_clim: Use symmetric color limits
        :param global_clim: Use the same color limits across subplots
        :param figure_name: Name of the figure
        :param output_directory: Directory to store the figure
        :param output_file_type: Figure extension (.jpg, .png, ...)
        :param interpolation: Parameter related to output resolution
        :param facecolor: Background color of the figure
        :param pxsize: Absolute size per pixel

        Example:
        matrix_columns_as_image_subplots(np.ones((25, 9))*np.arange(9)[None,:], (1,9))
        plt.show()
        """
        cdata = _values_as_np(cdata)  # TODO: Support torch.Tensors in all places

        # dimensions
        no_pixel, no_subplots = cdata.shape[:2]

        no_show, suf = _get_to_show(subplots_yx, no_subplots)

        height, width = _get_height_width(height_width, no_pixel)

        # to apply global colorscale
        observed = _get_observed(cdata)

        is_binary, val_min, val_max, global_clim, cmap = _get_is_binary(
            cdata,
            observed,
            global_clim,
            sym_clim,
            cmap,
        )

        suf += _get_clim_suff(clim, global_clim, sym_clim)

        fig, axarr = _get_figure(no_show, figsize, subplots_yx, height, width, pxsize, facecolor)

        spl = {}
        for h in range(np.prod(subplots_yx)):
            if no_show > 1:
                this_axarr = axarr[h]
            else:
                this_axarr = axarr

            if h in range(no_show):
                this_axarr.tick_params(
                    axis="both",
                    which="both",
                    bottom=False,
                    top=False,
                    labelbottom=False,
                    left=False,
                    right=False,
                    labelleft=False,
                )

                ToPlot, my_clim = _get_to_plot(
                    cdata,
                    h,
                    height,
                    width,
                    global_clim,
                    val_min,
                    val_max,
                    sym_clim,
                    clim,
                )

                spl["p{0}".format(h)] = this_axarr.imshow(
                    ToPlot,
                    interpolation=interpolation,
                    vmin=my_clim[0],
                    vmax=my_clim[1],
                    cmap=cmap,
                )
                if not is_binary:
                    for side in ["bottom", "right", "top", "left"]:
                        this_axarr.spines[side].set_visible(False)
            else:
                this_axarr.axis("off")

        plt.subplots_adjust(left=0.0, right=1.0, bottom=0.0, top=1.0)

        if figure_name is not None:
            figure_name = "%s%s" % (figure_name, suf)
            fig.canvas.manager.set_window_title(figure_name)

        if output_directory is not None:
            assert figure_name is not None, "Figure name not defined"
            if not os.path.exists(output_directory):
                os.makedirs(output_directory)
            fig.savefig(
                output_directory + "/" + figure_name + output_file_type,
                facecolor=facecolor,
                edgecolor="none",
            )
            print("Wrote " + output_directory + "/" + figure_name + output_file_type)

        self.suf = suf
        self.fig = fig
        self.axarr = axarr
        self.subplots_yx = subplots_yx
        self.no_show = no_show
        self.height = height
        self.width = width
        self.global_clim = global_clim
        self.sym_clim = sym_clim
        self.clim = clim
        self.cmap = cmap
        self.spl = spl
        self.interpolation = interpolation
        self.facecolor = facecolor

    def update(
        self,
        cdata: Union[np.ndarray, to.Tensor],
        figure_name: str = None,
        output_directory: str = None,
        output_file_type: str = ".png",
    ):
        cdata = _values_as_np(cdata)
        observed = _get_observed(cdata)
        _, val_min, val_max, _, _ = _get_is_binary(
            cdata,
            observed,
            self.global_clim,
            self.sym_clim,
            self.cmap,
        )
        for h in range(np.prod(self.subplots_yx)):
            if self.no_show > 1:
                this_axarr = self.axarr[h]
            else:
                this_axarr = self.axarr

            if h in range(self.no_show):
                this_axarr.tick_params(
                    axis="both",
                    which="both",
                    bottom=False,
                    top=False,
                    labelbottom=False,
                    left=False,
                    right=False,
                    labelleft=False,
                )

                ToPlot, my_clim = _get_to_plot(
                    cdata,
                    h,
                    self.height,
                    self.width,
                    self.global_clim,
                    val_min,
                    val_max,
                    self.sym_clim,
                    self.clim,
                )
                self.spl["p{0}".format(h)].set_array(ToPlot)
                self.spl["p{0}".format(h)].set_clim(my_clim)

        if figure_name is not None:
            figure_name = "%s%s" % (figure_name, self.suf)
            self.fig.canvas.manager.set_window_title(figure_name)

        if output_directory is not None:
            assert figure_name is not None, "Figure name not defined"
            if not os.path.exists(output_directory):
                os.makedirs(output_directory)
            self.fig.savefig(
                output_directory + "/" + figure_name + output_file_type,
                facecolor=self.facecolor,
                edgecolor="none",
            )
            print("Wrote " + output_directory + "/" + figure_name + output_file_type)


if __name__ == "__main__":

    _ = matrix_columns_as_image_subplots(
        np.ones((30, 9)) * np.arange(1, 10)[None, :], (1, 9), height_width=(10, 3)
    )
    subplots = matrix_columns_as_image_subplots(
        to.ones((25, 9)) * to.arange(1, 10)[None, :], (1, 9), figure_name="v1"
    )
    subplots.update(to.ones((25, 9)) * to.arange(10, 19)[None, :], figure_name="v2")
    plt.show()
