import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(
    font="Franklin Gothic Book",
    rc={
        "axes.axisbelow": False,
        "axes.edgecolor": "lightgrey",
        "axes.facecolor": "None",
        "axes.grid": False,
        "axes.labelcolor": "dimgrey",
        "axes.spines.right": False,
        "axes.spines.top": False,
        "figure.facecolor": "white",
        "lines.solid_capstyle": "round",
        "patch.edgecolor": "w",
        "patch.force_edgecolor": True,
        "text.color": "dimgrey",
        "xtick.bottom": False,
        "xtick.color": "dimgrey",
        "xtick.direction": "out",
        "xtick.top": False,
        "ytick.color": "dimgrey",
        "ytick.direction": "out",
        "ytick.left": False,
        "ytick.right": False,
    },
)

# setting some global font sizes
sns.set_context(
    "notebook", rc={"font.size": 14, "axes.titlesize": 16, "axes.labelsize": 16}
)
sns.set_style("darkgrid")

# Defining colour names
CB91_Blue = "#2CBDFE"
CB91_Green = "#47DBCD"
CB91_Pink = "#F3A0F2"
CB91_Purple = "#9D2EC5"
CB91_Violet = "#661D98"
CB91_Amber = "#F5B14C"

# Setting default colour for plotting and cycling through them
color_list = [
    CB91_Violet,
    CB91_Blue,
    CB91_Green,
    CB91_Amber,
    CB91_Purple,
    CB91_Pink,
]
plt.rcParams["axes.prop_cycle"] = plt.cycler(color=color_list)

K = 10
R = 10
x = list(range(K * R))
y_0 = []
y_1 = []
y_2 = []
y_3 = []
for r in range(R):
    y_0.append(0)
    for k in range(K):
        y_1.append(1 - 0.5**k)
        y_2.append(0.5 * (1 - 0.5**k) * (0.5**r))
print(x)
print(y_1)
plt.scatter(
    [10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
    y_0,
    marker="H",
    color=CB91_Amber,
    label="Communication Round",
)
plt.plot(x, y_1, label="Local SGD w/o Personalization", linewidth=1.5)
plt.plot(x, y_2, label="Local SGD w/ Personalization", linewidth=1.5)
plt.xlabel("Time(t)")
plt.ylabel(r"$\xi_t=\frac{1}{M}\sum_{m\in[M]}\|\|w_t^m-w_t\|\|^2$")
plt.legend()
plt.tight_layout()
plt.savefig("consensus_hard.png", dpi=300)
