import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

import pandas
from channel import StochasticChannel, approximate_drops_and_util


# font size keys from https://stackoverflow.com/questions/3899980/how-to-change-the-font-size-on-a-matplotlib-plot

SMALL_SIZE = 13.5
MEDIUM_SIZE = 15
BIGGER_SIZE = 17

plt.rc('font', size=SMALL_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)  # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


def create_adaptive_return_fig(df, row_filter, msg_size_str_and_labels, msg_types, xlabel="Message sizes",
                               filename="pomnist_adaptive_return.pdf", color_inc=0, spacing=0.05):
    """
    Bar plot that compares different message types
    """
    plt.clf()
    labels = [pair[1] for pair in msg_size_str_and_labels]
    X = np.arange(len(msg_size_str_and_labels))
    bar_width = 1.0 / len(msg_types) - spacing

    # increment color counter by plotting nothing
    for a in range(0, color_inc):
        plt.bar([], [])

    for i, msg_type in enumerate(msg_types):
        msg_type_settings = [row_filter.format(pair[0], msg_type) for pair in msg_size_str_and_labels]
        # reduction with correct order, see
        # https://stackoverflow.com/questions/23414161/pandas-isin-with-output-keeping-order-of-input-list
        res = df.set_index('Settings').loc[msg_type_settings].reset_index()
        if len(res) == 0:
            print("Warning: Number of rows is 0.")

        agent_mean_return = res[res.Metric_Name == "agent_mean_return"].Metric_MeanValue
        agent_mean_return_std = res[res.Metric_Name == "agent_mean_return"].Metric_StandardDeviation

        x_offset = -len(msg_types) * bar_width / 2.0 + (i + 0.5) * bar_width

        plt.bar(X + x_offset, agent_mean_return.values, yerr=agent_mean_return_std.values, align='center',
                ecolor='black', capsize=5, width=bar_width, bottom=0, label=msg_type)

    plt.xticks(X, labels)
    # plt.axis([X[0], X[-1], 0.5, 0.8])
    return_min = 0.6
    return_max = 0.9
    plt.ylim(return_min, return_max)
    plt.grid(axis='y')

    plt.xlabel(xlabel)
    plt.ylabel("Mean agent return")

    if len(msg_types) > 1:
        plt.legend(loc='upper left')

    plt.tight_layout()
    plt.savefig(filename, bbox_inches='tight')


def create_adaptive_drops_util_fig(df, row_filter, msg_size_str_and_labels, msg_type):
    """
    Bar plot that compares drop ratio and channel utilisation.
    """
    plt.clf()
    labels = [pair[1] for pair in msg_size_str_and_labels]
    X = np.arange(len(msg_size_str_and_labels))

    msg_type_settings = [row_filter.format(pair[0], msg_type) for pair in msg_size_str_and_labels]
    res = df[df.Settings.isin(msg_type_settings)]

    num_drops = res[res.Metric_Name == "num_drops"].Metric_MeanValue
    num_drops_std = res[res.Metric_Name == "num_drops"].Metric_StandardDeviation
    channel_util = res[res.Metric_Name == "channel_util"].Metric_MeanValue
    channel_util_std = res[res.Metric_Name == "channel_util"].Metric_StandardDeviation

    # plot values
    bar_width = 1.0 / 2.0 - 0.15
    p1 = plt.bar(X - bar_width / 2, num_drops.values, yerr=num_drops_std.values, align='center', ecolor='black',
                 capsize=5, width=bar_width, bottom=0, color=mcolors.CSS4_COLORS["darkmagenta"])
    ax1 = plt.gca()
    ax2 = ax1.twinx()
    p2 = ax2.bar(X + bar_width / 2, channel_util.values, yerr=channel_util_std.values, align='center', ecolor='black',
                 capsize=5, width=bar_width, bottom=0, color=mcolors.CSS4_COLORS["darkcyan"])

    # calculate and plot baseline values for random msg size
    channel = StochasticChannel(
        nb_slots=8,
        use_msg_size_spacing=True,
    )
    baseline_drops, baseline_util = approximate_drops_and_util(4, (0, 1, 2, 4), channel, num_tests=int(1e6))
    ax1.plot((X - bar_width / 2)[-1], baseline_drops, 'o', color='black')
    ax2.plot((X + bar_width / 2)[-1], baseline_util, 'o', color='black')

    plt.xticks(X, labels)
    plt.grid(axis='y')

    ax1.set_xlabel("Message sizes")
    ax1.set_ylabel("Mean number of drops")
    p1_col = p1.patches[0].get_facecolor()
    ax1.yaxis.label.set_color(p1_col)
    ax1.tick_params(axis='y', colors=p1_col)
    ax2.set_ylabel("Mean throughput")
    p2_col = p2.patches[0].get_facecolor()
    ax2.yaxis.label.set_color(p2_col)
    ax2.tick_params(axis='y', colors=p2_col)

    plt.tight_layout()
    plt.savefig("pomnist_adaptive_drops_util.pdf", bbox_inches='tight')


if __name__ == '__main__':
    row_filter = "1x1_msgSizes={}_2000iter_{}_StochasticSpacing_commSize=8"
    msg_size_str_and_labels = [
        ("(0,)", "{0}"),
        ("(1,)", "{1}"),
        ("(2,)", "{2}"),
        ("(4,)", "{4}"),
        ("(0, 1, 2, 4)", "{0, 1, 2, 4}"),
    ]

    msg_types = [
        "Continuous",
        "PseudoGradient",
        "Discrete",
        "DRU",
    ]

    metrics_file = "./results/paper_pomnist_c8_adaptation.csv"
    df = pandas.read_csv(metrics_file)

    create_adaptive_return_fig(df, row_filter, msg_size_str_and_labels, msg_types)
    create_adaptive_drops_util_fig(df, row_filter, msg_size_str_and_labels[1:], "PseudoGradient")
    plt.close()

    df_random_size = pandas.read_csv("./results/paper_pomnist_c8_random_size.csv")
    # hack for joint table: rename settings
    df_random_size.Settings = df_random_size.Settings.str.replace("(0, 1, 2, 4)_", "random_", regex=False)
    df_zero_content = pandas.read_csv("./results/paper_pomnist_c8_zero_content.csv")
    # hack for joint table: rename settings
    df_zero_content.Settings = df_zero_content.Settings\
        .str.replace("Zeros", "PseudoGradient", regex=False).str.replace("(0, 1, 2, 4)_", "zeros_", regex=False)

    all_df = pandas.concat((df, df_random_size, df_zero_content))

    msg_size_str_and_labels = [
        ("(0,)", "none"),
        ("random", "random"),
        ("zeros", "zeros"),
        ("(0, 1, 2, 4)", "adaptive"),
    ]
    # print(plt.rcParamsDefault["figure.figsize"])
    # squish the figure
    plt.rcParams["figure.figsize"] = (4.1, 4.8)
    create_adaptive_return_fig(all_df, row_filter, msg_size_str_and_labels, ["PseudoGradient"], "Communication method",
                               filename="pomnist_adaptive_ablations.pdf", color_inc=1, spacing=0.3)
    plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"]

