# -*- coding: utf-8 -*-
# Copyright (C) 2020 Machine Learning Group of the University of Oldenburg.
# Licensed under the Academic Free License version 3.0

import numpy as np
import torch as to
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from typing import Tuple, Dict, Any, Union
import os


def _get_rcparams(fontsize: int, ticksize: int, use_tex_fonts: bool):
    params: Dict[str, Any] = {
        "axes.labelsize": fontsize,
        "axes.titlesize": fontsize,
        "xtick.labelsize": ticksize,
    }
    if use_tex_fonts:
        params = {
            "text.usetex": True,
            "font.family": "sans-serif",
        }
    return params


class barlike_plot:
    def __init__(
        self,
        ydata: Union[np.ndarray, to.Tensor],
        figure_name: str,
        xlabel: str,
        ylabel: str,
        inds_sorted: Union[np.ndarray, to.Tensor] = None,
        ylim: Tuple[float, float] = None,
        text: str = None,
        figsize: Tuple[int, int] = None,
        fontsize: int = 18,
        ticksize: int = 10,
        use_tex_fonts: bool = False,
        marker: str = ".",
        markersize: int = 2,
        output_directory: str = None,
        transparent: bool = False,
        output_file_type: str = "png",
    ):
        """
        Barplot-like visualization of values in ydata

        :param ydata: Values to plot (1-dim tensor)
        :param figure_name: Figure name
        :param xlabel: x-axis label
        :param ylabel: y-axis label
        :param inds_sorted: Reorder values according to these indices
        :param ylim: y-axis limits
        :param figsize: Figure size (if not specified matplotlib's default is used)
        :param fontsize: Fontsize of axis labels.
        :param ticksize: Fontsize of x- and y-axis ticks.
        :param use_tex_fonts: Set rcParams to use LaTeX fonts
        :param marker: Marker style
        :param markersize: Marker size
        :param output_directory: Directory to save the figure, e.g. as .png file.
        :param transparent: Store figure with transparency.
        :param output_file_type: Figure extension (jpg, png, ...)
        """
        assert np.ndim(ydata) == 1, "ydata must be a 1-dim np.array"
        self.ylim = ylim
        self.default_extend_ylim = 0.05
        self.fontsize = fontsize
        self.inds_sorted = inds_sorted
        self.marker = marker
        self.markersize = markersize
        self.xdata = np.arange(len(ydata)) + 1
        self.figure_name = figure_name
        plt.rcParams.update(_get_rcparams(fontsize, ticksize, use_tex_fonts))

        self.fig, self.ax = plt.subplots(figsize=figsize, facecolor="white", edgecolor="none")

        self.update(ydata, text, figure_name=figure_name)

        self.ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        self.ax.set_xlabel(xlabel, fontsize=fontsize)
        self.ax.set_ylabel(ylabel, fontsize=fontsize)
        if ticksize is not None:
            plt.tick_params(labelsize=ticksize)
        self.ax.set_frame_on(True)
        self.fig.canvas.manager.set_window_title(figure_name)
        self.fig.tight_layout()

        if output_directory is not None:
            if not os.path.exists(output_directory):
                os.makedirs(output_directory)
            plt.savefig(
                f"{output_directory}/{figure_name}.{output_file_type}",
                transparent=transparent,
                bbox_inches="tight",
            )
            print(f"Wrote {output_directory}/{figure_name}.{output_file_type}")

    def update(
        self,
        ydata: Union[np.ndarray, to.Tensor],
        text: str = None,
        figure_name: str = None,
        output_directory: str = None,
        output_file_type: str = "png",
    ):
        if self.inds_sorted is not None:
            ydata = ydata[self.inds_sorted]
        if not hasattr(self, "l"):
            (self.l,) = self.ax.plot(
                self.xdata,
                ydata,
                color="k",
                linestyle="None",
                marker=self.marker,
                markersize=self.markersize,
            )
        else:
            self.l.set_ydata(ydata)
            ydata_np = ydata.numpy() if isinstance(ydata, to.Tensor) else ydata
            _ylim = (np.min(ydata_np), np.max(ydata_np))
            _extend_ylim = np.diff(_ylim) * self.default_extend_ylim
            _ylim = (_ylim[0] - _extend_ylim, _ylim[1] + _extend_ylim)
            ylim = _ylim if self.ylim is None else self.ylim
            if ylim[0] != ylim[1]:
                self.ax.set_ylim(ylim)

        if text is not None:
            if not hasattr(self, "t"):
                self.t = plt.text(
                    0.8,
                    0.86,
                    text,
                    horizontalalignment="center",
                    verticalalignment="center",
                    transform=self.ax.transAxes,
                    fontdict={
                        "family": "serif",
                        "color": "black",
                        "weight": "normal",
                        "size": self.fontsize,
                    },
                )
            else:
                self.t.set_text(text)

        self.fig.tight_layout()
        if figure_name is not None:
            self.fig.canvas.manager.set_window_title(figure_name)
            _figure_name = figure_name
        else:
            _figure_name = self.figure_name

        if output_directory is not None:
            self.fig.savefig(
                f"{output_directory}/{_figure_name}.{output_file_type}",
                bbox_inches="tight",
            )
            print(f"Wrote {output_directory}/{_figure_name}.{output_file_type}")


if __name__ == "__main__":
    plt_pause = 0.5
    vals = to.rand(
        (
            5,
            500,
        )
    )
    inds_sorted = to.flip(
        to.argsort(vals[-1]),
        dims=[
            0,
        ],
    )
    plt.ion()
    barplot = barlike_plot(
        ydata=vals[0],
        figure_name="step0",
        xlabel=r"$h$",
        ylabel=r"$\pi_h$",
        text=r"$\sum_{h}\,\pi_h = {%.2f}$" % vals[0].sum(),
        inds_sorted=inds_sorted,
    )
    plt.draw()
    plt.show()
    plt.pause(plt_pause)
    for i in range(1, len(vals)):
        barplot.update(
            ydata=vals[i],
            figure_name=f"step{i}",
            text=r"$\sum_{h}\,\pi_h = {%.2f}$" % vals[i].sum(),
        )
        plt.draw()
        plt.show()
        plt.pause(plt_pause)
    plt.ioff()
    plt.show()
