from typing import List, Optional, Union, Tuple
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mplt
import math
import numpy as np


# Width of the text area in our document. Given in inches
#TEXT_WIDTH = 5.11811  # 13cm
TEXT_WIDTH = 5.50107
# Default aspect ratio of figures
ASPECT = 3 / 4
ASPECT_WIDE = 9 / 16

# Font sizes
SMALL_SIZE = 9
MEDIUM_SIZE = 10
BIGGER_SIZE = 11


def setup_matplotlib(small_size: int = SMALL_SIZE, medium_size: int = MEDIUM_SIZE):
    # Use LaTeX to typeset all text in the figure
    # This obviously needs a working LaTeX installation on the system
    plt.rcParams.update({
        'font.family': 'serif',
        'font.size': small_size,
        'axes.titlesize': medium_size,
        'axes.labelsize': medium_size,
        'xtick.labelsize': medium_size,
        'ytick.labelsize': medium_size,
        'legend.fontsize': medium_size,
        'figure.titlesize': medium_size,
        'text.usetex': True,
        'mathtext.fontset': 'cm',
        'mathtext.rm': 'serif',
        'text.latex.preamble': '\\usepackage{amsmath}\n'
                               '\\usepackage{amssymb}\n'
                               '\\usepackage{lmodern}'
})


def plot_single_mts(mts : pd.DataFrame, columns : int = 1, save : str = '') -> None:

    fig, axs = plt.subplots(columns, int(math.ceil(len(mts.columns)/columns)), sharex=True)

    x = np.linspace(0, mts.shape[0])

    for ax, col in zip(axs, mts.columns):
        ax.plot(x, mts[col])

    if save:
        plt.savefig(save, bbox_inches='tight', pad_inches=0.1, dpi=300)
    else:
        plt.show()
    plt.close()


def plot_multiple_mts(mts : List[pd.DataFrame], columns : int = 1, save : str = '',
                      colors : Optional[List[str]] = None, features : Optional[int] = None) -> None:

    if features is None:
        rows = int(math.ceil(len(mts[0].columns)/columns))
    else:
        rows = min(int(math.ceil(len(mts[0].columns)/columns)), features)

    fig, axs = plt.subplots(rows, columns, sharex=True)

    x = np.arange(0, mts[0].shape[0], 1)

    for ax, col in zip(axs.flatten(), mts[0].columns):
        if colors is None:
            for m in mts:
                ax.plot(x, m[col], linewidth=0.1)
                ax.axis('off')
        else:
            assert len(mts) == len(colors)
            for m, c in zip(mts, colors):
                ax.plot(x, m[col], c=c, linewidth=0.1)
                ax.axis('off')

    if save:
        plt.savefig(save, bbox_inches='tight', pad_inches=0.1, dpi=300)
    else:
        plt.show()

    plt.close()


def plot_sparse_3d_matrix(data : np.array, save : str = ''):

    x_lim, y_lim, z_lim = data.shape

    x_data, y_data, z_data = [], [], []

    for x in range(x_lim):
        for y in range(y_lim):
            for z in range(z_lim):
                if not data[x,y,z] == 0:
                    x_data.append(x)
                    y_data.append(y)
                    z_data.append(z)

    fig = plt.figure()
    ax = fig.gca(projection='3d')

    ax.scatter(x_data, y_data, z_data)
    plt.show()
    plt.close()


def save_figure(file: str, dpi: int = 300, **kwargs):
    '''Simplified wrapper for matplotlib.pyplot.savefig.

    :param file: Name of the file.
    :type file: str
    :param dpi: DPI of resulting figure.
    :type dpi: int
    '''

    plt.savefig(file, bbox_inches='tight', pad_inches=0.1, dpi=dpi, **kwargs)
    plt.close()


