import os
import pickle
import math

import matplotlib.pyplot as plt

from main import main, SAVE_DIR
from plot import compare_plot


main_kwargs = {
    "num_clients": 5,
    "client_samples": 200,
    "input_d": 784,
    "data_type": "mnist",
    "rounds": 256,
    "label_het": 0.95,
}
intervals = [4, 16, 64, 256, 1024]
base_lr = 4.0

def tuned_stages(K):
    if K == 4:
        return 0
    elif K == 16:
        return 1
    elif K == 64:
        return 4
    elif K == 256:
        return 16
    elif K == 1024:
        return 64

schedules = {
    "small": lambda K: (base_lr * 4 ** (1-math.log(interval, 4)), None, None),
    "twostage": lambda K: (base_lr * 4 ** (1-math.log(interval, 4)), tuned_stages(K), base_lr),
    "vanilla": lambda K: (base_lr, None, None),
}

FIGURE_SIZE = (6 * len(schedules), 5)
DISPLAY_NAMES = {
    "small": "$\eta$ = 1/(KH)",
    "twostage": "$\eta_1 = 1/(KH), \eta_2 = 1/H$",
    "vanilla": "$\eta$ = 1/H",
}
plt.rcParams.update({'font.size': 14})

results_dir = os.path.join(SAVE_DIR, "mnist")

# Run training (or load cached results).
all_results = {}
for alg, schedule in schedules.items():
    all_results[alg] = {}
    if not os.path.isdir(os.path.join(results_dir, alg)):
        os.makedirs(os.path.join(results_dir, alg))

    for interval in intervals:
        results_file = os.path.join(results_dir, alg, f"interval_{interval}.pkl")
        if os.path.isfile(results_file):
            with open(results_file, "rb") as f:
                all_results[alg][interval] = pickle.load(f)
        else:
            lr, r1, lr2 = schedule(interval)
            print(f"K = {interval}")
            two_stage = f"{r1},{lr2}" if r1 is not None else None
            all_results[alg][interval] = main(
                interval=interval, lr=lr, two_stage=two_stage, **main_kwargs
            )
            print("")
            with open(results_file, "wb") as f:
                pickle.dump(all_results[alg][interval], f)

# Plot results.
fig, axs = plt.subplots(1, len(schedules), figsize=FIGURE_SIZE, sharey=True)
for i, alg in enumerate(schedules):
    for interval in intervals:
        xs = sorted(list(all_results[alg][interval]["loss"].keys()))
        ys = [all_results[alg][interval]["loss"][r] for r in xs]

        # Add initial point and re-index other points.
        xs = [0] + [x+1 for x in xs]
        ys = [math.log(2)] + [y for y in ys]

        label = f"K = {interval}"
        axs[i].plot(xs, ys, label=label)
        axs[i].set_xlabel("Rounds")
        axs[i].set_yscale("log")
        axs[i].yaxis.set_tick_params(labelbottom=True)
        axs[i].set_title(DISPLAY_NAMES[alg])
        axs[i].legend()

plt.savefig(os.path.join(results_dir, f"mnist.eps"), bbox_inches="tight")
