# -*- coding: utf-8 -*-
# Copyright (C) 2020 Machine Learning Group of the University of Oldenburg.
# Licensed under the Academic Free License version 3.0

import os
import numpy as np
import torch as to
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from typing import Tuple, Union
from sklearn.utils import as_float_array
from sklearn.preprocessing import MinMaxScaler


def _rescale(data: np.ndarray, feature_range: Tuple[float, float]):
    """Rescale values of data to fill feature_range"""
    data_shape = data.shape
    data_float = as_float_array(data, copy=True)
    scaler = MinMaxScaler(feature_range=feature_range, copy=True)
    scaler.set_params()

    data_float_flattened = data_float.reshape(-1, 1)
    scaler.fit(data_float_flattened)
    data_float_flattened_scaled = scaler.transform(data_float_flattened)
    return data_float_flattened_scaled.reshape(data_shape)


def _get_to_show(subplots_yx: Tuple, no_subplots: int):
    no_show = np.min([subplots_yx[0] * subplots_yx[1], no_subplots])
    if no_show < no_subplots:
        print(f"Displaying {no_show} of {no_subplots} subplots.")
        suf = "_sel"
    else:
        suf = ""
    return no_show, suf


def _get_obs_resh(data: np.ndarray, no_subplots: int):
    if np.ndim(data) == 3:
        return data[-1, :, :].transpose(1, 0).reshape(no_subplots, -1)
    elif np.ndim(data) == 4:
        return data[-1, :, :, :].transpose(1, 0, 2).reshape(no_subplots, -1)


def _get_height_width(height_width: Union[None, Tuple[int, int]], no_pixel: int):
    if height_width is None:
        height = int(np.sqrt(no_pixel))
        if height == np.double(np.sqrt(no_pixel)):
            width = height
        else:
            raise ValueError("height * width not a perfect square number.")
    else:
        height, width = height_width
    return int(height), int(width)


def _get_to_plot(
    values: np.ndarray,
    subplot_ind: int,
    height: int,
    width: int,
    global_clim: bool,
    val_min_last_glob: float,
    val_max_last_glob: float,
    val_min_last_ind: Union[Tuple[float, ...], float],
    val_max_last_ind: Union[Tuple[float, ...], float],
    sym_clim: bool,
    clim: Union[None, Tuple[float, float]],
):
    if np.ndim(values) == 3:
        ToPlot = values[0, :, subplot_ind].reshape(height, width)
    elif np.ndim(values) == 4:
        ToPlot = values[0, :, subplot_ind, :].reshape(height, width, 3)

    if global_clim:
        my_val_min = val_min_last_glob
        my_val_max = val_max_last_glob
    else:
        assert isinstance(val_min_last_ind, tuple)
        assert isinstance(val_max_last_ind, tuple)
        my_val_min = val_min_last_ind[subplot_ind]
        my_val_max = val_max_last_ind[subplot_ind]
        if sym_clim:
            my_val_max = np.nanmax([np.abs(my_val_min), my_val_max])
            my_val_min = -my_val_max

    if clim is None:
        my_clim = (my_val_min, my_val_max)
    else:
        my_clim = clim

    if np.ndim(values) == 4:
        ToPlot = np.minimum(np.maximum(ToPlot, my_clim[0]), my_clim[1])
        my_clim = (0.0, 1.0)
        ToPlot = _rescale(ToPlot, my_clim)

    return ToPlot, my_clim


def _rescale_toplot(values: np.ndarray, ToPlot: np.ndarray, clim: Tuple):
    if np.ndim(values) == 4:
        ToPlot = np.minimum(np.maximum(ToPlot, clim[0]), clim[1])
        clim = (0.0, 1.0)
        ToPlot = _rescale(ToPlot, clim)
    return ToPlot, clim


def _values_as_np(data: Union[np.ndarray, to.Tensor]):
    if type(data) == to.Tensor:
        return data.detach().cpu().numpy()
    else:
        return data


