import matplotlib.pyplot as plt
import seaborn as sns


def plot_results(df, out_file=None, figsize=(3., 2.2), legend_loc="best",
                 ylim=None, frameon=False, legend_ncol=1, handlelength=.2,
                 mag=1, ax=None):
    colors = {False: {"err_test": "b", "sob2": "g"},
              True: {"err_test": "m", "sob2": "c"}}
    for regime in df.regime.unique():
        if ax is None:
            plt.figure(figsize=figsize)
            ax = plt.gca()
        # plt.title(regime)
        for aligned in df.aligned.unique():
            subdf = df.loc[(df.regime == regime) & (df.aligned == aligned)]
            for yaxis in ["err_test", "sob2"]:
                if "rf" in regime:
                    if aligned:
                        tag = "(\\Gamma_*)"
                    else:
                        tag = "(I_d)"
                else:
                    tag = "(f_{{%s}})" % regime.replace("lazy_rf", "RFL")
                    tag = tag.replace("lazy_nt", "NTL").replace("nt", "NT")
                    tag = tag.replace("rf", "RF")
                if yaxis == "err_test":
                    label = "$\\widetilde{\\varepsilon}_{gen}%s$" % tag
                elif yaxis == "sob2":
                    label = "$\\widetilde{\\varepsilon}_{rob}%s$" % tag
                else:
                    raise RuntimeError
                c = colors[aligned][yaxis]
                if regime == "sgd":
                    linestyle = linewidth = None
                    if yaxis == "err_test":
                        label = "$\\widetilde{\\varepsilon}_{gen}(f_{SGD})$"
                    elif yaxis == "sob2":
                        label = "$\\widetilde{\\varepsilon}_{rob}(f_{SGD})$"
                else:
                    sns.lineplot(data=subdf, x="m", y=yaxis, color=c,
                                 label=label, linewidth=1 * mag, ax=ax)
                    linestyle = "--"
                    linewidth = 3
                    label = None
                sns.lineplot(data=subdf, x="m", y="%s_theory" % yaxis,
                             color=c, linestyle=linestyle, label=label,
                             linewidth=linewidth, ax=ax)

        ax.set_ylabel("")
        ax.set_xlabel("width $m$")
        ax.axhline(0, linestyle="--", c="k");
        ax.axhline(.5, linestyle="--", c="k");
        ax.axhline(1, linestyle="--", c="k");
        ax.tick_params(axis='both', labelsize=12)
        ax.legend(loc=legend_loc, frameon=frameon, ncol=legend_ncol,
                   handlelength=handlelength, fontsize=11)
        if ylim is not None:
            ax.set_ylim(*ylim)
        plt.tight_layout()
        if out_file is not None:
            plt.savefig(out_file, dpi=200, bbox_inches="tight")
            print(out_file)
