import os

import numpy as np
import matplotlib.pyplot as plt

import matplotlib as mpl

import pickle

import seaborn as sns


mpl.rcParams.update({'font.size': 12})
palette = sns.color_palette("colorblind")

stream_length = 500000
global_lr = 0.01
local_lr = 1e-3
momentum_coef = 0.1
h = "heterogeneous"

metric = "grad_norm"
algos = ["local_sgd", "minibatch_sgd", "local_sgd_momentum"]
colors = {
    "local_sgd": palette[0],
    "minibatch_sgd": palette[1],
    "local_sgd_momentum": palette[2],
}

colors_speed_up = {
    10: palette[7],
    100: palette[8],
    1000: palette[9],
}

colors_independent_batch = {
    True: palette[0],
    False: palette[1],
}

markers = {
    "local_sgd": "o",
    "minibatch_sgd": "x",
    "local_sgd_momentum": "+",
}

label = {
    "local_sgd": "Local SGD",
    "minibatch_sgd": "Minibatch SGD",
    "local_sgd_momentum": "Local SGD-M",
}

label_independent_batch = {
    True: "Independent Batch",
    False: "Dependent Batch",
}


def ema(scalars, weight=0.5):
    last = scalars[0]
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point
        smoothed.append(smoothed_val)
        last = smoothed_val

    return np.array(smoothed)


def smooth(scalars, weight=0.1, start=0):
    return np.concatenate((scalars[:start], ema(scalars[start:], weight)))


def smooth_upper(scalars, variance, weight=0.1, start=0):
    return smooth(scalars + variance, weight=weight, start=start)


def smooth_lower(scalars, variance, min_val, weight=0.1, start=0):
    return smooth(np.maximum(scalars - variance, min_val), weight=weight, start=start)


