import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib import collections as matcoll
import os

# This file implements a plotting style using Matplotlib and seaborn,
# and actually makes plots of Local SGD variants' loss versus communication rounds

###################################################################################################
# Tweaking seaborn to make our curves more beautiful :)
# Seaborn allows us to actually change matplotlob parameters through it
# Inspired by: https://towardsdatascience.com/making-matplotlib-beautiful-by-default-d0d41e3534fd

sns.set(font='Franklin Gothic Book',
        rc={
            'axes.axisbelow': False,
            'axes.edgecolor': 'black',
            'axes.facecolor': 'None',
            'axes.grid': False,
            'axes.labelcolor': 'black',
            'axes.spines.right': False,
            'axes.spines.top': False,
            'figure.facecolor': 'white',
            'lines.solid_capstyle': 'round',
            'patch.edgecolor': 'w',
            'patch.force_edgecolor': True,
            'text.color': 'black',
            'xtick.bottom': False,
            'xtick.color': 'black',
            'xtick.direction': 'out',
            'xtick.top': False,
            'ytick.color': 'black',
            'ytick.direction': 'out',
            'ytick.left': False,
            'ytick.right': False})

# setting some global font sizes
sns.set_context("notebook", rc={"font.size": 15,
                                "axes.titlesize": 15,
                                "axes.labelsize": 15})

# Defining colour names
Blue = '#0000FF'
Light_Blue = '#add8e6'
Green = '#008000'
Light_Green = '#32CD32'
Pink = '#FFC0CB'
Purple = '#800080'
Light_Purple = '#CBC3E3'
Violet = '#8F00FF'
Light_Violet = '#CC99FF'
CB91_Amber = '#F5B14C'
Red = '#FF0000'
Light_Red = '#ffcccb'
Yellow = '#FFFF00'
Black = '#000000'
Gray = '#808080'
Orange = '#FFA500'
Light_Orange = '#FED8B1'

# Setting default colour for plotting and cycling through them
color_list = [Red, Light_Red, Black, Gray,
              Green, Light_Green, Purple, Light_Purple]
# color_list = [Red, Black, Gray,
#               Green, Purple]
# color_list = sns.color_palette("colorblind")
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=color_list)
plt.rcParams.update({'lines.markeredgewidth': 1})

#############################################################################

T = 100

R_values = [1, 5, 10, 25, 50, 100]
M_values = [20, 50, 100, 200]
# M_values = [100, 200]
optimal = 0.32262070790219677

x = np.arange(len(R_values[1:]))

fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(20, 3))
# fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 4.5))

i = 0
for M in M_values:
    Losses = np.load(
        f"results/{T}/Losses_{str(M)}.npy", allow_pickle=True).item()

    newton = ((np.array(Losses["Newton"])-optimal)/optimal)[1:]
    newton_w_M = ((np.array(Losses["Newton_w_M"])-optimal)/optimal)[1:]
    fedac1 = ((np.array(Losses["FEDAC1"])-optimal)/optimal)[1:]
    fedac2 = ((np.array(Losses["FEDAC2"])-optimal)/optimal)[1:]
    mbsgd = ((np.array(Losses["MBSGD"])-optimal)/optimal)[1:]
    mbsgd_w_M = ((np.array(Losses["MBSGD_w_M"])-optimal)/optimal)[1:]
    lsgd = ((np.array(Losses["LSGD"])-optimal)/optimal)[1:]
    lsgd_w_M = ((np.array(Losses["LSGD_w_M"])-optimal)/optimal)[1:]

    y = np.log10(np.mean(newton_w_M, axis=1))
    y_err = np.log10(1 + np.std(newton_w_M, axis=1) /
                     np.mean(newton_w_M, axis=1))
    axs[i].errorbar(x, y, yerr=y_err, fmt='-o',
                    label="FedSN-Lite w/ Mom. (Our Method)", capsize=3, linewidth=2.5)

    y = np.log10(np.mean(newton, axis=1))
    y_err = np.log10(1 + np.std(newton, axis=1) /
                     np.mean(newton, axis=1))
    axs[i].errorbar(x, y, yerr=y_err, fmt='-o',
                    label="FedSN-Lite (Our Method)", capsize=3, linestyle="dotted", linewidth=1)

    y = np.log10(np.mean(fedac2, axis=1))
    y_err = np.log10(1 + np.std(fedac2, axis=1) /
                     np.mean(fedac2, axis=1))
    axs[i].errorbar(x, y, yerr=y_err, fmt='-o', label="FedAC-2 (Optimally Regularized)",
                    capsize=3, linewidth=2)

    y = np.log10(np.mean(fedac1, axis=1))
    y_err = np.log10(1 + np.std(fedac1, axis=1) /
                     np.mean(fedac1, axis=1))
    axs[i].errorbar(x, y, yerr=y_err, fmt='-o', label="FedAC-1 (Optimally Regularized)",
                    capsize=3, linewidth=2)

    y = np.log10(np.mean(mbsgd_w_M, axis=1))
    y_err = np.log10(1 + np.std(mbsgd_w_M, axis=1) /
                     np.mean(mbsgd_w_M, axis=1))
    axs[i].errorbar(x, y, yerr=y_err, fmt='-o', label="MB SGD w/ Mom.",
                    capsize=3, linewidth=2)

    y = np.log10(np.mean(mbsgd, axis=1))
    y_err = np.log10(1 + np.std(mbsgd, axis=1)/np.mean(mbsgd, axis=1))
    axs[i].errorbar(x, y, yerr=y_err, fmt='-o', label="MB SGD",
                    capsize=3, linestyle="dotted", linewidth=1)

    y = np.log10(np.mean(lsgd_w_M, axis=1))
    y_err = np.log10(1 + np.std(lsgd_w_M, axis=1) /
                     np.mean(lsgd_w_M, axis=1))
    axs[i].errorbar(x, y, yerr=y_err, fmt='-o', label="Local SGD w/ Mom.",
                    capsize=3, linewidth=2)

    y = np.log10(np.mean(lsgd, axis=1))
    y_err = np.log10(1 + np.std(lsgd, axis=1)/np.mean(lsgd, axis=1))
    axs[i].errorbar(x, y, yerr=y_err, fmt='-o', label="Local SGD",
                    capsize=3, linestyle="dotted", linewidth=1)

    axs[i].xaxis.set_ticklabels([1, 5, 10, 25, 50, 100])

    i += 1

cols = ["M="+str(M) for M in M_values]

for ax, col in zip(axs, cols):
    ax.set_title(col)


# plt.legend(loc='upper center', bbox_to_anchor=(
#     -0.2, 1.55), fancybox=False, shadow=False, ncol=2, prop={'size': 15})

plt.subplots_adjust(wspace=0.20,
                    hspace=0.05, left=0.2)

fig.text(0.55, -0.05, 'Number of Communication Rounds', ha='center')
fig.text(0.165, 0.5, 'Best Log(Rel. Sub-optimality)',
         va='center', rotation='vertical')


plt.savefig(f"figures/{T}/plot_w_rep.png", dpi=500,
            orientation="portrait", bbox_inches='tight')

# plt.savefig(f"figures/{T}/plot_w_rep_small.png", dpi=500,
#             orientation="portrait", bbox_inches='tight')
