# -*- coding: utf-8 -*-
# Copyright (C) 2020 Machine Learning Group of the University of Oldenburg.
# Licensed under the Academic Free License version 3.0

import numpy as np
import torch as to
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from typing import Tuple, Union, Dict, Any, List
import os


def _check_ydata(ydata: Union[np.ndarray, to.Tensor]):
    if type(ydata) == np.ndarray:
        assert np.ndim(ydata) == 2, "ydata must be a 2-dim np.array"
    elif type(ydata) == to.Tensor:
        assert ydata.dim() == 2, "ydata must be a 2-dim to.Tensor"


def _get_rcparams(fontsize: int, ticksize: int, use_tex_fonts: bool):
    params: Dict[str, Any] = {
        "axes.labelsize": fontsize,
        "axes.titlesize": fontsize,
        "xtick.labelsize": ticksize,
    }
    if use_tex_fonts:
        params = {
            "text.usetex": True,
            "font.family": "sans-serif",
        }
    return params


def _get_target_to_plot(
    ydata_target: Union[None, np.ndarray, to.Tensor],
    ydata: Union[np.ndarray, to.Tensor],
):
    if ydata_target is None:
        t = None
    else:
        if type(ydata_target) == np.ndarray:
            assert np.ndim(ydata_target) == 1, "ydata_target must be a 1-dim np.ndarray"
            assert type(ydata) == np.ndarray
            t = np.ones((ydata.shape[0], len(ydata_target))) * ydata_target
        elif type(ydata_target) == to.Tensor:
            assert ydata_target.dim() == 1, "ydata_target must be a 1-dim to.Tensor"
            assert type(ydata) == to.Tensor
            t = to.ones((ydata.shape[0], len(ydata_target))) * ydata_target
        else:
            t = None
    return t


