import runner as R
import numpy as np
import matplotlib.pyplot as plt

from datetime import datetime
from pathlib import Path
from experiment_runner import run_experiment

def plot_from_results(results, name=None, verison=None):
    FIG_W, FIG_H = 3.6, 2.6
    DPI = 300

    FONTSIZE_BASE   = 9
    FONTSIZE_TICKS  = 8
    FONTSIZE_LABEL  = 10
    FONTSIZE_LEGEND = 8

    LINEWIDTH   = 1.6
    MARKERSIZE  = 2.1
    ALPHA_LINE  = 1.0
    GRID_ALPHA  = 0.25

    plt.rcParams.update({
        "figure.dpi": DPI,
        "savefig.dpi": DPI,
        "font.size": FONTSIZE_BASE,
        "axes.titlesize": FONTSIZE_LABEL,
        "axes.labelsize": FONTSIZE_LABEL,
        "xtick.labelsize": FONTSIZE_TICKS,
        "ytick.labelsize": FONTSIZE_TICKS,
        "legend.fontsize": FONTSIZE_LEGEND,
        "axes.linewidth": 0.8,
        "xtick.major.width": 0.8,
        "ytick.major.width": 0.8,
        "xtick.minor.visible": True,
        "ytick.minor.visible": True,
        "xtick.direction": "in",
        "ytick.direction": "in",
        "grid.linestyle": ":",
        "grid.linewidth": 0.6,
    })

    fig, ax = plt.subplots(figsize=(FIG_W, FIG_H))

    for m in ["PFL+23", "Alg.4"]:
        pairs = sorted(results[m], key=lambda t: t[0])
        xs, ys = zip(*pairs)

        markevery = 3
        if name == "t":
            markevery = None

        if m == "Alg.4":
            ax.plot(xs, ys, marker="o", linewidth=1.4, markersize=3.0,
                    alpha=1.0, color="#0077BB", markevery=markevery,
                    label=m, zorder=10)
        else:
            ax.plot(xs, ys, marker="^", linewidth=1.0, markersize=3.0,
                    alpha=0.8, color="#F2A100", markevery=markevery,
                    label=m, zorder=5)

    ax.set_yscale("log")
    ax.set_xlabel(r"$\gamma$")
    ax.set_ylabel(r"squared operator norm")
    ax.grid(True, which="both", alpha=GRID_ALPHA)

    handles, labels = ax.get_legend_handles_labels()
    order = ["Alg.4"] + [lbl for lbl in labels if lbl != "Alg.4"]
    sorted_handles = [handles[labels.index(m)] for m in order]
    sorted_labels  = [labels[labels.index(m)] for m in order]
    ax.legend(sorted_handles, sorted_labels, loc="best", frameon=False)

    if name != None:
        out_dir_plot = Path(__file__).resolve().parent / "figs"
        out_dir_plot.mkdir(parents=True, exist_ok=True)
        if verison == "quick":
            name = f"quick_{name}"

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        fig.savefig(out_dir_plot / f"{name}_{timestamp}_noise.png",bbox_inches="tight", dpi=300)

    return fig

def result_from_run(noise_model, version):
    start = 1.0
    interval = -0.01
    n = 7

    if noise_model == "t":
        start = 0.3

    gammas = np.arange(start, 0.008, interval).tolist()

    if version == "quick":
        interval = -0.1
        n = 3
        gammas = np.arange(start, 0.008, interval).tolist()

        if noise_model == "t":
            gammas = [0.3, 0.2, 0.15, 0.1]
        
    methods = ["PFL+23", "Alg.4"]

    results = {m: [] for m in methods}

    common = dict(
        T=1_000_000, 
        d=1, 
        batch_size=1,
        alpha0=0.12, 
        beta0=1.0, 
        decrease="sqrt", 
        decrease_factor=116,
        projection=None, 
        radius=1,   
        problem="quadratic", 
        L=1.0, rho=-0.1,
        init=(1.0, 1.0), 
        offset=0.0,
        noise=0.03, 
        noise_model=noise_model,
        plot_field=False,
    )

    for m in methods:
        for g in gammas:
            total = 0
            BASE = 1234

            for i in range(n):
                seed_i = BASE - i
                args = R.Args(gamma=g, method=m, name=f"{m}|g={g}|", seed=seed_i, **common)
                hist, z = run_experiment(args)
                total += np.asarray(hist['squared operator norm'])[-1]

            y = total / n
            results[m].append((g, y))
            print(f"{m:7s}  γ={g:<7g}  final(||F||^2)≈ {y:.6g}")

    return results

def run_and_plot(noise_model: str, version=None):
    result = result_from_run(noise_model, version)
    plot_from_results(result, noise_model, version)