import pandas as pd
pd.options.mode.chained_assignment = None  # default='warn'
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
sns.set(rc={'figure.figsize':(12,4),"font.size":18,"axes.titlesize":20,"axes.labelsize":18, 'legend.fontsize': 20, 'legend.title_fontsize': 20,},style="whitegrid")

vtab = pd.read_csv("vtab_final_results_spider.csv")


def plot_spider_rolled_vtab(dfplot, sortby, fill=False):
    fig, axes = plt.subplots(1, 1, figsize=(12, 4), facecolor="white")
    categories = list(dfplot["dataset"].unique())
    categories.remove("All")
    categories.remove("Natural")
    categories.remove("Specialized")
    categories.remove("Structured")

    dfplot = dfplot[dfplot["dataset"].isin(categories)]

    if sortby is not None:
        dfplot[sortby] = dfplot[sortby].astype(int)
        dfplot = dfplot.sort_values(by=sortby)

    categories = list(dfplot["dataset"].unique())
    N = len(categories)

    # Initialise the spider plot
    ax = plt.subplot(111,)

    X = range(N)

    # Draw one axe per variable + add labels
    category_labels = [c + "({})".format(dfplot[dfplot["dataset"] == c]["num_classes"].values[0]) for c in categories]
    ax.set_xticks(X, category_labels, color='black', rotation=30, ha='right', fontsize='medium')

    # Draw ylabels
    accuracies = [0, 20, 40, 60, 80, 100]
    axes.set_yticks(accuracies, [str(a) for a in accuracies], color="black", size=20)
    axes.set_ylim(0,100)

    colors = [
        '#1f77b4',
        '#ff7f0e',
        '#2ca02c',
        '#d62728',
        '#9467bd'
    ]

    for i, eps in enumerate(["1", "2", "4", "8", "∞"]):
        values = dfplot[dfplot["dataset"].isin(categories)][eps].apply(lambda x : float(x.split("±")[0])).values
        # Plot data
        ax.plot(X, values, linestyle='solid', label='$\infty$' if eps == "∞" else eps, linewidth=2, c=colors[i])

        # Fill area
        if fill:
            axes.fill_between(X, values, color=colors[i], alpha=1)

        for a in accuracies+[0]:
            ax.hlines(a, X[0], X[-1], color="lightgrey", alpha=0.2)
        for x in X:
            ax.vlines(x, 0, accuracies[-1], color="lightgrey", alpha=0.2)
    ax.set_ylabel("Accuracy (%)")
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[::-1], labels=labels[::-1], title=r"$ϵ$")

    # add tick sticks
    ax.tick_params(bottom=True, left=True)
    # add border
    for _, spine in ax.spines.items():
        spine.set_visible(True) # You have to first turn them on
        spine.set_color('black')
        spine.set_linewidth(1)
    fig.tight_layout()


    return fig, ax


for p in ["all", "FiLM", "head"]:
    fig, ax = plot_spider_rolled_vtab(vtab[(vtab["backbone"] == "vit-b-16") & (vtab["params"] == p)], "num_classes", fill=False)
    sns.move_legend(ax, "upper left", bbox_to_anchor=(-0.25, 1.05))
    fig.savefig("plots/vtab_class_{}_{}.pdf".format("vit-b", p), bbox_inches = 'tight')

for p in ["all", "FiLM", "head"]:
    fig, ax = plot_spider_rolled_vtab(vtab[(vtab["backbone"] == "ResNet50") & (vtab["params"] == p)], "num_classes", fill=False)
    sns.move_legend(ax, "upper left", bbox_to_anchor=(-0.25, 1.05))
    fig.savefig("plots/vtab_class_{}_{}.pdf".format("R-50", p), bbox_inches = 'tight')
