{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2f267b5b-d9bb-478b-9350-d92e825c82b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def corr(mat1, mat2):\n",
    "    new_inds = np.triu_indices(len(mat1),1)\n",
    "    r = spearmanr((mat1[new_inds].flatten()), mat2[new_inds].flatten())\n",
    "    #pearson = pearsonr((mat1.flatten()), mat2.flatten())\n",
    "    return r[0]\n",
    "\n",
    "import pickle\n",
    "with open(\"human_sims.pkl\", \"rb\") as a_file:\n",
    "    datasets = pickle.load(a_file)\n",
    "dnames = datasets.keys()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b31922d4-7f47-45ea-81f0-c74c7bf69bc5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2347995/2564760024.py:15: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  t1['mean']=t1.mean(axis=1)\n",
      "/tmp/ipykernel_2347995/2564760024.py:16: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  t2['mean']=t2.mean(axis=1)\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 560/560 [00:04<00:00, 122.86it/s]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "dnames=['animals', 'automobiles', 'fruits', 'vegetables', 'furniture', 'various']\n",
    "def pairwise_alignment(model):\n",
    "    row=[]\n",
    "    for dname in dnames:\n",
    "        model_pairs=np.load(f'timm_simmats2/{dname}_{model}.npy')\n",
    "        human_pairs=datasets[dname]['similarity']\n",
    "        alignment=corr(model_pairs, human_pairs)\n",
    "        row.append(alignment)\n",
    "    return row\n",
    "\n",
    "t1 = pd.read_csv('timm_models_raw.csv', header=None)\n",
    "t2 = pd.read_csv('timm_models2_raw.csv', header=None)\n",
    "t1['mean']=t1.mean(axis=1)\n",
    "t2['mean']=t2.mean(axis=1)\n",
    "t_dims=pd.read_csv('timm_model_dims.csv', header=None)\n",
    "t1=t1.sort_values(by=[0])\n",
    "t2=t2.sort_values(by=[0])\n",
    "t_dims=t_dims.sort_values(by=[0])\n",
    "timm_means=np.mean([t1['mean'], t2['mean']], axis=0)\n",
    "timm_names=[t[5:] for t in t1[0].values]\n",
    "timm_dims=t_dims[1].values\n",
    "\n",
    "t3=pd.DataFrame([timm_names, timm_means, timm_dims]).T\n",
    "t3.columns=['model', 'mean', 'dims']\n",
    "t3['mean']=t3['mean'].astype('float64')*100\n",
    "t3['dims']=t3['dims'].astype('float64')\n",
    "\n",
    "t3=t3[t3['mean']>0]\n",
    "\n",
    "rows=[]\n",
    "for model in tqdm(t3.model.values):\n",
    "    row=pairwise_alignment('timm_'+model+'_sim_full2')\n",
    "    rows.append(row)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "014009b8-efa6-47d1-93b7-631607040846",
   "metadata": {},
   "outputs": [],
   "source": [
    "df=pd.DataFrame(rows, columns=dnames).astype('float64')\n",
    "df['spearman']=np.mean(rows,axis=1)\n",
    "df['model']=t3.model.values\n",
    "df.to_csv('timm_models_spearman_align2.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a820687f-d504-4623-89d1-080bc972cd03",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import reduce\n",
    "\n",
    "rf = reduce(lambda  left,right: pd.merge(left,right,on=['model'],\n",
    "                                            how='inner'), [t3,df])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ce9385da-af97-4a56-b6a8-1c5e06735aa8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((0.8516659024213293, 1.0570044221323017e-158),\n",
       " SpearmanrResult(correlation=0.8420415425887383, pvalue=1.034618113263667e-151))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy.stats import pearsonr, spearmanr\n",
    "pearsonr(rf['mean'],rf.spearman), spearmanr(rf['mean'],rf.spearman)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dffccd23-2371-4150-a84d-1bbb7be9f386",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
