{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b138d14b-c5da-4c45-ade9-18a0714f7cae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_dci_on_fixed_data(observations, labels, representation_function,\n",
    "                              train_percentage=gin.REQUIRED, batch_size=100):\n",
    "  \"\"\"Computes the DCI scores on the fixed set of observations and labels.\n",
    "\n",
    "  Args:\n",
    "    observations: Observations on which to compute the score. Observations have\n",
    "      shape (num_observations, 64, 64, num_channels).\n",
    "    labels: Observed factors of variations.\n",
    "    representation_function: Function that takes observations as input and\n",
    "      outputs a dim_representation sized representation for each observation.\n",
    "    train_percentage: Percentage of observations used for training.\n",
    "    batch_size: Batch size used to compute the representation.\n",
    "\n",
    "  Returns:\n",
    "    DCI score.\n",
    "  \"\"\"\n",
    "  mus = utils.obtain_representation(observations, representation_function,\n",
    "                                    batch_size)\n",
    "  assert labels.shape[1] == observations.shape[0], \"Wrong labels shape.\"\n",
    "  assert mus.shape[1] == observations.shape[0], \"Wrong representation shape.\"\n",
    "  mus_train, mus_test = utils.split_train_test(\n",
    "      mus,\n",
    "      train_percentage)\n",
    "  ys_train, ys_test = utils.split_train_test(\n",
    "      labels,\n",
    "      train_percentage)\n",
    "  return _compute_dci(mus_train, ys_train, mus_test, ys_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80b1633d-a654-4ffa-9a3e-0a674eca6b57",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _compute_dci(mus_train, ys_train, mus_test, ys_test):\n",
    "  \"\"\"Computes score based on both training and testing codes and factors.\"\"\"\n",
    "  scores = {}\n",
    "  importance_matrix, train_err, test_err = compute_importance_gbt(\n",
    "      mus_train, ys_train, mus_test, ys_test)\n",
    "  assert importance_matrix.shape[0] == mus_train.shape[0]\n",
    "  assert importance_matrix.shape[1] == ys_train.shape[0]\n",
    "  scores[\"informativeness_train\"] = train_err\n",
    "  scores[\"informativeness_test\"] = test_err\n",
    "  scores[\"disentanglement\"] = disentanglement(importance_matrix)\n",
    "  scores[\"completeness\"] = completeness(importance_matrix)\n",
    "  return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "358bed98-9010-4bdd-9ced-02bb3e0444c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import ensemble\n",
    "import numpy as np\n",
    "import torch\n",
    "import scipy\n",
    "\n",
    "data_path = \"../results/reps_CelebA_erm_0.9.pth\"\n",
    "\n",
    "reps = torch.load(data_path)\n",
    "\n",
    "def disentanglement_per_code(importance_matrix):\n",
    "  \"\"\"Compute disentanglement score of each code.\"\"\"\n",
    "  # importance_matrix is of shape [num_codes, num_factors].\n",
    "  return 1. - scipy.stats.entropy(importance_matrix.T + 1e-11,\n",
    "                                  base=importance_matrix.shape[1])\n",
    "\n",
    "def disentanglement(importance_matrix):\n",
    "  \"\"\"Compute the disentanglement score of the representation.\"\"\"\n",
    "  per_code = disentanglement_per_code(importance_matrix)\n",
    "  if importance_matrix.sum() == 0.:\n",
    "    importance_matrix = np.ones_like(importance_matrix)\n",
    "  code_importance = importance_matrix.sum(axis=1) / importance_matrix.sum()\n",
    "\n",
    "  return np.sum(per_code*code_importance)\n",
    "\n",
    "def compute_importance_gbt(x_train, y_train, x_test, y_test):\n",
    "  \"\"\"Compute importance based on gradient boosted trees.\"\"\"\n",
    "  print(x_train.shape,y_train.shape)\n",
    "  models = []\n",
    "  num_factors = y_train.shape[0]\n",
    "  num_codes = x_train.T.shape[0]\n",
    "  importance_matrix = np.zeros(shape=[num_codes, num_factors],\n",
    "                               dtype=np.float64)\n",
    "\n",
    "  train_loss = []\n",
    "  test_loss = []\n",
    "  for i in range(num_factors):\n",
    "    print(f\"Training for factor {i+1}\")\n",
    "    model = ensemble.GradientBoostingClassifier(verbose=1)\n",
    "    model.fit(x_train, y_train[i, :])\n",
    "    models.append(model)\n",
    "    importance_matrix[:, i] = np.abs(model.feature_importances_)\n",
    "    train_loss.append(np.mean(model.predict(x_train) == y_train[i, :]))\n",
    "    test_loss.append(np.mean(model.predict(x_test) == y_test[i, :]))\n",
    "  return importance_matrix, np.mean(train_loss), np.mean(test_loss), models\n",
    "\n",
    "def completeness_per_factor(importance_matrix):\n",
    "  \"\"\"Compute completeness of each factor.\"\"\"\n",
    "  # importance_matrix is of shape [num_codes, num_factors].\n",
    "  return 1. - scipy.stats.entropy(importance_matrix + 1e-11,\n",
    "                                  base=importance_matrix.shape[0])\n",
    "\n",
    "\n",
    "def completeness(importance_matrix):\n",
    "  \"\"\"\"Compute completeness of the representation.\"\"\"\n",
    "  per_factor = completeness_per_factor(importance_matrix)\n",
    "  if importance_matrix.sum() == 0.:\n",
    "    importance_matrix = np.ones_like(importance_matrix)\n",
    "  factor_importance = importance_matrix.sum(axis=0) / importance_matrix.sum()\n",
    "  return np.sum(per_factor*factor_importance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "f80c8797-6a3f-4ba0-bb80-87a423f94732",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = models[1].predict(reps['test']['x'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "cf29ea13-a24b-406f-92fe-c0e0d37e319d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9225851915236322"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(preds == np.array(reps['test']['g']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "fe3b7ce7-6aea-46cc-9c48-01904825c2a8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10000])"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reps['train']['y'][idxs].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "2c959081-7c10-48fa-87d5-9c60e35a235f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10000,)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "36e8fe49-dbeb-4059-afe0-37b2c1155ce5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.978817919476531"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "disentanglement(matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "c64ed7aa-9803-4b5c-a78b-f11aa470989d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 19867])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "id": "980c8ff5-b05e-43f2-8273-c0e68fe907cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../results/reps_CUB_frz1_0.0.pth\n",
      "torch.Size([4795, 2048]) (2, 4795)\n",
      "Training for factor 1\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           0.9880            6.12m\n",
      "         2           0.9142            6.03m\n",
      "         3           0.8543            5.95m\n",
      "         4           0.8013            5.90m\n",
      "         5           0.7573            5.83m\n",
      "         6           0.7185            5.76m\n",
      "         7           0.6864            5.70m\n",
      "         8           0.6569            5.63m\n",
      "         9           0.6297            5.57m\n",
      "        10           0.6058            5.51m\n",
      "        20           0.4651            4.91m\n",
      "        30           0.3926            4.32m\n",
      "        40           0.3315            3.70m\n",
      "        50           0.2879            3.09m\n",
      "        60           0.2585            2.47m\n",
      "        70           0.2354            1.86m\n",
      "        80           0.2116            1.24m\n",
      "        90           0.1948           37.24s\n",
      "       100           0.1791            0.00s\n",
      "Training for factor 2\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.2983            6.03m\n",
      "         2           1.2245            5.93m\n",
      "         3           1.1603            5.86m\n",
      "         4           1.1080            5.82m\n",
      "         5           1.0620            5.75m\n",
      "         6           1.0186            5.68m\n",
      "         7           0.9810            5.64m\n",
      "         8           0.9491            5.57m\n",
      "         9           0.9220            5.51m\n",
      "        10           0.8942            5.45m\n",
      "        20           0.7001            4.83m\n",
      "        30           0.5951            4.24m\n",
      "        40           0.5265            3.63m\n",
      "        50           0.4786            3.03m\n",
      "        60           0.4389            2.43m\n",
      "        70           0.4087            1.83m\n",
      "        80           0.3809            1.22m\n",
      "        90           0.3582           36.60s\n",
      "       100           0.3422            0.00s\n",
      "../results/reps_CUB_frz2_0.0.pth\n",
      "torch.Size([4795, 2048]) (2, 4795)\n",
      "Training for factor 1\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           0.9887            6.34m\n",
      "         2           0.9195            6.28m\n",
      "         3           0.8633            6.23m\n",
      "         4           0.8153            6.16m\n",
      "         5           0.7760            6.11m\n",
      "         6           0.7411            6.05m\n",
      "         7           0.7112            5.99m\n",
      "         8           0.6850            5.92m\n",
      "         9           0.6615            5.86m\n",
      "        10           0.6413            5.80m\n",
      "        20           0.5135            5.15m\n",
      "        30           0.4465            4.52m\n",
      "        40           0.3990            3.87m\n",
      "        50           0.3618            3.23m\n",
      "        60           0.3290            2.59m\n",
      "        70           0.3054            1.95m\n",
      "        80           0.2818            1.30m\n",
      "        90           0.2621           39.03s\n",
      "       100           0.2486            0.00s\n",
      "Training for factor 2\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.2710            6.36m\n",
      "         2           1.1754            6.29m\n",
      "         3           1.0942            6.23m\n",
      "         4           1.0205            6.17m\n",
      "         5           0.9575            6.10m\n",
      "         6           0.9001            6.03m\n",
      "         7           0.8483            5.97m\n",
      "         8           0.8058            5.91m\n",
      "         9           0.7654            5.85m\n",
      "        10           0.7303            5.79m\n",
      "        20           0.5068            5.16m\n",
      "        30           0.4012            4.50m\n",
      "        40           0.3375            3.86m\n",
      "        50           0.2948            3.22m\n",
      "        60           0.2635            2.57m\n",
      "        70           0.2397            1.93m\n",
      "        80           0.2210            1.29m\n",
      "        90           0.2063           38.81s\n",
      "       100           0.1949            0.00s\n",
      "../results/reps_CUB_frz1_0.25.pth\n",
      "torch.Size([4795, 2048]) (2, 4795)\n",
      "Training for factor 1\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.0090            6.48m\n",
      "         2           0.9369            6.35m\n",
      "         3           0.8793            6.29m\n",
      "         4           0.8305            6.22m\n",
      "         5           0.7891            6.14m\n",
      "         6           0.7556            6.07m\n",
      "         7           0.7248            6.00m\n",
      "         8           0.6977            5.93m\n",
      "         9           0.6741            5.85m\n",
      "        10           0.6525            5.79m\n",
      "        20           0.5203            5.16m\n",
      "        30           0.4497            4.53m\n",
      "        40           0.3978            3.88m\n",
      "        50           0.3615            3.25m\n",
      "        60           0.3300            2.61m\n",
      "        70           0.2992            1.96m\n",
      "        80           0.2825            1.31m\n",
      "        90           0.2635           39.32s\n",
      "       100           0.2499            0.00s\n",
      "Training for factor 2\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.2596            6.37m\n",
      "         2           1.1698            6.34m\n",
      "         3           1.0919            6.27m\n",
      "         4           1.0260            6.19m\n",
      "         5           0.9669            6.11m\n",
      "         6           0.9165            6.05m\n",
      "         7           0.8707            5.98m\n",
      "         8           0.8282            5.92m\n",
      "         9           0.7928            5.85m\n",
      "        10           0.7579            5.77m\n",
      "        20           0.5441            5.14m\n",
      "        30           0.4396            4.49m\n",
      "        40           0.3729            3.85m\n",
      "        50           0.3269            3.21m\n",
      "        60           0.2929            2.57m\n",
      "        70           0.2675            1.93m\n",
      "        80           0.2479            1.29m\n",
      "        90           0.2333           38.75s\n",
      "       100           0.2198            0.00s\n",
      "../results/reps_CUB_frz2_0.25.pth\n",
      "torch.Size([4795, 2048]) (2, 4795)\n",
      "Training for factor 1\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.0079            6.58m\n",
      "         2           0.9378            6.41m\n",
      "         3           0.8791            6.30m\n",
      "         4           0.8302            6.22m\n",
      "         5           0.7905            6.15m\n",
      "         6           0.7548            6.09m\n",
      "         7           0.7240            6.02m\n",
      "         8           0.6962            5.95m\n",
      "         9           0.6722            5.89m\n",
      "        10           0.6502            5.84m\n",
      "        20           0.5181            5.17m\n",
      "        30           0.4470            4.52m\n",
      "        40           0.3963            3.88m\n",
      "        50           0.3562            3.24m\n",
      "        60           0.3284            2.59m\n",
      "        70           0.3015            1.95m\n",
      "        80           0.2788            1.30m\n",
      "        90           0.2564           39.02s\n",
      "       100           0.2420            0.00s\n",
      "Training for factor 2\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.2596            6.44m\n",
      "         2           1.1681            6.33m\n",
      "         3           1.0888            6.26m\n",
      "         4           1.0221            6.20m\n",
      "         5           0.9638            6.12m\n",
      "         6           0.9126            6.09m\n",
      "         7           0.8652            6.02m\n",
      "         8           0.8250            5.98m\n",
      "         9           0.7861            5.91m\n",
      "        10           0.7519            5.85m\n",
      "        20           0.5360            5.17m\n",
      "        30           0.4319            4.52m\n",
      "        40           0.3654            3.87m\n",
      "        50           0.3218            3.22m\n",
      "        60           0.2891            2.58m\n",
      "        70           0.2645            1.93m\n",
      "        80           0.2433            1.29m\n",
      "        90           0.2287           38.86s\n",
      "       100           0.2153            0.00s\n",
      "../results/reps_CUB_frz1_0.5.pth\n",
      "torch.Size([4795, 2048]) (2, 4795)\n",
      "Training for factor 1\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           0.9777            6.39m\n",
      "         2           0.9095            6.33m\n",
      "         3           0.8539            6.27m\n",
      "         4           0.8075            6.19m\n",
      "         5           0.7682            6.20m\n",
      "         6           0.7347            6.13m\n",
      "         7           0.7064            6.06m\n",
      "         8           0.6811            5.98m\n",
      "         9           0.6577            5.93m\n",
      "        10           0.6362            5.86m\n",
      "        20           0.5106            5.18m\n",
      "        30           0.4407            4.53m\n",
      "        40           0.3900            3.89m\n",
      "        50           0.3456            3.24m\n",
      "        60           0.3157            2.60m\n",
      "        70           0.2877            1.95m\n",
      "        80           0.2687            1.30m\n",
      "        90           0.2513           39.09s\n",
      "       100           0.2332            0.00s\n",
      "Training for factor 2\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.2061            6.39m\n",
      "         2           1.1220            6.30m\n",
      "         3           1.0508            6.23m\n",
      "         4           0.9925            6.17m\n",
      "         5           0.9387            6.11m\n",
      "         6           0.8925            6.05m\n",
      "         7           0.8509            5.99m\n",
      "         8           0.8158            5.95m\n",
      "         9           0.7831            5.89m\n",
      "        10           0.7526            5.83m\n",
      "        20           0.5616            5.16m\n",
      "        30           0.4651            4.51m\n",
      "        40           0.4013            3.86m\n",
      "        50           0.3535            3.22m\n",
      "        60           0.3175            2.57m\n",
      "        70           0.2881            1.93m\n",
      "        80           0.2658            1.29m\n",
      "        90           0.2462           38.76s\n",
      "       100           0.2301            0.00s\n",
      "../results/reps_CUB_frz2_0.5.pth\n",
      "torch.Size([4795, 2048]) (2, 4795)\n",
      "Training for factor 1\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           0.9787            6.39m\n",
      "         2           0.9091            6.30m\n",
      "         3           0.8540            6.23m\n",
      "         4           0.8073            6.16m\n",
      "         5           0.7681            6.10m\n",
      "         6           0.7333            6.05m\n",
      "         7           0.7037            6.02m\n",
      "         8           0.6773            5.96m\n",
      "         9           0.6548            5.89m\n",
      "        10           0.6338            5.82m\n",
      "        20           0.5078            5.15m\n",
      "        30           0.4383            4.51m\n",
      "        40           0.3852            3.88m\n",
      "        50           0.3487            3.24m\n",
      "        60           0.3202            2.59m\n",
      "        70           0.2934            1.95m\n",
      "        80           0.2718            1.30m\n",
      "        90           0.2552           39.05s\n",
      "       100           0.2413            0.00s\n",
      "Training for factor 2\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.2057            6.35m\n",
      "         2           1.1218            6.30m\n",
      "         3           1.0513            6.24m\n",
      "         4           0.9893            6.25m\n",
      "         5           0.9360            6.17m\n",
      "         6           0.8887            6.09m\n",
      "         7           0.8472            6.02m\n",
      "         8           0.8105            5.94m\n",
      "         9           0.7788            5.87m\n",
      "        10           0.7488            5.80m\n",
      "        20           0.5620            5.15m\n",
      "        30           0.4640            4.50m\n",
      "        40           0.4013            3.85m\n",
      "        50           0.3553            3.21m\n",
      "        60           0.3180            2.57m\n",
      "        70           0.2888            1.93m\n",
      "        80           0.2648            1.29m\n",
      "        90           0.2453           38.71s\n",
      "       100           0.2300            0.00s\n",
      "../results/reps_CUB_frz1_0.75.pth\n",
      "torch.Size([4795, 2048]) (2, 4795)\n",
      "Training for factor 1\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           0.9998            6.35m\n",
      "         2           0.9276            6.30m\n",
      "         3           0.8711            6.25m\n",
      "         4           0.8229            6.19m\n",
      "         5           0.7820            6.13m\n",
      "         6           0.7476            6.06m\n",
      "         7           0.7175            6.03m\n",
      "         8           0.6908            5.98m\n",
      "         9           0.6673            5.91m\n",
      "        10           0.6459            5.83m\n",
      "        20           0.5132            5.17m\n",
      "        30           0.4409            4.52m\n",
      "        40           0.3887            3.88m\n",
      "        50           0.3487            3.23m\n",
      "        60           0.3134            2.59m\n",
      "        70           0.2886            1.94m\n",
      "        80           0.2658            1.30m\n",
      "        90           0.2467           38.94s\n",
      "       100           0.2299            0.00s\n",
      "Training for factor 2\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.1350            6.46m\n",
      "         2           1.0650            6.33m\n",
      "         3           1.0057            6.30m\n",
      "         4           0.9560            6.22m\n",
      "         5           0.9119            6.24m\n",
      "         6           0.8742            6.18m\n",
      "         7           0.8412            6.12m\n",
      "         8           0.8107            6.03m\n",
      "         9           0.7825            5.97m\n",
      "        10           0.7575            5.90m\n",
      "        20           0.5940            5.21m\n",
      "        30           0.5087            4.55m\n",
      "        40           0.4511            3.90m\n",
      "        50           0.4063            3.25m\n",
      "        60           0.3735            2.60m\n",
      "        70           0.3466            1.95m\n",
      "        80           0.3266            1.30m\n",
      "        90           0.3053           39.06s\n",
      "       100           0.2877            0.00s\n",
      "../results/reps_CUB_frz2_0.75.pth\n",
      "torch.Size([4795, 2048]) (2, 4795)\n",
      "Training for factor 1\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           0.9989            6.46m\n",
      "         2           0.9285            6.35m\n",
      "         3           0.8713            6.30m\n",
      "         4           0.8235            6.20m\n",
      "         5           0.7840            6.12m\n",
      "         6           0.7489            6.04m\n",
      "         7           0.7183            5.96m\n",
      "         8           0.6912            5.89m\n",
      "         9           0.6672            5.82m\n",
      "        10           0.6464            5.76m\n",
      "        20           0.5181            5.10m\n",
      "        30           0.4396            4.45m\n",
      "        40           0.3883            3.84m\n",
      "        50           0.3466            3.20m\n",
      "        60           0.3126            2.57m\n",
      "        70           0.2851            1.93m\n",
      "        80           0.2597            1.29m\n",
      "        90           0.2413           38.67s\n",
      "       100           0.2223            0.00s\n",
      "Training for factor 2\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.1370            6.29m\n",
      "         2           1.0688            6.23m\n",
      "         3           1.0125            6.18m\n",
      "         4           0.9635            6.30m\n",
      "         5           0.9224            6.27m\n",
      "         6           0.8830            6.30m\n",
      "         7           0.8520            6.26m\n",
      "         8           0.8221            6.17m\n",
      "         9           0.7962            6.07m\n",
      "        10           0.7721            5.98m\n",
      "        20           0.6167            5.21m\n",
      "        30           0.5304            4.54m\n",
      "        40           0.4747            3.89m\n",
      "        50           0.4325            3.23m\n",
      "        60           0.4007            2.59m\n",
      "        70           0.3715            1.94m\n",
      "        80           0.3473            1.29m\n",
      "        90           0.3261           38.82s\n",
      "       100           0.3082            0.00s\n",
      "../results/reps_CUB_frz1_0.9.pth\n",
      "torch.Size([4795, 2048]) (2, 4795)\n",
      "Training for factor 1\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           0.9992            6.35m\n",
      "         2           0.9330            6.28m\n",
      "         3           0.8800            6.20m\n",
      "         4           0.8349            6.14m\n",
      "         5           0.7963            6.11m\n",
      "         6           0.7617            6.05m\n",
      "         7           0.7309            5.98m\n",
      "         8           0.7051            5.91m\n",
      "         9           0.6805            5.84m\n",
      "        10           0.6587            5.78m\n",
      "        20           0.5248            5.13m\n",
      "        30           0.4456            4.49m\n",
      "        40           0.3935            3.85m\n",
      "        50           0.3529            3.21m\n",
      "        60           0.3181            2.57m\n",
      "        70           0.2912            1.94m\n",
      "        80           0.2643            1.29m\n",
      "        90           0.2445           38.70s\n",
      "       100           0.2274            0.00s\n",
      "Training for factor 2\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.0593            6.36m\n",
      "         2           0.9945            6.27m\n",
      "         3           0.9393            6.21m\n",
      "         4           0.8933            6.31m\n",
      "         5           0.8545            6.26m\n",
      "         6           0.8215            6.16m\n",
      "         7           0.7925            6.07m\n",
      "         8           0.7636            5.99m\n",
      "         9           0.7404            5.91m\n",
      "        10           0.7186            5.83m\n",
      "        20           0.5775            5.14m\n",
      "        30           0.5007            4.50m\n",
      "        40           0.4467            3.85m\n",
      "        50           0.4014            3.21m\n",
      "        60           0.3676            2.57m\n",
      "        70           0.3375            1.93m\n",
      "        80           0.3142            1.29m\n",
      "        90           0.2902           38.68s\n",
      "       100           0.2742            0.00s\n",
      "../results/reps_CUB_frz2_0.9.pth\n",
      "torch.Size([4795, 2048]) (2, 4795)\n",
      "Training for factor 1\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           0.9984            6.32m\n",
      "         2           0.9334            6.28m\n",
      "         3           0.8795            6.22m\n",
      "         4           0.8340            6.15m\n",
      "         5           0.7937            6.09m\n",
      "         6           0.7586            6.03m\n",
      "         7           0.7283            5.96m\n",
      "         8           0.7015            5.90m\n",
      "         9           0.6765            5.83m\n",
      "        10           0.6538            5.77m\n",
      "        20           0.5162            5.15m\n",
      "        30           0.4397            4.50m\n",
      "        40           0.3883            3.86m\n",
      "        50           0.3476            3.22m\n",
      "        60           0.3131            2.58m\n",
      "        70           0.2864            1.94m\n",
      "        80           0.2629            1.30m\n",
      "        90           0.2423           39.02s\n",
      "       100           0.2245            0.00s\n",
      "Training for factor 2\n",
      "      Iter       Train Loss   Remaining Time \n",
      "         1           1.0602            6.32m\n",
      "         2           0.9929            6.26m\n",
      "         3           0.9371            6.20m\n",
      "         4           0.8908            6.14m\n",
      "         5           0.8501            6.09m\n",
      "         6           0.8161            6.03m\n",
      "         7           0.7852            5.97m\n",
      "         8           0.7586            5.90m\n",
      "         9           0.7351            5.84m\n",
      "        10           0.7129            5.77m\n",
      "        20           0.5674            5.15m\n",
      "        30           0.4902            4.50m\n",
      "        40           0.4375            3.86m\n",
      "        50           0.3933            3.22m\n",
      "        60           0.3569            2.57m\n",
      "        70           0.3287            1.93m\n",
      "        80           0.3050            1.29m\n",
      "        90           0.2844           38.72s\n",
      "       100           0.2654            0.00s\n"
     ]
    }
   ],
   "source": [
    "from  os.path import join\n",
    "root = \"../results\" \n",
    "n = { \"CelebA\": 10000\n",
    "         ,'MNISTCIFAR': 10000\n",
    "         , \"CUB\": 4795}\n",
    "dataset = \"CelebA\"\n",
    "\n",
    "spurs = { \"CelebA\": [0.9]\n",
    "         ,'MNISTCIFAR': [0.0,0.25,0.5,0.75,0.9]\n",
    "         , \"CUB\": [0.0,0.25,0.5,0.75,0.9]}\n",
    "#method = \"erm\"\n",
    "spur = \"0.9\"\n",
    "for dataset in ['CUB']:#['MNISTCIFAR','CUB',\"CelebA\"]:\n",
    "    for spur in spurs[dataset]:\n",
    "        for method in [\"frz1\",\"frz2\"]:#[\"erm\",\"gdro\",\"rw\"]:\n",
    "            data_path = join(root,f\"reps_{dataset}_{method}_{spur}.pth\")\n",
    "            print(data_path)\n",
    "            reps = torch.load(data_path)\n",
    "            filename = join(root,f\"imps_{method}_{dataset}_{spur}\")\n",
    "            # Randomly sample n indices without replacement\n",
    "            #print(len(reps['train']['x']),len(reps['train']['y']),len(reps['train']['g']))\n",
    "            idxs = np.random.choice(len(reps['train']['x']), size=n[dataset], replace=False)\n",
    "            y_train = torch.cat([reps['train']['y'][idxs].unsqueeze(dim=0),reps['train']['g'][idxs].unsqueeze(dim=0)],dim=0)\n",
    "            y_test = torch.cat([reps['test']['y'].unsqueeze(dim=0),reps['test']['g'].unsqueeze(dim=0)],dim=0)\n",
    "            y_train = np.array(y_train)\n",
    "            y_test = np.array(y_test)\n",
    "            matrix, train_loss, test_loss, models = compute_importance_gbt(reps['train']['x'][idxs],y_train,reps['test']['x'],y_test)\n",
    "            dis = disentanglement(matrix)\n",
    "            comp = completeness(matrix)\n",
    "            np.savez(f\"{filename}.npz\", \n",
    "                     disentanglement = dis,\n",
    "                     completeness = comp,\n",
    "                     matrix=matrix, \n",
    "                     train_loss=train_loss,\n",
    "                     test_loss=test_loss, \n",
    "                     models=models)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcd35d72-c5be-44fe-913e-5db354982ccf",
   "metadata": {},
   "source": [
    "## DCI Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 197,
   "id": "a0e26504-2b73-4462-972c-b7cd71ee9e7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../results/imps_erm_MNISTCIFAR_0.0.npz\n",
      "../results/imps_rw_MNISTCIFAR_0.0.npz\n",
      "../results/imps_gdro_MNISTCIFAR_0.0.npz\n",
      "../results/imps_unfreeze_MNISTCIFAR_0.0.npz\n",
      "../results/imps_frz1_MNISTCIFAR_0.0.npz\n",
      "../results/imps_frz2_MNISTCIFAR_0.0.npz\n",
      "../results/imps_frz3_MNISTCIFAR_0.0.npz\n",
      "../results/imps_frz4_MNISTCIFAR_0.0.npz\n",
      "../results/imps_erm_MNISTCIFAR_0.25.npz\n",
      "../results/imps_rw_MNISTCIFAR_0.25.npz\n",
      "../results/imps_gdro_MNISTCIFAR_0.25.npz\n",
      "../results/imps_unfreeze_MNISTCIFAR_0.25.npz\n",
      "../results/imps_frz1_MNISTCIFAR_0.25.npz\n",
      "../results/imps_frz2_MNISTCIFAR_0.25.npz\n",
      "../results/imps_frz3_MNISTCIFAR_0.25.npz\n",
      "../results/imps_frz4_MNISTCIFAR_0.25.npz\n",
      "../results/imps_erm_MNISTCIFAR_0.5.npz\n",
      "../results/imps_rw_MNISTCIFAR_0.5.npz\n",
      "../results/imps_gdro_MNISTCIFAR_0.5.npz\n",
      "../results/imps_unfreeze_MNISTCIFAR_0.5.npz\n",
      "../results/imps_frz1_MNISTCIFAR_0.5.npz\n",
      "../results/imps_frz2_MNISTCIFAR_0.5.npz\n",
      "../results/imps_frz3_MNISTCIFAR_0.5.npz\n",
      "../results/imps_frz4_MNISTCIFAR_0.5.npz\n",
      "../results/imps_erm_MNISTCIFAR_0.75.npz\n",
      "../results/imps_rw_MNISTCIFAR_0.75.npz\n",
      "../results/imps_gdro_MNISTCIFAR_0.75.npz\n",
      "../results/imps_unfreeze_MNISTCIFAR_0.75.npz\n",
      "../results/imps_frz1_MNISTCIFAR_0.75.npz\n",
      "../results/imps_frz2_MNISTCIFAR_0.75.npz\n",
      "../results/imps_frz3_MNISTCIFAR_0.75.npz\n",
      "../results/imps_frz4_MNISTCIFAR_0.75.npz\n",
      "../results/imps_erm_MNISTCIFAR_0.9.npz\n",
      "../results/imps_rw_MNISTCIFAR_0.9.npz\n",
      "../results/imps_gdro_MNISTCIFAR_0.9.npz\n",
      "../results/imps_unfreeze_MNISTCIFAR_0.9.npz\n",
      "../results/imps_frz1_MNISTCIFAR_0.9.npz\n",
      "../results/imps_frz2_MNISTCIFAR_0.9.npz\n",
      "../results/imps_frz3_MNISTCIFAR_0.9.npz\n",
      "../results/imps_frz4_MNISTCIFAR_0.9.npz\n",
      "../results/imps_erm_CUB_0.0.npz\n",
      "../results/imps_rw_CUB_0.0.npz\n",
      "../results/imps_gdro_CUB_0.0.npz\n",
      "../results/imps_unfreeze_CUB_0.0.npz\n",
      "../results/imps_frz1_CUB_0.0.npz\n",
      "../results/imps_frz2_CUB_0.0.npz\n",
      "../results/imps_frz3_CUB_0.0.npz\n",
      "../results/imps_frz4_CUB_0.0.npz\n",
      "../results/imps_erm_CUB_0.25.npz\n",
      "../results/imps_rw_CUB_0.25.npz\n",
      "../results/imps_gdro_CUB_0.25.npz\n",
      "../results/imps_unfreeze_CUB_0.25.npz\n",
      "../results/imps_frz1_CUB_0.25.npz\n",
      "../results/imps_frz2_CUB_0.25.npz\n",
      "../results/imps_frz3_CUB_0.25.npz\n",
      "../results/imps_frz4_CUB_0.25.npz\n",
      "../results/imps_erm_CUB_0.5.npz\n",
      "../results/imps_rw_CUB_0.5.npz\n",
      "../results/imps_gdro_CUB_0.5.npz\n",
      "../results/imps_unfreeze_CUB_0.5.npz\n",
      "../results/imps_frz1_CUB_0.5.npz\n",
      "../results/imps_frz2_CUB_0.5.npz\n",
      "../results/imps_frz3_CUB_0.5.npz\n",
      "../results/imps_frz4_CUB_0.5.npz\n",
      "../results/imps_erm_CUB_0.75.npz\n",
      "../results/imps_rw_CUB_0.75.npz\n",
      "../results/imps_gdro_CUB_0.75.npz\n",
      "../results/imps_unfreeze_CUB_0.75.npz\n",
      "../results/imps_frz1_CUB_0.75.npz\n",
      "../results/imps_frz2_CUB_0.75.npz\n",
      "../results/imps_frz3_CUB_0.75.npz\n",
      "../results/imps_frz4_CUB_0.75.npz\n",
      "../results/imps_erm_CUB_0.9.npz\n",
      "../results/imps_rw_CUB_0.9.npz\n",
      "../results/imps_gdro_CUB_0.9.npz\n",
      "../results/imps_unfreeze_CUB_0.9.npz\n",
      "../results/imps_frz1_CUB_0.9.npz\n",
      "../results/imps_frz2_CUB_0.9.npz\n",
      "../results/imps_frz3_CUB_0.9.npz\n",
      "../results/imps_frz4_CUB_0.9.npz\n",
      "../results/imps_erm_CelebA_0.9.npz\n",
      "../results/imps_rw_CelebA_0.9.npz\n",
      "../results/imps_gdro_CelebA_0.9.npz\n",
      "../results/imps_unfreeze_CelebA_0.9.npz\n",
      "../results/imps_frz1_CelebA_0.9.npz\n",
      "../results/imps_frz2_CelebA_0.9.npz\n",
      "../results/imps_frz3_CelebA_0.9.npz\n",
      "../results/imps_frz4_CelebA_0.9.npz\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from os.path import exists\n",
    "data_list = []\n",
    "for dataset in ['MNISTCIFAR','CUB',\"CelebA\"]:\n",
    "    for spur in spurs[dataset]:\n",
    "        for method in [\"erm\",\"rw\",\"gdro\",\"unfreeze\", \"frz1\",\"frz2\",\"frz3\",\"frz4\"]:\n",
    "            filename = join(root,f\"imps_{method}_{dataset}_{spur}.npz\")\n",
    "            print(filename)\n",
    "            if not exists(filename):\n",
    "                continue\n",
    "            data = np.load(filename)\n",
    "            dis = data['disentanglement']\n",
    "            comp = data['completeness']\n",
    "            train_loss = data['train_loss']\n",
    "            test_loss = data['test_loss']\n",
    "        # Append a dictionary for each file to the data list\n",
    "            data_list.append({\n",
    "            'method': method,\n",
    "            'dataset': dataset,\n",
    "            'spur': float(spur),\n",
    "            'disentanglement': dis,\n",
    "            'completeness': comp,\n",
    "            'train_loss': train_loss,\n",
    "                'test_loss': test_loss\n",
    "            })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 198,
   "id": "2edc0330-19c6-478e-8c5d-ef5959cffb7e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      method     dataset  spur      disentanglement         completeness  \\\n",
      "0        erm  MNISTCIFAR  0.00   0.9998801836004187   0.7670580188353394   \n",
      "1         rw  MNISTCIFAR  0.00   0.9999985294924588   0.7699469909819328   \n",
      "2       gdro  MNISTCIFAR  0.00   0.9996722370216437   0.7198756810420746   \n",
      "3        erm  MNISTCIFAR  0.25   0.9969461414565524   0.7366528493466923   \n",
      "4         rw  MNISTCIFAR  0.25   0.9982140213706219   0.7387641546096028   \n",
      "5       gdro  MNISTCIFAR  0.25   0.9979602305719759   0.7186259987877207   \n",
      "6        erm  MNISTCIFAR  0.50   0.8701771108087004   0.7485828434561188   \n",
      "7         rw  MNISTCIFAR  0.50   0.8927225632096434   0.7631101552763799   \n",
      "8       gdro  MNISTCIFAR  0.50   0.8932000233279621   0.7596261365244018   \n",
      "9        erm  MNISTCIFAR  0.75    0.405939933846305   0.8088998555409826   \n",
      "10        rw  MNISTCIFAR  0.75  0.29399359120532387   0.8360102522170587   \n",
      "11      gdro  MNISTCIFAR  0.75   0.4106810014643423   0.7188992994178913   \n",
      "12       erm  MNISTCIFAR  0.90  0.14329458811176515   0.8994688347948134   \n",
      "13        rw  MNISTCIFAR  0.90  0.28014388313439254   0.8141446246739557   \n",
      "14      gdro  MNISTCIFAR  0.90   0.2457719901596822   0.8364723185824245   \n",
      "15       erm         CUB  0.00   0.9814835220609995   0.5454469520176699   \n",
      "16        rw         CUB  0.00   0.9569092702838601  0.48394851354277557   \n",
      "17      gdro         CUB  0.00   0.9719528918943618   0.5275774900626482   \n",
      "18  unfreeze         CUB  0.00   0.9777928655500363   0.5949159750709458   \n",
      "19      frz1         CUB  0.00   0.9747710295749547   0.5424163219898255   \n",
      "20      frz2         CUB  0.00    0.977952051450252   0.5452595447189297   \n",
      "21      frz3         CUB  0.00   0.9848125663142442   0.5407857098900548   \n",
      "22      frz4         CUB  0.00   0.9395406598543063   0.5821368941217197   \n",
      "23       erm         CUB  0.25   0.9584689239617199   0.5344604975121883   \n",
      "24        rw         CUB  0.25   0.9547691303495897  0.45515680315691787   \n",
      "25      gdro         CUB  0.25   0.9241449761278479   0.4912713178964504   \n",
      "26  unfreeze         CUB  0.25   0.9569626264992802   0.5526298231610884   \n",
      "27      frz1         CUB  0.25   0.9729571648900551   0.5362351580110231   \n",
      "28      frz2         CUB  0.25   0.9691210032760884    0.530479957895551   \n",
      "29      frz3         CUB  0.25   0.9708427956010193   0.5335210937739683   \n",
      "30      frz4         CUB  0.25   0.9576957820321024   0.5764754189379051   \n",
      "31       erm         CUB  0.50   0.9020486238092311   0.5288651638212996   \n",
      "32        rw         CUB  0.50   0.8566692308110173    0.443136069460227   \n",
      "33      gdro         CUB  0.50   0.8913559139910819   0.5131165699136826   \n",
      "34  unfreeze         CUB  0.50   0.9310770362511457   0.5435130007933242   \n",
      "35      frz1         CUB  0.50   0.8918233621002882   0.5124518556663388   \n",
      "36      frz2         CUB  0.50   0.8853009100493267   0.5128039200158241   \n",
      "37      frz3         CUB  0.50   0.8825915009629379   0.5204988084854497   \n",
      "38      frz4         CUB  0.50   0.9279838590365495   0.5319150257974541   \n",
      "39       erm         CUB  0.75   0.7505892331287652   0.5129925795338759   \n",
      "40        rw         CUB  0.75   0.6561315988757279  0.45780427831232284   \n",
      "41      gdro         CUB  0.75   0.6970978327787012  0.48208483258749857   \n",
      "42  unfreeze         CUB  0.75    0.739159928812186  0.48982495621148237   \n",
      "43      frz1         CUB  0.75   0.7033109447034495  0.49621527327721765   \n",
      "44      frz2         CUB  0.75    0.746388410691055  0.49141149985231325   \n",
      "45      frz3         CUB  0.75   0.7585651694244104   0.4875937018518813   \n",
      "46      frz4         CUB  0.75   0.7047362467462107  0.49216744821840264   \n",
      "47       erm         CUB  0.90   0.4002080950023117    0.478453522546868   \n",
      "48        rw         CUB  0.90   0.4598517018717483   0.4303953701865538   \n",
      "49      gdro         CUB  0.90  0.29604631804405523  0.40688057324229165   \n",
      "50  unfreeze         CUB  0.90  0.47613625369052104   0.5238741021882181   \n",
      "51      frz1         CUB  0.90  0.49256117726732906   0.4572096488502839   \n",
      "52      frz2         CUB  0.90  0.45791963633980803   0.4621589293979554   \n",
      "53      frz3         CUB  0.90  0.48682336928425085   0.4642961551724103   \n",
      "54      frz4         CUB  0.90   0.3047869732555881   0.5287215928292461   \n",
      "55       erm      CelebA  0.90   0.9605967886025125   0.5610206996162619   \n",
      "56        rw      CelebA  0.90   0.9179635729021276   0.5375095101508569   \n",
      "57      gdro      CelebA  0.90   0.9400285258663984   0.5412908274943793   \n",
      "\n",
      "            train_loss           test_loss  \n",
      "0               0.9701              0.9525  \n",
      "1   0.9864999999999999  0.9670000000000001  \n",
      "2               0.9695              0.9545  \n",
      "3               0.9701  0.9450000000000001  \n",
      "4               0.9855               0.962  \n",
      "5              0.97125              0.9495  \n",
      "6   0.9731000000000001              0.9325  \n",
      "7   0.9875499999999999  0.9550000000000001  \n",
      "8              0.97475  0.9430000000000001  \n",
      "9               0.9777               0.881  \n",
      "10              0.9856              0.9165  \n",
      "11  0.9824999999999999              0.9065  \n",
      "12             0.98695               0.831  \n",
      "13  0.9892000000000001              0.8485  \n",
      "14              0.9891  0.8494999999999999  \n",
      "15   0.967778936392075  0.9199332777314428  \n",
      "16  0.9623566214807091  0.8861551292743953  \n",
      "17  0.9641293013555787  0.9236864053377816  \n",
      "18  0.9710114702815433  0.9303586321934946  \n",
      "19   0.964754953076121  0.9036697247706422  \n",
      "20  0.9692387904066736   0.920767306088407  \n",
      "21  0.9699687174139728  0.9157631359466222  \n",
      "22  0.9711157455683004  0.9261884904086739  \n",
      "23  0.9651720542231491  0.9228523769808173  \n",
      "24  0.9622523461939521  0.9036697247706422  \n",
      "25  0.9421272158498436  0.8452877397831526  \n",
      "26  0.9711157455683004  0.9153461217681401  \n",
      "27  0.9675703858185609  0.9224353628023353  \n",
      "28  0.9684045881126173  0.9216013344453711  \n",
      "29  0.9696558915537017  0.9220183486238531  \n",
      "30  0.9706986444212722  0.9174311926605505  \n",
      "31  0.9678832116788321  0.8798999165971644  \n",
      "32  0.9564129301355578   0.853628023352794  \n",
      "33  0.9654848800834201  0.8861551292743953  \n",
      "34  0.9724713242961418  0.8882402001668057  \n",
      "35   0.970281543274244  0.8786488740617181  \n",
      "36  0.9701772679874869  0.8711426188490409  \n",
      "37  0.9701772679874869  0.8790658882402002  \n",
      "38  0.9717413972888426  0.8886572143452878  \n",
      "39  0.9610010427528676  0.8557130942452043  \n",
      "40  0.9404588112617309  0.8185988323603003  \n",
      "41  0.9610010427528676  0.8552960800667222  \n",
      "42  0.9655891553701773  0.8527939949958299  \n",
      "43  0.9651720542231491  0.8507089241034195  \n",
      "44  0.9644421272158499  0.8494578815679733  \n",
      "45  0.9653806047966631  0.8498748957464555  \n",
      "46  0.9650677789363921  0.8511259382819016  \n",
      "47  0.9597497393117831  0.8052543786488742  \n",
      "48   0.946089676746611  0.7585487906588824  \n",
      "49  0.9385818561001043  0.7793994995829858  \n",
      "50  0.9716371220020854  0.8015012510425354  \n",
      "51  0.9672575599582899  0.7973311092577148  \n",
      "52  0.9691345151199166  0.8035863219349457  \n",
      "53  0.9665276329509906  0.8098415346121768  \n",
      "54  0.9742440041710114  0.8035863219349458  \n",
      "55             0.97145  0.9366537474203454  \n",
      "56             0.96295  0.9289273669904867  \n",
      "57             0.96245  0.9283485176423214  \n"
     ]
    }
   ],
   "source": [
    "data = pd.DataFrame(data_list)\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9c484b98-3b02-48c7-be03-5d353cfe2d8d",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'data' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m df \u001b[38;5;241m=\u001b[39m \u001b[43mdata\u001b[49m\n\u001b[1;32m      4\u001b[0m metric_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdisentanglement\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m      5\u001b[0m metrics \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdisentanglement\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcompleteness\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_loss\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_loss\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
      "\u001b[0;31mNameError\u001b[0m: name 'data' is not defined"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "df = data\n",
    "\n",
    "metric_name = \"disentanglement\"\n",
    "metrics = ['disentanglement', 'completeness', 'train_loss', 'test_loss']\n",
    "datasets = ['MNISTCIFAR','CUB']\n",
    "n_rows = len(metrics)\n",
    "n_cols = len(datasets)\n",
    "methods  = ['erm', 'rw', 'gdro']\n",
    "methods  = ['erm', 'frz1', 'frz2', 'frz3', 'frz4']\n",
    "fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(n_rows*2.5, n_cols*6))\n",
    "#axs = axs.ravel()  # Flatten the grid\n",
    "for j, metric_name in enumerate(metrics):\n",
    "    for i, dataset in enumerate(datasets):\n",
    "        ax = axs[j, i]\n",
    "        print(dataset)\n",
    "        dataset_df = df[(df['dataset'] == dataset) & (df['method'].isin(methods))]\n",
    "        \n",
    "        #ax.set_title(f'{metric_name.capitalize()} - Dataset: {dataset}')\n",
    "        #ax.set_xlabel('Level of Correlation Between Spurious and Class Label')\n",
    "        #ax.set_ylabel(metric_name.capitalize())\n",
    "        \n",
    "        # Plot each method for this dataset\n",
    "        for method in dataset_df['method'].unique():\n",
    "            method_df = dataset_df[dataset_df['method'] == method]\n",
    "            ax.plot(method_df['spur'], method_df[metric_name], label=method)\n",
    "         # Set x-ticks based on the unique values in the 'spur' column\n",
    "        x_ticks = sorted(dataset_df['spur'].unique())  # Get unique spur values\n",
    "        ax.set_xticks(x_ticks)  # Set the x-ticks to these values\n",
    "        ax.set_xticklabels([str(xt) for xt in x_ticks], rotation=0)  # Rotate labels for readability\n",
    "       \n",
    "        ax.legend()\n",
    "        ax.grid(True)\n",
    "        # Save the current ax's content as an image\n",
    "        extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())  # Get the extent of the ax\n",
    "        extent = ax.get_tightbbox(fig.canvas.get_renderer()).transformed(fig.dpi_scale_trans.inverted())\n",
    "        # Add padding (adjust the padding amount as needed)\n",
    "        x_pad = 0.05  # Padding along the x-axis\n",
    "        y_pad = 0.5  # Padding along the y-axis\n",
    "        extent = extent.expanded(1.0 + x_pad, 1.0 + y_pad)\n",
    "        fig.savefig(f'subplot_{metric_name}_{dataset}.png', bbox_inches=extent)\n",
    "    # Hide any unused subplots\n",
    "    #for j in range(i + 1, len(axs)):\n",
    "    #    axs[j].axis('off')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "id": "4220314e-b490-4927-8409-b369860b67b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/alain/Data/Tesis/svdrop/disentanglement_lib\n"
     ]
    }
   ],
   "source": [
    "!pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "id": "104fd170-176a-45f4-b978-462f3207b50a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>method</th>\n",
       "      <th>dataset</th>\n",
       "      <th>spur</th>\n",
       "      <th>disentanglement</th>\n",
       "      <th>completeness</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>test_loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>erm</td>\n",
       "      <td>CelebA</td>\n",
       "      <td>0.9</td>\n",
       "      <td>0.9605967886025125</td>\n",
       "      <td>0.5610206996162619</td>\n",
       "      <td>0.97145</td>\n",
       "      <td>0.9366537474203454</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>gdro</td>\n",
       "      <td>CelebA</td>\n",
       "      <td>0.9</td>\n",
       "      <td>0.9400285258663984</td>\n",
       "      <td>0.5412908274943793</td>\n",
       "      <td>0.96245</td>\n",
       "      <td>0.9283485176423214</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>rw</td>\n",
       "      <td>CelebA</td>\n",
       "      <td>0.9</td>\n",
       "      <td>0.9179635729021276</td>\n",
       "      <td>0.5375095101508569</td>\n",
       "      <td>0.96295</td>\n",
       "      <td>0.9289273669904867</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   method dataset  spur     disentanglement        completeness train_loss  \\\n",
       "30    erm  CelebA   0.9  0.9605967886025125  0.5610206996162619    0.97145   \n",
       "31   gdro  CelebA   0.9  0.9400285258663984  0.5412908274943793    0.96245   \n",
       "32     rw  CelebA   0.9  0.9179635729021276  0.5375095101508569    0.96295   \n",
       "\n",
       "             test_loss  \n",
       "30  0.9366537474203454  \n",
       "31  0.9283485176423214  \n",
       "32  0.9289273669904867  "
      ]
     },
     "execution_count": 130,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[df['dataset']==\"CelebA\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "id": "7a7d3903-3d24-4d79-97c9-a55c936ab33a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lllllllllllll}\n",
      "\\toprule\n",
      " & \\multicolumn{3}{r}{disentanglement} & \\multicolumn{3}{r}{completeness} & \\multicolumn{3}{r}{train_loss} & \\multicolumn{3}{r}{test_loss} \\\\\n",
      "method & erm & gdro & rw & erm & gdro & rw & erm & gdro & rw & erm & gdro & rw \\\\\n",
      "spur &  &  &  &  &  &  &  &  &  &  &  &  \\\\\n",
      "\\midrule\n",
      "0.00 & 1.000 & 1.000 & 1.000 & 0.767 & 0.720 & 0.770 & 0.97 & 0.97 & 0.99 & 0.95 & 0.95 & 0.97 \\\\\n",
      "0.25 & 0.997 & 0.998 & 0.998 & 0.737 & 0.719 & 0.739 & 0.97 & 0.97 & 0.99 & 0.95 & 0.95 & 0.96 \\\\\n",
      "0.50 & 0.870 & 0.893 & 0.893 & 0.749 & 0.760 & 0.763 & 0.97 & 0.97 & 0.99 & 0.93 & 0.94 & 0.96 \\\\\n",
      "0.75 & 0.406 & 0.411 & 0.294 & 0.809 & 0.719 & 0.836 & 0.98 & 0.98 & 0.99 & 0.88 & 0.91 & 0.92 \\\\\n",
      "0.90 & 0.143 & 0.246 & 0.280 & 0.899 & 0.836 & 0.814 & 0.99 & 0.99 & 0.99 & 0.83 & 0.85 & 0.85 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n",
      "\\begin{tabular}{lllllllllllll}\n",
      "\\toprule\n",
      " & \\multicolumn{3}{r}{disentanglement} & \\multicolumn{3}{r}{completeness} & \\multicolumn{3}{r}{train_loss} & \\multicolumn{3}{r}{test_loss} \\\\\n",
      "method & erm & gdro & rw & erm & gdro & rw & erm & gdro & rw & erm & gdro & rw \\\\\n",
      "spur &  &  &  &  &  &  &  &  &  &  &  &  \\\\\n",
      "\\midrule\n",
      "0.00 & 0.981 & 0.972 & 0.957 & 0.545 & 0.528 & 0.484 & 0.97 & 0.96 & 0.96 & 0.92 & 0.92 & 0.89 \\\\\n",
      "0.25 & 0.958 & 0.924 & 0.955 & 0.534 & 0.491 & 0.455 & 0.97 & 0.94 & 0.96 & 0.92 & 0.85 & 0.90 \\\\\n",
      "0.50 & 0.902 & 0.891 & 0.857 & 0.529 & 0.513 & 0.443 & 0.97 & 0.97 & 0.96 & 0.88 & 0.89 & 0.85 \\\\\n",
      "0.75 & 0.751 & 0.697 & 0.656 & 0.513 & 0.482 & 0.458 & 0.96 & 0.96 & 0.94 & 0.86 & 0.86 & 0.82 \\\\\n",
      "0.90 & 0.400 & 0.296 & 0.460 & 0.478 & 0.407 & 0.430 & 0.96 & 0.94 & 0.95 & 0.81 & 0.78 & 0.76 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_59208/858565565.py:9: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['spur'] = temp_df['spur'].apply(format_float2)\n",
      "/tmp/ipykernel_59208/858565565.py:10: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['completeness'] = temp_df['completeness'].apply(format_float)\n",
      "/tmp/ipykernel_59208/858565565.py:11: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['disentanglement'] = temp_df['disentanglement'].apply(format_float)\n",
      "/tmp/ipykernel_59208/858565565.py:12: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['train_loss'] = temp_df['train_loss'].apply(format_float2)\n",
      "/tmp/ipykernel_59208/858565565.py:13: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['test_loss'] = temp_df['test_loss'].apply(format_float2)\n",
      "/tmp/ipykernel_59208/858565565.py:9: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['spur'] = temp_df['spur'].apply(format_float2)\n",
      "/tmp/ipykernel_59208/858565565.py:10: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['completeness'] = temp_df['completeness'].apply(format_float)\n",
      "/tmp/ipykernel_59208/858565565.py:11: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['disentanglement'] = temp_df['disentanglement'].apply(format_float)\n",
      "/tmp/ipykernel_59208/858565565.py:12: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['train_loss'] = temp_df['train_loss'].apply(format_float2)\n",
      "/tmp/ipykernel_59208/858565565.py:13: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['test_loss'] = temp_df['test_loss'].apply(format_float2)\n"
     ]
    }
   ],
   "source": [
    "def format_float(perc):\n",
    "    return '{:.3f}'.format(perc)\n",
    "\n",
    "def format_float2(perc):\n",
    "    return '{:.2f}'.format(perc)\n",
    "\n",
    "for dataset in datasets:\n",
    "    temp_df = df[df['dataset']==dataset]\n",
    "    temp_df['spur'] = temp_df['spur'].apply(format_float2)\n",
    "    temp_df['completeness'] = temp_df['completeness'].apply(format_float)\n",
    "    temp_df['disentanglement'] = temp_df['disentanglement'].apply(format_float)\n",
    "    temp_df['train_loss'] = temp_df['train_loss'].apply(format_float2)\n",
    "    temp_df['test_loss'] = temp_df['test_loss'].apply(format_float2)\n",
    "    print(temp_df.pivot(index='spur',columns='method', values=([\"disentanglement\", \"completeness\",\"train_loss\",\"test_loss\"])).to_latex(index=None))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "id": "6ea09a65-3971-42f7-bdcb-2af4f32591c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lllllllllllll}\n",
      "\\toprule\n",
      " & \\multicolumn{3}{r}{disentanglement} & \\multicolumn{3}{r}{completeness} & \\multicolumn{3}{r}{train_loss} & \\multicolumn{3}{r}{test_loss} \\\\\n",
      "method & erm & gdro & rw & erm & gdro & rw & erm & gdro & rw & erm & gdro & rw \\\\\n",
      "spur &  &  &  &  &  &  &  &  &  &  &  &  \\\\\n",
      "\\midrule\n",
      "0.90 & 0.961 & 0.940 & 0.918 & 0.561 & 0.541 & 0.538 & 0.97 & 0.96 & 0.96 & 0.94 & 0.93 & 0.93 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_59208/2858638449.py:2: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['spur'] = temp_df['spur'].apply(format_float2)\n",
      "/tmp/ipykernel_59208/2858638449.py:3: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['completeness'] = temp_df['completeness'].apply(format_float)\n",
      "/tmp/ipykernel_59208/2858638449.py:4: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['disentanglement'] = temp_df['disentanglement'].apply(format_float)\n",
      "/tmp/ipykernel_59208/2858638449.py:5: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['train_loss'] = temp_df['train_loss'].apply(format_float2)\n",
      "/tmp/ipykernel_59208/2858638449.py:6: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  temp_df['test_loss'] = temp_df['test_loss'].apply(format_float2)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "temp_df = df[df['dataset']==\"CelebA\"]\n",
    "temp_df['spur'] = temp_df['spur'].apply(format_float2)\n",
    "temp_df['completeness'] = temp_df['completeness'].apply(format_float)\n",
    "temp_df['disentanglement'] = temp_df['disentanglement'].apply(format_float)\n",
    "temp_df['train_loss'] = temp_df['train_loss'].apply(format_float2)\n",
    "temp_df['test_loss'] = temp_df['test_loss'].apply(format_float2)\n",
    "print(temp_df.pivot(index='spur',columns='method', values=([\"disentanglement\", \"completeness\",\"train_loss\",\"test_loss\"])).to_latex(index=None))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31a1a71c-cd71-467e-a599-17eef209a148",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "results = pd.read_csv(f'{dataset}_{split}.csv')\n",
    "results['worst_acc'] = results[['avg_acc_group:0','avg_acc_group:1','avg_acc_group:2','avg_acc_group:3']].min(axis=1)\n",
    "epoch_df = results.groupby(['method','corr',\"epoch\"]).agg(\n",
    "    {'worst_acc': 'mean','avg_acc': 'mean'}\n",
    ").reset_index()\n",
    "grouped_df = results.groupby(['method','corr','seed']).agg(\n",
    "    {'avg_acc': 'max', 'worst_acc': 'max'}\n",
    ").reset_index()\n",
    "means_df = grouped_df.groupby(['method','corr']).agg(\n",
    "    {'avg_acc':'mean','worst_acc':'mean'}\n",
    ").reset_index()\n",
    "for m in ['worst_acc','avg_acc']:\n",
    "    means_df[m] = 100*means_df[m]\n",
    "    means_df[f'{m}_std'] = 100*grouped_df.groupby(['method','corr'])[m].std().reset_index()[m]\n",
    "    means_df['count'] = grouped_df.groupby(['method','corr'])[m].count().reset_index()[m]\n",
    "    means_df[m] = means_df[m].apply(format_perc)\n",
    "    means_df[f'{m}_std'] = means_df[f'{m}_std'].apply(format_float)\n",
    "    means_df[f'final_{m}'] = (means_df[m]).astype(str) + \" $\\pm$ \" + means_df[f'{m}_std'].astype(str) + \" (\" + means_df['count'].astype(str) + \")\"\n",
    "\n",
    "def create_final_table(df, metric=\"avg_acc\"):\n",
    "    final_table = df.pivot(index='method',\n",
    "                           columns='corr',\n",
    "                           values=([f\"final_{metric}\"])).reset_index()\n",
    "\n",
    "    return final_table"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
