# -*- coding: utf-8 -*-
# Copyright (C) 2020 Machine Learning Group of the University of Oldenburg.
# Licensed under the Academic Free License version 3.0

import os
import numpy as np
import torch as to
import matplotlib.pyplot as plt
from typing import Tuple, Union
from sklearn.utils import as_float_array
from sklearn.preprocessing import MinMaxScaler


def _rescale(data: np.ndarray, feature_range: Tuple[float, float]):
    """Rescale values of data to fill feature_range"""
    data_shape = data.shape
    data_float = as_float_array(data, copy=True)
    scaler = MinMaxScaler(feature_range=feature_range, copy=True)
    scaler.set_params()

    data_float_flattened = data_float.reshape(-1, 1)
    scaler.fit(data_float_flattened)
    data_float_flattened_scaled = scaler.transform(data_float_flattened)
    return data_float_flattened_scaled.reshape(data_shape)


def _data_as_np(data: Union[np.ndarray, to.Tensor]):
    if type(data) == to.Tensor:
        return data.detach().cpu().numpy()
    else:
        return data


class matrix_as_image:
    def __init__(
        self,
        cdata: Union[np.ndarray, to.Tensor],
        figsize: Tuple[float, float] = None,
        dpi: int = None,
        clim: Tuple[float, float] = None,
        colormap: str = "jet",
        figure_name: str = None,
        output_directory: str = None,
        output_file_type: str = "png",
        interpolation: str = "nearest",
    ):
        """
        Visualize 2- or 3-dim matrix as image. If input is 3-dim., last dim. will be treated
        as color channel.

        :param cdata: Input tensor. Can be 2 or 3-dim.
        :param figsize: plt.figure size (one of `figsize` and `dpi` must be different to `None`)
        :param dpi: resolution (one of `figsize` and `dpi` must be different to `None`)
        :param clim: Color limits
        :param colormap: Color map
        :param figure_name: Name of the figure
        :param output_directory: Directory to store the figure
        :param output_file_type: Figure extension (.jpg, .png, ...)
        :param interpolation: Parameter related to output resolution

        Example:
        matrix_as_image(np.eye(500))
        plt.show()
        """
        cdata = _data_as_np(cdata)  # TODO: Support torch.Tensors in all places
        assert np.ndim(cdata) == 2 or np.ndim(cdata) == 3, "Input must be a 2- or 3-dim. matrix"
        height, width = cdata.shape[:2]
        if np.ndim(cdata) == 3:
            no_channel = cdata.shape[2]
            assert no_channel == 3, "input must provide three (RGB) color channels"
        else:
            no_channel = 1
        assert not (
            figsize is None and dpi is None
        ), "One of figsize and dpi must be different to None"
        if dpi is not None:
            if figsize is None:
                figsize = (width / dpi, height / dpi)
            else:
                figsize = (figsize[0] / dpi, figsize[1] / dpi)

        fig = plt.figure(figsize=figsize, edgecolor="none", dpi=dpi)
        ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
        ax.axis("off")

        if np.ndim(cdata) == 3:
            if clim is not None:
                cdata = np.minimum(np.maximum(cdata, clim[0]), clim[1])
            cdata = _rescale(data=cdata, feature_range=(0.0, 1.0))

        im = ax.imshow(
            cdata,
            interpolation=interpolation,
            vmin=clim[0] if clim is not None else clim,
            vmax=clim[1] if clim is not None else clim,
            cmap=colormap,
        )

        if figure_name is not None:
            fig.canvas.manager.set_window_title(figure_name)

        if output_directory is not None and not os.path.exists(output_directory):
            os.makedirs(output_directory)

        if output_directory is not None:
            assert figure_name is not None, "Must provide figure name"
            fig.savefig(f"{output_directory}/{figure_name}.{output_file_type}")
            print(f"Wrote {output_directory}/{figure_name}.{output_file_type}")

        self.fig = fig
        self.im = im
        self.interpolation = interpolation

    def update(
        self,
        cdata: Union[np.ndarray, to.Tensor],
        clim: Tuple[float, float] = None,
        figure_name: str = None,
        output_directory: str = None,
        output_file_type: str = "png",
    ):
        cdata = _data_as_np(cdata)
        clim_ = (np.nanmin(cdata), np.nanmax(cdata)) if clim is None else clim
        if np.ndim(cdata) == 3:
            cdata = np.minimum(np.maximum(cdata, clim_[0]), clim_[1])
            cdata = _rescale(data=cdata, feature_range=(0.0, 1.0))
        self.im.set_data(cdata)
        self.im.set_clim(clim_)

        if figure_name is not None:
            self.fig.canvas.manager.set_window_title(figure_name)

        if output_directory is not None:
            assert figure_name is not None, "Must provide figure name"
            self.fig.savefig(f"{output_directory}/{figure_name}.{output_file_type}")
            print(f"Wrote {output_directory}/{figure_name}.{output_file_type}")


if __name__ == "__main__":

    matrix_as_image(np.eye(200), dpi=96, figure_name="demo")
    plt.show()
