import os
import pickle
import math

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

from main import main, SAVE_DIR
from plot import compare_plot


main_kwargs = {
    "num_clients": 2,
    "client_samples": 2,
    "input_d": 2,
    "data_type": "hard",
    "rounds": 2048,
}
intervals = [4, 16, 64, 256, 1024]
base_lr = 4.0

def tuned_stages(K):
    if K == 4:
        return 16
    elif K == 16:
        return 64
    elif K == 64:
        return 256
    elif K == 256:
        return 1024
    elif K == 1024:
        return 4096

schedules = {
    "small": lambda K: (base_lr * 4 ** (1-math.log(interval, 4)), None, None),
    "twostage": lambda K: (base_lr * 4 ** (1-math.log(interval, 4)), tuned_stages(K), base_lr),
    "vanilla": lambda K: (base_lr, None, None),
}

FIGURE_SIZE = (6 * len(schedules), 5)
DISPLAY_NAMES = {
    "small": "$\eta$ = 1/(KH)",
    "twostage": "$\eta_1 = 1/(KH), \eta_2 = 1/H$",
    "vanilla": "$\eta$ = 1/H",
}
plt.rcParams.update({'font.size': 14})

results_dir = os.path.join(SAVE_DIR, "synthetic")

# Run training (or load cached results).
all_results = {}
for alg, schedule in schedules.items():
    all_results[alg] = {}
    if not os.path.isdir(os.path.join(results_dir, alg)):
        os.makedirs(os.path.join(results_dir, alg))

    for interval in intervals:
        results_file = os.path.join(results_dir, alg, f"interval_{interval}.pkl")
        if os.path.isfile(results_file):
            with open(results_file, "rb") as f:
                all_results[alg][interval] = pickle.load(f)
        else:
            lr, r1, lr2 = schedule(interval)
            print(f"K = {interval}")
            two_stage = f"{r1},{lr2}" if r1 is not None else None
            all_results[alg][interval] = main(
                interval=interval, lr=lr, two_stage=two_stage, **main_kwargs
            )
            print("")
            with open(results_file, "wb") as f:
                pickle.dump(all_results[alg][interval], f)

# Plot results.
fig, axs = plt.subplots(1, len(schedules), figsize=FIGURE_SIZE, sharey=True)
for i, alg in enumerate(schedules):

    # Created zoomed window for vanilla.
    if alg == "vanilla":
        window_left, window_bottom, window_width, window_height = 0.45, 0.65, 0.45, 0.3
        zoom_x1, zoom_x2, zoom_y1, zoom_y2 = -10, 150, 0.05, 5.0
        axins = axs[i].inset_axes(
            [window_left, window_bottom, window_width, window_height],
            xlim=(zoom_x1, zoom_x2), ylim=(zoom_y1, zoom_y2),
            xticklabels=[], yticklabels=[],
        )

    for interval in intervals:
        xs = sorted(list(all_results[alg][interval]["loss"].keys()))
        ys = [all_results[alg][interval]["loss"][r] for r in xs]

        # Add initial point and re-index other points.
        xs = [0] + [x+1 for x in xs]
        ys = [math.log(2)] + [y for y in ys]

        label = f"K = {interval}"
        axs[i].plot(xs, ys, label=label)
        axs[i].set_xlabel("Rounds")
        axs[i].set_yscale("log")
        axs[i].yaxis.set_tick_params(labelbottom=True)
        axs[i].set_title(DISPLAY_NAMES[alg])
        if alg != "vanilla":
            axs[i].legend()
        else:
            # Plot curves in zoomed window.
            axins.plot(xs, ys)

    # Draw lines between zoomed window and zoom window.
    if alg == "vanilla":
        axs[i].indicate_inset_zoom(axins, alpha=0.5)

plt.savefig(os.path.join(results_dir, f"synthetic.eps"), bbox_inches="tight")
