import os
import json
import matplotlib.pyplot as plt
import re

import numpy as np
from matplotlib.lines import Line2D

results_root = "results"
os.makedirs(results_root, exist_ok=True)
output_dir = os.path.join(results_root, "worldsize_curves")
os.makedirs(output_dir, exist_ok=True)

optimizer = "DNSGD/world_size"
all_data = {}


def moving_average(data, window_size=5):
    if window_size < 2:
        return data
    smoothed = []
    for i in range(len(data)):
        start = max(0, i - window_size + 1)
        smoothed.append(sum(data[start:i + 1]) / (i - start + 1))
    return smoothed


steps_per_val = {
    16: 450 // 46,  # ≈ 10
    8: 900 // 91,  # ≈ 10
    4: 1770 // 178,  # ≈ 10
    2: 3540 // 355  # ≈ 10
}

opt_root = os.path.join(results_root, optimizer)
if os.path.exists(opt_root):
    opt_dirs = [os.path.join(opt_root, d) for d in os.listdir(opt_root) if os.path.isdir(os.path.join(opt_root, d))]

    for d in opt_dirs:
        m = re.search(r"(\d+)_agents", os.path.basename(d))
        if not m:
            continue
        world_size = int(m.group(1))

        ms = re.search(r"lr_(0\.\d{6})", os.path.basename(d))
        step_size = float(ms.group(1)) if ms else None

        if world_size not in steps_per_val:
            continue

        print(f"目录: {d}, world_size={world_size}, step_size={step_size}")

        val_res = []

        rank_dirs = [os.path.join(d, rd) for rd in os.listdir(d) if rd.startswith("rank_")]
        for rank_dir in rank_dirs:
            file_path = os.path.join(rank_dir, "train_val_results.json")
            if not os.path.exists(file_path):
                continue
            with open(file_path, "r") as f:
                data = json.load(f)

            val_res.extend(data.get("val", []))  # [(val_index, acc)]

        val_dict = {}

        for val_idx, acc in val_res:
            step = val_idx * steps_per_val[world_size]
            val_dict.setdefault(step, []).append(acc)

        val_steps = sorted(val_dict.keys())
        avg_val_acc = [sum(val_dict[st]) / len(val_dict[st]) for st in val_steps]

        all_data[(world_size, step_size)] = {
            "val_steps": val_steps,
            "avg_val_acc": avg_val_acc,
        }

    min_max_val_step = min(max(data["val_steps"]) for data in all_data.values() if data["val_steps"])

    for ws, data in all_data.items():
        data["val_steps"], data["avg_val_acc"] = zip(*[
            (st, val) for st, val in zip(data["val_steps"], data["avg_val_acc"])
            if st <= min_max_val_step
        ])

fig, ax = plt.subplots(figsize=(5, 5), dpi=300)

def smooth_curve_left(x, y, window):
    y = np.array(y)
    if len(y) < window:
        return x, y

    cumsum = np.cumsum(np.insert(y, 0, 0))
    y_smooth = (cumsum[window:] - cumsum[:-window]) / window

    y_out = np.concatenate([y_smooth, y[-(window - 1):]])

    return x, y_out

opt_colors = {
    (16, 0.5): "red",
    (8, 0.25): "blue",
    (4, 0.1): "green",
    (2, 0.05): "orange"
}
opt_markers = {
    (16, 0.5): "o",
    (8, 0.25): "s",
    (4, 0.1): "^",
    (2, 0.05): "D"
}
opt_linestyles = {
    (16, 0.5): "-",
    (8, 0.25): "--",
    (4, 0.1): "-.",
    (2, 0.05): ":"
}

line_width = 2
marker_size = 7
sample_interval = 15

for i, (ws_step_size, data) in enumerate(all_data.items()):
    ws, step_size = ws_step_size
    color = opt_colors.get((ws, step_size), "black")
    marker = opt_markers.get((ws, step_size), "o")
    linestyle = opt_linestyles.get((ws, step_size), "-")

    val_steps = data["val_steps"]
    avg_val_acc = data["avg_val_acc"]

    smoothed_steps, smoothed_val = smooth_curve_left(val_steps, avg_val_acc, window=4)

    ax.plot(smoothed_steps, smoothed_val, color=color, linewidth=line_width, linestyle=linestyle)

    ax.plot(
        smoothed_steps[::sample_interval],
        smoothed_val[::sample_interval],
        marker=marker,
        markersize=marker_size,
        linestyle='None',
        color=color
    )

    ax.plot(
        [smoothed_steps[-1]],
        [smoothed_val[-1]],
        marker=marker,
        markersize=marker_size,
        linestyle='None',
        color=color
    )

legend_order = [(2, 0.05), (4, 0.1), (8, 0.25) ,(16, 0.5)]

legend_elements = [
    Line2D([0], [0], color=opt_colors.get((ws, step_size), "black"), lw=line_width,
           linestyle=opt_linestyles.get((ws, step_size), '-'),
           marker=opt_markers.get((ws, step_size), "o"), markersize=marker_size,
           label=f"$n = {ws:d}")
    for (ws, step_size) in legend_order
]

ax.set_xlabel("Iterations", fontsize=16)
ax.set_ylabel("Test Accuracy", fontsize=16)
ax.tick_params(axis='both', labelsize=14)
plt.subplots_adjust(top=0.99, bottom=0.11, left=0.14, right=0.99)
ax.grid(True)
ax.legend(handles=legend_elements, fontsize=16)

plt.savefig(os.path.join(output_dir, "val_acc_vs_steps_world_size.pdf"), dpi=300)
plt.close()