import os
import json
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np

results_root = "results"
output_dir = os.path.join(results_root, "training_curves")
os.makedirs(output_dir, exist_ok=True)

optimizers = ["DNSGD", "DSGD", "DSGT", "DNASA"]
markers = ['o', 's', '^', 'D', 'v', '*', 'P', 'X']

all_data = {}
dns_steps = None

for opt in optimizers:
    opt_root = os.path.join(results_root, opt)
    if not os.path.exists(opt_root):
        continue

    opt_dirs = [os.path.join(opt_root, d) for d in os.listdir(opt_root) if os.path.isdir(os.path.join(opt_root, d))]
    if not opt_dirs:
        continue

    training_res = []

    for d in opt_dirs:
        rank_dirs = [os.path.join(d, rd) for rd in os.listdir(d) if rd.startswith("rank_")]
        for rank_dir in rank_dirs:
            file_path = os.path.join(rank_dir, "train_val_results.json")
            if not os.path.exists(file_path):
                continue
            with open(file_path, "r") as f:
                data = json.load(f)
            training_res.extend(data.get("training", []))

    if not training_res:
        continue

    train_step_loss = {}
    train_step_grad = {}

    for entry in training_res:
        if len(entry) == 3:
            step, loss, grad_norm = entry
            train_step_loss.setdefault(step, []).append(loss)
            train_step_grad.setdefault(step, []).append(grad_norm)
        else:
            print(f"training error: {entry}")

    train_steps = sorted(train_step_grad.keys())
    avg_train_grad = [sum(train_step_grad[s]) / len(train_step_grad[s]) for s in train_steps]
    avg_train_loss = [sum(train_step_loss[s]) / len(train_step_loss[s]) for s in train_steps]

    all_data[opt] = {
        "train_grad": avg_train_grad,
        "train_loss": avg_train_loss
    }

    # if opt == "DNSGD":
    #     dns_steps = len(avg_train_grad)

# if dns_steps is not None:
#     for opt in all_data:
#         all_data[opt]["train_grad"] = all_data[opt]["train_grad"][:dns_steps]
#         all_data[opt]["train_loss"] = all_data[opt]["train_loss"][:dns_steps]

step_max = 168960
tick_interval = 30000
natural_ticks = np.arange(0, step_max + 1, tick_interval)

line_width = 1
marker_size = 7

fig, axes = plt.subplots(2, 1, figsize=(7, 7), sharex=True)

metrics = ["train_loss", "train_grad"]
ylabels = {
    "train_loss": "Training Loss",
    "train_grad": "Gradient Norm"
}

for j, metric in enumerate(metrics):
    ax = axes[j]
    for i, (opt, data) in enumerate(all_data.items()):
        if metric not in data:
            continue
        y = data[metric]
        steps_axis = np.linspace(0, step_max, len(y))

        if opt == "DNSGD":
            color = "red"
        elif opt == "DNASA":
            color = "blue"
        elif opt == "DSGT":
            color = "green"
        else:
            color = "orange"

        ax.plot(steps_axis, y, color=color, linewidth=line_width)

    ax.set_ylabel(ylabels[metric], fontsize=16)
    ax.tick_params(axis='both', labelsize=14)
    ax.grid(True)

axes[-1].set_xticks(natural_ticks)
axes[-1].ticklabel_format(style="sci", axis="x", scilimits=(4, 4))
axes[-1].xaxis.get_offset_text().set_fontsize(14)
axes[-1].set_xlabel("Sample Complexity", fontsize=16)

legend_order = ["DSGD", "DSGT", "DNASA", "DNSGD"]
legend_elements = []
for i, opt in enumerate(legend_order):
    if opt not in all_data:
        continue
    if opt == "DNSGD":
        color = "red"
    elif opt == "DNASA":
        color = "blue"
    elif opt == "DSGT":
        color = "green"
    else:
        color = "orange"
    legend_elements.append(
        Line2D([0],[0], color=color, lw=1, marker=markers[i % len(markers)],
               markersize=5, label=opt)
    )

axes[0].legend(handles=legend_elements, fontsize=14, loc="best")

plt.tight_layout()
plt.savefig(os.path.join(output_dir, "loss_grad_vs_steps.pdf"), dpi=300)
plt.close()