class multi_lineplots:
    def __init__(
        self,
        ydata: Union[np.ndarray, to.Tensor],
        ylabel: str,
        xlabel: str,
        figure_name: str,
        ydata_target: Union[np.ndarray, to.Tensor] = None,
        labels: Union[Tuple[str, ...], List[str]] = None,
        xlim: Tuple[float, float] = None,
        ylim: Tuple[float, float] = None,
        figsize: Tuple[int, int] = None,
        fontsize: int = 18,
        ticksize: int = 10,
        use_tex_fonts: bool = False,
        output_directory: str = None,
        transparent: bool = False,
        output_file_type: str = "png",
    ):
        """
        Simple lineplots of rows of ydata. Optional dashed straight lines at ydata_target.

        :param ydata: Values to plot (2-dim tensor of shape (no_steps, no_lines)).
        :param ylabel: y-axis label
        :param xlabel: x-axis label
        :param ydata_target: Target values (1-dim tensor of length no_lines).
        :param labels: Line labels
        :param xlim: x-axis limits
        :param ylim: y-axis limits
        :param figsize: Figure size (if not specified matplotlib's default is used)
        :param fontsize: Fontsize of axis labels.
        :param ticksize: Fontsize of x- and y-axis ticks.
        :param use_tex_fonts: Set rcParams to use LaTeX fonts
        :param figure_name: Figure name.
        :param output_directory: Directory to save the figure, e.g. as .png file.
        :param transparent: Store figure with transparency.
        :param output_file_type: Figure extension (jpg, png, ...)
        """
        _check_ydata(ydata)
        plt.rcParams.update(_get_rcparams(fontsize, ticksize, use_tex_fonts))
        self.ydata_target = ydata_target
        self.xlim = xlim
        self.ylim = ylim
        self.default_extend_ylim = 0.05
        self.figure_name = figure_name

        self.fig, self.ax = plt.subplots(figsize=figsize, facecolor="white", edgecolor="none")

        self.update(ydata, figure_name=figure_name)

        if labels is not None:
            plt.legend(iter(self.lines), labels)

        self.ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        self.ax.set_ylabel(ylabel, fontsize=fontsize)
        self.ax.set_xlabel(xlabel, fontsize=fontsize)
        if ticksize is not None:
            plt.tick_params(labelsize=ticksize)
        self.ax.set_frame_on(True)
        self.fig.tight_layout()

        if output_directory is not None:
            if not os.path.exists(output_directory):
                os.makedirs(output_directory)
            plt.savefig(
                f"{output_directory}/{figure_name}.{output_file_type}",
                transparent=transparent,
                bbox_inches="tight",
            )
            print(f"Wrote {output_directory}/{figure_name}.{output_file_type}")

    def update(
        self,
        ydata: Union[np.ndarray, to.Tensor],
        figure_name: str = None,
        output_directory: str = None,
        transparent: bool = False,
        output_file_type: str = "png",
    ):
        target_to_plot = _get_target_to_plot(self.ydata_target, ydata)
        if not hasattr(self, "lines"):
            self.lines = self.ax.plot(
                ydata,
                color=None,
                linestyle="-",
                linewidth=1,
            )
            if target_to_plot is not None:
                self.ax.set_prop_cycle(None)  # reset matplotlib's color cycle
                self.lines_target = self.ax.plot(
                    target_to_plot,
                    color=None,
                    linestyle="--",
                    linewidth=1,
                )
        else:
            xdata = np.arange(len(ydata[:, 0]))
            for i, l in enumerate(self.lines):
                l.set_xdata(xdata)
                l.set_ydata(ydata[:, i])
                if target_to_plot is not None and i < len(self.lines_target):
                    self.lines_target[i].set_xdata(xdata)
                    self.lines_target[i].set_ydata(target_to_plot[:, i])

            _xlim = (xdata[0], xdata[-1])
            ydata_np = ydata.numpy() if isinstance(ydata, to.Tensor) else ydata
            if target_to_plot is not None:
                target_np = (
                    target_to_plot.numpy()
                    if isinstance(target_to_plot, to.Tensor)
                    else target_to_plot
                )
            else:
                target_np = None

            all_ydata = (
                np.concatenate((ydata_np, target_np), axis=1) if target_np is not None else ydata_np
            )
            _ylim = (np.min(all_ydata), np.max(all_ydata))
            _extend_ylim = np.diff(_ylim) * self.default_extend_ylim
            _ylim = (_ylim[0] - _extend_ylim, _ylim[1] + _extend_ylim)

            self.ax.set_xlim(_xlim if self.xlim is None and _xlim[0] != _xlim[1] else self.xlim)
            ylim = _ylim if self.ylim is None else self.ylim
            if ylim[0] != ylim[1]:
                self.ax.set_ylim(ylim)

        self.fig.tight_layout()
        if figure_name is not None:
            self.fig.canvas.manager.set_window_title(figure_name)
            _figure_name = figure_name
        else:
            _figure_name = self.figure_name

        if output_directory is not None:
            self.fig.savefig(
                f"{output_directory}/{_figure_name}.{output_file_type}",
                transparent=transparent,
                bbox_inches="tight",
            )
            print(f"Wrote {output_directory}/{_figure_name}.{output_file_type}")


if __name__ == "__main__":
    plt_pause = 0.001
    no_vals = 100
    ydata_target = np.linspace(0, no_vals, 4)
    ydata = np.random.normal(
        scale=5,
        loc=np.ones((len(ydata_target), no_vals)) * np.tile(ydata_target[:, None], (1, no_vals)),
    )
    ydata = to.from_numpy(ydata.T)
    ydata_target = to.from_numpy(ydata_target)

    plt.ion()
    lineplots = multi_lineplots(
        ydata=ydata,
        ydata_target=ydata_target,
        labels=[f"{x}" for x in range(len(ydata_target))],
        ylabel=r"$\mathcal{F}(\mathcal{K},\Theta)\,/\,N$",
        xlabel=r"$\mathrm{Iteration}$",
        figure_name="step0",
    )
    plt.draw()
    plt.show()
    plt.pause(plt_pause)
    for i in range(1, no_vals + 1):
        lineplots.update(
            ydata=ydata[: (i + 1)],
            figure_name=f"step{i}",
        )
        plt.draw()
        plt.show()
        plt.pause(plt_pause)
    plt.ioff()
    plt.show()
