import math
from typing import List

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from tigramite import plotting as tp

FONT_SIZE_SUPTITLE = 20
FONT_SIZE_AXTITLE = 16
FONT_SIZE_LEGEND = 12


def get_plot_ts(
    ts_synthetic: np.ndarray, graph: np.ndarray, val_matrix: np.ndarray, vmin_edges: float, vmax_edges: float,
    title_plot: str, feature_names: List[str],
) -> Figure:
    fig, axes = plt.subplots(2, 1, figsize=(10, 7))

    axes[0].plot(ts_synthetic, label=feature_names, linewidth=2)
    axes[0].legend(fontsize=FONT_SIZE_LEGEND)

    tp.plot_time_series_graph(
        graph=graph, val_matrix=val_matrix, var_names=feature_names,
        fig_ax=(fig, axes[1]), vmin_edges=vmin_edges, vmax_edges=vmax_edges,
        cmap_edges='Reds', label_fontsize=FONT_SIZE_LEGEND, link_colorbar_label='Coefficient',
    )
    axes[1].axis('off')

    fig.suptitle(f'{title_plot}', fontsize=FONT_SIZE_SUPTITLE)
    return fig


def get_plot_hist(
    ts_real: np.ndarray,
    ts_predicted: np.ndarray,
    title_plot: str,
    feature_names: List[str],
) -> Figure:
    # ts_real.shape = ts_predicted.shape = [seq_len, n_features]

    n_features = len(feature_names)
    sqrt_n_features = math.ceil(math.sqrt(n_features))

    fig, axes = plt.subplots(sqrt_n_features, sqrt_n_features)
    axes = axes.ravel()

    add_label = True
    for ax, ts_real_, ts_predicted_ in zip(axes, ts_real.T, ts_predicted.T):
        ax.hist(ts_real_, color='r', label='Real' if add_label else None, density=True, histtype="step", linewidth=3)
        ax.hist(
            ts_predicted_, color='b', label='Predicted' if add_label else None,
            density=True, histtype="step", linewidth=3
        )
        add_label = False

    for ax, feature_name in zip(axes, feature_names):
        ax.set_title(feature_name, fontsize=FONT_SIZE_AXTITLE)

    fig.legend(fontsize=FONT_SIZE_LEGEND, loc='upper left')
    fig.suptitle(f'{title_plot}', fontsize=FONT_SIZE_SUPTITLE)

    return fig
