import os

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from scipy import spatial

ACTFUN = [
    "relu",
    "maxout",
    "max_min_dup",
    "signedgeomean",
    "ail_xnor",
    "ail_or",
    "ail_and_or_dup",
    "ail_or_xnor_part",
    "ail_or_xnor_dup",
    "ail_and_or_xnor_part",
    "ail_and_or_xnor_dup",
]
key2legend = {
    "relu": "ReLU",
    "maxout": "Max",
    "max_min_dup": "Max, Min (d)",
    "signedgeomean": "signed_geomean",
    "ail_xnor": "XNOR$_{AIL}$",
    "ail_or": "OR$_{AIL}$",
    "ail_and_or_dup": "OR$_{AIL}$, AND$_{AIL}$ (d)",
    "ail_or_xnor_part": "OR$_{AIL}$, XNOR$_{AIL}$ (p)",
    "ail_or_xnor_dup": "OR$_{AIL}$, XNOR$_{AIL}$ (d)",
    "ail_and_or_xnor_part": "OR$_{AIL}$, AND$_{AIL}$, XNOR$_{AIL}$ (p)",
    "ail_and_or_xnor_dup": "OR$_{AIL}$, AND$_{AIL}$, XNOR$_{AIL}$ (d)",
}
colors = [
    "#1a1a1a",
    "#34a02c",
    "#b2df8a",
    "#ccb974",
    "#b15828",
    "#1f78b4",
    "#a6cee3",
    "#e31a1d",
    "#fb9b99",
    "#6a3d9a",
    "#cab2d6",
]

weights = {}
for fname in os.listdir("logs/logs_jsb_actfun_weights/"):
    with open(os.path.join("logs", "logs_jsb_actfun_weights", fname), "rb") as f:
        content = torch.load(f, map_location="cpu")
        actfun = fname[7:-3]
        weights[actfun] = content.detach().cpu()

sns.set_theme(style="darkgrid")

# Adjacent features
sims = {}
for i in range(11):
    sims[ACTFUN[i]] = []
    for j in range(0, 1536, 2):
        sims[ACTFUN[i]].append(
            torch.nn.functional.cosine_similarity(
                weights[ACTFUN[i]][j], weights[ACTFUN[i]][j + 1], dim=0
            )
            .detach()
            .item()
        )
fig, ax = plt.subplots(nrows=2, ncols=6, figsize=(15, 5))
fig.suptitle("Cosine similarities between paired preactivations (1st layer)")
i = 0
for row in ax:
    for col in row:
        if i >= 11:
            break
        x = sims[ACTFUN[i]]
        col.set(title=key2legend[ACTFUN[i]], xlim=(-1, 1), ylim=(0, 8))
        col.hist(x, bins=40, range=(-1, 1), density=True, color=colors[i])
        i += 1
fig.delaxes(ax[1, 5])
fig.tight_layout()
plt.savefig("figures/jsb_cosine_pair.eps", format="eps")
plt.show()


# All pairwise features
sims = {}
for i in range(11):
    sims[ACTFUN[i]] = []
    for j in range(0, 1536):
        ks = np.random.choice(range(1536), (10,), replace=False)
        for k in ks:
            sims[ACTFUN[i]].append(
                torch.nn.functional.cosine_similarity(
                    weights[ACTFUN[i]][j], weights[ACTFUN[i]][k], dim=0
                )
                .detach()
                .item()
            )
fig, ax = plt.subplots(nrows=2, ncols=6, figsize=(15, 5))
fig.suptitle("Cosine similarities between all preactivations (1st layer)")
i = 0
for row in ax:
    for col in row:
        if i >= 11:
            break
        x = sims[ACTFUN[i]]
        col.set(title=key2legend[ACTFUN[i]], xlim=(-1, 1), ylim=(0, 8))
        col.hist(x, bins=40, range=(-1, 1), density=True, color=colors[i])
        i += 1
fig.delaxes(ax[1, 5])
fig.tight_layout()
plt.savefig("figures/jsb_cosine_all.eps", format="eps")
plt.show()
