# -*- coding: utf-8 -*-

import numpy as np
import torch as to
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from typing import Tuple, Union, List, Dict, Any, Optional
import os


def _get_rcparams(fontsize: int, ticksize: int, use_tex_fonts: bool):
    params: Dict[str, Any] = {
        "axes.labelsize": fontsize,
        "axes.titlesize": fontsize,
        "xtick.labelsize": ticksize,
    }
    if use_tex_fonts:
        params = {
            "text.usetex": True,
            "font.family": "sans-serif",
        }
    return params


def _get_ylim_of_last_x(ydata: Union[np.ndarray, to.Tensor], focus_last_x_factor: float):
    fig, ax = plt.subplots()
    ax.plot(ydata[-int(focus_last_x_factor * len(ydata)) :])
    ylim = ax.get_ylim()
    plt.close(fig)
    del fig
    return ylim


class twin_lineplot:
    def __init__(
        self,
        ydata1: Union[List[int], to.Tensor, np.ndarray],
        ydata2: Union[List[int], to.Tensor, np.ndarray],
        xdata1: Union[List[int], to.Tensor, np.ndarray] = None,
        xdata2: Union[List[int], to.Tensor, np.ndarray] = None,
        figure_name: Optional[str] = None,
        xlabel: str = "",
        ylabel1: str = "",
        ylabel2: str = "",
        linestyle1: str = "--",
        linestyle2: str = "--",
        marker1: str = "*",
        marker2: str = "o",
        markersize: int = 1,
        focus1_last_x: float = 1.0,
        focus2_last_x: float = 1.0,
        fontsize: int = 18,
        ticksize: int = 10,
        xlim: Tuple[float, float] = None,
        use_tex_fonts: bool = False,
        output_directory: str = None,
        output_file_type: str = "png",
        figsize: Tuple[int, int] = None,
        dpi: int = 96,
    ):
        """
        Two lineplots with separate y-axis.

        :param xdata1: Data1 iteration numbers to plot. Defaults to np.arange(len(xdata1)).
        :param ydata1: Data1 values to plot (1-dim tensor of length no_iterations).
        :param xdata2: Data2 iteration numbers to plot. Defaults to np.arange(len(xdata2)).
        :param ydata2: Data2 values to plot (1-dim tensor of length no_iterations).
        :param figure_name: Figure name.
        :param focus1_last_x: Optimize ylim of Data1 curve based on last x percent of epochs
        :param focus1_last_x: Optimize ylim of Data2 curve based on last x percent of epochs
        :param xlabel: x-axis label
        :param ylabel1: label of Data1 y-axis
        :param ylabel2: label of Data2 y-axis
        :param fontsize: Fontsize of axis labels.
        :param ticksize: Fontsize of x- and y-axis ticks.
        :param linestyle1: Line style for Data1 curve
        :param linestyle2: Line style for Data2 curve
        :param marker1: Marker style for Data1 curve
        :param marker2: Marker style for Data2 curve
        :param markersize: Size of markers illustrating Data1 and Data2 values
        :param xlim: Limits for y-axis.
        :param use_tex_fonts: Set rcParams to use LaTeX fonts
        :param output_directory: Directory to save the figure, e.g. as .png file.
        :param output_file_type: Figure extension (jpg, png, ...)
        :param figsize: Figure size (if not specified matplotlib's default is used)
        :param dpi: Resolution in DPI passed to plt.subplots
        """
        plt.rcParams.update(_get_rcparams(fontsize, ticksize, use_tex_fonts))

        self.fig, self.ax1 = plt.subplots(
            figsize=figsize, dpi=dpi, facecolor="white", edgecolor="none"
        )

        xdata1 = np.arange(len(ydata1)) if xdata1 is None else xdata1
        (self.l1,) = self.ax1.plot(
            xdata1,
            ydata1,
            color="k",
            linestyle=linestyle1,
            marker=marker1,
            markersize=markersize,
        )
        if focus1_last_x is not None:
            self.ax1.set_ylim(_get_ylim_of_last_x(ydata1, focus1_last_x))

        self.figure_name = figure_name
        if figure_name is not None:
            self.fig.canvas.manager.set_window_title(figure_name)

        self.ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
        if xlim is not None:
            self.ax1.set_xlim(xlim)
        self.ax1.set_ylabel(ylabel1, fontsize=fontsize)
        self.ax1.set_xlabel(xlabel, fontsize=fontsize)
        self.ax1.tick_params(axis="y", labelcolor="k")
        if ticksize is not None:
            plt.tick_params(labelsize=ticksize)

        self.ax1.set_frame_on(True)

        self.ax2 = self.ax1.twinx()
        xdata2 = np.arange(len(ydata2)) if xdata2 is None else xdata1
        (self.l2,) = self.ax2.plot(
            xdata2,
            ydata2,
            color="b",
            linestyle=linestyle2,
            marker=marker2,
            markersize=markersize,
        )
        if focus2_last_x is not None:
            self.ax2.set_ylim(_get_ylim_of_last_x(ydata2, focus2_last_x))

        self.ax2.set_ylabel(ylabel2, color="b")
        self.ax2.tick_params(axis="y", labelcolor="b")

        if output_directory is not None:
            assert figure_name is not None, "must provide figure_name"
            if not os.path.exists(output_directory):
                os.makedirs(output_directory)

            self.fig.savefig(
                f"{output_directory}/{figure_name}.{output_file_type}",
                bbox_inches="tight",
            )
            print(f"Wrote {output_directory}/{figure_name}.{output_file_type}")

    def update(
        self,
        ydata1: Union[List[int], to.Tensor, np.ndarray],
        ydata2: Union[List[int], to.Tensor, np.ndarray],
        xdata1: Union[List[int], to.Tensor, np.ndarray] = None,
        xdata2: Union[List[int], to.Tensor, np.ndarray] = None,
        figure_name: Optional[str] = None,
        xlim: Tuple[float, float] = None,
        focus1_last_x: float = 1.0,
        focus2_last_x: float = 1.0,
        output_directory: str = None,
        output_file_type: str = "png",
    ):
        assert hasattr(self, "l1")
        xdata1 = np.arange(len(ydata1)) if xdata1 is None else xdata1
        self.l1.set_xdata(xdata1)
        self.l1.set_ydata(ydata1)

        assert hasattr(self, "l2")
        xdata2 = np.arange(len(ydata2)) if xdata2 is None else xdata2
        self.l2.set_xdata(xdata2)
        self.l2.set_ydata(ydata2)

        if focus1_last_x is not None:
            self.ax1.set_ylim(_get_ylim_of_last_x(ydata1, focus1_last_x))
        if focus2_last_x is not None:
            self.ax2.set_ylim(_get_ylim_of_last_x(ydata2, focus2_last_x))

        _xlim = (xdata1[0], xdata1[-1])
        self.ax1.set_xlim(_xlim if xlim is None and _xlim[0] != _xlim[1] else xlim)

        self.fig.tight_layout()
        if figure_name is not None:
            self.fig.canvas.manager.set_window_title(figure_name)

        self.fig.tight_layout()
        if figure_name is not None:
            self.fig.canvas.manager.set_window_title(figure_name)
            _figure_name = figure_name
        else:
            assert self.figure_name is not None
            _figure_name = self.figure_name

        if output_directory is not None:
            assert _figure_name is not None
            self.fig.savefig(
                f"{output_directory}/{_figure_name}.{output_file_type}",
                bbox_inches="tight",
            )
            print(f"Wrote {output_directory}/{_figure_name}.{output_file_type}")


if __name__ == "__main__":
    plt.ion()
    TwinLineplot = twin_lineplot(
        ydata1=np.random.random((10,)),
        ydata2=np.random.random((10,)),
        ylabel1=r"$-\mathrm{ELBO}\,/\,N$",
        ylabel2=r"$\mathrm{PSNR}$",
        xlabel="Epoch",
    )

    plt.draw()
    plt.show()
    plt.pause(1.0)
    for _ in range(5):
        TwinLineplot.update(ydata1=np.random.random((10,)), ydata2=np.random.random((10,)))
        plt.draw()
        plt.show()
        plt.pause(1.0)

    plt.ioff()
