import os
import json
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np

results_root = "results"
output_dir = os.path.join(results_root, "comparison_curves")
os.makedirs(output_dir, exist_ok=True)

optimizers = ["DNSGD", "DSGD", "DSGT", "DNASA"]
sample_interval = 5
markers = ['o', 's', '^', 'D', 'v', '*', 'P', 'X']
colors = plt.cm.tab10.colors
smooth_window = 5


def moving_average(y, window):
    y = np.array(y)
    if len(y) < window:
        return y
    cumsum = np.cumsum(np.insert(y, 0, 0))
    return (cumsum[window:] - cumsum[:-window]) / window


all_data = {}
dns_epochs = None

for opt in optimizers:
    opt_root = os.path.join(results_root, opt)
    if not os.path.exists(opt_root):
        continue

    opt_dirs = [os.path.join(opt_root, d) for d in os.listdir(opt_root) if os.path.isdir(os.path.join(opt_root, d))]
    if not opt_dirs:
        continue

    train_res, val_res = [], []

    for d in opt_dirs:
        rank_dirs = [os.path.join(d, rd) for rd in os.listdir(d) if rd.startswith("rank_")]
        for rank_dir in rank_dirs:
            file_path = os.path.join(rank_dir, "train_val_results.json")
            if not os.path.exists(file_path):
                continue
            with open(file_path, "r") as f:
                data = json.load(f)
            train_res.extend(data.get("train", []))
            val_res.extend(data.get("val", []))

    val_epoch_dict = {}

    for epoch, acc in val_res:
        val_epoch_dict.setdefault(epoch, []).append(acc)

    val_epochs = sorted(val_epoch_dict.keys())
    avg_val_acc = [sum(val_epoch_dict[ep]) / len(val_epoch_dict[ep]) for ep in val_epochs]

    all_data[opt] = {
        "val_acc": avg_val_acc
    }

    if opt == "DNSGD":
        dns_epochs = len(avg_val_acc)

if dns_epochs is not None:
    for opt in all_data:
        all_data[opt]["val_acc"] = all_data[opt]["val_acc"][:dns_epochs]

def smooth_curve_left(x, y, window):
    y = np.array(y)
    if len(y) < window:
        return x, y

    cumsum = np.cumsum(np.insert(y, 0, 0))
    y_smooth = (cumsum[window:] - cumsum[:-window]) / window

    y_out = np.concatenate([y_smooth, y[-(window - 1):]])

    return x, y_out

step_max = 168960
sample_interval = 15
smooth_window = 4
tick_interval = 30000
natural_ticks = np.arange(0, step_max + 1, tick_interval)

line_width = 2
marker_size = 7

opt_markers = {
    "DNSGD": "o",
    "DNASA": "s",
    "DSGT": "^",
    "DSGD": "D"
}
opt_linestyles = {
    "DNSGD": "-",
    "DNASA": "--",
    "DSGT": "-.",
    "DSGD": ":"
}

for metric in ["val_acc"]:
    fig, ax = plt.subplots(figsize=(5, 5))
    for i, (opt, data) in enumerate(all_data.items()):
        y = data[metric]
        steps_axis = np.linspace(0, step_max, len(y))

        if opt == "DNSGD":
            color = "red"
        elif opt == "DNASA":
            color = "blue"
        elif opt == "DSGT":
            color = "green"
        else:
            color = "black"


        marker = opt_markers.get(opt, "o")
        linestyle = opt_linestyles.get(opt, "-")

        smoothed_steps, smoothed_y = smooth_curve_left(steps_axis, y, smooth_window)

        ax.plot(smoothed_steps, smoothed_y, color=color, linewidth=line_width, linestyle=linestyle)

        marker_indices = np.arange(0, len(smoothed_y), sample_interval)
        ax.plot(
            smoothed_steps[marker_indices],
            np.array(smoothed_y)[marker_indices],
            marker=marker,
            markersize=marker_size,
            linestyle='None',
            color=color
        )
        ax.plot(
            smoothed_steps[-1],
            smoothed_y[-1],
            marker=marker,
            markersize=marker_size,
            linestyle='None',
            color=color
        )

    ax.set_xticks(natural_ticks)

    ax.ticklabel_format(style="sci", axis="x", scilimits=(4, 4))
    ax.xaxis.get_offset_text().set_fontsize(14)

    legend_order = ["DSGD", "DSGT", "DNASA", "DNSGD"]

    legend_elements = []
    for i, opt in enumerate(legend_order):
        if opt not in all_data:
            continue
        if opt == "DNSGD":
            color = "red"
        elif opt == "DNASA":
            color = "blue"
        elif opt == "DSGT":
            color = "green"
        else:
            color = "black"

        legend_elements.append(
            Line2D([0], [0], color=color, lw=line_width, marker=opt_markers.get(opt, "o"),
                   linestyle=opt_linestyles.get(opt, '-'),
                   markersize=marker_size, label=opt)
        )

    ax.set_xlabel("Sample Complexity", fontsize=16)
    ax.set_ylabel("Test Accuracy", fontsize=16)
    ax.tick_params(axis='both', labelsize=14)
    ax.grid(True)
    ax.legend(handles=legend_elements, fontsize=16)
    plt.subplots_adjust(top=0.99, bottom=0.11, left=0.14, right=0.99)
    plt.savefig(os.path.join(output_dir, f"{metric}_vs_samples.pdf"), dpi=300)
    plt.close()
