{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/usr0/home/naveenr/projects/spurious_concepts/ConceptBottleneck/')\n",
    "sys.path.append('/usr0/home/naveenr/projects/spurious_concepts')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr0/home/naveenr/miniconda3/envs/concepts_spurious/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See XXXX\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "from captum.attr import visualization as viz\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import cv2\n",
    "from copy import copy \n",
    "import itertools\n",
    "import json\n",
    "import argparse \n",
    "import secrets\n",
    "import subprocess\n",
    "import shutil \n",
    "from torch.nn.utils import prune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.images import *\n",
    "from src.util import *\n",
    "from src.models import *\n",
    "from src.plot import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'seed': 42, 'encoder_model': 'inceptionv3', 'retrain_epochs': 0, 'pruning_technique': 'weight', 'num_attributes': 112, 'debugging': False, 'prune_rate': 0.25, 'dataset_name': 'CUB'}\n"
     ]
    }
   ],
   "source": [
    "is_jupyter = 'ipykernel' XXXX-20sys.modules\n",
    "if is_jupyter:\n",
    "    encoder_model='inceptionv3'\n",
    "    seed = 42\n",
    "    retrain_epochs = 0\n",
    "    pruning_technique = 'weight'\n",
    "    prune_rate = 0.25\n",
    "    dataset_name = \"CUB\"\n",
    "else:\n",
    "    parser = argparse.ArgumentParser(description=\"Synthetic Dataset Experiments\")\n",
    "\n",
    "\n",
    "    parser.add_argument('--encoder_model', type=str, default='small3', help='Encoder model')\n",
    "    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n",
    "    parser.add_argument('--retrain_epochs', type=int, default=5, help='Number of epochs')\n",
    "    parser.add_argument('--pruning_technique', type=str, default='weight', help='\"layer\" or \"weight\" pruning')\n",
    "    parser.add_argument('--prune_rate', type=float, default=0.25, help='Rate of pruning')\n",
    "\n",
    "    args = parser.parse_args()\n",
    "    encoder_model = args.encoder_model \n",
    "    seed = args.seed \n",
    "    retrain_epochs = args.retrain_epochs \n",
    "    pruning_technique = args.pruning_technique \n",
    "    prune_rate = args.prune_rate\n",
    "\n",
    "dataset_name = \"CUB\"\n",
    "\n",
    "parameters = {\n",
    "    'seed': seed, \n",
    "    'encoder_model': encoder_model ,\n",
    "    'retrain_epochs': retrain_epochs,\n",
    "    'pruning_technique': pruning_technique,  \n",
    "    'num_attributes': 112,\n",
    "    'debugging': False,\n",
    "    'prune_rate': prune_rate, \n",
    "    'dataset_name': dataset_name,\n",
    "}\n",
    "print(parameters)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fc4a41366f0>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, test_loader, train_pkl, val_pkl, test_pkl = get_data(1,encoder_model=encoder_model,dataset_name=dataset_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_images, test_y, test_c = unroll_data(test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "rand_name = secrets.token_hex(4)\n",
    "results_file = \"../../results/cub_pruning/{}.json\".format(rand_name)\n",
    "delete_same_dict(parameters,\"../../results/cub_pruning\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_parameters = {\n",
    "    'debugging': False, \n",
    "    'epochs': 100,\n",
    "    'encoder_model': encoder_model, \n",
    "    'seed': seed, \n",
    "    'dataset': 'CUB', \n",
    "    'epochs': 100,\n",
    "    'lr': 0.005\n",
    "\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = get_synthetic_model(dataset_name,model_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = joint_model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prune Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_conv2d_modules(model):\n",
    "    conv2d_modules = []\n",
    "\n",
    "    def find_conv2d_recursively(module):\n",
    "        for child_module XXXX-20module.children():\n",
    "            if isinstance(child_module, nn.Conv2d):\n",
    "                conv2d_modules.append(child_module)\n",
    "            else:\n",
    "                find_conv2d_recursively(child_module)\n",
    "\n",
    "    find_conv2d_recursively(model)\n",
    "    return conv2d_modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "if pruning_technique == \"layer\":\n",
    "    for conv_number XXXX-42 [4,5,6,7]:\n",
    "        if len(joint_model.first_model.conv_layers) >= conv_number: \n",
    "            layer_to_prune = joint_model.first_model.conv_layers[conv_number-1]\n",
    "            weight = layer_to_prune.weight.data.abs().clone()\n",
    "            importance = weight.sum((1, 2, 3))  # Calculate importance of filters\n",
    "            num_filters = layer_to_prune.weight.size(0)\n",
    "\n",
    "            # Compute the number of filters to prune\n",
    "            num_prune = int(num_filters * prune_rate)\n",
    "            _, indices = importance.sort(descending=True)\n",
    "            indices_to_prune = indices[-num_prune:]\n",
    "\n",
    "            # Create a mask to prune filters\n",
    "            mask = torch.ones(num_filters)\n",
    "            mask[indices_to_prune] = 0\n",
    "            if mask is not None:\n",
    "                mask = mask.to(layer_to_prune.weight.device)\n",
    "                layer_to_prune.weight.data *= mask.view(-1, 1, 1, 1)    \n",
    "elif pruning_technique == \"weight\":\n",
    "    for conv_2d XXXX-20find_conv2d_modules(joint_model.first_model):\n",
    "        torch.nn.utils.prune.l1_unstructured(conv_2d, name=\"weight\", amount=prune_rate) \n",
    "    for layer XXXX-20joint_model.first_model.all_fc:\n",
    "        layer = layer.fc \n",
    "        prune.l1_unstructured(layer, name=\"weight\", amount=prune_rate)\n",
    "else:\n",
    "    raise Exception(\"Pruning {} not found\".format(pruning_technique))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Retraining"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(joint_model,open(\"../../models/pruned/cub/{}.pt\".format(rand_name),\"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = None "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "command_to_run = \"python train_cbm.py -dataset CUB -epochs {} --load_model pruned/cub/{}.pt -num_attributes 112 --encoder_model {} -num_classes 200 -seed {} -lr 0.005\".format(retrain_epochs,rand_name,encoder_model,seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'python train_cbm.py -dataset CUB -epochs 0 --load_model pruned/cub/91bffe7e.pt -num_attributes 112 --encoder_model inceptionv3 -num_classes 200 -seed 42'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "command_to_run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(attr_loss_weight=1.0, batch_size=32, bottleneck=False, ckpt='0', connect_CY=False, data_dir='../../../datasets/CUB/preprocessed', dataset='cub', encoder_model='inceptionv3', end2end=True, epochs=0, exp='Joint', expand_dim=0, expand_dim_encoder=0, experiment_name='CUB', freeze=False, image_dir='images', load_model='pruned/cub/91bffe7e.pt', log_dir='../models/pruned/cub/91bffe7e/joint', lr=0.05, mask_loss_weight=1.0, n_attributes=112, n_class_attr=2, no_img=False, normalize_loss=True, num_classes=200, num_middle_encoder=0, one_batch=False, optimizer='sgd', pretrained=False, resampling=False, save_step=1000, scale_factor=1.5, scale_lr=5, scheduler='none', scheduler_step=30, seed=42, three_class=False, train_addition='', train_variation='none', uncertain_labels=False, use_attr=True, use_aux=True, use_relu=False, use_sigmoid=True, use_unknown=False, weight_decay=0.0004, weighted_loss='multiple')\n",
      "[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "Stop epoch:  60\n",
      "train data path: ../../../datasets/CUB/preprocessed/train.pkl\n",
      "Saving the model again to ../models/pruned/cub/91bffe7e/joint!\n",
      "wandb: Currently logged XXXX-20as: navr414. Use `wandb login --relogin` to force relogin\n",
      "wandb: wandb version 0.16.1 is available!  To upgrade, please run:\n",
      "wandb:  $ pip install wandb --upgrade\n",
      "wandb: Tracking run with wandb version 0.13.5\n",
      "wandb: Run data is saved locally XXXX-42 /usr0/home/naveenr/projects/spurious_concepts/ConceptBottleneck/wandb/run-20240104_134008-1b41t8t4\n",
      "wandb: Run `wandb offline` to turn off syncing.\n",
      "wandb: Syncing run worthy-salad-22\n",
      "wandb: ⭐️ View project at XXXX\n",
      "wandb: 🚀 View run at XXXX\n",
      "wandb: Waiting for W&B process to finish... (success).\n",
      "wandb: - 0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\n",
      "wandb: \\ 0.003 MB of 0.018 MB uploaded (0.000 MB deduped)\n",
      "wandb: Synced worthy-salad-22: XXXX\n",
      "wandb: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)\n",
      "wandb: Find logs at: ./wandb/run-20240104_134008-1b41t8t4/logs\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CompletedProcess(args='cd ../../ConceptBottleneck && python train_cbm.py -dataset CUB -epochs 0 --load_model pruned/cub/91bffe7e.pt -num_attributes 112 --encoder_model inceptionv3 -num_classes 200 -seed 42', returncode=0)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subprocess.run(\"cd ../../ConceptBottleneck && {}\".format(command_to_run),shell=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.remove(\"../../models/pruned/cub/{}.pt\".format(rand_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_location = \"../../models/pruned/cub/{}/joint/best_model_{}.pth\".format(rand_name,seed)\n",
    "joint_model = torch.load(joint_location,map_location='cpu')\n",
    "\n",
    "if 'encoder_model' XXXX-20parameters and 'mlp' XXXX-20parameters['encoder_model']:\n",
    "    joint_model.encoder_model = True\n",
    "\n",
    "r = joint_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": XXXX-45,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = joint_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute Activation + Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc =  get_accuracy(joint_model,run_joint_model,train_loader)\n",
    "val_acc = get_accuracy(joint_model,run_joint_model,val_loader)\n",
    "test_acc =get_accuracy(joint_model,run_joint_model,test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0, 0.6385642737896494, 0.6446323783224025)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_acc, val_acc, test_acc  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_directory = \"../../../../datasets\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "part_location = dataset_directory + \"/CUB/metadata/parts/part_locs.txt\"\n",
    "part_list = dataset_directory + \"/CUB/metadata/parts/parts.txt\"\n",
    "\n",
    "part_file = open(part_location).read().strip().split(\"\\n\")\n",
    "part_list = open(part_list).read().strip().split(\"\\n\")\n",
    "part_list = [' '.join(i.split(' ')[1:]) for i XXXX-20part_list]\n",
    "\n",
    "attribute_names = open(dataset_directory+\"/CUB/metadata/attributes.txt\").read().strip().split(\"\\n\")\n",
    "attribute_names = [' '.join(i.split(' ')[1:]) for i XXXX-20attribute_names]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "parts_to_attribute = {}\n",
    "\n",
    "for i XXXX-20range(len(part_list)):\n",
    "    if 'left' XXXX-20part_list[i] or 'right' XXXX-20part_list[i]:\n",
    "        opposite = part_list[i].replace('left','RIGHT').replace('right','LEFT').lower()\n",
    "        indices = sorted([i,part_list.index(opposite)])\n",
    "        current_name = '^'.join([str(j) for j XXXX-20indices])\n",
    "    else:\n",
    "        current_name = str(i)\n",
    "    \n",
    "    parts_to_attribute[current_name] = [] \n",
    "    parts_split = part_list[i].split(' ')\n",
    "\n",
    "    for j XXXX-20range(len(attribute_names)):\n",
    "        main_part = set(attribute_names[j].split(\"::\")[0].split(\"_\"))\n",
    "\n",
    "        if len(main_part.intersection(parts_split)) > 0:\n",
    "            parts_to_attribute[current_name].append(j)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "locations_by_image_id = {}\n",
    "for i XXXX-20part_file:\n",
    "    split_vals = i.split(' ')\n",
    "\n",
    "    if split_vals[-1] == '1':\n",
    "        image_id = int(split_vals[0])\n",
    "        part_id = int(split_vals[1])-1 # 0 index \n",
    "        x = float(split_vals[2])\n",
    "        y = float(split_vals[3])\n",
    "\n",
    "        if image_id not XXXX-20locations_by_image_id:\n",
    "            locations_by_image_id[image_id] = {}\n",
    "        locations_by_image_id[image_id][part_id] = (x,y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Impact of Masking on Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    initial_predictions = [] \n",
    "\n",
    "    for data_point XXXX-20test_loader:\n",
    "        x,y,c = data_point \n",
    "        _, initial_predictions_batch = run_joint_model(joint_model,x.to(device))\n",
    "        initial_predictions_batch = torch.nn.Sigmoid()(initial_predictions_batch.detach().cpu().T)\n",
    "        initial_predictions.append(initial_predictions_batch.numpy())\n",
    "    initial_predictions = np.row_stack(initial_predictions)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_parts = [int(i) for i XXXX-20parts_to_attribute if '^' not XXXX-20i and parts_to_attribute[i] != []]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_by_part_mask = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilon = 0.3\n",
    "test_data_num = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Main part is 0\n"
     ]
    }
   ],
   "source": [
    "for main_part XXXX-20valid_parts:\n",
    "    print(\"Main part is {}\".format(main_part))\n",
    "    results_by_part_mask[part_list[main_part]] = {}\n",
    "    for mask_part XXXX-20valid_parts:\n",
    "        main_attributes = parts_to_attribute[str(main_part)]\n",
    "        mask_attributes = parts_to_attribute[str(mask_part)]\n",
    "        test_images, test_y, test_c = unroll_data(test_loader)\n",
    "\n",
    "        valid_data_points = [i for i XXXX-20range(len(test_pkl)) if main_part XXXX-20locations_by_image_id[test_pkl[i]['id']] and mask_part XXXX-20locations_by_image_id[test_pkl[i]['id']]]\n",
    "        data_points = random.sample(valid_data_points,test_data_num)\n",
    "        other_part_locations = [[get_part_location(data_point,new_part, locations_by_image_id, test_pkl) for new_part XXXX-20valid_parts if new_part!=mask_part and new_part XXXX-20locations_by_image_id[\n",
    "            test_pkl[data_point]['id']]] for data_point XXXX-20data_points]\n",
    "\n",
    "        masked_dataset = [mask_image_closest(test_images[data_points[idx]],get_part_location(data_points[idx],mask_part, locations_by_image_id, test_pkl),other_part_locations[idx],epsilon=epsilon) for idx XXXX-20range(len(data_points))]\n",
    "        masked_dataset = torch.stack(masked_dataset)    \n",
    "\n",
    "        final_predictions = None \n",
    "        with torch.no_grad():\n",
    "            _, final_predictions_batch = run_joint_model(joint_model,masked_dataset.to(device))\n",
    "            final_predictions_batch = torch.nn.Sigmoid()(final_predictions_batch.detach().cpu().T)\n",
    "            final_predictions = final_predictions_batch.numpy()     \n",
    "        avg_diff = np.mean(np.abs(initial_predictions[data_points] - final_predictions)[:,main_attributes])\n",
    "        std_diff = np.std(np.abs(initial_predictions[data_points] - final_predictions)[:,main_attributes])\n",
    "\n",
    "        results_by_part_mask[part_list[main_part]][part_list[mask_part]] = (avg_diff,std_diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i XXXX-20results_by_part_mask:\n",
    "    for j XXXX-20results_by_part_mask[i]:\n",
    "        results_by_part_mask[i][j] = (float(results_by_part_mask[i][j][0]),float(results_by_part_mask[i][j][1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '../../models/pruned/cub/91bffe7e'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_1304532/162295848.py\u001b[0m XXXX-42 \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mshutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrmtree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"../../models/pruned/cub/{}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrand_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/miniconda3/envs/concepts_spurious/lib/python3.7/shutil.py\u001b[0m XXXX-42 \u001b[0;36mrmtree\u001b[0;34m(path, ignore_errors, onerror)\u001b[0m\n\u001b[1;32m    483\u001b[0m             \u001b[0morig_st\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlstat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    484\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 485\u001b[0;31m             \u001b[0monerror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlstat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    486\u001b[0m             \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    487\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/concepts_spurious/lib/python3.7/shutil.py\u001b[0m XXXX-42 \u001b[0;36mrmtree\u001b[0;34m(path, ignore_errors, onerror)\u001b[0m\n\u001b[1;32m    481\u001b[0m         \u001b[0;31m# lstat()/open()/fstat() trick.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    482\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 483\u001b[0;31m             \u001b[0morig_st\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlstat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    484\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    485\u001b[0m             \u001b[0monerror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlstat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../../models/pruned/cub/91bffe7e'"
     ]
    }
   ],
   "source": [
    "shutil.rmtree(\"../../models/pruned/cub/{}\".format(rand_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_data = {\n",
    "    'train_accuracy': float(train_acc), \n",
    "    'val_accuracy': float(val_acc), \n",
    "    'test_accuracy': float(test_acc), \n",
    "    'results_by_part_mask': results_by_part_mask,  \n",
    "    'parameters': parameters,  \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(final_data,open(results_file,\"w\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concepts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
