# %%
from matplotlib import ticker
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# %%
models = ["afo", "imrn50", "cliprn50x4", "vitb", "dinov2"]
plot_models = ["AFO", "IM-RN50", "CLIP-RN50x4", "IM-ViTB", "DiNOv2-ViTB"]
id_sub = ["NSD_01"]
od_sub = ["B5K_01", "fMRI1_01"]
# %%
datas = []
for model in models:
    for subs in [id_sub, od_sub]:
        ps = []
        for sub in subs:
            path = f"/data/results/pca/{model}_{sub}_p_test.pt"
            p = torch.load(path)
            ps.append(p)
        p = torch.cat(ps, dim=0)
        mean_p = p.mean(dim=0).item()
        data = {
            "model": model,
            "subject": ",".join(subs),
            "mean_p": mean_p,
            "p": p.cpu().numpy().tolist(),
        }
        datas.append(data)

# %%
df = pd.DataFrame(datas)
# %%
df
# %%
df = df.sort_values(by=["model", "subject"])
# %%
id_sub = ",".join(id_sub)
od_sub = ",".join(od_sub)
x = df[df["subject"] == id_sub]["mean_p"].values
orig_x = df[df["subject"] == id_sub]["p"].values
y = df[df["subject"] == od_sub]["mean_p"].values
orig_y = df[df["subject"] == od_sub]["p"].values
labels = df[df["subject"] == id_sub]["model"].values
# %%
plt.figure(figsize=(5, 5))
plt.scatter(x, y)
for i, label in enumerate(labels):
    plt.annotate(label, (x[i], y[i]))
plt.xlabel("In-distribution")
plt.ylabel("Out-of-distribution")
plt.show()


# %%
# fig, axs = plt.subplots(1, 3, figsize=(10, 5))
# fig, ax = plt.subplots(figsize=(5, 5))
@ticker.FuncFormatter
def major_formatter(x, pos):
    return f"{x:.2f}"


@ticker.FuncFormatter
def large_formatter(x, pos):
    return f"{x:.1f}"


fig = plt.figure(constrained_layout=True, figsize=(7, 3))
gs = fig.add_gridspec(2, 2)
ax1 = fig.add_subplot(gs[:, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 1])
ax = ax1
plt.sca(ax)
for spine in ["top", "right"]:
    ax.spines[spine].set_visible(False)
ax.grid(axis="both", linestyle="--", alpha=0.5)
sns.scatterplot(x=x, y=y, hue=labels, ax=ax, s=100, legend=False, alpha=0.5)
for i, label in enumerate(labels):
    loc = (x[i] + 0.001, y[i] + 0.001)
    # if label == "vitb":
    #     loc = (loc[0], loc[1] - 0.003)
    label = plot_models[models.index(label)]
    plt.annotate(label, loc, fontsize=14)
# addx = [0.48593956]
# addy = [0.259]
# sns.scatterplot(
#     x=addx,
#     y=addy,
#     ax=ax,
#     s=200,
#     legend=False,
#     alpha=0.75,
#     marker="*",
#     color="purple",
# )
# plt.annotate("AFOTopyNeck", (addx[0] - 0.003, addy[0] + 0.001), horizontalalignment="right")

plt.xlim(0.36, 0.5)
plt.ylim(0.17, 0.225)
ax.set_xlabel("In-Distribution Dataset Pearson's R", fontsize=14)
ax.set_ylabel("2 Holdout Datasets\nPearson's R", fontsize=14)
# ax.get_legend().remove()
plt.tick_params(
    axis="both", which="both", direction="in", length=5, width=1, labelsize=11
)
ax.xaxis.set_major_formatter(major_formatter)
ax.yaxis.set_major_formatter(major_formatter)
ax.yaxis.set_major_locator(ticker.MultipleLocator(0.01))
ax.xaxis.set_major_locator(ticker.MultipleLocator(0.03))
# ax.spines["left"].set_position(("outward", 15))
# ax.spines["bottom"].set_position(("outward", 15))
ax.xaxis.set_tick_params(length=0)
ax.yaxis.set_tick_params(length=0)
ax.spines["bottom"].set_color("grey")
ax.spines["left"].set_color("grey")
# ax.tick_params(axis='both', colors='grey')
# plt.show()

