import matplotlib
import numpy as np
import sys

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

from typing import List


def plot_multi_dataset_metrics(
    x_label: str,
    y_label: str,
    fname: str,
    xs: np.ndarray,
    metric_means: np.ndarray,
    metric_stds: np.ndarray,
    datasets: List[str],
    custom_colors: List[str],
):
    """Plots metrics across several datasets.

    Args:
        x_label (str): X-axis label.
        y_label (str): Y-axis label.
        fname (str): File to save plot as.
        xs (np.ndarray): X-axis.
        metric_means (np.ndarray): List of metric means.
        metric_stds (np.ndarray): List of metric stds.
        datasets (List[str]): The datasets that the metrics are associated with.
        custom_colors (List[str]): Custom colors for graphs.
    """
    if len(datasets) != len(metric_means):
        sys.exit("Length of datasets and metrics arrays must be the same.")

    # Plot parameters.
    plt.figure(figsize=(9, 7))
    plt.rc("axes", titlesize=18, labelsize=18)
    plt.rc("xtick", labelsize=15)
    plt.rc("ytick", labelsize=15)
    plt.rc("legend", fontsize=18)
    plt.rc("figure", titlesize=18)

    plt.xlabel(x_label)
    plt.ylabel(y_label)

    for i, dataset in enumerate(datasets):
        color = f"C{i}"
        if custom_colors is not None:
            color = custom_colors[i]
        plt.plot(xs, metric_means[i], label=dataset, color=color)
        if metric_stds is not None:
            # One std area around each curve.
            plt.fill_between(
                xs,
                metric_means[i] - metric_stds[i],
                metric_means[i] + metric_stds[i],
                facecolor=color,
                alpha=0.2,
            )

    if len(datasets) > 1:
        plt.legend(loc="lower left")
    if isinstance(xs, np.ndarray) and xs.dtype == np.int32:
        plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))

    plt.tight_layout()
    plt.savefig(fname)
