# -*- coding: utf-8 -*-
# Copyright (C) 2020 Machine Learning Group of the University of Oldenburg.
# Licensed under the Academic Free License version 3.0

import numpy as np
import torch as to
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from typing import Tuple, Dict, Any, Union
import os


def _get_rcparams(fontsize: int, ticksize: int, use_tex_fonts: bool):
    params: Dict[str, Any] = {
        "axes.labelsize": fontsize,
        "axes.titlesize": fontsize,
        "xtick.labelsize": ticksize,
    }
    if use_tex_fonts:
        params = {
            "text.usetex": True,
            "font.family": "sans-serif",
        }
    return params


def _get_ylim_of_last_perc(ydata: Union[np.ndarray, to.Tensor], factor: float):
    fig, ax = plt.subplots()
    ax.plot(ydata[-int(factor * len(ydata)) :])
    ylim = ax.get_ylim()
    plt.close(fig)
    del fig
    return ylim


class single_lineplot:
    def __init__(
        self,
        ydata: Union[np.ndarray, to.Tensor],
        ylabel: str,
        xlabel: str,
        figure_name: str,
        xdata: np.ndarray = None,
        ydata_target: float = None,
        xlim: Tuple[float, float] = None,
        ylim: Tuple[float, float] = None,
        focus_last_perc_factor: float = None,
        figsize: Tuple[int, int] = None,
        fontsize: int = 18,
        ticksize: int = 10,
        use_tex_fonts: bool = False,
        output_directory: str = None,
        transparent: bool = False,
        output_file_type: str = "png",
    ):
        """
        Simple lineplot of ydata. Optional dashed straight line at ydata_target.

        :param ydata: 1-dim tensor containing ydata to be plotted
        :param ylabel: y-axis label
        :param xlabel: x-axis label
        :param figure_name: Figure name.
        :param xdata: Defaults to np.arange(len(ydata))
        :param ydata_target: Target value (scalar)
        :param xlim: x-axis limits
        :param ylim: y-axis limits
        :param focus_last_perc_factor: Set xlim to zoom into last X percent of iterations
        :param figsize: Figure size (if not specified matplotlib's default is used)
        :param fontsize: Fontsize of axis labels.
        :param ticksize: Fontsize of x- and y-axis ticks.
        :param use_tex_fonts: Set rcParams to use LaTeX fonts
        :param output_directory: Directory to save the figure, e.g. as .png file.
        :param transparent: Store figure with transparency.
        :param output_file_type: Figure extension (jpg, png, ...)
        """
        assert np.ndim(ydata) == 1, "ydata must be a 1-dim np.array"
        plt.rcParams.update(_get_rcparams(fontsize, ticksize, use_tex_fonts))
        self.ydata_target = ydata_target
        self.xlim = xlim
        self.ylim = ylim
        self.default_extend_ylim = 0.05
        self.figure_name = figure_name

        self.fig, self.ax = plt.subplots(figsize=figsize, facecolor="white", edgecolor="none")

        self.update(ydata, xdata, figure_name=figure_name)

        self.ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        self.ax.set_ylabel(ylabel, fontsize=fontsize)
        self.ax.set_xlabel(xlabel, fontsize=fontsize)
        if ticksize is not None:
            plt.tick_params(labelsize=ticksize)
        self.ax.set_frame_on(True)
        self.fig.tight_layout()

        if output_directory is not None:
            if not os.path.exists(output_directory):
                os.makedirs(output_directory)
            plt.savefig(
                f"{output_directory}/{figure_name}.{output_file_type}",
                transparent=transparent,
                bbox_inches="tight",
            )
            print(f"Wrote {output_directory}/{figure_name}.{output_file_type}")

    def update(
        self,
        ydata: Union[np.ndarray, to.Tensor],
        xdata: Union[np.ndarray, to.Tensor] = None,
        focus_last_perc_factor: float = None,
        figure_name: str = None,
        output_directory: str = None,
        transparent: bool = False,
        output_file_type: str = "png",
    ):
        if xdata is not None:
            assert np.ndim(xdata) == 1, "xdata must be None or a 1-dim np.array"
            assert len(xdata) == len(ydata), "xdata must be of same length as ydata"
        if xdata is None:
            xdata = np.arange(len(ydata))

        if not hasattr(self, "l1"):
            (self.l1,) = self.ax.plot(
                xdata,
                ydata,
                color="k",
                linestyle="-",
                linewidth=1,
            )
        else:
            self.l1.set_xdata(xdata)
            self.l1.set_ydata(ydata)

        if self.ydata_target is not None:
            _ones_like_ydata = (
                np.ones_like(ydata) if isinstance(ydata, np.ndarray) else to.ones_like(ydata)
            )
            if not hasattr(self, "l2"):
                (self.l2,) = self.ax.plot(
                    xdata,
                    _ones_like_ydata * self.ydata_target,
                    color="k",
                    linestyle="--",
                    linewidth=1,
                )
            else:
                self.l2.set_xdata(xdata)
                self.l2.set_ydata(_ones_like_ydata * self.ydata_target)

        _xlim = (xdata[0], xdata[-1])
        ydata_np = ydata.numpy() if isinstance(ydata, to.Tensor) else ydata
        all_ydata = (
            np.concatenate((ydata_np, np.array([self.ydata_target])))
            if self.ydata_target is not None
            else ydata_np
        )
        _ylim = (np.min(all_ydata), np.max(all_ydata))
        _extend_ylim = np.diff(_ylim) * self.default_extend_ylim
        _ylim = (_ylim[0] - _extend_ylim, _ylim[1] + _extend_ylim)

        self.ax.set_xlim(_xlim if self.xlim is None and _xlim[0] != _xlim[1] else self.xlim)
        ylim = (
            (_ylim if self.ylim is None else self.ylim)
            if focus_last_perc_factor is None
            else _get_ylim_of_last_perc(all_ydata, focus_last_perc_factor)
        )
        if ylim[0] != ylim[1]:
            self.ax.set_ylim(ylim)

        self.fig.tight_layout()
        if figure_name is not None:
            self.fig.canvas.manager.set_window_title(figure_name)
            _figure_name = figure_name
        else:
            _figure_name = self.figure_name

        if output_directory is not None:
            self.fig.savefig(
                f"{output_directory}/{_figure_name}.{output_file_type}",
                transparent=transparent,
                bbox_inches="tight",
            )
            print(f"Wrote {output_directory}/{_figure_name}.{output_file_type}")


if __name__ == "__main__":
    no_vals = 100
    plt_pause = 0.001
    focus_last_perc_factor = None  # 0.7
    vals = to.from_numpy(np.log(np.linspace(0.001, 0.9, no_vals)))

    plt.ion()
    lineplot = single_lineplot(
        ydata=vals[:1],
        ydata_target=vals[-1],
        ylabel=r"$\mathcal{F}(\mathcal{K},\Theta)\,/\,N$",
        xlabel=r"$\mathrm{Iteration}$",
        figure_name="step0",
    )
    plt.draw()
    plt.show()
    plt.pause(plt_pause)
    for i in range(no_vals):
        lineplot.update(
            ydata=vals[: (i + 1)],
            figure_name=f"step{i}",
            focus_last_perc_factor=focus_last_perc_factor,
        )
        plt.draw()
        plt.show()
        plt.pause(plt_pause)
    plt.ioff()
    # plt.show()