def plot_grad_norm_trajectories(
    res_dir: str,
):
    n_clients_list = [10]
    mixing_time_list = [100]
    local_steps_list = [10, 100, 1000]

    plot_dir = "../figures/grad_norm_trajectories"
    os.makedirs(plot_dir, exist_ok=True)

    for n_clients in n_clients_list:
        for mixing_time in mixing_time_list:
            for local_steps in local_steps_list:
                f_name = f"mixing_time={mixing_time},local_lr={local_lr},global_lr={global_lr},momentum={momentum_coef},local_steps={local_steps},n_communications={stream_length // local_steps},n_clients={n_clients}.pkl"
                grad_norm_mean, grad_norm_std, grad_norm_min = dict(), dict(), dict()

                with open(os.path.join(res_dir, f_name), "rb") as f:
                    results = pickle.load(f)

                    for algo in algos:
                        grad_norm = results[algo]["grad_norm"]
                        grad_norm_mean[algo] = grad_norm.mean(axis=0)
                        grad_norm_std[algo] = grad_norm.std(axis=0)
                        grad_norm_min[algo] = grad_norm.min(axis=0)

                # plot
                fig, ax = plt.subplots(1, 1, figsize=(4, 3))
                ax.set_yscale("log", base=10)
                plt.xlabel("Communication rounds", )
                # set ylim by hand here
                if n_clients == 10:
                    ax.set_ylim(3e-6, 5)
                else:
                    ax.set_ylim(3e-7, 5)

                idx = np.linspace(0, len(grad_norm_mean["local_sgd"]) - 1,
                                  min(len(grad_norm_mean["local_sgd"]), 200), dtype=int)

                for algo in algos:
                    plt.plot(idx, smooth(grad_norm_mean[algo][idx], weight=0.5),
                             color=colors[algo], marker=markers[algo],
                             label=label[algo], markevery=len(idx) // 10)

                    plt.fill_between(idx,
                                     smooth_lower(grad_norm_mean[algo][idx], grad_norm_std[algo][idx], grad_norm_min[algo][idx], weight=0.5),
                                     smooth_upper(grad_norm_mean[algo][idx], grad_norm_std[algo][idx], weight=0.5),
                                     color=colors[algo], edgecolor=colors[algo], alpha=0.5)

                if local_steps == 1000:
                    plt.legend(loc="best")

                plot_subdir = f"{plot_dir}/mixing_time={mixing_time}"
                os.makedirs(plot_subdir, exist_ok=True)
                plt.savefig(f"{plot_subdir}/M={n_clients},K={local_steps}.pdf", bbox_inches='tight')
                plt.close()


def plot_mixing_time(
    res_dir: str
):
    plot_dir = "../figures/effect_mixing_time"

    mixing_times_effect_list = [2, 10, 100, 1000]
    n_clients_list_mixing_time = [10, 100]
    local_steps_list = [10, 100, 1000]

    for n_clients in n_clients_list_mixing_time:
        for local_steps in local_steps_list:
            fig, ax = plt.subplots(1, 1, figsize=(4, 3))
            ax.set_yscale("log", base=10)
            ax.set_xscale("log", base=10)
            ax.set_xlabel("Mixing time")
            if n_clients == 100:
                ax.set_ylim(1e-6, 1e-1)

            last_grad_norm_mean, last_grad_norm_std, last_grad_norm_min = dict(), dict(), dict()
            for algo in algos:
                last_grad_norm_mean[algo], last_grad_norm_std[algo], last_grad_norm_min[algo] = (np.zeros(len(mixing_times_effect_list)),
                                                                                                 np.zeros(len(mixing_times_effect_list)),
                                                                                                 np.zeros(len(mixing_times_effect_list)))

            for i, mixing_time in enumerate(mixing_times_effect_list):
                f_name = f"mixing_time={mixing_time},local_lr={local_lr},global_lr={global_lr},momentum={momentum_coef},local_steps={local_steps},n_communications={stream_length // local_steps},n_clients={n_clients}.pkl"
                with open(os.path.join(res_dir, f_name), "rb") as f:
                    results = pickle.load(f)

                for algo in algos:
                    grad_norm = results[algo]["grad_norm"]
                    last_grad_norm_mean[algo][i] = grad_norm[:, -10:].mean(axis=-1).mean()
                    last_grad_norm_std[algo][i] = grad_norm[:, -10].mean(axis=-1).std()
                    last_grad_norm_min[algo][i] = grad_norm[:, -10].mean(axis=-1).min()

            for algo in algos:
                plt.plot(mixing_times_effect_list, last_grad_norm_mean[algo], color=colors[algo], marker=markers[algo],
                         label=label[algo])
                plt.fill_between(mixing_times_effect_list,
                                 last_grad_norm_mean[algo] + last_grad_norm_std[algo],
                                 np.maximum(last_grad_norm_mean[algo] - last_grad_norm_std[algo], last_grad_norm_min[algo]),
                                 color=colors[algo], edgecolor=colors[algo], alpha=0.5)
            if local_steps == 1000:
                plt.legend(loc="best")

            plot_subdir = f"{plot_dir}/M={n_clients}"
            os.makedirs(plot_subdir, exist_ok=True)
            plt.savefig(f"{plot_subdir}/K={local_steps}.pdf", bbox_inches='tight')


def plot_speed_up(
    res_dir: str
):
    plot_dir = "../figures/speed_up"

    speed_up_n_clients_list = [10, 100, 1000]
    speed_up_mixing_time_list = [10, 100]
    speed_up_local_steps_list = [10, 100]

    for mixing_time in speed_up_mixing_time_list:
        for local_steps in speed_up_local_steps_list:
            for algo in algos:
                fig, ax = plt.subplots(1, 1, figsize=(4, 3))
                ax.set_yscale("log", base=10)
                plt.xlabel("Communication rounds")
                # set ylim by hand here
                ax.set_ylim(1e-6, 1e0)

                for n_clients in speed_up_n_clients_list:
                    f_name = f"mixing_time={mixing_time},local_lr={local_lr},global_lr={global_lr},momentum={momentum_coef},local_steps={local_steps},n_communications={stream_length // local_steps},n_clients={n_clients}.pkl"

                    with open(os.path.join(res_dir, f_name), "rb") as f:
                        results = pickle.load(f)

                    grad_norm = results[algo]["grad_norm"]
                    grad_norm_mean = np.mean(grad_norm, axis=0)
                    grad_norm_std = np.std(grad_norm, axis=0)
                    grad_norm_min = np.min(grad_norm, axis=0)

                    plot_idx = np.linspace(0, len(grad_norm_mean) - 1,
                                           min(len(grad_norm_mean), 200), dtype=int)

                    plt.plot(plot_idx, smooth(grad_norm_mean[plot_idx], weight=0.5), color=colors_speed_up[n_clients],
                             marker=markers[algo], label=f"M={n_clients}", markevery=len(plot_idx) // 10)

                    plt.fill_between(plot_idx,
                                     smooth_lower(grad_norm_mean[plot_idx], grad_norm_std[plot_idx],
                                                  grad_norm_min[plot_idx], weight=0.5),
                                     smooth_upper(grad_norm_mean[plot_idx], grad_norm_std[plot_idx], weight=0.5),
                                     color=colors_speed_up[n_clients], edgecolor=colors_speed_up[n_clients], alpha=0.5)

                if algo == "local_sgd_momentum":
                    plt.legend(loc="best")

                plot_subdir = f"{plot_dir}/mixing_time={mixing_time}/{algo}"
                os.makedirs(plot_subdir, exist_ok=True)
                plt.savefig(f"{plot_subdir}/K={local_steps}.pdf", bbox_inches='tight')
                plt.close()


if __name__ == "__main__":
    res_dir = "../results_synthetic"
    plot_grad_norm_trajectories(res_dir)
    plot_mixing_time(res_dir)
    plot_speed_up(res_dir)
