import os
import json
import numpy as np

results_root = "results"
os.makedirs(results_root, exist_ok=True)
output_dir = os.path.join(results_root, "comparison_curves")
os.makedirs(output_dir, exist_ok=True)

optimizers = ["DNSGD", "DSGD", "DSGT", "DNASA"]
steps_per_comm = {
    "DNSGD": 4,  # K=2
    "DSGD": 1,
    "DSGT": 2,
    "DNASA": 3
}

all_data = {}
training_data = {}
min_comm_train = float("inf")
min_comm_val = float("inf")

for opt in optimizers:
    opt_root = os.path.join(results_root, opt)
    if not os.path.exists(opt_root):
        continue

    opt_dirs = [os.path.join(opt_root, d) for d in os.listdir(opt_root) if os.path.isdir(os.path.join(opt_root, d))]
    if not opt_dirs:
        continue

    val_res = []

    for d in opt_dirs:
        rank_dirs = [os.path.join(d, rd) for rd in os.listdir(d) if rd.startswith("rank_")]
        for rank_dir in rank_dirs:
            file_path = os.path.join(rank_dir, "train_val_results.json")
            if not os.path.exists(file_path):
                continue
            with open(file_path, "r") as f:
                data = json.load(f)

            val_res.extend(data.get("val", []))

    val_epoch_dict = {}

    for epoch, acc in val_res:
        val_epoch_dict.setdefault(epoch, []).append(acc)

    val_epochs = sorted(val_epoch_dict.keys())
    avg_val_acc = [sum(val_epoch_dict[ep]) / len(val_epoch_dict[ep]) for ep in val_epochs]

    expanded_val_acc = []
    expanded_comm_val = []
    for acc in avg_val_acc:
        for _ in range(steps_per_comm[opt]):
            expanded_val_acc.append(acc)
            expanded_comm_val.append(len(expanded_comm_val))

    all_data[opt] = {
        "val_acc": expanded_val_acc,
        "comm_val": expanded_comm_val
    }

    min_comm_val = min(min_comm_val, len(expanded_comm_val))

for opt in all_data:
    all_data[opt]["val_acc"] = all_data[opt]["val_acc"][:min_comm_val]
    all_data[opt]["comm_val"] = all_data[opt]["comm_val"][:min_comm_val]


import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import os

def smooth_curve_left(x, y, window):
    y = np.array(y)
    if len(y) < window:
        return x, y

    cumsum = np.cumsum(np.insert(y, 0, 0))
    y_smooth = (cumsum[window:] - cumsum[:-window]) / window

    y_out = np.concatenate([y_smooth, y[-(window - 1):]])

    return x, y_out


sample_interval = 15
smooth_window = 4
markers = ['o', 's', '^', 'D', 'v', '*', 'P', 'X']
colors = plt.cm.tab10.colors

max_step = 2640

opt_colors = {
    "DNSGD": "red",
    "DNASA": "blue",
    "DSGT": "green",
    "DSGD": "black"
}
opt_markers = {
    "DNSGD": "o",
    "DNASA": "s",
    "DSGT": "^",
    "DSGD": "D"
}
opt_linestyles = {
    "DNSGD": "-",
    "DNASA": "--",
    "DSGT": "-.",
    "DSGD": ":"
}

legend_order = ["DSGD", "DSGT", "DNASA", "DNSGD"]

line_width = 2
marker_size = 7

fig, ax = plt.subplots(figsize=(5, 5), dpi=300)

for i, opt in enumerate(legend_order):
    if opt not in all_data:
        continue
    data = all_data[opt]
    color = opt_colors.get(opt, "black")
    marker = opt_markers.get(opt, "o")
    linestyle = opt_linestyles.get(opt, "-")

    filtered_comm = []
    filtered_val = []
    last_val = None
    for c, v in zip(data["comm_val"], data["val_acc"]):
        if v != last_val:
            filtered_comm.append(c)
            filtered_val.append(v)
            last_val = v

    steps_axis = [c / max(filtered_comm) * max_step for c in filtered_comm]

    smoothed_steps, smoothed_val = smooth_curve_left(steps_axis, filtered_val, smooth_window)

    ax.plot(smoothed_steps, smoothed_val, color=color, linewidth=line_width, linestyle=linestyle)

    ax.plot(
        [smoothed_steps[j] for j in range(0, len(smoothed_steps), sample_interval)],
        [smoothed_val[j] for j in range(0, len(smoothed_val), sample_interval)],
        marker=marker,
        markersize=marker_size,
        linestyle='None',
        color=color
    )
    ax.plot(
        [smoothed_steps[-1]],
        [smoothed_val[-1]],
        marker=marker,
        markersize=marker_size,
        linestyle='None',
        color=color
    )

legend_elements = [
    Line2D([0], [0], color=opt_colors.get(opt, "black"), lw=line_width,
           linestyle=opt_linestyles.get(opt, '-'),
           marker=opt_markers.get(opt, "black"), markersize=marker_size, label=opt,
           )
    for i, opt in enumerate(legend_order)
    if opt in all_data
]

ax.set_xlabel("Communication Rounds", fontsize=16)
ax.set_ylabel("Test Accuracy", fontsize=16)
ax.tick_params(axis='both', labelsize=14)
ax.grid(True)
ax.legend(handles=legend_elements, fontsize=16)
plt.subplots_adjust(top=0.99, bottom=0.11, left=0.14, right=0.99)
plt.savefig(os.path.join(output_dir, "val_acc_vs_comrounds.pdf"), dpi=300)
plt.close()
