import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib import collections as matcoll

# This file implements a plotting style using Matplotlib and seaborn,
# and actually makes plots of Local SGD variants' loss versus communication rounds

###################################################################################################
# Tweaking seaborn to make our curves more beautiful :)
# Seaborn allows us to actually change matplotlob parameters through it
# Inspired by: https://towardsdatascience.com/making-matplotlib-beautiful-by-default-d0d41e3534fd

sns.set(font='Franklin Gothic Book',
        rc={
            'axes.axisbelow': False,
            'axes.edgecolor': 'lightgrey',
            'axes.facecolor': 'None',
            'axes.grid': False,
            'axes.labelcolor': 'dimgrey',
            'axes.spines.right': False,
            'axes.spines.top': False,
            'figure.facecolor': 'white',
            'lines.solid_capstyle': 'round',
            'patch.edgecolor': 'w',
            'patch.force_edgecolor': True,
            'text.color': 'dimgrey',
            'xtick.bottom': False,
            'xtick.color': 'dimgrey',
            'xtick.direction': 'out',
            'xtick.top': False,
            'ytick.color': 'dimgrey',
            'ytick.direction': 'out',
            'ytick.left': False,
            'ytick.right': False})

# setting some global font sizes
sns.set_context("notebook", rc={"font.size": 14,
                                "axes.titlesize": 16,
                                "axes.labelsize": 16})

# Defining colour names
CB91_Blue = '#2CBDFE'
CB91_Green = '#47DBCD'
CB91_Pink = '#F3A0F2'
CB91_Purple = '#9D2EC5'
CB91_Violet = '#661D98'
CB91_Amber = '#F5B14C'
CB91_Black = '#000000'

# Setting default colour for plotting and cycling through them
color_list = [CB91_Blue, CB91_Pink, CB91_Green,
              CB91_Purple, CB91_Black, CB91_Amber, CB91_Violet]
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=color_list)
plt.rcParams.update({'lines.markeredgewidth': 1})
#############################################################################


optimal = 0.3226712387963549

R_values = [1, 5, 10, 50, 100]
T = 1000
y_1 = []
y_2 = []
y_3 = []
y_4 = []
y_5 = []
y_6 = []
y_7 = []
values = [(1, 0.05), (5, 0.05), (10, 0.05), (25, 0.1),
          (50, 0.1), (100, 0.2), (500, 1), (1000, 2)]
for R, lr in values:
    fedsn = ((np.load(f"runtime/{T}/losses_newton_lsgd_{R}_{lr}.npy",
                      allow_pickle=True) - optimal)/optimal)
    y_1.append(np.min(fedsn, axis=1))

    try:
        giant = (np.load(f"runtime/{T}/demo_logis_{R}_a9a.npz")['err'])
        y_2.append(np.min(giant, axis=1))
    except:
        None

    try:
        giant_50 = (np.load(f"runtime/{T}/s_50/demo_logis_{R}_a9a.npz")['err'])
        y_3.append(np.min(giant_50, axis=1))
    except:
        None

    try:
        giant_40 = (np.load(f"runtime/{T}/s_40/demo_logis_{R}_a9a.npz")['err'])
        y_4.append(np.min(giant_40, axis=1))
    except:
        None

    try:
        giant_30 = (np.load(f"runtime/{T}/s_30/demo_logis_{R}_a9a.npz")['err'])
        y_5.append(np.min(giant_30, axis=1))
    except:
        None

    try:
        giant_20 = (np.load(f"runtime/{T}/s_20/demo_logis_{R}_a9a.npz")['err'])
        y_6.append(np.min(giant_20, axis=1))
    except:
        None

    try:
        giant_10 = (np.load(f"runtime/{T}/s_10/demo_logis_{R}_a9a.npz")['err'])
        y_7.append(np.min(giant_10, axis=1))
    except:
        None


# y = np.log10(np.min(fedsn, axis=0))[1:]
# y_up = np.log10(np.max(fedsn, axis=0))[1:]
# plt.plot(
#     np.log10(x_1), y, label="FedSN, T = 10000, K=1", c=color_list[0], zorder=1)
# plt.fill_between(np.log10(x_1), y, y_up,
#                  alpha=0.1, facecolor=color_list[0], zorder=2)


