{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import natsort\n",
    "import os\n",
    "\n",
    "import numpy as np\n",
    "from numpy.core import finfo \n",
    "from numpy.linalg import svd\n",
    "from scipy import stats\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "os.chdir(\"../data/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 1/25 [00:20<08:00, 20.03s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[2], line 6\u001b[0m\n\u001b[1;32m      4\u001b[0m     S \u001b[38;5;241m=\u001b[39m svd(representations, compute_uv\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, hermitian\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m      5\u001b[0m     tol \u001b[38;5;241m=\u001b[39m S\u001b[38;5;241m.\u001b[39mmax(axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdims\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mmax\u001b[39m(representations\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m:]) \u001b[38;5;241m*\u001b[39m finfo(S\u001b[38;5;241m.\u001b[39mdtype)\u001b[38;5;241m.\u001b[39meps\n\u001b[0;32m----> 6\u001b[0m     ranks_sst2_test\u001b[38;5;241m.\u001b[39mappend(\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinalg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmatrix_rank\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrepresentations\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m      8\u001b[0m ranks_mrpc_test \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m seed \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m25\u001b[39m)):\n",
      "File \u001b[0;32m~/Documents/encoder-equivalence/venv/lib/python3.10/site-packages/numpy/linalg/linalg.py:1922\u001b[0m, in \u001b[0;36mmatrix_rank\u001b[0;34m(A, tol, hermitian)\u001b[0m\n\u001b[1;32m   1920\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m A\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m   1921\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mint\u001b[39m(\u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mall\u001b[39m(A\u001b[38;5;241m==\u001b[39m\u001b[38;5;241m0\u001b[39m))\n\u001b[0;32m-> 1922\u001b[0m S \u001b[38;5;241m=\u001b[39m \u001b[43msvd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mA\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompute_uv\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhermitian\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhermitian\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1923\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tol \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   1924\u001b[0m     tol \u001b[38;5;241m=\u001b[39m S\u001b[38;5;241m.\u001b[39mmax(axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdims\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mmax\u001b[39m(A\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m:]) \u001b[38;5;241m*\u001b[39m finfo(S\u001b[38;5;241m.\u001b[39mdtype)\u001b[38;5;241m.\u001b[39meps\n",
      "File \u001b[0;32m~/Documents/encoder-equivalence/venv/lib/python3.10/site-packages/numpy/linalg/linalg.py:1693\u001b[0m, in \u001b[0;36msvd\u001b[0;34m(a, full_matrices, compute_uv, hermitian)\u001b[0m\n\u001b[1;32m   1690\u001b[0m     gufunc \u001b[38;5;241m=\u001b[39m _umath_linalg\u001b[38;5;241m.\u001b[39msvd_n\n\u001b[1;32m   1692\u001b[0m signature \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mD->d\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m isComplexType(t) \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124md->d\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m-> 1693\u001b[0m s \u001b[38;5;241m=\u001b[39m \u001b[43mgufunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msignature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msignature\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mextobj\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1694\u001b[0m s \u001b[38;5;241m=\u001b[39m s\u001b[38;5;241m.\u001b[39mastype(_realType(result_t), copy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m   1695\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m s\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "ranks_sst2_test = []\n",
    "for seed in tqdm(range(25)):\n",
    "    representations = torch.load(f\"representations-large/seed-{seed}-task-sst2-train\", map_location=\"cpu\").numpy()[-1]\n",
    "    S = svd(representations, compute_uv=False, hermitian=False)\n",
    "    tol = S.max(axis=-1, keepdims=True) * max(representations.shape[-2:]) * finfo(S.dtype).eps\n",
    "    ranks_sst2_test.append(np.linalg.matrix_rank(representations))\n",
    "\n",
    "ranks_mrpc_test = []\n",
    "for seed in tqdm(range(25)):\n",
    "    representations = torch.load(f\"representations-large/seed-{seed}-task-mrpc-train\", map_location=\"cpu\").numpy()[-1]\n",
    "    S = svd(representations, compute_uv=False, hermitian=False)\n",
    "    tol = S.max(axis=-1, keepdims=True) * max(representations.shape[-2:]) * finfo(S.dtype).eps\n",
    "    ranks_mrpc_test.append(np.linalg.matrix_rank(representations, tol=tol))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SignificanceResult(statistic=0.3210720298035125, pvalue=0.02487741662494852)\n",
      "SignificanceResult(statistic=0.42708737289872706, pvalue=0.033227853542484094)\n",
      "SignificanceResult(statistic=0.41999999999999993, pvalue=0.002838809542227846)\n",
      "SignificanceResult(statistic=0.5599999999999999, pvalue=0.0036013919598933035)\n"
     ]
    }
   ],
   "source": [
    "def get_distance_matrix(task: str) -> np.array:\n",
    "    if task == \"MRPC\":\n",
    "        distances_all = []\n",
    "        for fname in natsort.natsorted(glob.glob('symmetry-data/v2-metrics-lr-mrpc-seedx-*.npy')):\n",
    "            distances_all.append(np.load(fname).reshape(24, 13, 3))\n",
    "        distance_matrix = np.array(distances_all)\n",
    "        return distance_matrix\n",
    "    elif task == \"SST-2\":\n",
    "        sst2_symmetry = np.load('symmetry-data/metrics-lr-sst2-full.npy')[::24] ## debug loading bug\n",
    "        return sst2_symmetry.reshape(25, 24, 13, 3)\n",
    "    else:\n",
    "        raise ValueError(\"Task not supported\")\n",
    "\n",
    "distance_matrix = get_distance_matrix(\"SST-2\") \n",
    "matrix_new = []\n",
    "a = distance_matrix[:, :, :,-1][:, :, -1]\n",
    "for ct, i in enumerate(a):\n",
    "    matrix_new.append(np.insert(i, ct, 0))\n",
    "\n",
    "matrix_new = np.array(matrix_new)\n",
    "print(stats.kendalltau(matrix_new.sum(axis=0), ranks_sst2_test))\n",
    "print(stats.spearmanr(matrix_new.sum(axis=0), ranks_sst2_test))\n",
    "\n",
    "distance_matrix = get_distance_matrix(\"MRPC\")\n",
    "matrix_new = []\n",
    "a = distance_matrix[:, :, :,-1][:, :, -1]\n",
    "for ct, i in enumerate(a):\n",
    "    matrix_new.append(np.insert(i, ct, 0))\n",
    "\n",
    "matrix_new = np.array(matrix_new)\n",
    "print(stats.kendalltau(matrix_new.sum(axis=0), ranks_mrpc_test))\n",
    "print(stats.spearmanr(matrix_new.sum(axis=0), ranks_mrpc_test))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
