{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reproduce plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display\n",
    "import matplotlib as mpl\n",
    "import os.path as osp\n",
    "\n",
    "\n",
    "template_textwidth_inches = 5.50107\n",
    "golden_ratio = 1.61803398875\n",
    "long_to_short_name = {\n",
    "    \"RedPajama-INCITE-7B-Base\": \"RedPajama\",\n",
    "    \"bloom-7b1\": \"bloom\",\n",
    "    \"falcon-7b\": \"falcon\",\n",
    "    \"galactica-6.7b\": \"galactica\",\n",
    "    \"gpt-j-6b\": \"gpt-j\",\n",
    "    \"llama-7b\": \"llama\",\n",
    "    \"mpt-7b\": \"mpt\",\n",
    "    \"open-llama-7b\": \"open-llama\",\n",
    "    \"opt-6.7b\": \"opt\",\n",
    "    \"pythia-6.9b-deduped\": \"pythia-deduped\",\n",
    "    \"stablelm-base-alpha-7b\": \"stablelm-alpha\",\n",
    "    \"CodeLlama-7b-hf\": \"CodeLlama\",\n",
    "    \"CodeLlama-7b-Python-hf\": \"CodeLlama-Python\",\n",
    "}\n",
    "\n",
    "parquet_basepath = (\n",
    "    \"results/zenodo\"  # TODO: change this to wherever you downloaded the parquetfiles\n",
    ")\n",
    "\n",
    "def prepare_df(path):\n",
    "    df = pd.read_parquet(path)\n",
    "    display(df.head())\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_otis = prepare_df(osp.join(parquet_basepath, \"winogrande_otis.parquet\"))\n",
    "df_otistr = prepare_df(osp.join(parquet_basepath, \"winogrande_otistr.parquet\"))\n",
    "\n",
    "human_otis = prepare_df(osp.join(parquet_basepath, \"humaneval_otis.parquet\"))\n",
    "human_otistr = prepare_df(osp.join(parquet_basepath, \"humaneval_otistr.parquet\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fig 1a (Winogrande, OTIS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df_otis\n",
    "\n",
    "\n",
    "def winogrande_otis_paper_figure(df):\n",
    "    measures = [\n",
    "        \"Pipeline(normalize_matrix_norm+orthogonal_procrustes{})\",\n",
    "        \"aligned_cossim\",\n",
    "        \"rsm_norm_diff\",\n",
    "        \"Pipeline(+jaccard_similarity{'k': 10})\",\n",
    "    ]\n",
    "    measures_short_names = [\n",
    "        \"Orthogonal Procrustes\",\n",
    "        \"Aligned Cossim\",\n",
    "        \"Norm RSM-Diff (Cosine)\",\n",
    "        \"Jaccard (k=10)\",\n",
    "    ]\n",
    "    reverse_cmap_measures = [\n",
    "        \"Pipeline(normalize_matrix_norm+orthogonal_procrustes{})\",\n",
    "        \"rsm_norm_diff\",\n",
    "    ]\n",
    "    score = \"score\"\n",
    "    cbar_width_scaler = 1.3\n",
    "    width_one_axis = template_textwidth_inches / len(measures) * cbar_width_scaler * 2\n",
    "\n",
    "    fig, ax = plt.subplots(\n",
    "        1,\n",
    "        len(measures),\n",
    "        figsize=(len(measures) * width_one_axis, width_one_axis),\n",
    "        squeeze=False,\n",
    "    )\n",
    "    for i, (measure, measure_name) in enumerate(zip(measures, measures_short_names)):\n",
    "        ticklabels = sorted(\n",
    "            set(pd.unique(df.loc[df[\"measure\"] == measure, \"model1\"])).union(\n",
    "                set(pd.unique(df.loc[df[\"measure\"] == measure, \"model2\"]))\n",
    "            )\n",
    "        )\n",
    "\n",
    "        G = nx.from_pandas_edgelist(\n",
    "            df.loc[df[\"measure\"] == measure, [\"model1\", \"model2\", score]].sort_values(\n",
    "                by=[\"model1\", \"model2\"], axis=0\n",
    "            ),\n",
    "            source=\"model1\",\n",
    "            target=\"model2\",\n",
    "            edge_attr=score,\n",
    "        )\n",
    "        data = nx.adjacency_matrix(G, weight=score, nodelist=ticklabels).todense()\n",
    "        # we only want the lower triangle as the measures are symmetric\n",
    "        mask = np.triu(np.ones_like(data, dtype=bool), k=0)\n",
    "        data[mask] = np.nan  # NaN values wont show up\n",
    "        data = data[\n",
    "            1:, :-1\n",
    "        ]  # eliminate the first row and the first column which exlusively consists of nans\n",
    "\n",
    "        ticklabels = [long_to_short_name.get(l, l) for l in ticklabels]\n",
    "        xticklabels = ticklabels[:-1]\n",
    "        yticklabels = ticklabels[1:] if i == 0 else [\"\"] * (len(ticklabels) - 1)\n",
    "        if reverse_cmap_measures and measure in reverse_cmap_measures:\n",
    "            cmap = \"rocket_r\"\n",
    "        else:\n",
    "            cmap = \"rocket\"\n",
    "        _ = sns.heatmap(\n",
    "            data,\n",
    "            ax=ax[0, i],\n",
    "            xticklabels=xticklabels,\n",
    "            yticklabels=yticklabels,\n",
    "            cmap=cmap,\n",
    "            annot=False,\n",
    "            annot_kws=dict(fontsize=\"xx-small\"),\n",
    "            square=False,\n",
    "        )\n",
    "        ax[0, i].set_title(measure_name)\n",
    "        if (i + 1) == len(measures):\n",
    "            with mpl.rc_context({\"text.usetex\": True}):\n",
    "                ax[0, i].collections[0].colorbar.set_label(\n",
    "                    r\"$\\leftarrow$ less similar       more similar $\\rightarrow$\"\n",
    "                    + \"\\n(darker)            (brighter)\"\n",
    "                )\n",
    "    return fig\n",
    "\n",
    "\n",
    "fig = winogrande_otis_paper_figure(\n",
    "    df_otis\n",
    ")\n",
    "fig.show()\n",
    "# fig.savefig(\"figures/repsim_otis_hm.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fig 1b (HumanEval, OTIS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def humaneval_otis_paper_figure(df, score=\"score\"):\n",
    "    measures = [\n",
    "        \"Pipeline(normalize_matrix_norm+orthogonal_procrustes{})\",\n",
    "        \"aligned_cossim\",\n",
    "        \"rsm_norm_diff\",\n",
    "        \"Pipeline(+jaccard_similarity{'k': 10})\",\n",
    "    ]\n",
    "    measures_short_names = [\n",
    "        \"Orthogonal Procrustes\",\n",
    "        \"Aligned Cossim\",\n",
    "        \"Norm RSM-Diff (Cosine)\",\n",
    "        \"Jaccard (k=10)\",\n",
    "    ]\n",
    "    reverse_cmap_measures = [\n",
    "        \"Pipeline(normalize_matrix_norm+orthogonal_procrustes{})\",\n",
    "        \"rsm_norm_diff\",\n",
    "    ]\n",
    "    model_order = [\n",
    "        \"RedPajama-INCITE-7B-Base\",\n",
    "        \"bloom-7b1\",\n",
    "        \"falcon-7b\",\n",
    "        \"galactica-6.7b\",\n",
    "        \"gpt-j-6b\",\n",
    "        \"llama-7b\",\n",
    "        \"mpt-7b\",\n",
    "        \"open-llama-7b\",\n",
    "        \"opt-6.7b\",\n",
    "        \"pythia-6.9b-deduped\",\n",
    "        \"stablelm-base-alpha-7b\",\n",
    "        \"CodeLlama-7b-hf\",\n",
    "        \"CodeLlama-7b-Python-hf\",\n",
    "    ]\n",
    "\n",
    "    df = df.copy()\n",
    "    df[\"model1\"] = pd.Categorical(df[\"model1\"], categories=model_order, ordered=True)\n",
    "    df[\"model2\"] = pd.Categorical(df[\"model2\"], categories=model_order, ordered=True)\n",
    "\n",
    "    cbar_width_scaler = 1.3\n",
    "    width_one_axis = template_textwidth_inches / len(measures) * cbar_width_scaler * 2\n",
    "    fig, ax = plt.subplots(\n",
    "        1,\n",
    "        len(measures),\n",
    "        figsize=(len(measures) * width_one_axis * 1.1, width_one_axis),\n",
    "        squeeze=False,\n",
    "    )\n",
    "    for i, (measure, measure_name) in enumerate(zip(measures, measures_short_names)):\n",
    "        print(measure)\n",
    "\n",
    "        ticklabels = [\n",
    "            s\n",
    "            for s in df[\"model1\"].values.categories\n",
    "            if s in set(df[\"model1\"].unique()).union(set(df[\"model2\"].unique()))\n",
    "        ]\n",
    "        G = nx.from_pandas_edgelist(\n",
    "            df.loc[df[\"measure\"] == measure, [\"model1\", \"model2\", score]].sort_values(\n",
    "                by=[\"model1\", \"model2\"], axis=0\n",
    "            ),\n",
    "            source=\"model1\",\n",
    "            target=\"model2\",\n",
    "            edge_attr=score,\n",
    "        )\n",
    "        data = nx.adjacency_matrix(G, weight=score, nodelist=ticklabels).todense()\n",
    "        # we only want the lower triangle as the measures are symmetric\n",
    "        mask = np.triu(np.ones_like(data, dtype=bool), k=0)\n",
    "        data[mask] = np.nan  # NaN values wont show up\n",
    "        data = data[\n",
    "            1:, :-1\n",
    "        ]  # eliminate the first row and the first column which exlusively consists of nans\n",
    "\n",
    "        ticklabels = [long_to_short_name.get(l, l) for l in ticklabels]\n",
    "        xticklabels = ticklabels[:-1]\n",
    "        yticklabels = ticklabels[1:] if i == 0 else [\"\"] * (len(ticklabels) - 1)\n",
    "        if reverse_cmap_measures and measure in reverse_cmap_measures:\n",
    "            cmap = \"rocket_r\"\n",
    "        else:\n",
    "            cmap = \"rocket\"\n",
    "        _ = sns.heatmap(\n",
    "            data,\n",
    "            ax=ax[0, i],\n",
    "            xticklabels=xticklabels,\n",
    "            yticklabels=yticklabels,\n",
    "            cmap=cmap,\n",
    "            annot=False,\n",
    "            annot_kws=dict(fontsize=\"xx-small\"),\n",
    "            square=False,\n",
    "        )\n",
    "        ax[0, i].set_title(measure_name)\n",
    "        if (i + 1) == len(measures):\n",
    "            with mpl.rc_context({\"text.usetex\": True}):\n",
    "                ax[0, i].collections[0].colorbar.set_label(\n",
    "                    r\"$\\leftarrow$ less similar       more similar $\\rightarrow$\"\n",
    "                    + \"\\n(darker)            (brighter)\"\n",
    "                )\n",
    "    return fig\n",
    "\n",
    "\n",
    "fig = humaneval_otis_paper_figure(human_otis)\n",
    "fig.show()\n",
    "# fig.savefig(\"figures/repsim_humaneval_otis_hm.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fig 2a (Winogrande, OTISTR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def winogrande_otistr_paper_figure(df):\n",
    "    measures = [\n",
    "        \"Pipeline(center_columns+normalize_matrix_norm+orthogonal_procrustes{})\",\n",
    "        \"Pipeline(center_columns+aligned_cossim{})\",\n",
    "        \"Pipeline(center_columns+normalize_matrix_norm+rsm_norm_diff{'inner': 'euclidean'})\",\n",
    "        \"Pipeline(center_columns+jaccard_similarity{})\",\n",
    "        \"Pipeline(normalize_matrix_norm+representational_similarity_analysis{'inner': 'euclidean', 'outer': 'spearman'})\",\n",
    "        \"centered_kernel_alignment\",\n",
    "    ]\n",
    "    measures_short_names = [\n",
    "        \"Orthogonal\\nProcrustes\",\n",
    "        \"Aligned Cossim\",\n",
    "        \"Norm RSM-Diff\\n(Euclidean)\",\n",
    "        \"Jaccard (k=10)\",\n",
    "        \"RSA\\n(Euclidean, Spearman)\",\n",
    "        \"CKA\",\n",
    "    ]\n",
    "    reverse_cmap_measures = [\n",
    "        \"Pipeline(center_columns+normalize_matrix_norm+orthogonal_procrustes{})\",\n",
    "        \"Pipeline(center_columns+normalize_matrix_norm+rsm_norm_diff{'inner': 'euclidean'})\",\n",
    "    ]\n",
    "    score = \"score\"\n",
    "    cbar_width_scaler = 1.3\n",
    "    width_one_axis = template_textwidth_inches / len(measures) * cbar_width_scaler * 2\n",
    "\n",
    "    fig, ax = plt.subplots(\n",
    "        1,\n",
    "        len(measures),\n",
    "        figsize=(len(measures) * width_one_axis, width_one_axis),\n",
    "        squeeze=False,\n",
    "    )\n",
    "    for i, (measure, measure_name) in enumerate(zip(measures, measures_short_names)):\n",
    "        print(measure)\n",
    "        ticklabels = sorted(\n",
    "            set(pd.unique(df.loc[df[\"measure\"] == measure, \"model1\"])).union(\n",
    "                set(pd.unique(df.loc[df[\"measure\"] == measure, \"model2\"]))\n",
    "            )\n",
    "        )\n",
    "\n",
    "        G = nx.from_pandas_edgelist(\n",
    "            df.loc[df[\"measure\"] == measure, [\"model1\", \"model2\", score]].sort_values(\n",
    "                by=[\"model1\", \"model2\"], axis=0\n",
    "            ),\n",
    "            source=\"model1\",\n",
    "            target=\"model2\",\n",
    "            edge_attr=score,\n",
    "        )\n",
    "        data = nx.adjacency_matrix(G, weight=score, nodelist=ticklabels).todense()\n",
    "        # we only want the lower triangle as the measures are symmetric\n",
    "        mask = np.triu(np.ones_like(data, dtype=bool), k=0)\n",
    "        data[mask] = np.nan  # NaN values wont show up\n",
    "        data = data[\n",
    "            1:, :-1\n",
    "        ]  # eliminate the first row and the first column which exlusively consists of nans\n",
    "\n",
    "        ticklabels = [long_to_short_name.get(l, l) for l in ticklabels]\n",
    "        xticklabels = ticklabels[:-1]\n",
    "        yticklabels = ticklabels[1:] if i == 0 else [\"\"] * (len(ticklabels) - 1)\n",
    "        if reverse_cmap_measures and measure in reverse_cmap_measures:\n",
    "            cmap = \"rocket_r\"\n",
    "        else:\n",
    "            cmap = \"rocket\"\n",
    "        _ = sns.heatmap(\n",
    "            data,\n",
    "            ax=ax[0, i],\n",
    "            xticklabels=xticklabels,\n",
    "            yticklabels=yticklabels,\n",
    "            cmap=cmap,\n",
    "            annot=False,\n",
    "            annot_kws=dict(fontsize=\"xx-small\"),\n",
    "            square=False,\n",
    "        )\n",
    "        ax[0, i].set_title(measure_name)\n",
    "        if (i + 1) == len(measures):\n",
    "            with mpl.rc_context({\"text.usetex\": True}):\n",
    "                ax[0, i].collections[0].colorbar.set_label(\n",
    "                    r\"$\\leftarrow$ less similar       more similar $\\rightarrow$\"\n",
    "                    + \"\\n(darker)            (brighter)\"\n",
    "                )\n",
    "    return fig\n",
    "\n",
    "\n",
    "fig = winogrande_otistr_paper_figure(df_otistr)\n",
    "fig.show()\n",
    "# fig.savefig(\"figures/repsim_otistr_hm.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fig 2b (HumanEval, OTISTR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def humaneval_otistr_paper_figure(df):\n",
    "    measures = [\n",
    "        \"Pipeline(center_columns+normalize_matrix_norm+orthogonal_procrustes{})\",\n",
    "        \"Pipeline(center_columns+aligned_cossim{})\",\n",
    "        \"Pipeline(center_columns+normalize_matrix_norm+rsm_norm_diff{'inner': 'euclidean'})\",\n",
    "        \"Pipeline(center_columns+jaccard_similarity{})\",\n",
    "        \"Pipeline(normalize_matrix_norm+representational_similarity_analysis{'inner': 'euclidean', 'outer': 'spearman'})\",\n",
    "        \"centered_kernel_alignment\",\n",
    "    ]\n",
    "    measures_short_names = [\n",
    "        \"Orthogonal\\nProcrustes\",\n",
    "        \"Aligned Cossim\",\n",
    "        \"Norm RSM-Diff\\n(Euclidean)\",\n",
    "        \"Jaccard (k=10)\",\n",
    "        \"RSA\\n(Euclidean, Spearman)\",\n",
    "        \"CKA\",\n",
    "    ]\n",
    "    reverse_cmap_measures = [\n",
    "        \"Pipeline(center_columns+normalize_matrix_norm+orthogonal_procrustes{})\",\n",
    "        \"Pipeline(center_columns+normalize_matrix_norm+rsm_norm_diff{'inner': 'euclidean'})\",\n",
    "    ]\n",
    "    model_order = [\n",
    "        \"RedPajama-INCITE-7B-Base\",\n",
    "        \"bloom-7b1\",\n",
    "        \"falcon-7b\",\n",
    "        \"galactica-6.7b\",\n",
    "        \"gpt-j-6b\",\n",
    "        \"llama-7b\",\n",
    "        \"mpt-7b\",\n",
    "        \"open-llama-7b\",\n",
    "        \"opt-6.7b\",\n",
    "        \"pythia-6.9b-deduped\",\n",
    "        \"stablelm-base-alpha-7b\",\n",
    "        \"CodeLlama-7b-hf\",\n",
    "        \"CodeLlama-7b-Python-hf\",\n",
    "    ]\n",
    "    score = \"score\"\n",
    "\n",
    "    df = df.copy()\n",
    "    df[\"model1\"] = pd.Categorical(df[\"model1\"], categories=model_order, ordered=True)\n",
    "    df[\"model2\"] = pd.Categorical(df[\"model2\"], categories=model_order, ordered=True)\n",
    "\n",
    "    cbar_width_scaler = 1.3\n",
    "    width_one_axis = template_textwidth_inches / len(measures) * cbar_width_scaler * 2\n",
    "    fig, ax = plt.subplots(\n",
    "        1,\n",
    "        len(measures),\n",
    "        figsize=(len(measures) * width_one_axis * 1.1, width_one_axis),\n",
    "        squeeze=False,\n",
    "    )\n",
    "    for i, (measure, measure_name) in enumerate(zip(measures, measures_short_names)):\n",
    "        print(measure)\n",
    "        ticklabels = [\n",
    "            s\n",
    "            for s in df[\"model1\"].values.categories\n",
    "            if s in set(df[\"model1\"].unique()).union(set(df[\"model2\"].unique()))\n",
    "        ]\n",
    "\n",
    "        G = nx.from_pandas_edgelist(\n",
    "            df.loc[df[\"measure\"] == measure, [\"model1\", \"model2\", score]].sort_values(\n",
    "                by=[\"model1\", \"model2\"], axis=0\n",
    "            ),\n",
    "            source=\"model1\",\n",
    "            target=\"model2\",\n",
    "            edge_attr=score,\n",
    "        )\n",
    "        data = nx.adjacency_matrix(G, weight=score, nodelist=ticklabels).todense()\n",
    "        # we only want the lower triangle as the measures are symmetric\n",
    "        mask = np.triu(np.ones_like(data, dtype=bool), k=0)\n",
    "        data[mask] = np.nan  # NaN values wont show up\n",
    "        data = data[\n",
    "            1:, :-1\n",
    "        ]  # eliminate the first row and the first column which exlusively consists of nans\n",
    "\n",
    "        ticklabels = [long_to_short_name.get(l, l) for l in ticklabels]\n",
    "        xticklabels = ticklabels[:-1]\n",
    "        yticklabels = ticklabels[1:] if i == 0 else [\"\"] * (len(ticklabels) - 1)\n",
    "        if reverse_cmap_measures and measure in reverse_cmap_measures:\n",
    "            cmap = \"rocket_r\"\n",
    "        else:\n",
    "            cmap = \"rocket\"\n",
    "        _ = sns.heatmap(\n",
    "            data,\n",
    "            ax=ax[0, i],\n",
    "            xticklabels=xticklabels,\n",
    "            yticklabels=yticklabels,\n",
    "            cmap=cmap,\n",
    "            annot=False,\n",
    "            annot_kws=dict(fontsize=\"xx-small\"),\n",
    "            square=False,\n",
    "        )\n",
    "        ax[0, i].set_title(measure_name)\n",
    "        if (i + 1) == len(measures):\n",
    "            with mpl.rc_context({\"text.usetex\": True}):\n",
    "                ax[0, i].collections[0].colorbar.set_label(\n",
    "                    r\"$\\leftarrow$ less similar       more similar $\\rightarrow$\"\n",
    "                    + \"\\n(darker)            (brighter)\"\n",
    "                )\n",
    "    return fig\n",
    "\n",
    "fig = humaneval_otistr_paper_figure(\n",
    "    human_otistr\n",
    ")\n",
    "fig.show()\n",
    "# fig.savefig(\"figures/repsim_humaneval_otistr_hm.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Correlations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import spearmanr\n",
    "import scipy.spatial.distance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Correlations across datasets per measures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df1 = df_otis\n",
    "df2 = human_otis\n",
    "\n",
    "\n",
    "def all_models(df):\n",
    "    return set(pd.unique(df.model1)).union(set(pd.unique(df.model2)))\n",
    "\n",
    "\n",
    "models_in_both_dfs = all_models(df1).intersection(all_models(df2))\n",
    "\n",
    "\n",
    "def cross_dataset_correlation(df1, df2, distance_measures, score=\"score\"):\n",
    "    def get_array(df):\n",
    "        ticklabels = sorted(\n",
    "            set(pd.unique(df.loc[df[\"measure\"] == measure, \"model1\"])).union(\n",
    "                set(pd.unique(df.loc[df[\"measure\"] == measure, \"model2\"]))\n",
    "            )\n",
    "        )\n",
    "        G = nx.from_pandas_edgelist(\n",
    "            df.loc[df[\"measure\"] == measure, [\"model1\", \"model2\", score]].sort_values(\n",
    "                by=[\"model1\", \"model2\"], axis=0\n",
    "            ),\n",
    "            source=\"model1\",\n",
    "            target=\"model2\",\n",
    "            edge_attr=score,\n",
    "        )\n",
    "        data = nx.adjacency_matrix(G, weight=score, nodelist=ticklabels).todense()\n",
    "        # we only want the lower triangle as the measures are symmetric\n",
    "        mask = np.triu(np.ones_like(data, dtype=bool), k=0)\n",
    "        data[mask] = np.nan\n",
    "        data = data.flatten()\n",
    "        return data[~np.isnan(data)]\n",
    "\n",
    "    measures = set(pd.unique(df1.measure)).intersection(set(pd.unique(df2.measure)))\n",
    "    corrs = {}\n",
    "    for measure in measures:\n",
    "        df1_measure = df1.loc[\n",
    "            (df1.model1.isin(models_in_both_dfs))\n",
    "            & (df1.model2.isin(models_in_both_dfs))\n",
    "            & (df1.measure == measure)\n",
    "        ]\n",
    "        df2_measure = df2.loc[\n",
    "            (df2.model1.isin(models_in_both_dfs))\n",
    "            & (df2.model2.isin(models_in_both_dfs))\n",
    "            & (df2.measure == measure)\n",
    "        ]\n",
    "\n",
    "        data1 = get_array(df1_measure)\n",
    "        data2 = get_array(df2_measure)\n",
    "        if measure in distance_measures:\n",
    "            data1, data2 = -1 * data1, -1 * data2\n",
    "\n",
    "        corrs[measure] = {\n",
    "            \"spearman\": spearmanr(data1, data2).statistic,\n",
    "            \"pearson\": 1\n",
    "            - scipy.spatial.distance.cdist(\n",
    "                data1.reshape(1, -1), data2.reshape(1, -1), metric=\"correlation\"\n",
    "            ),\n",
    "        }\n",
    "    return corrs\n",
    "\n",
    "\n",
    "res = cross_dataset_correlation(\n",
    "    df1,\n",
    "    df2,\n",
    "    [\n",
    "        \"Pipeline(normalize_matrix_norm+orthogonal_procrustes{})\",\n",
    "        \"rsm_norm_diff\",\n",
    "    ],\n",
    ")\n",
    "avg_spearman_corr = 0\n",
    "for measure, corrs in res.items():\n",
    "    print(measure)\n",
    "    print(corrs)\n",
    "    avg_spearman_corr += corrs[\"spearman\"]\n",
    "avg_spearman_corr /= len(res)\n",
    "avg_spearman_corr"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