# y = np.log10(np.min(giant, axis=0))[1:21]
# y_up = np.log10(np.max(giant, axis=0))[1:21]
# plt.plot(
#     np.log10(x_2), y, label=r"GIANT, T = 20, $s = \lfloor n/M \rfloor$, q = 20", c=color_list[1], zorder=1)
# plt.fill_between(np.log10(x_2), y, y_up,
#                  alpha=0.1, facecolor=color_list[1], zorder=2)

y_1 = np.array(y_1)
y_2 = np.array(y_2)
y_3 = np.array(y_3)
y_4 = np.array(y_4)
y_5 = np.array(y_5)
y_6 = np.array(y_6)
y_7 = np.array(y_7)

a = [2, 3, 4, 5, 6, 7, 8]

y_1_mean = np.log10(np.mean(y_1, axis=1))[1:]
y_1_err = np.log10(1 + np.std(y_1, axis=1)/np.mean(y_1, axis=1))[1:]
plt.plot(a, y_1_mean, label="FedSN-Lite")
plt.fill_between(a, y_1_mean - y_1_err, y_1_mean + y_1_err, alpha=0.1)

y_2_mean = np.log10(np.mean(y_2, axis=1))[1:]
y_2_err = np.log10(1 + np.std(y_2, axis=1)/np.mean(y_2, axis=1))[1:]
plt.plot(a[0:2], y_2_mean, label="GIANT, s=n/M")
plt.fill_between(a[0:2], y_2_mean - y_2_err, y_2_mean + y_2_err, alpha=0.1)

y_3_mean = np.log10(np.mean(y_3, axis=1))[1:]
y_3_err = np.log10(1 + np.std(y_3, axis=1)/np.mean(y_3, axis=1))[1:]
plt.plot(a[0:2], y_3_mean, label="GIANT, s=50")
plt.fill_between(a[0:2], y_3_mean - y_3_err, y_3_mean + y_3_err, alpha=0.1)

y_4_mean = np.log10(np.mean(y_4, axis=1))[1:]
y_4_err = np.log10(1 + np.std(y_4, axis=1)/np.mean(y_4, axis=1))[1:]
plt.plot(a[0:2], y_4_mean, label="GIANT, s=40", color="red")
plt.fill_between(a[0:2], y_4_mean - y_4_err, y_4_mean + y_4_err, alpha=0.1)

y_5_mean = np.log10(np.mean(y_5, axis=1))[1:]
y_5_err = np.log10(1 + np.std(y_5, axis=1)/np.mean(y_5, axis=1))[1:]
plt.plot(a[0:2], y_5_mean, label="GIANT, s=30")
plt.fill_between(a[0:2], y_5_mean - y_5_err, y_5_mean + y_5_err, alpha=0.1)

# y_6_mean = np.log10(np.mean(y_6, axis=1))[1:]
# y_6_err = np.log10(1 + np.std(y_6, axis=1)/np.mean(y_6, axis=1))[1:]
# plt.plot(a[0:2], y_6_mean, label="GIANT, s=20")
# plt.fill_between(a[0:2], y_6_mean - y_6_err, y_6_mean + y_6_err, alpha=0.1)

# y_7_mean = np.log10(np.mean(y_7, axis=1))[1:]
# y_7_err = np.log10(1 + np.std(y_7, axis=1)/np.mean(y_7, axis=1))[1:]
# plt.plot(a[0:2], y_7_mean, label="GIANT, s=10")
# plt.fill_between(a[0:2], y_7_mean - y_7_err, y_7_mean + y_7_err, alpha=0.1)
# plt.legend()
# plt.xlabel("Number of Communication Rounds")
# plt.ylabel("Best Sub-optimality")
# plt.savefig(f"figures/plot.png", bbox_inches="tight", dpi=400)

# plt.plot(np.log10(x_1), np.log10(fedsn), label="FedSN", zorder=1)
# plt.plot(np.log10(x_2), np.log10(giant), label="GIANT")

plt.xticks([2, 3, 4, 5, 6, 7, 8], [
           "5", "10", "25", "50", "100", "500", "1000"])
plt.legend()
plt.suptitle("Convergence for M = 500 Machines")
plt.xlabel(r"Communication Rounds")
plt.ylabel(r"$\log_{10}$(Best Log(Sub-optimality))")
plt.savefig(f"figures/test.png", bbox_inches="tight", dpi=400)
