{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dee1ca4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3278c18",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from rae import PROJECT_ROOT\n",
    "from rae.modules.enumerations import Output\n",
    "from rae.pl_modules.pl_gautoencoder import LightningAutoencoder\n",
    "\n",
    "try:\n",
    "    # be ready for 3.10 when it drops\n",
    "    from enum import StrEnum\n",
    "except ImportError:\n",
    "    from backports.strenum import StrEnum\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots import bundles\n",
    "from tueplots import figsizes\n",
    "\n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "DEVICE: str = \"cuda\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99d998ac",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.modules.text.encoder import GensimEncoder\n",
    "\n",
    "ENCODERS = {\n",
    "    model_name: GensimEncoder(language=\"en\", lemmatize=False, model_name=model_name)\n",
    "    for model_name in (\n",
    "        \"local_fasttext\",\n",
    "        \"word2vec-google-news-300\",\n",
    "    )\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d13cd066",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "assert len({frozenset(encoder.model.key_to_index.keys()) for encoder in ENCODERS.values()}) == 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e385241",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import random\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "seed_everything(4)\n",
    "\n",
    "NUM_ANCHORS = 300\n",
    "NUM_TARGETS = 200\n",
    "NUM_CLUSTERS = 4\n",
    "WORDS = sorted(ENCODERS[\"local_fasttext\"].model.key_to_index.keys())\n",
    "WORDS = [word for word in WORDS if word.isalpha() and len(word) >= 4]\n",
    "TARGET_WORDS = [\"school\", \"ferrari\", \"water\", \"martial\"]  # words to take the neighborhoods from\n",
    "# TARGET_WORDS = random.sample(WORDS, NUM_CLUSTERS)\n",
    "print(f\"{TARGET_WORDS=}\")\n",
    "word2index = {word: i for i, word in enumerate(WORDS)}\n",
    "TARGETS = torch.zeros(len(WORDS), device=\"cpu\")\n",
    "target_cluster = [\n",
    "    [word for word, sim in ENCODERS[\"local_fasttext\"].model.most_similar(target_word, topn=NUM_TARGETS)]\n",
    "    for target_word in TARGET_WORDS\n",
    "]\n",
    "\n",
    "valid_words, valid_targets = [], []\n",
    "for i, target_cluster in enumerate(target_cluster):\n",
    "    valid_words.append(TARGET_WORDS[i])\n",
    "    valid_targets.append(i + 1)\n",
    "    for word in target_cluster:\n",
    "        if word in word2index:\n",
    "            valid_words.append(word)\n",
    "            valid_targets.append(i + 1)\n",
    "\n",
    "WORDS = valid_words\n",
    "TARGETS = valid_targets\n",
    "\n",
    "ANCHOR_WORDS = sorted(random.sample(WORDS, NUM_ANCHORS))  # TODO: stratified\n",
    "\n",
    "ANCHOR_WORDS[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "148843f7",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "\n",
    "\n",
    "def latents_distance(latents):\n",
    "    assert len(latents) == 2\n",
    "    for x in latents:\n",
    "        assert x.shape[1] == 300\n",
    "\n",
    "    dist = F.pairwise_distance(latents[0], latents[1], p=2).mean()\n",
    "    return f\"{dist:.2f}\"\n",
    "\n",
    "\n",
    "def get_latents(words, encoder: GensimEncoder):\n",
    "    latents = torch.tensor([encoder.model.get_vector(word) for word in words], device=DEVICE)\n",
    "    return latents\n",
    "\n",
    "\n",
    "def to_df(latents, fit_pca: bool = True):\n",
    "    if fit_pca:\n",
    "        latents2d = PCA(n_components=2).fit_transform(latents.cpu())\n",
    "    else:\n",
    "        latents2d = latents[:, [0, 1]]\n",
    "    df = pd.DataFrame(\n",
    "        {\n",
    "            \"x\": latents2d[:, 0].tolist(),\n",
    "            \"y\": latents2d[:, 1].tolist(),\n",
    "            \"target\": TARGETS,\n",
    "        }\n",
    "    )\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed0388f8",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Plot stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61c6b02f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def plot_bg(\n",
    "    ax,\n",
    "    df,\n",
    "    cmap,\n",
    "    norm,\n",
    "    size,\n",
    "    bg_alpha,\n",
    "):\n",
    "    \"\"\"Create and return a plot of all our movie embeddings with very low opacity.\n",
    "    (Intended to be used as a basis for further - more prominent - plotting of a\n",
    "    subset of movies. Having the overall shape of the map space in the background is\n",
    "    useful for context.)\n",
    "    \"\"\"\n",
    "    ax.scatter(df.x, df.y, c=cmap(norm(df[\"target\"])), alpha=bg_alpha, s=size)\n",
    "    return ax\n",
    "\n",
    "\n",
    "def hightlight_cluster(\n",
    "    ax,\n",
    "    df,\n",
    "    target,\n",
    "    alpha,\n",
    "    cmap,\n",
    "    norm,\n",
    "    size=0.5,\n",
    "):\n",
    "    cluster_df = df[df[\"target\"] == target]\n",
    "    ax.scatter(cluster_df.x, cluster_df.y, c=cmap(norm(cluster_df[\"target\"])), alpha=alpha, s=size)\n",
    "\n",
    "\n",
    "def plot_latent_space(ax, df, targets, size, cmap, norm, bg_alpha, alpha):\n",
    "    ax = plot_bg(ax, df, bg_alpha=bg_alpha, cmap=cmap, norm=norm, size=size)\n",
    "    for target in targets:\n",
    "        hightlight_cluster(ax, df, target, alpha=alpha, size=size, cmap=cmap, norm=norm)\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13da9076",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## AE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c47ff8cd",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "ae_latents = {}\n",
    "anchors_latents = {}\n",
    "for enc_name, encoder in ENCODERS.items():\n",
    "    ae_latents[enc_name] = get_latents(words=WORDS, encoder=encoder)\n",
    "    anchors_latents[enc_name] = get_latents(words=ANCHOR_WORDS, encoder=encoder)\n",
    "\n",
    "import copy\n",
    "\n",
    "original_ae_latents = copy.deepcopy(ae_latents)\n",
    "original_anchor_latents = copy.deepcopy(anchors_latents)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e466ea65",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Rel Attention NO Quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdf7c14e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "from rae.modules.attention import *\n",
    "\n",
    "\n",
    "col_config = ((None, None),)\n",
    "N_ROWS = len(ENCODERS)\n",
    "N_COLS = len(col_config) + 1\n",
    "print(\n",
    "    N_ROWS,\n",
    "    N_COLS,\n",
    ")\n",
    "plt.rcParams.update(bundles.icml2022())\n",
    "plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))\n",
    "\n",
    "import matplotlib as mpl\n",
    "\n",
    "num_colors = len(TARGET_WORDS)\n",
    "cmap = mpl.colors.ListedColormap(plt.cm.get_cmap(\"Set1\", 10).colors[:num_colors], name=\"rgb\", N=num_colors)\n",
    "norm = plt.Normalize(min(TARGETS), max(TARGETS))\n",
    "\n",
    "fig, axes = plt.subplots(dpi=300, nrows=N_ROWS, ncols=N_COLS, sharey=True, sharex=True, squeeze=True)\n",
    "\n",
    "S = 7\n",
    "BG_ALPHA = 0.35\n",
    "ALPHA = 0.5\n",
    "\n",
    "TARGETS_HIGHTLIGHT = [1]\n",
    "for ax_encoders, (_, latents) in zip(axes[0], ae_latents.items()):\n",
    "\n",
    "    plot_latent_space(\n",
    "        ax_encoders,\n",
    "        to_df(latents),\n",
    "        targets=TARGETS_HIGHTLIGHT,\n",
    "        size=S,\n",
    "        bg_alpha=BG_ALPHA,\n",
    "        alpha=ALPHA,\n",
    "        cmap=cmap,\n",
    "        norm=norm,\n",
    "    )\n",
    "\n",
    "distances = {\"absolute\": latents_distance(list(ae_latents.values()))}\n",
    "\n",
    "for col_i, (quant_mode, bin_size) in enumerate(col_config):\n",
    "    rel_attention = RelativeAttention(\n",
    "        n_anchors=NUM_ANCHORS,\n",
    "        n_classes=len(set(TARGETS)),\n",
    "        similarity_mode=RelativeEmbeddingMethod.INNER,\n",
    "        values_mode=ValuesMethod.SIMILARITIES,\n",
    "        normalization_mode=NormalizationMode.L2,\n",
    "    )\n",
    "    assert sum(x.numel() for x in rel_attention.parameters()) == 0\n",
    "    rels = []\n",
    "    for row_ax, (enc_name, latents), (a_enc_name, a_latents) in zip(\n",
    "        axes[1], ae_latents.items(), anchors_latents.items()\n",
    "    ):\n",
    "        assert enc_name == a_enc_name\n",
    "        rel = rel_attention(x=latents, anchors=a_latents)[AttentionOutput.SIMILARITIES]\n",
    "        rels.append(rel)\n",
    "        plot_latent_space(\n",
    "            row_ax,\n",
    "            to_df(rel),\n",
    "            targets=TARGETS_HIGHTLIGHT,\n",
    "            size=S,\n",
    "            bg_alpha=BG_ALPHA,\n",
    "            alpha=ALPHA,\n",
    "            cmap=cmap,\n",
    "            norm=norm,\n",
    "        )\n",
    "    distances[f\"relative({quant_mode}, {bin_size})\"] = latents_distance(rels)\n",
    "\n",
    "distances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f3633f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(\"word-embeddings-spaces-no-quant.svg\", bbox_inches=\"tight\")\n",
    "!rsvg-convert -f pdf -o 'word-embeddings-spaces-no-quant.pdf' 'word-embeddings-spaces-no-quant.svg'\n",
    "!rm 'word-embeddings-spaces-no-quant'.svg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "237f1152",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Rel Attention Quantized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd726f9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "from rae.modules.attention import *\n",
    "\n",
    "\n",
    "col_config = (\n",
    "    (None, None),\n",
    "    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.0001),\n",
    "    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.05),\n",
    "    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.1),\n",
    "    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.3),\n",
    "    (\"cluster\", 1),\n",
    "    #     (\"cluster\", 0.5),\n",
    "    (\"cluster\", 1.5),\n",
    "    (\"cluster\", 2),\n",
    "    # ('kmeans', 3),\n",
    "    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.6),\n",
    "    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.7),\n",
    "    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.8),\n",
    "    #    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.9),\n",
    ")\n",
    "N_ROWS = len(ENCODERS)\n",
    "N_COLS = len(col_config) + 1\n",
    "\n",
    "plt.rcParams.update(bundles.icml2022())\n",
    "plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))\n",
    "\n",
    "import matplotlib as mpl\n",
    "\n",
    "num_colors = len(TARGET_WORDS)\n",
    "cmap = mpl.colors.ListedColormap(plt.cm.get_cmap(\"Set1\", 10).colors[:num_colors], name=\"rgb\", N=num_colors)\n",
    "# cmap = plt.cm.get_cmap(\"Set1\", 5)\n",
    "norm = plt.Normalize(min(TARGETS), max(TARGETS))\n",
    "\n",
    "fig, axes = plt.subplots(dpi=150, nrows=N_ROWS, ncols=N_COLS, sharey=True, sharex=True, squeeze=True)\n",
    "\n",
    "\n",
    "TARGETS_HIGHTLIGHT = [1]\n",
    "for ax_encoders, (_, latents) in zip(axes, ae_latents.items()):\n",
    "\n",
    "    plot_latent_space(\n",
    "        ax_encoders[0],\n",
    "        to_df(latents),\n",
    "        targets=TARGETS_HIGHTLIGHT,\n",
    "        size=0.75,\n",
    "        bg_alpha=0.25,\n",
    "        alpha=1,\n",
    "        cmap=cmap,\n",
    "        norm=norm,\n",
    "    )\n",
    "\n",
    "distances = {\"absolute\": latents_distance(list(ae_latents.values()))}\n",
    "\n",
    "for col_i, (quant_mode, bin_size) in enumerate(col_config):\n",
    "    rel_attention = RelativeAttention(\n",
    "        n_anchors=NUM_ANCHORS,\n",
    "        n_classes=len(set(TARGETS)),\n",
    "        similarity_mode=RelativeEmbeddingMethod.INNER,\n",
    "        values_mode=ValuesMethod.SIMILARITIES,\n",
    "        normalization_mode=NormalizationMode.L2,\n",
    "        #  output_normalization_mode=OutputNormalization.L2,\n",
    "        #         similarities_quantization_mode=quant_mode,\n",
    "        #         similarities_bin_size=bin_size,\n",
    "        #         similarities_num_clusters=bin_size,\n",
    "        absolute_quantization_mode=quant_mode,\n",
    "        absolute_bin_size=bin_size,\n",
    "        absolute_num_clusters=bin_size,\n",
    "    )\n",
    "    assert sum(x.numel() for x in rel_attention.parameters()) == 0\n",
    "    rels = []\n",
    "    for row_axes, (enc_name, latents), (a_enc_name, a_latents) in zip(\n",
    "        axes, ae_latents.items(), anchors_latents.items()\n",
    "    ):\n",
    "        assert enc_name == a_enc_name\n",
    "        rel = rel_attention(x=latents, anchors=a_latents)[AttentionOutput.SIMILARITIES]\n",
    "        rels.append(rel)\n",
    "        plot_latent_space(\n",
    "            row_axes[col_i + 1],\n",
    "            to_df(rel),\n",
    "            targets=TARGETS_HIGHTLIGHT,\n",
    "            size=0.75,\n",
    "            bg_alpha=0.25,\n",
    "            alpha=1,\n",
    "            cmap=cmap,\n",
    "            norm=norm,\n",
    "        )\n",
    "    distances[f\"relative({quant_mode}, {bin_size})\"] = latents_distance(rels)\n",
    "\n",
    "distances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb02efa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(\"word-embeddings-spaces.svg\", bbox_inches=\"tight\")\n",
    "!rsvg-convert -f pdf -o 'word-embeddings-spaces.pdf' 'word-embeddings-spaces.svg'\n",
    "!rm 'word-embeddings-spaces'.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15fdd2e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_words = list(list(ENCODERS.values())[0].model.key_to_index.keys())[400:]\n",
    "random_words = [word for word in random_words if word.isalpha() and len(word) >= 4]\n",
    "# random.shuffle(random_words)\n",
    "SEARCH_WORDS = random_words[:20_000]\n",
    "SEARCH_WORDS"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6710e11c",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Faiss Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c770cbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "RETRIEVAL_ANCHORS_NUM = 300\n",
    "RETRIEVAL_ANCHORS = sorted(random.sample(SEARCH_WORDS, RETRIEVAL_ANCHORS_NUM))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a45bb30",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.openfaiss import FaissIndex\n",
    "\n",
    "enc_type2enc_name2faiss_index = {\n",
    "    enc_type: {\n",
    "        enc_name: FaissIndex(d=300 if enc_type == \"absolute\" else RETRIEVAL_ANCHORS_NUM) for enc_name in ENCODERS.keys()\n",
    "    }\n",
    "    for enc_type in (\"absolute\", \"relative\")\n",
    "}\n",
    "\n",
    "\n",
    "for enc_name, encoder in ENCODERS.items():\n",
    "    print(enc_name)\n",
    "    latents = encoder.model.vectors_for_all(keys=SEARCH_WORDS).vectors\n",
    "    enc_type2enc_name2faiss_index[\"absolute\"][enc_name].add_vectors(\n",
    "        embeddings=list(zip(SEARCH_WORDS, latents)), normalize=True\n",
    "    )\n",
    "\n",
    "    rel_attention = RelativeAttention(\n",
    "        n_anchors=NUM_ANCHORS,\n",
    "        n_classes=len(set(TARGETS)),\n",
    "        similarity_mode=RelativeEmbeddingMethod.INNER,\n",
    "        values_mode=ValuesMethod.SIMILARITIES,\n",
    "        normalization_mode=NormalizationMode.L2,\n",
    "        #  output_normalization_mode=OutputNormalization.L2,\n",
    "        #          similarities_quantization_mode='differentiable_round',\n",
    "        #          similarities_bin_size=0.01,\n",
    "        #          similarities_num_clusters=,\n",
    "        #         absolute_quantization_mode=\"cluster\",\n",
    "        #         absolute_bin_size=2,  # ignored\n",
    "        #         absolute_num_clusters=2,\n",
    "    )\n",
    "    anchors = get_latents(words=RETRIEVAL_ANCHORS, encoder=encoder)\n",
    "    latents = torch.as_tensor(latents, dtype=torch.float32)\n",
    "    relative_representation = rel_attention(x=latents, anchors=anchors.cpu())\n",
    "\n",
    "    enc_type2enc_name2faiss_index[\"relative\"][enc_name].add_vectors(\n",
    "        embeddings=list(zip(SEARCH_WORDS, relative_representation[AttentionOutput.SIMILARITIES].cpu().numpy())),\n",
    "        normalize=True,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccb68ff9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "enc_type2enc_name2faiss_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ead1cdfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import numpy as np\n",
    "\n",
    "K = 5\n",
    "enc_type2enc_names2word2topk = {enc_type: {} for enc_type in (\"absolute\", \"relative\")}\n",
    "\n",
    "for enc_type, enc_name2faiss_index in enc_type2enc_name2faiss_index.items():\n",
    "    for enc_name1, enc_name2 in itertools.product(enc_name2faiss_index.keys(), repeat=2):\n",
    "        faiss_index1 = enc_name2faiss_index[enc_name1]\n",
    "        faiss_index2 = enc_name2faiss_index[enc_name2]\n",
    "\n",
    "        enc1_vectors = np.asarray(faiss_index1.reconstruct_n(SEARCH_WORDS), dtype=\"float32\")\n",
    "        enc2_neighbors = faiss_index2.search_by_vectors(query_vectors=enc1_vectors, k_most_similar=K, normalize=False)\n",
    "\n",
    "        enc_type2enc_names2word2topk[enc_type][(enc_name1, enc_name2)] = {\n",
    "            word: topk for word, topk in zip(SEARCH_WORDS, enc2_neighbors)\n",
    "        }\n",
    "# enc_type2enc_names2word2topk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efea7647",
   "metadata": {},
   "outputs": [],
   "source": [
    "performance_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d3048d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "performance = {key: [] for key in (\"src_enc\", \"tgt_enc\", \"enc_type\", \"topk_jaccard\", \"mrr\", \"semantic_horizon\")}\n",
    "\n",
    "for enc_type, enc_names2word2topk in enc_type2enc_names2word2topk.items():\n",
    "    for (enc_name1, enc_name2), word2topk in enc_names2word2topk.items():\n",
    "        target_words = {\n",
    "            search_word: set(enc_names2word2topk[(enc_name2, enc_name2)][search_word].keys())\n",
    "            for search_word in SEARCH_WORDS\n",
    "        }\n",
    "        actual_words = {\n",
    "            search_word: set(enc_names2word2topk[(enc_name1, enc_name2)][search_word].keys())\n",
    "            for search_word in SEARCH_WORDS\n",
    "        }\n",
    "\n",
    "        topk_jaccard = {\n",
    "            search_word: len(set.intersection(target_words[search_word], actual_words[search_word]))\n",
    "            / len(set.union(target_words[search_word], actual_words[search_word]))\n",
    "            for search_word in SEARCH_WORDS\n",
    "        }\n",
    "        topk_jaccard = np.mean(list(topk_jaccard.values()))\n",
    "\n",
    "        search_word2word2rank = {\n",
    "            search_word: {key: index for index, key in enumerate(word2sim.keys(), start=1)}\n",
    "            for search_word, word2sim in enc_names2word2topk[(enc_name1, enc_name2)].items()\n",
    "        }\n",
    "        mrr = {\n",
    "            search_word: (\n",
    "                #                 word2rank.get(search_word, K)\n",
    "                0\n",
    "                if search_word not in word2rank\n",
    "                else 1 / word2rank[search_word]\n",
    "            )\n",
    "            for search_word, word2rank in search_word2word2rank.items()\n",
    "        }\n",
    "        mrr = np.mean(list(mrr.values()))\n",
    "\n",
    "        semantic_horizon = []\n",
    "        for search_word, neighbors in actual_words.items():\n",
    "            neighbor2ranking = {\n",
    "                neighbor: {\n",
    "                    key: index\n",
    "                    for index, key in enumerate(\n",
    "                        enc_type2enc_names2word2topk[\"absolute\"][(enc_name2, enc_name2)][neighbor].keys(), start=1\n",
    "                    )\n",
    "                }\n",
    "                for neighbor in neighbors\n",
    "            }\n",
    "            neighbor2mrr = {\n",
    "                neighbor: (\n",
    "                    #                 topk.get(search_word, K)\n",
    "                    0\n",
    "                    if search_word not in ranking\n",
    "                    else 1 / ranking[search_word]\n",
    "                )\n",
    "                for neighbor, ranking in neighbor2ranking.items()\n",
    "            }\n",
    "            semantic_horizon.append(np.mean(list(neighbor2mrr.values())))\n",
    "\n",
    "        semantic_horizon = np.mean(semantic_horizon)\n",
    "\n",
    "        performance[\"src_enc\"].append(enc_name1)\n",
    "        performance[\"tgt_enc\"].append(enc_name2)\n",
    "        performance[\"enc_type\"].append(enc_type)\n",
    "        performance[\"topk_jaccard\"].append(topk_jaccard)\n",
    "        performance[\"mrr\"].append(mrr)\n",
    "        performance[\"semantic_horizon\"].append(semantic_horizon)\n",
    "\n",
    "performance_df = pd.DataFrame(performance)\n",
    "performance_df.to_csv(\n",
    "    PROJECT_ROOT / \"experiments\" / \"fig:latent-rotation-comparison\" / \"semantic_horizon.tsv\", sep=\"\\t\"\n",
    ")\n",
    "performance_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74732009",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_type2enc_names2word2topk[\"absolute\"][(\"local_fasttext\", \"word2vec-google-news-300\")][\"student\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62bcd982",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_type2enc_names2word2topk[\"relative\"][(\"local_fasttext\", \"word2vec-google-news-300\")][\"student\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa636699",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_type2enc_names2word2topk[\"absolute\"][(\"local_fasttext\", \"local_fasttext\")][\"student\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42a02dbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_type2enc_names2word2topk[\"relative\"][(\"local_fasttext\", \"local_fasttext\")][\"student\"]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