def plot_ts_scores(scores : List[float], values : Optional[List[float]] = None, targets : Optional[List[int]] = None,
                   axes: mplt.axes.Axes = None):
    """Plots the predicted scores against the targets.

    :param scores: List of scores
    :type scores: List[float]
    :param values: List of values of a time series of the same length.
    :type values: Optional[List[float]]
    :param targets: List of target values.
    :type targets: Optional[List[int]]
    :param axes: Axes object to plot image to.
    :type axes: matplotlib.axes.Axes
    """

    if axes is None:
        ax = plt.gca()
    else:
        ax = axes

    if targets is not None:

        assert len(scores) == len(targets)

        boundaries = 2 * np.array(targets + [0]) - np.array([0] + targets)

        lbs = np.argwhere(boundaries > 1).flatten()
        rbs = np.argwhere(boundaries < 0).flatten()

        for lb, rb in zip(lbs, rbs):
            ax.axvspan(lb - 0.5, rb + 0.5, facecolor='r', alpha=0.75)

    if values is not None:

        assert len(scores) == len(values)

        ax.plot(list(range(len(values))), (np.array(values) - min(values)) / (max(values) - min(values))  + 1.5, c='k')

    ax.plot(list(range(len(scores))), ((np.array(scores) - min(scores)) / (max(scores) - min(scores))), c='k')

    ax.set_xlim(-0.5, len(targets) - 0.5)
    ax.axis('off')


def plot_error_bars(means : list, deviations : list, ax : Optional[plt.Axes] = None, offset : int = 0, **kwargs):
    """
    TODO: relevant rcParameter
    """

    if not ax:
        ax = plt.gca()

    ax.errorbar([_ + offset for _ in range(len(means))], means, deviations, **kwargs)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.get_xaxis().set_ticks([])


def plot_intervals(intervals : List[int], ax: Optional[mplt.axes.Axes] = None, **kwargs):
    """
    TODO: relevant rcParameter
    """

    if not ax:
        ax = plt.gca()

    boundaries = 2 * np.array(intervals + [0]) - np.array([0] + intervals)

    lbs = np.argwhere(boundaries > 1).flatten()
    rbs = np.argwhere(boundaries < 0).flatten()

    for lb, rb in zip(lbs, rbs):
        ax.axvspan(lb - 0.5, rb + 0.5, **kwargs)


def plot_sequence(sequence : List[float], ax : Optional[mplt.axes.Axes] = None, scatter : bool = True):
    """
    TODO: relevant rcParameter
    axes.prop_cycle mplt.cycler('color', List[colors])
    lines.linewidth float
    lines.linestyle str
    """

    if not ax:
        ax = plt.gca()

    if scatter:
        ax.scatter(list(range(len(sequence))), sequence)
    else:
        ax.plot(list(range(len(sequence))), sequence)


def plot_histogram(data : Optional[List[int]] = None, resolution : int = 100,
                   yticks : Optional[Union[int, List]] = None, ax : Optional[mplt.axes.Axes] = None,
                   hist_range : Tuple[int, int] = (0, 1), xticks : Optional[List[int]] = None , **kwargs):

    if not ax:
        ax = plt.gca()

    ax.hist(data, resolution, hist_range, histtype='stepfilled', **kwargs)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    if xticks:
        ax.get_xaxis().set_ticks(xticks)
    else:
        ax.get_xaxis().set_ticks([])

    if isinstance(yticks, int):
        plt.locator_params(axis='y', nbins=yticks)
    elif isinstance(yticks, list):
        ax.get_xaxis().set_ticks(yticks)
    else:

        ax.spines['left'].set_visible(False)
        ax.get_yaxis().set_ticks([])


def plot_feature_against_anomaly(feature : Optional[List[float]], targets : Optional[List[float]],
                                 ax : Optional[mplt.axes.Axes] = None, xticks : Optional[Union[int, List]] = None,
                                 scatter : bool = True, yticks : bool = True):

    if not ax:
        ax = plt.gca()

    assert len(feature) == len(targets)

    plot_intervals(targets, ax, facecolor='r', alpha=0.5)
    plot_sequence(feature, scatter=scatter)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    if yticks:

        yticks = [round(min(feature), 2), round(max(feature), 2)]
        yticks = [int(tick) if tick.is_integer() else tick for tick in yticks]

        ax.set_yticks(yticks)
        ax.set_yticklabels(yticks)

    else:

        ax.spines['left'].set_visible(False)
        ax.get_yaxis().set_ticks([])

    if isinstance(xticks, int):
        plt.locator_params(axis='x', nbins=xticks)
    elif isinstance(xticks, list):
        ax.get_xaxis().set_ticks(xticks)
    else:

        ax.spines['bottom'].set_visible(False)
        ax.get_xaxis().set_ticks([])


def save_plot(path : str):
    """
    TODO: relevant rcParameter
    """

    plt.savefig(path, bbox_inches='tight')
    plt.close()

