# %%
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from sklearn.metrics import balanced_accuracy_score

from joblib import Parallel, delayed

palette = sns.color_palette("tab10")


def get_name(x):
    if x["filter_size"] == 0:
        return x["norm"]
    else:
        return rf"PSDNorm($f$={x['filter_size']})"


def compute_bacc(row):
    return balanced_accuracy_score(row.y_true, row.y_pred)


# %%
fnames = list(Path("pickles_sensitive_analysis").glob("results*.pkl"))
df = pd.concat([pd.read_pickle(fname) for fname in fnames], axis=0)
df = df[
    [
        "dataset",
        "seed",
        "subject",
        "y_true",
        "y_pred",
        "filter_size",
        "norm",
        "model_name",
    ]
]

# %%
df["norm"] = df.apply(get_name, axis=1)

df["bacc"] = Parallel(n_jobs=-1)(
    delayed(compute_bacc)(row) for row in df.itertuples(index=False)
)
# %%
df_plot = df.query(r"norm != 'LayerNorm'").copy()
# df_plot = (
#     df_plot.groupby(
#         [
#             "norm",
#             "seed",
#             "dataset",
#             "filter_size",
#         ]
#     )
#     .agg(bacc=("bacc", "mean"))
#     .reset_index()
# )
df_plot = (
    df_plot.groupby(
        [
            "norm",
            "seed",
            "filter_size",
        ]
    )
    .agg(bacc=("bacc", "mean"))
    .reset_index()
)

fig, ax = plt.subplots(
    1, 1, figsize=(7, 2.5)
)
sns.lineplot(
    data=df_plot.query("norm not in ['BatchNorm', 'InstanceNorm']"),
    x="filter_size",
    y="bacc",
    # hue="norm",
    palette="tab10",
    linewidth=2,
    # err_style=None,
    alpha=0.8,
    ax=ax,
    label="PSDNorm",
)
ax.set_ylim(0.77, 0.795)
ax.grid(axis="y", alpha=0.6)
ax.set_ylabel("Balanced Accuracy Score")

# plot line for BatchNorm and LayerNorm
mean_bacc_batch = df_plot.query("norm == 'BatchNorm'")["bacc"].mean()
mean_bacc_instance = df_plot.query("norm == 'InstanceNorm'")["bacc"].mean()
ax.axhline(mean_bacc_batch, color=palette[1], linestyle="--", label="BatchNorm")
ax.axhline(mean_bacc_instance, color=palette[2], linestyle="--", label="InstanceNorm")

ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
ax.set_xlabel("Filter size $f$")
ax.set_xticks([1, 3, 5, 7, 9, 11, 13, 15, 17, 21])
ax.set_yticks([0.77, 0.78, 0.79])

sns.despine()
plt.tight_layout()
fig.savefig("figures/sensitive.pdf", bbox_inches="tight")

# %%
df_plot = df.query(r"norm != 'LayerNorm'").copy()
df_plot = (
    df_plot.groupby(
        [
            "norm",
            "seed",
            "dataset",
            "filter_size",
        ]
    )
    .agg(bacc=("bacc", "mean"))
    .reset_index()
)
# df_plot = (
#     df_plot.groupby(
#         [
#             "norm",
#             "seed",
#             "filter_size",
#         ]
#     )
#     .agg(bacc=("bacc", "mean"))
#     .reset_index()
# )

fig, ax = plt.subplots(
    1, 1, figsize=(5, 5)
)
sns.lineplot(
    data=df_plot.query("norm not in ['BatchNorm', 'InstanceNorm']"),
    x="filter_size",
    y="bacc",
    hue="dataset",
    palette="tab10",
    linewidth=2,
    # err_style=None,
    alpha=0.8,
    ax=ax,
    legend=False,
)
ax.set_ylim(0.65, 0.85)
ax.grid(axis="y", alpha=0.6)
ax.set_ylabel("Balanced Accuracy Score")

# plot line for BatchNorm and LayerNorm
mean_bacc_batch = df_plot.query("norm == 'BatchNorm'")["bacc"].mean()
mean_bacc_instance = df_plot.query("norm == 'InstanceNorm'")["bacc"].mean()
ax.axhline(mean_bacc_batch, color="C0", linestyle="--", label="BatchNorm")
ax.axhline(mean_bacc_instance, color="C1", linestyle="--", label="InstanceNorm")
# ax.legend()
ax.set_xlabel("Filter size $f$")
ax.set_xticks([1, 3, 5, 7, 9, 15, 17, 21])

sns.despine()
plt.tight_layout()
fig.savefig("figures/sensitive.pdf", bbox_inches="tight")

# %%
