import os

import matplotlib.pyplot as plt


METRICS = ["loss", "acc", "hetero"]
FIGURE_SIZE = (len(METRICS) * 6, 5)
DISPLAY_NAMES = {
    "loss": "Train Loss",
    "acc": "Train Accuracy",
    "hetero": "Heterogeneity",
}


def plot(results, results_dir):

    # Check structure of results.
    for metric in METRICS:
        assert metric in results

    # Get x-axis limits.
    x_min = float('inf')
    x_max = -1 * float('inf')
    for metric in METRICS:
        x_min = min(x_min, min(list(results[metric].keys())))
        x_max = max(x_max, max(list(results[metric].keys())))

    # Plot results.
    fig, axs = plt.subplots(1, len(METRICS), figsize=FIGURE_SIZE)
    for i, metric in enumerate(METRICS):
        xs = sorted(list(results[metric].keys()))
        ys = [results[metric][r] for r in xs]
        axs[i].plot(xs, ys)
        axs[i].set_xlim([x_min, x_max])
        axs[i].set_title(metric)

    plt.savefig(os.path.join(results_dir, "results.png"), bbox_inches="tight")


def compare_plot(all_results, results_dir):

    # Check structure of results.
    for metric in METRICS:
        for results in all_results.values():
            assert metric in results

    # Get x-axis limits.
    x_min = float('inf')
    x_max = -1 * float('inf')
    for metric in METRICS:
        for results in all_results.values():
            x_min = min(x_min, min(list(results[metric].keys())))
            x_max = max(x_max, max(list(results[metric].keys())))

    if not os.path.isdir(results_dir):
        os.makedirs(results_dir)

    # Plot results.
    #fig, axs = plt.subplots(1, len(METRICS), figsize=FIGURE_SIZE)
    for i, metric in enumerate(METRICS):
        plt.clf()
        for j, (name, results) in enumerate(all_results.items()):
            xs = sorted(list(results[metric].keys()))
            ys = [results[metric][r] for r in xs]
            label = DISPLAY_NAMES[name] if name in DISPLAY_NAMES else name
            plt.plot(xs, ys, label=label)
        plt.xlim([x_min, x_max])
        plt.xlabel("Rounds")
        plt.yscale("log")
        title = DISPLAY_NAMES[metric] if metric in DISPLAY_NAMES else metric
        plt.title(title)
        plt.legend()

        plt.savefig(os.path.join(results_dir, f"{metric}_results.png"), bbox_inches="tight")
