import os
import re
import numpy as np
import torch
import json

import argparse
from archs import load_architecture
from utilities import get_gd_directory, get_flow_directory, get_hessian_eigenvalues, get_loss_and_acc, get_norm
import matplotlib.pyplot as plt
from data import load_dataset, take_first
import math


def get_lr_values(path):
    lr_values = []
    pattern = re.compile(r'lr_(\d+\.?\d*)')
    for entry in os.listdir(path):
        match = pattern.match(entry)
        if match:
            lr_values.append(float(match.group(1)))
    return sorted(lr_values)

def main(dataset: str, arch_id: str, loss: str, loss_goal: float = None, seed: int = 0, use_flow = True, eta_min=0,eta_max=1000000,
         es=-1,gf_mode="max",load_lr="all", max_smoothing: int = 0, show:bool = False, plot_es: bool = False, do_max: bool = False,
         no_legend: bool = False, enable_titles: bool = False, general_captions: bool = False): #plot_optimal_lr: bool = True,
    
    if plot_es and es == -1:
        for lg in [10**i for i in range(0, int(np.log10(loss_goal))-1, -1)]:
            main(dataset=dataset, arch_id=arch_id, loss=loss, loss_goal=lg, seed=seed, use_flow=use_flow, eta_min=eta_min,
                 eta_max=eta_max, es=lg, gf_mode=gf_mode, load_lr=load_lr, max_smoothing=max_smoothing, show=show, 
                 plot_es=False, do_max=do_max, no_legend=no_legend, enable_titles=enable_titles,general_captions=general_captions)

    directory = os.path.expanduser(os.path.dirname(get_gd_directory(dataset, 0, arch_id, seed, "gd", loss, None)))
    if(es==loss_goal):
        print("Early stopping: " + str(es))
        prefix=str(loss_goal).replace(".","_")+"_"
    else:
        prefix=""

    if(load_lr=="all"):
        etas = get_lr_values(directory)
    else:
        directory = os.path.expanduser(os.path.dirname(get_gd_directory(dataset, 0, arch_id, seed, "gd", loss, None)))
        with open(directory+'/lr_schedule.json', 'r') as f:
            ls_results = json.load(f)
        if(load_lr=="coarse"):
            smallest_eta=get_lr_values(directory)[0]
            if(smallest_eta not in ls_results["coarse_rounded"]):
                etas_ = [smallest_eta] + ls_results["coarse_rounded"]
            else:
                etas_ = ls_results["coarse_rounded"]
        elif(load_lr=="fine"):
            etas_ = ls_results["fine_rounded"]
        elif(load_lr=="old"):
            etas_ = ls_results["data_points_rounded"]
        elif(load_lr=="cf"):
            etas_ = ls_results["fine_rounded"]+ls_results["coarse_rounded"]
        etas=[]
        not_run=[]
        for eta in etas_:
            if(os.path.isdir(directory+"/"+"lr_"+str(eta))):
                etas.append(eta)
            else:
                not_run.append(eta)
        print("You still need to run: \n",not_run)

    gf_mode = "max"

    # Filter for eta_max
    etas = [eta for eta in etas if eta_min <= eta <= eta_max]
    if len(etas) == 0:
        print("No etas :/")
        return
    eta_max=max(etas) # for the name
    print("These etas are included in the plot: \n",etas)

    torch.manual_seed(seed)
    model = load_architecture(arch_id, dataset).cuda()
    train_dataset, _ = load_dataset(dataset, loss)
    # when to abridge the training set
    if(dataset[-2:]=="5k"):
        print("took first 5000")
        abridged_train = take_first(train_dataset, 5000)
    else:
        print("take full")
        abridged_train = train_dataset
    loss_fn, _ = get_loss_and_acc(loss)
    init_shap = get_hessian_eigenvalues(model, loss_fn, abridged_train, neigs=1, physical_batch_size=1000)


    smooth_all = True

    shaps = []
    maxshaps = []
    norms1 = []
    norms2 = []
    normsnuc = []
    losses = []
    tests = []
    traccs = []
    teaccs = []
    dists = []
    iterations = []
    maxnorms1 = []
    maxnorms2 = []
    maxnormsnuc = []
    maxlosses = []
    maxtests = []
    maxtraccs = []
    maxteaccs = []
    maxiterations = []
    for eta in etas:
        directory = os.path.expanduser(get_gd_directory(dataset, eta, arch_id, seed, "gd", loss, None))
        if os.path.exists(directory+"/"+prefix+"eigs"):

            shap = torch.load(directory+"/"+prefix+"eigs", weights_only=True)
            if len(shap) > 0 and not math.isnan(shap[-1]):
                print("eta=",eta)
                model = torch.load(directory+"/"+prefix+"full_model_final", weights_only=False)
                computed_sharpness=get_hessian_eigenvalues(model, loss_fn, abridged_train, neigs=1, physical_batch_size=1000)
                shaps.append(computed_sharpness[0])
                shap = torch.cat([shap,computed_sharpness.view(1,1)],dim=0)
                max_idx = np.argmax([s[0] for s in shap]).item()
                if(max_smoothing>0):
                    windowl = max(0, max_idx - max_smoothing)
                    windowr = min(len(shap), max_idx + max_smoothing + 1)
                    values = [float(s[0]) for s in shap[windowl:windowr] if np.isfinite(float(s[0]))]
                    maxshaps.append(np.mean(values) if values else 0.0)                                    
                else:
                    maxshaps.append(shap[max_idx][0])
            else:
                shaps.append(torch.tensor(-1))
                maxshaps.append(torch.tensor(-1))
        else:
            shaps.append(torch.tensor(-1))
            maxshaps.append(torch.tensor(-1))
        if os.path.exists(directory+"/"+prefix+"norms"):
            nrm = torch.load(directory+"/"+prefix+"norms", weights_only=True)
            if not math.isnan(nrm[-1,0]):           
                model = torch.load(directory+"/"+prefix+"full_model_final", weights_only=False)
                total_dist = 0.0
                for param in model.parameters():
                    total_dist += torch.norm(param, p=1).item()
                norms1.append(torch.tensor(total_dist))

                total_dist_2 = 0.0
                for param in model.parameters():
                    if (param.data.dim() == 2):
                        total_dist_2 += torch.norm(param, p="fro").item()

                norms2.append(total_dist_2)

                total_dist_nuc= 0.0
                for param in model.parameters():
                    if (param.data.dim() == 2):
                        total_dist_nuc += torch.norm(param, p="nuc").item()

                normsnuc.append(torch.tensor(total_dist_nuc))

                if smooth_all and (max_smoothing>0):
                    values1 = [float(n[0]) for n in nrm[windowl:windowr] if np.isfinite(float(n[0]))]
                    values2 = [float(n[1]) for n in nrm[windowl:windowr] if np.isfinite(float(n[1]))]
                    valuesnuc = [float(n[2]) for n in nrm[windowl:windowr] if np.isfinite(float(n[2]))]
                    maxnorms1.append(np.mean(values1) if values else 0.0)
                    maxnorms2.append(np.mean(values2) if values else 0.0)
                    maxnormsnuc.append(np.mean(valuesnuc) if values else 0.0)                                  
                else:
                    maxnorms1.append(nrm[min(max_idx, len(nrm)-1),0].clone().detach())
                    maxnorms2.append(nrm[min(max_idx, len(nrm)-1),1].clone().detach())
                    maxnormsnuc.append(nrm[min(max_idx, len(nrm)-1),2].clone().detach())
            else:
                norms1.append(torch.tensor(-1))
                norms2.append(torch.tensor(-1))
                normsnuc.append(torch.tensor(-1))
                maxnorms1.append(torch.tensor(-1))
                maxnorms2.append(torch.tensor(-1))
                maxnormsnuc.append(torch.tensor(-1))
        else:
            norms1.append(torch.tensor(-1))
            norms2.append(torch.tensor(-1))
            normsnuc.append(torch.tensor(-1))
            maxnorms1.append(torch.tensor(-1))
            maxnorms2.append(torch.tensor(-1))
            maxnormsnuc.append(torch.tensor(-1))
        if os.path.exists(directory+"/"+prefix+"train_loss"):
            ls = torch.load(directory+"/"+prefix+"train_loss",weights_only=True)
            
            if not math.isnan(ls[-1]):
                max_idx = math.ceil((max_idx+1)*len(ls)/len(shap))-1
                losses.append(ls[-1].clone().detach())
                if smooth_all and (max_smoothing>0):
                    windowl = math.ceil((windowl+1)*len(ls)/len(shap))-1
                    windowr = math.ceil((windowr)*len(ls)/len(shap))
                    values = [float(l) for l in ls[windowl:windowr] if np.isfinite(float(l))]
                    maxlosses.append(np.mean(values) if values else 0.0)
                else:
                    maxlosses.append(ls[max_idx].clone().detach())
            else:
                losses.append(torch.tensor(-1))
                maxlosses.append(torch.tensor(-1))
        else:
            losses.append(torch.tensor(-1))
            maxlosses.append(torch.tensor(-1))
        if os.path.exists(directory+"/"+prefix+"test_loss"):
            tst = torch.load(directory+"/"+prefix+"test_loss",weights_only=True)
            
            if not math.isnan(tst[-1]):
                tests.append(tst[-1].clone().detach())
                if smooth_all and (max_smoothing>0):
                    values = [float(t) for t in tst[windowl:windowr] if np.isfinite(float(t))]
                    maxtests.append(np.mean(values) if values else 0.0)
                else:
                    maxtests.append(tst[max_idx].clone().detach())
            else:
                tests.append(torch.tensor(-1))
                maxtests.append(torch.tensor(-1))
        else:
            tests.append(torch.tensor(-1))
            maxtests.append(torch.tensor(-1))
        if os.path.exists(directory+"/"+prefix+"train_acc"):
            trac = torch.load(directory+"/"+prefix+"train_acc",weights_only=True)#, weights_only=False)
            if not math.isnan(trac[-1]):
                traccs.append(trac[-1].clone().detach())
                if smooth_all and (max_smoothing>0):
                    values = [float(t) for t in trac[windowl:windowr] if np.isfinite(float(t))]
                    maxtraccs.append(np.mean(values) if values else 0.0)
                else:
                    maxtraccs.append(trac[-1].clone().detach())
            else:
                traccs.append(torch.tensor(-1))
                maxtraccs.append(torch.tensor(-1))
        else:
            traccs.append(torch.tensor(-1))
            maxtraccs.append(torch.tensor(-1))
        if os.path.exists(directory+"/"+prefix+"test_acc"):
            teac = torch.load(directory+"/"+prefix+"test_acc",weights_only=True)#, weights_only=False)
            if not math.isnan(teac[-1]):
                teaccs.append(teac[-1].clone().detach())
                if smooth_all and (max_smoothing>0):
                    values = [float(t) for t in teac[windowl:windowr] if np.isfinite(float(t))]
                    maxteaccs.append(np.mean(values) if values else 0.0)
                else:
                    maxteaccs.append(teac[max_idx].clone().detach())
            else:
                teaccs.append(torch.tensor(-1))
                maxteaccs.append(torch.tensor(-1))
        else:
            teaccs.append(torch.tensor(-1))
            maxteaccs.append(torch.tensor(-1))
        if os.path.exists(directory+"/"+prefix+"full_model_final"):
            model = torch.load(directory+"/"+prefix+"full_model_final", weights_only=False)
            dir2 = os.path.expanduser(get_flow_directory(dataset, arch_id, seed, loss, 1.0))
            if os.path.exists(dir2+"/"+prefix+"full_model_final"):
                model2 = load_architecture(arch_id, dataset).cuda()
                model2 = torch.load(dir2+"/"+prefix+"full_model_final", weights_only=False)
                total_dist = 0.0
                for param, param2 in zip(model.parameters(), model2.parameters()):
                    total_dist += torch.norm(param - param2, p=1).item()
                dists.append(torch.tensor(total_dist))
            else:
                dists.append(torch.tensor(-1))
        else:
            dists.append(torch.tensor(-1))
        if os.path.exists(directory+"/"+prefix+"num_iterations"):
            iterations.append(torch.load(directory+"/"+prefix+"num_iterations",weights_only=True))
            maxiterations.append(max_idx+1)
        else:
            iterations.append(torch.tensor(-1))
            maxiterations.append(torch.tensor(-1))
        
    if all([s == -1 for s in shaps]):
        return

    tick=1.0
    directory = os.path.expanduser(get_flow_directory(dataset, arch_id, seed, loss, tick))
    #print(f"flow import directory: {directory}")
    if use_flow and os.path.exists(directory+"/"+prefix+"eigs"):
        shap = torch.load(directory+"/"+prefix+"eigs",weights_only=True)
        if gf_mode == "max":
            values = [s[0] for s in shap]
            max_value = max(values)
            max_idx = len(values) - 1 - values[::-1].index(max_value)
            flow_shap = shap[max_idx]
            flow_shap_last = shap[-1]

            print(flow_shap)
            print(max_idx)
        else:
            flow_shap = shap[-1]

    else:
        flow_shap_last = [s for s in shaps if s != -1][0]
        flow_shap = [s for s in shaps if s != -1][0]
    if use_flow and os.path.exists(directory+"/"+prefix+"norms"):
        nrm = torch.load(directory+"/"+prefix+"norms",weights_only=True)
        if gf_mode == "max":
            flow_norm1 = nrm[max_idx,0]
            flow_norm_last = nrm[-1,0]

            flow_norm2 = nrm[max_idx,1]
            flow_norm2_last = nrm[-1,1]

            flow_normnuc = nrm[max_idx,2]
            flow_normnuc_last = nrm[-1,2]
        else:
            flow_norm1 = nrm[-1,0]
            flow_norm2 = nrm[-1,1]
            flow_normnuc = nrm[-1,2]
    else:
        flow_norm_last = [n for n in norms1 if n !=-1][0]
        flow_norm1 = [n for n in norms1 if n !=-1][0]
        flow_norm2 = [n for n in norms2 if n !=-1][0]
        flow_normnuc = [n for n in normsnuc if n !=-1][0]
    if use_flow and os.path.exists(directory+"/"+prefix+"train_loss"):
        ls = torch.load(directory+"/"+prefix+"train_loss",weights_only=True)
        if gf_mode == "max":
            flow_loss = ls[max_idx]
        else:
            flow_loss = ls[-1]
    else:
        flow_loss = losses[0]
    if use_flow and os.path.exists(directory+"/"+prefix+"test_loss"):
        tst = torch.load(directory+"/"+prefix+"test_loss",weights_only=True)
        if gf_mode == "max":
            flow_test = tst[max_idx]
            flow_test_last = tst[-1]
        else:
            flow_test = tst[-1]
    else:
        flow_test_last = [t for t in tests if t != -1][0]
        flow_test = [t for t in tests if t != -1][0]
    if use_flow and os.path.exists(directory+"/"+prefix+"train_acc"):
        tracc = torch.load(directory+"/"+prefix+"train_acc",weights_only=True)
        if gf_mode == "max":
            flow_tr_acc = tracc[-1]#max_idx]
        else:
            flow_tr_acc = tracc[-1]
    else:
        flow_tr_acc = traccs[0]
    if use_flow and os.path.exists(directory+"/"+prefix+"test_acc"):
        teacc = torch.load(directory+"/"+prefix+"test_acc",weights_only=True)
        if gf_mode == "max":
            flow_tst_acc = teacc[-1]
        else:
            flow_tst_acc = teacc[-1]
    else:
        flow_tst_acc = teaccs[0]

    etas = np.array(etas)
    shaps = np.array(shaps)
    maxshaps = np.array(maxshaps)
    norms1 = np.array(norms1)
    norms2 = np.array(norms2)
    normsnuc = np.array(normsnuc)
    losses = np.array(losses)
    tests = np.array(tests)
    traccs = np.array(traccs)
    teaccs = np.array(teaccs)
    dists = np.array(dists)
    iterations = np.array(iterations)
    maxnorms1 = np.array(maxnorms1)
    maxnorms2 = np.array(maxnorms2)
    maxnormsnuc = np.array(maxnormsnuc)
    maxlosses = np.array(maxlosses)
    maxtests = np.array(maxtests)
    maxtraccs = np.array(maxtraccs)
    maxteaccs = np.array(maxteaccs)
    maxiterations = np.array(maxiterations)

    if load_lr == "coarse":
        if etas[-1] < 1/init_shap:

            plot_optimal_lr = False
        else:
            plot_optimal_lr = True

    tmp=np.abs(losses)
    mask = tmp<=loss_goal
    np.set_printoptions(suppress=True, precision=5)
    print("Not converged for\n",etas[~mask])
    colors = {
        "converged": "#00CC03", 
        "converged_last": "darkgreen", 

        "maximum": "black",#"#FA0400", 
        "bound": "#0081D1",  
        "flow": "#FF7F0F",  
        "heuristic": "#CC00F5", 
        "goal": "#70C8FF", 
    }
    sizes = {
        "converged": 100,
        "maximum": 133,
        "maxwidth": 1.5,
        "bound": 160
    }
    linew = 3
    """
    colors = {
        "converged": "#006600", #"green",
        "maximum": "#FFC533",#"#E4BA4E", #"goldenrod",
        "bound": "#8585FF", #"blue",
        "flow": "#81D2D9", #"lightblue",
        "heuristic": "#8DA300", #"lightblue",
        "goal": "#9CFF9C", #"green"
    }"""
    
    if es == -1:
        plot_dirname = f"figures/{args.dataset}_{args.arch_id}_{args.loss}_seed{args.seed}/{load_lr}/"
    else:
        plot_dirname = f"figures/{args.dataset}_{args.arch_id}_{args.loss}_seed{args.seed}/es_{es}/{load_lr}/"
    print("saving to ", plot_dirname)
    prefix=f"_{args.dataset}_{args.arch_id}_{args.loss}_seed{args.seed}_mode_{gf_mode}_{max_smoothing}_{load_lr}_es_{es}"
    os.makedirs(plot_dirname, exist_ok=True)

    plt.rcParams['font.size'] = 16
    plt.rcParams['mathtext.fontset'] = 'stix'
    plt.rcParams['font.family'] = 'STIXGeneral'

    #do_max = True
    if (max_smoothing>0):
        smooth_label = " (smoothed)"
        if smooth_all:
            smooth_label_all = " (smoothed)"
        else:
            smooth_label_all = ""
    else:
        smooth_label = ""
        smooth_label_all = ""

    if no_legend:
        nltext = "_no_legend"
    else:
        nltext = ""

    nltext = nltext + prefix #prefix after subplot name

    plt.figure(figsize=(7,4))
    mask2 = np.array(shaps<=0)
    if len(etas[mask & ~mask2]) == 0:
        print("DID YOU CORRECTLY SET THE LOSS GOAL?!")
    if general_captions:
        glabel = r"final value"
        mlabel = "value at max sharpness"+smooth_label
        if gf_mode == "max":
            flabel = r"GF value at max sharpness"
            flabel_last = r"final GF value"

        else:
            flabel = r"final GF value"
    else:
        glabel = r"final sharpness"
        mlabel = "maximum value"+smooth_label
        if gf_mode == "max":
            flabel = r"max GF sharpness ($s_{GF}$)"
            flabel_last = r"final GF sharpness"

        else:
            flabel = r"final GF sharpness"
            flabel_last = r"final GF sharpness"

    if load_lr == "coarse" and plot_optimal_lr:
        t1 = np.linspace(np.max((0,np.min(etas[~mask2]) - 0.03 * (np.max(np.append(etas[~mask2], 1/init_shap)) - np.min(etas[~mask2])))), np.max(np.append(etas[~mask2], 1/init_shap)) + 0.03 * (np.max(np.append(etas[~mask2], 1/init_shap)) - np.min(etas[~mask2])), 100)
    else:
        t1 = np.linspace(np.max((0,np.min(etas[~mask2]) - 0.03 * (np.max(etas[~mask2]) - np.min(etas[~mask2])))), np.max(etas[~mask2]) + 0.03 * (np.max(etas[~mask2]) - np.min(etas[~mask2])), 100)

    if(t1[0]==0):
        t1=t1[1:]
    plt.plot(t1, 2/t1, color=colors["bound"],zorder=0,linewidth=linew)
    plt.scatter(etas[~mask2], 2/etas[~mask2],color=colors["bound"],marker="x",zorder=3,s=sizes["bound"])
    plt.plot([],ls="-", marker="x", color=colors["bound"], label = r"$2/\eta$",zorder=3,ms=8, linewidth=linew)
    plt.axhline(flow_shap_last, label = flabel_last,color=colors["converged_last"],alpha=.5, zorder=2,linewidth=linew,linestyle=":")
    plt.axhline(flow_shap, label = flabel,color=colors["flow"],zorder=1,linewidth=linew)
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$", color=colors["flow"],zorder=1,linestyle="--",linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--",linewidth=linew)
    plt.scatter(etas[mask & ~mask2], shaps[mask & ~mask2], c=colors["converged"], label=glabel, zorder=4,s=sizes["converged"])
    plt.scatter(etas[~mask2], maxshaps[~mask2], marker="o", label=mlabel,zorder=5,s=sizes["converged"], facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"])#, c=colors["maximum"]
    plt.xlabel(r"$\eta$")
    plt.ylabel("sharpness")
    if enable_titles:
        plt.title("Sharpness")
    if isinstance(flow_shap, torch.Tensor):
        flow_shap_np = np.array(flow_shap.detach().numpy())
    else:
        flow_shap_np = np.array([flow_shap])
    lims = np.concatenate((shaps[~mask2], maxshaps[~mask2],flow_shap_np))
    plt.ylim(np.min(lims) - 0.1*(np.max(lims) - np.min(lims)), np.max(lims) + 0.2*(np.max(lims) - np.min(lims)))
    plt.grid(True, linestyle = '-')
    if not no_legend:
        plt.legend(loc="upper right")
    plt.tight_layout()
    plt.savefig(plot_dirname + "sharpness"+nltext+".png", dpi=300)  # Save the plot as a high-quality PNG file
    if show:
        plt.show()

    plt.figure(figsize=(7,4))
    mask2 = np.array(norms1<=0)  
    if not general_captions:
        glabel = r"final $\ell 1$-norm"
        mlabel = "norm at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"GF norm at max sharpness"
            flabel_last = r"GF norm at final value "

        else:
            flabel = r"final GF norm"
    plt.axhline(flow_norm_last, label = flabel_last,color=colors["converged_last"],alpha=.5,zorder=2, linewidth=linew,linestyle=":")
    plt.axhline(flow_norm1, label = flabel,color=colors["flow"],zorder=1, linewidth=linew)

    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=2,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], norms1[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    if do_max:
        plt.scatter(etas[~mask2], maxnorms1[~mask2], marker="o", facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"], label=mlabel,zorder=5,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    if not enable_titles:
        plt.ylabel(r"$\ell 1$-norm")
    else:
        plt.ylabel("norm")
        plt.title(r"Total $\ell 1$-Norm")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="lower right")
        else:
            plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "norm_l1"+nltext+".png", dpi=300) 
    if show:
        plt.show()

    plt.figure(figsize=(7,4))
    mask2 = np.array(norms2<=0)  
    if not general_captions:
        glabel = r"final $\ell 2$-norm"
        mlabel = "norm at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"GF norm at max sharpness"
        else:
            flabel = r"final GF norm"
    plt.axhline(flow_norm2, label = flabel,color=colors["flow"],zorder=1, linewidth=linew)
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=2,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], norms2[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    if do_max:
        plt.scatter(etas[~mask2], maxnorms2[~mask2], marker="o", facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"], label=mlabel,zorder=5,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    if not enable_titles:
        plt.ylabel(r"$\ell 2$-norm")
    else:
        plt.ylabel("norm")
        plt.title(r"Total $\ell 2$-Norm")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="lower right")
        else:
            plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "norm_l2"+nltext+".png", dpi=300)
    if show:
        plt.show()

    plt.figure(figsize=(7,4))
    mask2 = np.array(normsnuc<=0)  
    if not general_captions:
        glabel = r"final nuc norm"
        mlabel = "norm at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"GF norm at max sharpness"
        else:
            flabel = r"final GF norm"
    plt.axhline(flow_normnuc, label = flabel,color=colors["flow"],zorder=1, linewidth=linew)
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=2,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], normsnuc[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    if do_max:
        plt.scatter(etas[~mask2], maxnormsnuc[~mask2], marker="o", facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"], label=mlabel,zorder=5,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    if not enable_titles:
        plt.ylabel("nuclear norm")
    else:
        plt.ylabel("norm")
        plt.title("Total Nuclear Norm")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="lower right")
        else:
            plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "norm_nuc"+nltext+".png", dpi=300) 
    if show:
        plt.show()


    plt.figure(figsize=(7,4))
    mask2 = np.array(losses<0)
    if not general_captions:
        glabel = r"final train loss"
        mlabel = "loss at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"flow train loss at max sharpness"
        else:
            flabel = r"final flow train loss"
    plt.axhline(loss_goal,color=colors["goal"], label="loss goal "+str(loss_goal),zorder=1, linewidth=linew)
    plt.axhline(flow_loss, label = flabel,color=colors["flow"], alpha=0.7,zorder=2, linewidth=linew)
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=0,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], losses[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    if do_max:
        plt.scatter(etas[~mask2], maxlosses[~mask2], marker="o", facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"], label=mlabel,zorder=5,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    plt.ylabel("train loss")
    if enable_titles:
        plt.title("Train Loss")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        plt.legend(loc="lower left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "train_loss"+nltext+".png", dpi=300)  
    if show:
        plt.show()
    

    plt.figure(figsize=(7,4))
    mask2 = np.array(traccs<0) 
    if not general_captions:
        glabel = r"final train acc"
        mlabel = "acc at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"flow train acc at max sharpness"
        else:
            flabel = r"final flow train acc"
    plt.axhline(flow_tr_acc, label = flabel_last,color=colors["converged_last"],alpha=.5, zorder=2,linewidth=linew,linestyle=":")
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=0,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], traccs[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    if do_max:
        plt.scatter(etas[~mask2], maxtraccs[~mask2], marker="o", facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"], label=mlabel,zorder=5,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    plt.ylabel("train accuracy")
    if enable_titles:
        plt.title("Train Accuracy")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        plt.legend(loc="lower left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "train_acc"+nltext+".png", dpi=300) 
    if show:
        plt.show()
    

    plt.figure(figsize=(7,4))
    mask2 = np.array(teaccs<0) 
    if not general_captions:
        glabel = r"final test acc"
        mlabel = "acc at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"flow test acc at max sharpness"
        else:
            flabel = r"final flow test acc"
    plt.axhline(flow_tst_acc, label = flabel_last,color=colors["converged_last"],alpha=.5, zorder=2,linewidth=linew,linestyle=":")
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=0,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], teaccs[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    if do_max:
        plt.scatter(etas[~mask2], maxteaccs[~mask2], marker="o", facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"], label=mlabel,zorder=5,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    plt.ylabel("test accuracy")
    if enable_titles:
        plt.title("Test Accuracy")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="lower right")
        else:
            plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "test_acc"+nltext+".png", dpi=300) 
    if show:
        plt.show()
    

    plt.figure(figsize=(7,4))
    mask2 = np.array(dists<0) 
    if not general_captions:
        glabel = "solution distance"
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=0,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], dists[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    plt.ylabel("distance from GF")
    if enable_titles:
        plt.title("Distance from the GF solution")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "distance_GF"+nltext+".png", dpi=300) 
    if show:
        plt.show()
    

    plt.figure(figsize=(7,4))
    mask2 = np.array(iterations<0) 
    if not general_captions:
        glabel = r"num of iterations"
        mlabel = "iterations till max sharpness"
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=0,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], iterations[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    if do_max:
        plt.scatter(etas[~mask2], maxiterations[~mask2], marker="o", facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"], label=mlabel,zorder=5,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    plt.ylabel("#iterations")
    if enable_titles:
        plt.title("Number of iterations till loss " + str(loss_goal))
    plt.grid(True, linestyle = '-')
    if not no_legend:
        plt.legend(loc="upper right")
    plt.tight_layout()
    plt.savefig(plot_dirname + "iterations"+nltext+".png", dpi=300)  
    if show:
        plt.show()
    

    plt.figure(figsize=(7,4))
    mask2 = np.array(iterations<0) 
    if not general_captions:
        glabel = r"curve length"
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=0,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    iterationss = etas*iterations
    plt.scatter(etas[mask & ~mask2], iterationss[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    plt.ylabel(r"$\eta \cdot$#iterations")
    if enable_titles:
        plt.title("Total curve length of GD")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "curve_length"+nltext+".png", dpi=300) 
    if show:
        plt.show()
    

    plt.figure(figsize=(7,4))
    mask2 = np.array(tests<0) 
    if not general_captions:
        glabel = r"final test loss"
        mlabel = "loss at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"flow test loss at max sharpness"
            flabel_last = r"final flow test loss"

        else:
            flabel = r"final flow test loss"
    plt.axhline(flow_test_last, label = flabel_last,color=colors["converged_last"],zorder=2,alpha=.5, linewidth=linew,linestyle=":")
    plt.axhline(flow_test, label = flabel,color=colors["flow"],zorder=1, linewidth=linew)
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=2,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], tests[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    if do_max:
        plt.scatter(etas[~mask2], maxtests[~mask2], marker="o", facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"], label=mlabel,zorder=5,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    plt.ylabel("test loss")
    if enable_titles:
        plt.title("Test Loss")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="upper right")
        else:
            plt.legend(loc="lower left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "test_loss"+nltext+".png", dpi=300)
    if show:
        plt.show()

    plt.close("all")

    plt.figure(figsize=(7,4))
    mask2 = np.array(tests<0) 
    if not general_captions:
        glabel = r"final test loss"
        mlabel = "loss at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"flow test loss at max sharpness"
            flabel_last = r"final flow test loss"

        else:
            flabel = r"final flow test loss"
    plt.axhline(flow_test_last, label = flabel_last,color=colors["converged_last"],zorder=2,alpha=.5, linewidth=linew,linestyle=":")
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=2,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], tests[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    if do_max:
        plt.scatter(etas[~mask2], maxtests[~mask2], marker="o", facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"], label=mlabel,zorder=5,s=sizes["converged"])
   
    plt.xlabel(r"$\eta$")
    plt.ylabel("test loss")
    if enable_titles:
        plt.title("Test Loss")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="upper right")
        else:
            plt.legend(loc="lower left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "test_loss_no_max_no_line_"+nltext+".png", dpi=300)  
    if show:
        plt.show()
    plt.close("all")

    plt.figure(figsize=(7,4))
    mask2 = np.array(norms1<=0)  
    if not general_captions:
        glabel = r"final $\ell 1$-norm"
        mlabel = "norm at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"GF norm at max sharpness"
            flabel_last = r"GF norm at final value "

        else:
            flabel = r"final GF norm"
    plt.axhline(flow_norm_last, label = flabel_last,color=colors["converged_last"],alpha=.5,zorder=2, linewidth=linew,linestyle=":")

    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=2,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], norms1[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    if do_max:
        plt.scatter(etas[~mask2], maxnorms1[~mask2], marker="o", facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"], label=mlabel,zorder=5,s=sizes["converged"])
    #plt.scatter(etas[~mask & ~mask2], norms[~mask & ~mask2],c=colors["not_converged"], label="final value, loss not converged")
    plt.xlabel(r"$\eta$")
    if not enable_titles:
        plt.ylabel(r"$\ell 1$-norm")
    else:
        plt.ylabel("norm")
        plt.title(r"Total $\ell 1$-Norm")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="lower right")
        else:
            plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "norm_l1_no_max_no_line_"+nltext+".png", dpi=300) 
    if show:
        plt.show()

    plt.figure(figsize=(7,4))
    mask2 = np.array(norms1<=0)  
    if not general_captions:
        glabel = r"final $\ell 1$-norm"
        mlabel = "norm at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"GF norm at max sharpness"
            flabel_last = r"GF norm at final value "

        else:
            flabel = r"final GF norm"
    plt.axhline(flow_norm_last, label = flabel_last,color=colors["converged_last"],alpha=.5,zorder=2, linewidth=linew,linestyle=":")

    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=2,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], norms1[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    if do_max:
        plt.scatter(etas[~mask2], maxnorms1[~mask2], marker="o", facecolors='none', edgecolors=colors["maximum"],linewidth=sizes["maxwidth"], label=mlabel,zorder=5,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    if not enable_titles:
        plt.ylabel(r"$\ell 1$-norm")
    else:
        plt.ylabel("norm")
        plt.title(r"Total $\ell 1$-Norm")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="lower right")
        else:
            plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "norm_l1_no_line_"+nltext+".png", dpi=300)  
    if show:
        plt.show()


    plt.figure(figsize=(7,4))
    mask2 = np.array(norms2<=0)  
    if not general_captions:
        glabel = r"final $\ell 2$-norm"
        mlabel = "norm at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"GF norm at max sharpness"
            flabel_last = r"GF norm at final value "
        else:
            flabel = r"final GF norm"
    plt.axhline(flow_norm2, label = flabel,color=colors["flow"],zorder=1, linewidth=linew)
    plt.axhline(flow_norm2_last, label = flabel_last,color=colors["converged_last"],alpha=.5,zorder=2, linewidth=linew,linestyle=":")

    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=2,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], norms2[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    if not enable_titles:
        plt.ylabel(r"$\ell 2$-norm")
    else:
        plt.ylabel("norm")
        plt.title(r"Total $\ell 2$-Norm")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="lower right")
        else:
            plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "norm_l2_no_max_"+nltext+".png", dpi=300)  
    if show:
        plt.show()

    plt.figure(figsize=(7,4))
    mask2 = np.array(norms2<=0)  
    if not general_captions:
        glabel = r"final $\ell 2$-norm"
        mlabel = "norm at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"GF norm at max sharpness"
            flabel_last = r"GF norm at final value "
        else:
            flabel = r"final GF norm"
    plt.axhline(flow_norm2_last, label = flabel_last,color=colors["converged_last"],alpha=.5,zorder=2, linewidth=linew,linestyle=":")
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=2,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], norms2[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    if not enable_titles:
        plt.ylabel(r"$\ell 2$-norm")
    else:
        plt.ylabel("norm")
        plt.title(r"Total $\ell 2$-Norm")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="lower right")
        else:
            plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "norm_l2_no_max_no_line"+nltext+".png", dpi=300)  
    if show:
        plt.show()



    plt.figure(figsize=(7,4))
    mask2 = np.array(normsnuc<=0)  
    if not general_captions:
        glabel = r"final nuc norm"
        mlabel = "norm at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"GF norm at max sharpness"
            flabel_last = r"GF norm at final value "

        else:
            flabel = r"final GF norm"
    plt.axhline(flow_normnuc_last, label = flabel_last,color=colors["converged_last"],alpha=.5,zorder=2, linewidth=linew,linestyle=":")
    plt.axhline(flow_normnuc, label = flabel,color=colors["flow"],zorder=1, linewidth=linew)
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=2,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], normsnuc[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    if not enable_titles:
        plt.ylabel("nuclear norm")
    else:
        plt.ylabel("norm")
        plt.title("Total Nuclear Norm")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="lower right")
        else:
            plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "norm_nuc_no_max"+nltext+".png", dpi=300) 
    if show:
        plt.show()

    plt.figure(figsize=(7,4))
    mask2 = np.array(normsnuc<=0)  
    if not general_captions:
        glabel = r"final nuc norm"
        mlabel = "norm at max sharpness"+smooth_label_all
        if gf_mode == "max":
            flabel = r"GF norm at max sharpness"
            flabel_last = r"GF norm at final value "
        else:
            flabel = r"final GF norm"
    plt.axhline(flow_normnuc_last, label = flabel_last,color=colors["converged_last"],alpha=.5,zorder=2, linewidth=linew,linestyle=":")
    plt.axvline(2/flow_shap, label = "$2/s_{GF}$",color=colors["flow"],zorder=2,linestyle="--", linewidth=linew)
    if load_lr == "coarse" and plot_optimal_lr:
        plt.axvline(1/init_shap, label = "$1/s_0$", color=colors["heuristic"],zorder=0,linestyle="--", linewidth=linew)
    plt.scatter(etas[mask & ~mask2], normsnuc[mask & ~mask2],c=colors["converged"], label=glabel,zorder=3,s=sizes["converged"])
    plt.xlabel(r"$\eta$")
    if not enable_titles:
        plt.ylabel("nuclear norm")
    else:
        plt.ylabel("norm")
        plt.title("Total Nuclear Norm")
    plt.grid(True, linestyle = '-')
    if not no_legend:
        if load_lr == "coarse":
            plt.legend(loc="lower right")
        else:
            plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dirname + "norm_nuc_no_max_no_line"+nltext+".png", dpi=300)  
    if show:
        plt.show()