class matrix_columns_as_image_subplot_animations(animation.TimedAnimation):
    def __init__(
        self,
        values: Union[np.ndarray, to.Tensor],
        subplots_yx: Tuple[int, int],
        height_width: Tuple[int, int] = None,
        figsize: Tuple[float, float] = None,
        clim: Tuple[float, float] = None,
        cmap: str = "jet",
        sym_clim: bool = True,
        global_clim: bool = True,
        fname: str = None,
        out_dir: str = None,
        ms_per_step: int = 100,
        print_steps: bool = False,
        repeat: bool = True,
        pxsize: float = 0.05,
    ):
        """
        Create animated subplot figure based on 2- or 3-dim tensor.
        The last dimension of values will be reshaped as height_width[0] x height_width[1] image
        and printed in a new subplot.

        :param values: Input tensor. Can be 2 or 3-dim.
        :param subplots_yx: Number of subplots in y- and x-direction.
        :param height_width: Vertical and horizontal dimensions of each subplot
        :param figsize: Figure size
        :param clim: Color limits
        :param cmap: Color map
        :param sym_clim: Use symmetric color limits
        :param global_clim: Use the same color limits across subplots
        :param add_colorbar: Add a colorbar to the figure
        :param fname: Name of the figure
        :param out_dir: Directory to store the figure
        :param ms_per_step: Duration per frame
        :param print_steps: Print frame index
        :param repeat: Repeat animation
        :param pxsize: Absolute size per pixel

        Example:
        values = np.ones((10, 25, 9))*np.arange(9)[None,None,:]*np.linspace(0, 1, 10)[:,None,None]
        matrix_column_subplots_anim(values, (1,9)
        plt.show()
        """
        values = _values_as_np(values)  # TODO: Support torch.Tensors in all places

        # dimensions
        no_steps, no_pixel, no_subplots = values.shape[:3]

        no_show, suf = _get_to_show(subplots_yx, no_subplots)

        observed_resh = _get_obs_resh(values, no_subplots)

        height, width = _get_height_width(height_width, no_pixel)

        is_binary = values.dtype == "uint8"
        if not is_binary:
            val_min_last_ind = np.nanmin(observed_resh, axis=1)  # is (no_subplots,)
            val_max_last_ind = np.nanmax(observed_resh, axis=1)  # is (no_subplots,)

            if sym_clim:
                val_min_max_conc = np.concatenate(
                    (np.abs(val_min_last_ind)[:, None], val_max_last_ind[:, None]),
                    axis=1,
                )
                val_max_last_ind = np.nanmax(val_min_max_conc, axis=1)  # is (no_subplots,)
                val_min_last_ind = -val_max_last_ind  # is (no_subplots,)

            # to apply global colorscale w.r.t. last iteration
            observed = values[-1]
            observed = observed[np.logical_not(np.isnan(observed))]
            val_max_last_glob = np.nanmax(observed)
            val_min_last_glob = np.nanmin(observed)

            if sym_clim:
                val_max_last_glob = np.nanmax([np.abs(val_min_last_glob), val_max_last_glob])
                val_min_last_glob = -val_max_last_glob
        else:
            val_min_last_ind, val_max_last_ind, val_min_last_glob, val_max_last_glob = (
                0,
                1,
                0,
                1,
            )
            global_clim = True
            cmap = "gray"

        self.values = values
        self.no_subplots = no_subplots
        self.subplots_yx = subplots_yx
        self.no_show = no_show
        self.height = height
        self.width = width
        self.no_steps = no_steps
        self.curr_step = 0
        self.print_steps = print_steps
        self.my_clims = np.empty((np.prod(subplots_yx), 2))
        figsize_ = (
            (pxsize * width * subplots_yx[1], pxsize * height * subplots_yx[0])
            if figsize is None
            else figsize
        )
        self.f, self.axarr = plt.subplots(subplots_yx[0], subplots_yx[1], figsize=figsize_)
        self.axarr = self.axarr.ravel()
        self.spl = {}

        if self.print_steps:
            print("%d of %d" % (self.curr_step, self.no_steps))
        for h in range(np.prod(subplots_yx)):
            if h in range(no_show):
                self.axarr[h].tick_params(
                    axis="both",
                    which="both",
                    bottom=False,
                    top=False,
                    labelbottom=False,
                    left=False,
                    right=False,
                    labelleft=False,
                )

                ToPlot, self.my_clims[h] = _get_to_plot(
                    values,
                    h,
                    height,
                    width,
                    global_clim,
                    val_min_last_glob,
                    val_max_last_glob,
                    val_min_last_ind,
                    val_max_last_ind,
                    sym_clim,
                    clim,
                )

                ToPlot, self.my_clims[h] = _rescale_toplot(values, ToPlot, self.my_clims[h])

                self.spl["i{0}".format(h)] = self.axarr[h].imshow(
                    ToPlot,
                    interpolation="none",
                    animated=True,
                    vmin=self.my_clims[h][0],
                    vmax=self.my_clims[h][1],
                    cmap=cmap,
                )
            else:
                self.axarr[h].axis(False)

        plt.subplots_adjust(left=0.0, right=1.0, bottom=0.0, top=1.0)

        if fname is not None:
            fname = "%s%s" % (fname, suf)
            self.f.canvas.manager.set_window_title(fname)

        animation.TimedAnimation.__init__(
            self, self.f, interval=ms_per_step, blit=True, repeat=repeat
        )

        if out_dir is not None:
            assert fname is not None, "Figure name not defined"
            if not os.path.exists(out_dir):
                os.makedirs(out_dir)
            animation.TimedAnimation.save(self, "%s/%s.mp4" % (out_dir, fname), writer="ffmpeg")
            print("Wrote " + "%s/%s.mp4" % (out_dir, fname))

    def _draw_frame(self, framedata):
        self._drawn_artists = []
        self.curr_step += 1
        if self.curr_step > self.no_steps:
            self.curr_step = 1
        if self.print_steps:
            print("%d of %d" % (self.curr_step, self.no_steps))
        for h in range(self.no_show):

            if np.ndim(self.values) == 3:
                ToPlot = self.values[self.curr_step - 1, :, h].reshape(self.height, self.width)
            elif np.ndim(self.values) == 4:
                ToPlot = self.values[self.curr_step - 1, :, h, :].reshape(
                    self.height, self.width, 3
                )

                ToPlot = np.minimum(np.maximum(ToPlot, self.my_clims[h][0]), self.my_clims[h][1])
                my_clim = (0.0, 1.0)
                ToPlot = _rescale(ToPlot, my_clim)

            self.spl["i{0}".format(h)].set_array(ToPlot)
            self._drawn_artists.append(self.spl["i{0}".format(h)])

    def new_frame_seq(self):
        return iter(range(self.no_steps - 1))


if __name__ == "__main__":

    values = (
        np.ones((10, 25, 9)) * np.arange(9)[None, None, :] * np.linspace(0, 1, 10)[:, None, None]
    )
    a = matrix_columns_as_image_subplot_animations(values, (1, 9))
    plt.show()