# fig, ax = plt.subplots(figsize=(5, 5))
ax = ax2
plt.sca(ax)
for spine in ["top", "right"]:
    ax.spines[spine].set_visible(False)
ax.grid(axis="x", linestyle="--", alpha=0.5)
for i, label in enumerate(labels):
    plt.violinplot(
        orig_x[i], positions=[y[i]], showmeans=True, vert=False, widths=0.01
    )
    loc = (-0.1, y[i] + 0.001)
    if label == "dinov2":
        loc = (loc[0], loc[1] - 0.001)
    label = plot_models[models.index(label)]
    # plt.annotate(label, loc, horizontalalignment="left", fontsize=14)
# plt.show()
plt.tick_params(
    axis="both", which="both", direction="in", length=5, width=1, labelsize=11
)
# ax.set_xlabel("In-Distribution Pearson's R", fontsize=12)
ax.set_ylabel("2 Holdout Datasets \n Pearson's R", fontsize=14)
ax.xaxis.set_major_formatter(large_formatter)
ax.yaxis.set_major_formatter(major_formatter)
ax.xaxis.set_major_locator(ticker.MultipleLocator(0.2))
ax.yaxis.set_major_locator(ticker.MultipleLocator(0.02))
ax.xaxis.set_tick_params(length=0)
ax.spines["bottom"].set_color("grey")
ax.spines["left"].set_color("grey")
# ax.yaxis.set_tick_params(length=0)
# fig, ax = plt.subplots(figsize=(5, 5))

ax = ax3
plt.sca(ax)
for spine in ["top", "right"]:
    ax.spines[spine].set_visible(False)
ax.grid(axis="y", linestyle="--", alpha=0.5)
for i, label in enumerate(labels):
    plt.violinplot(orig_y[i], positions=[x[i]], showmeans=True, vert=True, widths=0.01)
    loc = (x[i] + 0.002, 0.8)
    if label == "vitb":
        loc = (loc[0] - 0.01, loc[1])
    if label == "cliprn50x4":
        loc = (loc[0] - 0.011, loc[1])
    label = plot_models[models.index(label)]
    # plt.annotate(label, loc, rotation=90, verticalalignment="top", fontsize=14)
plt.xlim(0.35, 0.5)
plt.tick_params(
    axis="both", which="both", direction="in", length=5, width=1, labelsize=11
)
ax.set_xlabel("In-Distribution Dataset Pearson's R", fontsize=14)
# ax.set_ylabel("2 Holdout Datasets \n Pearson's R", fontsize=14)
ax.yaxis.set_major_formatter(large_formatter)
ax.xaxis.set_major_formatter(major_formatter)
ax.xaxis.set_major_locator(ticker.MultipleLocator(0.03))
ax.yaxis.set_major_locator(ticker.MultipleLocator(0.3))
# ax.xaxis.set_tick_params(length=0)
ax.yaxis.set_tick_params(length=0)
ax.spines["bottom"].set_color("grey")
ax.spines["left"].set_color("grey")

plt.tight_layout()
plt.savefig("/workspace/figs/ood_scatter.pdf")
plt.show()
# %%
# import numpy as np
# vm = np.load("/data/ray_results/ood/run_09057_00000_0_DATAMODULE_BATCH_SIZE=64_2023-05-14_01-31-23/stage_1/lightning_logs/voxel_metric/stage=TEST.step=000000001296.pkl.npy", allow_pickle=True).item()
# # %%
# v1 = vm["B5K_01"]['TEST/PearsonCorrCoef/B5K_01/all']
# # %%
# v2 = vm["fMRI1_01"]['TEST/PearsonCorrCoef/fMRI1_01/all']

# # %%
# np.mean(np.concatenate([v1, v2]))
# # %%