def int_or_str(value):
    try:
        return int(value)
    except ValueError:
        if value in {"nuc", "fro"}:
            return value
        raise argparse.ArgumentTypeError(f"Invalid value: {value}. Must be an integer, 'nuc', or 'fro'.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot results for different learning rates.")
    parser.add_argument("dataset", type=str, help="which dataset")
    parser.add_argument("arch_id", type=str, help="which network architecture")
    parser.add_argument("loss", type=str, choices=["ce", "mse"], help="which loss function")
    parser.add_argument("--seed", type=int, help="the random seed used when initializing the network weights", default=42)
    parser.add_argument("--loss_goal", type=float, help="converged if the train loss reached this value", default=0.0001)
    parser.add_argument("--no_flow", action=argparse.BooleanOptionalAction, help="if 'True', use RK GF results, otherwise use smallest GD value as GF")
    parser.add_argument("--eta_min", type=float, help="lowest eta to consider in the plot",default=0)
    parser.add_argument("--eta_max", type=float, help="highest eta to consider in the plot",default=1000000.0)
    parser.add_argument("--es", type=float, help="early stopping loss",default=-1)
    parser.add_argument("--gf_mode", type=str, choices=["last","max"],help="use last or max value",default="last")
    parser.add_argument("--load_lr", type=str, choices=["fine","coarse","all","old","cf"],help="load coarse or old lr schedule from json",default="all")
    parser.add_argument("--max_smoothing", type=int,help="how many values on each side of sharpness maximum should be used for smoothing",default=0)
    parser.add_argument("--show", action=argparse.BooleanOptionalAction, help="if 'True', shows each plot")
    parser.add_argument("--plot_es", action=argparse.BooleanOptionalAction, help="if 'True', also calls all availabe es instances between 1 and loss_goal")
    parser.add_argument("--do_max", action=argparse.BooleanOptionalAction, help="if 'True', also displays values at maximum sharpness for all plots where possible")
    parser.add_argument("--no_legend", action=argparse.BooleanOptionalAction, help="if 'True', doesnt display legend in any plot")
    parser.add_argument("--enable_titles", action=argparse.BooleanOptionalAction, help="if 'True', does plot titles")
    parser.add_argument("--general_captions", action=argparse.BooleanOptionalAction, help="if 'True', makes captions in legend nonspecific to concrete plot")

    args = parser.parse_args()
    
    main(dataset=args.dataset, arch_id=args.arch_id, loss=args.loss, loss_goal=args.loss_goal, seed=args.seed, 
         use_flow=(not args.no_flow), eta_min=args.eta_min,eta_max=args.eta_max,es=args.es,
         gf_mode=args.gf_mode, load_lr =args.load_lr, max_smoothing= args.max_smoothing, show=args.show,
         plot_es=args.plot_es, do_max=args.do_max, no_legend=args.no_legend, 
         enable_titles=args.enable_titles, general_captions=args.general_captions)
    