{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2f267b5b-d9bb-478b-9350-d92e825c82b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3360/3360 [09:37<00:00,  5.82it/s]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "def pairs2triplets(simmat):\n",
    "    n=simmat.shape[0]\n",
    "    triplets=[]\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            if j==i:\n",
    "                continue\n",
    "            for k in range(j+1,n):\n",
    "                if k==i:\n",
    "                    continue\n",
    "                query=simmat[i,j]<simmat[i,k]\n",
    "                triplets.append(query)\n",
    "    return triplets\n",
    "\n",
    "\n",
    "triplets=pairs2triplets(sims)\n",
    "mats=os.listdir('timm_simmats/')\n",
    "for mat in tqdm(mats):\n",
    "    sims=np.load(f'timm_simmats/{mat}')\n",
    "    triplets=pairs2triplets(sims)\n",
    "    np.save(f'timm_triplets/{mat}',triplets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "f470618f-0985-423c-8337-e78d4d7472a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3360/3360 [09:30<00:00,  5.89it/s]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "def pairs2triplets(simmat):\n",
    "    n=simmat.shape[0]\n",
    "    triplets=[]\n",
    "    for i in range(n):\n",
    "        for j in (range(n)):\n",
    "            if j==i:\n",
    "                continue\n",
    "            for k in range(j+1,n):\n",
    "                if k==i:\n",
    "                    continue\n",
    "                query=simmat[i,j]<simmat[i,k]\n",
    "                triplets.append(query)\n",
    "    return triplets\n",
    "\n",
    "\n",
    "triplets=pairs2triplets(sims)\n",
    "mats=os.listdir('timm_simmats2/')\n",
    "for mat in tqdm(mats):\n",
    "    sims=np.load(f'timm_simmats2/{mat}')\n",
    "    triplets=pairs2triplets(sims)\n",
    "    np.save(f'timm_triplets2/{mat}',triplets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b31922d4-7f47-45ea-81f0-c74c7bf69bc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(\"human_sims.pkl\", \"rb\") as a_file:\n",
    "    datasets = pickle.load(a_file)\n",
    "dnames = datasets.keys()\n",
    "for dname in dnames:\n",
    "    sims=datasets[dname]['similarity']\n",
    "    triplets=pairs2triplets(sims)\n",
    "    np.save(f'human_triplets/{dname}',triplets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "60cae101-5c95-4dc2-9d5c-01d7fc323142",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "842520"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(triplets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "000f4730-448e-4736-b206-29d8678d6fd2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_24559/3349608220.py:18: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  t1['mean']=t1.mean(axis=1)\n",
      "/tmp/ipykernel_24559/3349608220.py:19: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  t2['mean']=t2.mean(axis=1)\n",
      "100%|█████████████████████████████████████████| 560/560 [00:09<00:00, 62.09it/s]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "dnames=['animals', 'automobiles', 'fruits', 'vegetables', 'furniture', 'various']\n",
    "def triplet_alignment(model):\n",
    "    row=[]\n",
    "    for dname in dnames:\n",
    "        model_triplets=np.load(f'timm_triplets/{dname}_{model}.npy')\n",
    "        human_triplets=np.load(f'human_triplets/{dname}.npy')\n",
    "        alignment=np.sum(model_triplets==human_triplets)/len(model_triplets)\n",
    "        row.append(alignment)\n",
    "    return row\n",
    "\n",
    "\n",
    "t1 = pd.read_csv('timm_models_raw.csv', header=None)\n",
    "t2 = pd.read_csv('timm_models2_raw.csv', header=None)\n",
    "t1['mean']=t1.mean(axis=1)\n",
    "t2['mean']=t2.mean(axis=1)\n",
    "t_dims=pd.read_csv('timm_model_dims.csv', header=None)\n",
    "t1=t1.sort_values(by=[0])\n",
    "t2=t2.sort_values(by=[0])\n",
    "t_dims=t_dims.sort_values(by=[0])\n",
    "timm_means=np.mean([t1['mean'], t2['mean']], axis=0)\n",
    "timm_names=[t[5:] for t in t1[0].values]\n",
    "timm_dims=t_dims[1].values\n",
    "\n",
    "t3=pd.DataFrame([timm_names, timm_means, timm_dims]).T\n",
    "t3.columns=['model', 'mean', 'dims']\n",
    "t3['mean']=t3['mean'].astype('float64')*100\n",
    "t3['dims']=t3['dims'].astype('float64')\n",
    "\n",
    "t3=t3[t3['mean']>0]\n",
    "\n",
    "rows=[]\n",
    "for model in tqdm(t3.model.values):\n",
    "    row=triplet_alignment('timm_'+model+'_sim_full')\n",
    "    rows.append(row)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "80f7a90d-80b3-4b9b-a9e2-4285d014f885",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([  1.,   5.,  20.,  83., 106., 156.,  93.,  44.,  41.,  11.]),\n",
       " array([0.5294448 , 0.54468693, 0.55992906, 0.57517119, 0.59041332,\n",
       "        0.60565545, 0.62089759, 0.63613972, 0.65138185, 0.66662398,\n",
       "        0.68186611]),\n",
       " <BarContainer object of 10 artists>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAARLUlEQVR4nO3df4xld13G8fdjlxYKkpbstJZuYYrZIi2BgGPlh5JKJW0sYfuHmEXBFUs2koqAP3ALkfpPk1WIYlQwm1JYQm3dlEo3IEhdgWqirdMfSLdL6UprO3TtDqIgaAoLH/+4pzoMMzsz99zZe/vd9yvZ3Hu+55x7nszsee6Zc+89N1WFJKktPzDuAJKk0bPcJalBlrskNchyl6QGWe6S1KAN4w4AsHHjxpqenh53DEl6XLn99tu/UlVTS82biHKfnp5mdnZ23DEk6XElyb8uN2/F0zJJrklyOMndi8bflOTeJPuT/P6C8SuSHOzmXdQvuiRpGKs5cv8g8CfAhx4bSPJTwBbgeVX1aJLTuvFzga3AecDTgb9Jck5VfWfUwSVJy1vxyL2qbgG+umj4jcDOqnq0W+ZwN74FuL6qHq2q+4GDwPkjzCtJWoVh3y1zDvCTSW5N8tkkP9aNnwk8tGC5uW7s+yTZnmQ2yez8/PyQMSRJSxm23DcApwIvAn4L2JMkQJZYdsmL11TVrqqaqaqZqaklX+yVJA1p2HKfA26sgduA7wIbu/GzFiy3CXi4X0RJ0loNW+4fBV4OkOQc4ETgK8BeYGuSk5KcDWwGbhtBTknSGqz4bpkk1wEXABuTzAFXAtcA13Rvj/wWsK0G1w7en2QPcA9wBLjcd8pI0rGXSbie+8zMTPkhJklamyS3V9XMUvMm4hOq0iSb3vHxsWz3gZ2XjGW7aoMXDpOkBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGrVjuSa5Jcrj7vtTF834zSSXZuGDsiiQHk9yb5KJRB5YkrWw1R+4fBC5ePJjkLOAVwIMLxs4FtgLndeu8N8kJI0kqSVq1Fcu9qm4BvrrErD8E3gYs/IbtLcD1VfVoVd0PHATOH0VQSdLqDXXOPcmrgC9X1ecWzToTeGjB9Fw3ttRjbE8ym2R2fn5+mBiSpGWsudyTnAy8A3jnUrOXGKslxqiqXVU1U1UzU1NTa40hSTqKDUOs88PA2cDnkgBsAu5Icj6DI/WzFiy7CXi4b0hJ0tqs+ci9qj5fVadV1XRVTTMo9BdW1b8Be4GtSU5KcjawGbhtpIklSStazVshrwP+AXh2krkkly23bFXtB/YA9wCfBC6vqu+MKqwkaXVWPC1TVa9ZYf70oumrgKv6xZIk9eEnVCWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBq/mavWuSHE5y94KxdyX5QpJ/TvKXSU5ZMO+KJAeT3JvkonXKLUk6itUcuX8QuHjR2M3Ac6vqecAXgSsAkpwLbAXO69Z5b5ITRpZWkrQqK5Z7Vd0CfHXR2Keq6kg3+Y/Apu7+FuD6qnq0qu4HDgLnjzCvJGkVRnHO/ZeBT3T3zwQeWjBvrhuTJB1Dvco9yTuAI8C1jw0tsVgts+72JLNJZufn5/vEkCQtMnS5J9kGvBL4hap6rMDngLMWLLYJeHip9atqV1XNVNXM1NTUsDEkSUsYqtyTXAz8NvCqqvrvBbP2AluTnJTkbGAzcFv/mJKktdiw0gJJrgMuADYmmQOuZPDumJOAm5MA/GNV/UpV7U+yB7iHwemay6vqO+sVXpK0tBXLvapes8Tw+4+y/FXAVX1CSZL68ROqktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDVrxqpDSJJje8fFxR5AeVzxyl6QGWe6S1CDLXZIaZLlLUoNWLPck1yQ5nOTuBWNPS3Jzkvu621MXzLsiycEk9ya5aL2CS5KWt5oj9w8CFy8a2wHsq6rNwL5umiTnAluB87p13pvkhJGllSStyorlXlW3AF9dNLwF2N3d3w1cumD8+qp6tKruBw4C548mqiRptYY95356VR0C6G5P68bPBB5asNxcN/Z9kmxPMptkdn5+fsgYkqSljPoF1SwxVkstWFW7qmqmqmampqZGHEOSjm/DlvsjSc4A6G4Pd+NzwFkLltsEPDx8PEnSMIYt973Atu7+NuCmBeNbk5yU5GxgM3Bbv4iSpLVa8doySa4DLgA2JpkDrgR2AnuSXAY8CLwaoKr2J9kD3AMcAS6vqu+sU3ZJ0jJWLPeqes0ysy5cZvmrgKv6hJIk9eMnVCWpQZa7JDXIcpekBvllHdKEGucXlDyw85KxbVuj4ZG7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSg3pdFTLJW4E3AAV8Hng9cDLwF8A08ADwc1X1H71SamKM80qFklZv6CP3JGcCvwbMVNVzgROArcAOYF9VbQb2ddOSpGOo72mZDcCTkmxgcMT+MLAF2N3N3w1c2nMbkqQ1Grrcq+rLwLuBB4FDwNeq6lPA6VV1qFvmEHDaKIJKklavz2mZUxkcpZ8NPB14cpLXrmH97Ulmk8zOz88PG0OStIQ+p2V+Gri/quar6tvAjcBLgEeSnAHQ3R5eauWq2lVVM1U1MzU11SOGJGmxPuX+IPCiJCcnCXAhcADYC2zrltkG3NQvoiRprYZ+K2RV3ZrkBuAO4AhwJ7ALeAqwJ8llDJ4AXj2KoJKk1ev1PvequhK4ctHwowyO4iVJY+InVCWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNahXuSc5JckNSb6Q5ECSFyd5WpKbk9zX3Z46qrCSpNXpe+T+R8Anq+pHgOcDB4AdwL6q2gzs66YlScfQ0OWe5KnAy4D3A1TVt6rqP4EtwO5usd3Apf0iSpLWqs+R+7OAeeADSe5McnWSJwOnV9UhgO72tKVWTrI9yWyS2fn5+R4xJEmL9Sn3DcALgfdV1QuAb7KGUzBVtauqZqpqZmpqqkcMSdJifcp9Dpirqlu76RsYlP0jSc4A6G4P94soSVqrocu9qv4NeCjJs7uhC4F7gL3Atm5sG3BTr4SSpDXb0HP9NwHXJjkR+BLwegZPGHuSXAY8CLy65zYkSWvUq9yr6i5gZolZF/Z5XElSP35CVZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAb1/bIOSQ2a3vHxsWz3gZ2XjGW7LfLIXZIaZLlLUoN6l3uSE5LcmeRj3fTTktyc5L7u9tT+MSVJazGKc+5vBg4AT+2mdwD7qmpnkh3d9G+PYDuSGue5/tHpdeSeZBNwCXD1guEtwO7u/m7g0j7bkCStXd/TMu8B3gZ8d8HY6VV1CKC7PW2pFZNsTzKbZHZ+fr5nDEnSQkOXe5JXAoer6vZh1q+qXVU1U1UzU1NTw8aQJC2hzzn3lwKvSvIzwBOBpyb5MPBIkjOq6lCSM4DDowgqSVq9oY/cq+qKqtpUVdPAVuBvq+q1wF5gW7fYNuCm3iklSWuyHu9z3wm8Isl9wCu6aUnSMTSSyw9U1WeAz3T3/x24cBSPK0kajp9QlaQGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAYNXe5Jzkry6SQHkuxP8uZu/GlJbk5yX3d76ujiSpJWo8+R+xHgN6rqOcCLgMuTnAvsAPZV1WZgXzctSTqGhi73qjpUVXd09/8LOACcCWwBdneL7QYu7ZlRkrRGI/mC7CTTwAuAW4HTq+oQDJ4Akpy2zDrbge0Az3jGM0YR47gxvePj444gacL1fkE1yVOAjwBvqaqvr3a9qtpVVTNVNTM1NdU3hiRpgV7lnuQJDIr92qq6sRt+JMkZ3fwzgMP9IkqS1qrPu2UCvB84UFV/sGDWXmBbd38bcNPw8SRJw+hzzv2lwOuAzye5qxt7O7AT2JPkMuBB4NW9EkqS1mzocq+qvweyzOwLh31cSVJ/fkJVkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNWgk15aRpMezcV6v6YGdl6zL43rkLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSg3yfew9+l6mkSeWRuyQ1yHKXpAZZ7pLUoHUr9yQXJ7k3ycEkO9ZrO5Kk77cuL6gmOQH4U+AVwBzwT0n2VtU967E9X9iUpO+1Xkfu5wMHq+pLVfUt4HpgyzptS5K0yHq9FfJM4KEF03PAjy9cIMl2YHs3+Y0k965TlmFtBL4y7hBHMcn5JjkbmK8v8/XzPfnye70e65nLzVivcs8SY/U9E1W7gF3rtP3eksxW1cy4cyxnkvNNcjYwX1/m6+dY5Vuv0zJzwFkLpjcBD6/TtiRJi6xXuf8TsDnJ2UlOBLYCe9dpW5KkRdbltExVHUnyq8BfAycA11TV/vXY1jqa2FNGnUnON8nZwHx9ma+fY5IvVbXyUpKkxxU/oSpJDbLcJalBx125r3RZhCQXJPlakru6f+9cNP+EJHcm+dik5UtySpIbknwhyYEkL56wfG9Nsj/J3UmuS/LEY51vQca7uiyfXcu648qX5Kwkn+5+r/uTvHmS8i2YN9b942j5JmH/WCHfaPePqjpu/jF4cfdfgGcBJwKfA85dtMwFwMeO8hi/Dvz50ZYZVz5gN/CG7v6JwCmTko/BB9vuB57UTe8BfmkM+U4B7gGe0U2fttp1x5zvDOCF3f0fBL44SfkWzB/3/rFsvgnZP5b7/Y58/zjejtx7XRYhySbgEuDqScuX5KnAy4D3A1TVt6rqPyclX2cD8KQkG4CTGf1nH1aT7+eBG6vqQYCqOryGdceWr6oOVdUd3f3/Ag4wKISJyAcTs38smW+C9o9lf36MeP843sp9qcsiLLWDvDjJ55J8Isl5C8bfA7wN+O4E5nsWMA98oPuz+OokT56UfFX1ZeDdwIPAIeBrVfWpMeQ7Bzg1yWeS3J7kF9ew7jjz/Z8k08ALgFsnLN97GP/+sVy+Sdk/lsy3HvvH8VbuK14WAbgDeGZVPR/4Y+CjAEleCRyuqtsnMR+DZ/0XAu+rqhcA3wRGfd64z8/vVAZHMWcDTweenOS1Y8i3AfhRBkeYFwG/k+ScVa7bV598gwdIngJ8BHhLVX19UvJN0P6x3M9vUvaP5X5+I98/jrdyX/GyCFX19ar6Rnf/r4AnJNkIvBR4VZIHGPy59fIkH56gfHPAXFU9djR3A4P/zJOS76eB+6tqvqq+DdwIvORY5+uW+WRVfbOqvgLcAjx/leuOMx9JnsCg2K+tqhtHnK1vvonYP46SbyL2j6PkG/3+McoXFCb9H4NnzS8xeHZ87AWP8xYt80P8/4e7zmfwZ1IWLXMB6/OCUa98wN8Bz+7u/y7wrknJx+CqoPsZnEsMgxe33jSGfM8B9nXLngzcDTx3NeuOOV+ADwHvGfP+sWS+Cdo/ls03IfvHcr/fke8f63VVyIlUy1wWIcmvdPP/DPhZ4I1JjgD/A2yt7rfyOMj3JuDaDK7n8yXg9ROU79YkNzA4bXMEuJMRfwx7Nfmq6kCSTwL/zODc8NVVdTfAUutOSr4kPwG8Dvh8kru6h3x7Df46Gnu+UWVY53xj3z9W+P830v3Dyw9IUoOOt3PuknRcsNwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSg/4XRhaAgMTM6j8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.hist(np.mean(rows,axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "ac113442-9c92-420d-b49a-538af285fccf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2266969/3612553724.py:16: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  t1['mean']=t1.mean(axis=1)\n",
      "/tmp/ipykernel_2266969/3612553724.py:17: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  t2['mean']=t2.mean(axis=1)\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 560/560 [00:04<00:00, 133.92it/s]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "dnames=['animals', 'automobiles', 'fruits', 'vegetables', 'furniture', 'various']\n",
    "def triplet_alignment(model):\n",
    "    row=[]\n",
    "    for dname in dnames:\n",
    "        model_triplets=np.load(f'timm_triplets2/{dname}_{model}.npy')\n",
    "        human_triplets=np.load(f'human_triplets/{dname}.npy')\n",
    "        alignment=np.sum(model_triplets==human_triplets)/len(model_triplets)\n",
    "        row.append(alignment)\n",
    "    return row\n",
    "\n",
    "\n",
    "t1 = pd.read_csv('timm_models_raw.csv', header=None)\n",
    "t2 = pd.read_csv('timm_models2_raw.csv', header=None)\n",
    "t1['mean']=t1.mean(axis=1)\n",
    "t2['mean']=t2.mean(axis=1)\n",
    "t_dims=pd.read_csv('timm_model_dims.csv', header=None)\n",
    "t1=t1.sort_values(by=[0])\n",
    "t2=t2.sort_values(by=[0])\n",
    "t_dims=t_dims.sort_values(by=[0])\n",
    "timm_means=np.mean([t1['mean'], t2['mean']], axis=0)\n",
    "timm_names=[t[5:] for t in t1[0].values]\n",
    "timm_dims=t_dims[1].values\n",
    "\n",
    "t3=pd.DataFrame([timm_names, timm_means, timm_dims]).T\n",
    "t3.columns=['model', 'mean', 'dims']\n",
    "t3['mean']=t3['mean'].astype('float64')*100\n",
    "t3['dims']=t3['dims'].astype('float64')\n",
    "\n",
    "t3=t3[t3['mean']>0]\n",
    "\n",
    "rows=[]\n",
    "for model in tqdm(t3.model.values):\n",
    "    row=triplet_alignment('timm_'+model+'_sim_full2')\n",
    "    rows.append(row)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "014009b8-efa6-47d1-93b7-631607040846",
   "metadata": {},
   "outputs": [],
   "source": [
    "df=pd.DataFrame(rows, columns=dnames).astype('float64')\n",
    "df['tri_align']=np.mean(rows,axis=1)\n",
    "df['model']=t3.model.values\n",
    "df.to_csv('timm_models_tri_align.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "6202f26d-6fdb-4b81-bc78-0e7a13c48d22",
   "metadata": {},
   "outputs": [],
   "source": [
    "df=pd.DataFrame(rows, columns=dnames).astype('float64')\n",
    "df['tri_align']=np.mean(rows,axis=1)\n",
    "df['model']=t3.model.values\n",
    "df.to_csv('timm_models_tri_align2.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "a820687f-d504-4623-89d1-080bc972cd03",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>mean</th>\n",
       "      <th>dims</th>\n",
       "      <th>animals</th>\n",
       "      <th>automobiles</th>\n",
       "      <th>fruits</th>\n",
       "      <th>vegetables</th>\n",
       "      <th>furniture</th>\n",
       "      <th>various</th>\n",
       "      <th>tri_align</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>adv_inception_v3</td>\n",
       "      <td>35.756273</td>\n",
       "      <td>2048.0</td>\n",
       "      <td>0.584104</td>\n",
       "      <td>0.591522</td>\n",
       "      <td>0.593054</td>\n",
       "      <td>0.582707</td>\n",
       "      <td>0.577612</td>\n",
       "      <td>0.575722</td>\n",
       "      <td>0.584120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>bat_resnext26ts</td>\n",
       "      <td>39.886754</td>\n",
       "      <td>2048.0</td>\n",
       "      <td>0.599077</td>\n",
       "      <td>0.587027</td>\n",
       "      <td>0.614438</td>\n",
       "      <td>0.613877</td>\n",
       "      <td>0.577618</td>\n",
       "      <td>0.592325</td>\n",
       "      <td>0.597394</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>beit_base_patch16_224</td>\n",
       "      <td>20.048360</td>\n",
       "      <td>768.0</td>\n",
       "      <td>0.561812</td>\n",
       "      <td>0.561946</td>\n",
       "      <td>0.586062</td>\n",
       "      <td>0.594768</td>\n",
       "      <td>0.539661</td>\n",
       "      <td>0.552402</td>\n",
       "      <td>0.566109</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>beit_base_patch16_224_in22k</td>\n",
       "      <td>16.855483</td>\n",
       "      <td>768.0</td>\n",
       "      <td>0.554215</td>\n",
       "      <td>0.554995</td>\n",
       "      <td>0.581844</td>\n",
       "      <td>0.588717</td>\n",
       "      <td>0.533255</td>\n",
       "      <td>0.542781</td>\n",
       "      <td>0.559301</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>beit_base_patch16_384</td>\n",
       "      <td>14.778900</td>\n",
       "      <td>768.0</td>\n",
       "      <td>0.534540</td>\n",
       "      <td>0.552358</td>\n",
       "      <td>0.574196</td>\n",
       "      <td>0.579596</td>\n",
       "      <td>0.528611</td>\n",
       "      <td>0.539796</td>\n",
       "      <td>0.551516</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>555</th>\n",
       "      <td>xcit_tiny_24_p16_224_dist</td>\n",
       "      <td>56.122795</td>\n",
       "      <td>192.0</td>\n",
       "      <td>0.573288</td>\n",
       "      <td>0.585396</td>\n",
       "      <td>0.645252</td>\n",
       "      <td>0.638077</td>\n",
       "      <td>0.659485</td>\n",
       "      <td>0.581444</td>\n",
       "      <td>0.613824</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>556</th>\n",
       "      <td>xcit_tiny_24_p16_384_dist</td>\n",
       "      <td>57.107617</td>\n",
       "      <td>192.0</td>\n",
       "      <td>0.584019</td>\n",
       "      <td>0.582702</td>\n",
       "      <td>0.657753</td>\n",
       "      <td>0.643385</td>\n",
       "      <td>0.662596</td>\n",
       "      <td>0.575150</td>\n",
       "      <td>0.617601</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>557</th>\n",
       "      <td>xcit_tiny_24_p8_224</td>\n",
       "      <td>55.611786</td>\n",
       "      <td>192.0</td>\n",
       "      <td>0.591451</td>\n",
       "      <td>0.599987</td>\n",
       "      <td>0.630421</td>\n",
       "      <td>0.639280</td>\n",
       "      <td>0.636418</td>\n",
       "      <td>0.577339</td>\n",
       "      <td>0.612483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>558</th>\n",
       "      <td>xcit_tiny_24_p8_224_dist</td>\n",
       "      <td>56.257066</td>\n",
       "      <td>192.0</td>\n",
       "      <td>0.575970</td>\n",
       "      <td>0.595559</td>\n",
       "      <td>0.645185</td>\n",
       "      <td>0.658898</td>\n",
       "      <td>0.655048</td>\n",
       "      <td>0.573138</td>\n",
       "      <td>0.617299</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>559</th>\n",
       "      <td>xcit_tiny_24_p8_384_dist</td>\n",
       "      <td>56.826449</td>\n",
       "      <td>192.0</td>\n",
       "      <td>0.578658</td>\n",
       "      <td>0.596437</td>\n",
       "      <td>0.640198</td>\n",
       "      <td>0.653075</td>\n",
       "      <td>0.661811</td>\n",
       "      <td>0.571672</td>\n",
       "      <td>0.616975</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>560 rows × 10 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                           model       mean    dims   animals  automobiles  \\\n",
       "0               adv_inception_v3  35.756273  2048.0  0.584104     0.591522   \n",
       "1                bat_resnext26ts  39.886754  2048.0  0.599077     0.587027   \n",
       "2          beit_base_patch16_224  20.048360   768.0  0.561812     0.561946   \n",
       "3    beit_base_patch16_224_in22k  16.855483   768.0  0.554215     0.554995   \n",
       "4          beit_base_patch16_384  14.778900   768.0  0.534540     0.552358   \n",
       "..                           ...        ...     ...       ...          ...   \n",
       "555    xcit_tiny_24_p16_224_dist  56.122795   192.0  0.573288     0.585396   \n",
       "556    xcit_tiny_24_p16_384_dist  57.107617   192.0  0.584019     0.582702   \n",
       "557          xcit_tiny_24_p8_224  55.611786   192.0  0.591451     0.599987   \n",
       "558     xcit_tiny_24_p8_224_dist  56.257066   192.0  0.575970     0.595559   \n",
       "559     xcit_tiny_24_p8_384_dist  56.826449   192.0  0.578658     0.596437   \n",
       "\n",
       "       fruits  vegetables  furniture   various  tri_align  \n",
       "0    0.593054    0.582707   0.577612  0.575722   0.584120  \n",
       "1    0.614438    0.613877   0.577618  0.592325   0.597394  \n",
       "2    0.586062    0.594768   0.539661  0.552402   0.566109  \n",
       "3    0.581844    0.588717   0.533255  0.542781   0.559301  \n",
       "4    0.574196    0.579596   0.528611  0.539796   0.551516  \n",
       "..        ...         ...        ...       ...        ...  \n",
       "555  0.645252    0.638077   0.659485  0.581444   0.613824  \n",
       "556  0.657753    0.643385   0.662596  0.575150   0.617601  \n",
       "557  0.630421    0.639280   0.636418  0.577339   0.612483  \n",
       "558  0.645185    0.658898   0.655048  0.573138   0.617299  \n",
       "559  0.640198    0.653075   0.661811  0.571672   0.616975  \n",
       "\n",
       "[560 rows x 10 columns]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from functools import reduce\n",
    "\n",
    "rf = reduce(lambda  left,right: pd.merge(left,right,on=['model'],\n",
    "                                            how='inner'), [t3,df])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "ce9385da-af97-4a56-b6a8-1c5e06735aa8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((0.8364165316311885, 7.70783479868535e-148),\n",
       " SpearmanrResult(correlation=0.8240179109536155, pvalue=8.4129281070516e-140))"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy.stats import pearsonr, spearmanr\n",
    "pearsonr(rf['mean'],rf.tri_align), spearmanr(rf['mean'],rf.tri_align)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
