import numpy as np
import matplotlib.pyplot as plt


def qqplot_from_one_file(txt_path, out_png="qq_from_file.png",
                         lim=4.5, ticks=(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5), delimiter=None):

    data = np.loadtxt(txt_path, delimiter=delimiter)
    if data.ndim != 2 or data.shape[1] < 2:
        raise ValueError(f"Expected at least 2 columns in {txt_path}, got shape {data.shape}")

    y_vanilla = data[:, 0].ravel()
    y_nesterov = data[:, 1].ravel()

    n = min(len(y_vanilla), len(y_nesterov))
    xv = np.sort(y_vanilla)[:n]
    yn = np.sort(y_nesterov)[:n]

    plt.figure(figsize=(6, 6))
    plt.plot(xv, yn, marker='.', linestyle='none')

    # fixed window + fixed ticks/grid
    plt.xlim(-lim, lim)
    plt.ylim(-lim, lim)
    plt.xticks(ticks, fontsize=24)
    plt.yticks(ticks, fontsize=24)

    # 45-degree line
    plt.plot([-lim, lim], [-lim, lim], linestyle='--')

    plt.xlabel("Quantiles for Unaccelerated Sketching", fontsize=19)
    plt.ylabel("Quantiles for Accelerated Sketching", fontsize=19)
    ax = plt.gca()
    ax.set_aspect('equal', adjustable='box')
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.show()

    # optional quick stats
    print("Loaded:", txt_path)
    print("n =", n)
    print("Vanilla : mean =", float(np.mean(y_vanilla)), " var =", float(np.var(y_vanilla, ddof=1)))
    print("Nesterov: mean =", float(np.mean(y_nesterov)), " var =", float(np.var(y_nesterov, ddof=1)))


if __name__ == "__main__":
    qqplot_from_one_file(
        txt_path="y_vanilla_y_nesterov_nx_10_tau_3_munu_true.txt",
        out_png="qq_linear_kazmaz_nx_10_tau_3_munu_true.png",
        lim=4.5,
        ticks=(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5),
        delimiter=None   # set delimiter=',' if CSV
    )
