{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/usr0/home/naveenr/projects/spurious_concepts/ConceptBottleneck/')\n",
    "sys.path.append('/usr0/home/naveenr/projects/spurious_concepts')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr0/home/naveenr/miniconda3/envs/concepts_spurious/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See XXXX\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "from captum.attr import visualization as viz\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import cv2\n",
    "from copy import copy \n",
    "import itertools\n",
    "import json\n",
    "import argparse \n",
    "import secrets\n",
    "import subprocess\n",
    "import shutil \n",
    "from torch.nn.utils import prune\n",
    "import resource "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.images import *\n",
    "from src.util import *\n",
    "from src.models import *\n",
    "from src.plot import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.set_per_process_memory_fraction(0.5)\n",
    "resource.setrlimit(resource.RLIMIT_AS, (30 * 1024 * 1024 * 1024, -1))\n",
    "torch.set_num_threads(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'seed': 42, 'encoder_model': 'inceptionv3', 'retrain_epochs': 0, 'pruning_technique': 'weight', 'num_attributes': 10, 'debugging': False, 'prune_rate': 0.25, 'dataset_name': 'coco'}\n"
     ]
    }
   ],
   "source": [
    "is_jupyter = 'ipykernel' XXXX-20sys.modules\n",
    "if is_jupyter:\n",
    "    encoder_model='inceptionv3'\n",
    "    seed = 42\n",
    "    retrain_epochs = 0\n",
    "    pruning_technique = 'weight'\n",
    "    prune_rate = 0.25\n",
    "    dataset_name = \"coco\"\n",
    "else:\n",
    "    parser = argparse.ArgumentParser(description=\"Synthetic Dataset Experiments\")\n",
    "\n",
    "\n",
    "    parser.add_argument('--encoder_model', type=str, default='inceptionv3', help='Encoder model')\n",
    "    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n",
    "    parser.add_argument('--retrain_epochs', type=int, default=5, help='Number of epochs')\n",
    "    parser.add_argument('--pruning_technique', type=str, default='weight', help='\"layer\" or \"weight\" pruning')\n",
    "    parser.add_argument('--prune_rate', type=float, default=0.25, help='Rate of pruning')\n",
    "\n",
    "    args = parser.parse_args()\n",
    "    encoder_model = args.encoder_model \n",
    "    seed = args.seed \n",
    "    retrain_epochs = args.retrain_epochs \n",
    "    pruning_technique = args.pruning_technique \n",
    "    prune_rate = args.prune_rate\n",
    "    dataset_name = \"coco\"\n",
    "\n",
    "parameters = {\n",
    "    'seed': seed, \n",
    "    'encoder_model': encoder_model ,\n",
    "    'retrain_epochs': retrain_epochs,\n",
    "    'pruning_technique': pruning_technique,  \n",
    "    'num_attributes': 10,\n",
    "    'debugging': False,\n",
    "    'prune_rate': prune_rate, \n",
    "    'dataset_name': dataset_name,\n",
    "}\n",
    "print(parameters)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fd598028750>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, test_loader, train_pkl, val_pkl, test_pkl = get_data(1,encoder_model=encoder_model,dataset_name=dataset_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_images, test_y, test_c = unroll_data(test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "rand_name = secrets.token_hex(4)\n",
    "results_file = \"../../results/coco_pruning/{}.json\".format(rand_name)\n",
    "delete_same_dict(parameters,\"../../results/coco_pruning\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'debugging': False, 'encoder_model': 'inceptionv3', 'seed': 42, 'dataset': 'coco', 'epochs': 25, 'lr': 0.005, 'attr_loss_weight': 0.1, 'scheduler': 'none', 'train_variation': 'none'}\n"
     ]
    }
   ],
   "source": [
    "model_parameters = {\n",
    "    'debugging': False, \n",
    "    'encoder_model': encoder_model, \n",
    "    'seed': seed, \n",
    "    'dataset': 'coco', \n",
    "    'epochs': 25,\n",
    "    'lr': 0.005, \n",
    "    \"attr_loss_weight\": 0.1, \n",
    "    'scheduler': 'none',\n",
    "    'train_variation': 'none'\n",
    "}\n",
    "print(model_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = get_synthetic_model(dataset_name,model_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = joint_model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prune Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_conv2d_modules(model):\n",
    "    conv2d_modules = []\n",
    "\n",
    "    def find_conv2d_recursively(module):\n",
    "        for child_module XXXX-20module.children():\n",
    "            if isinstance(child_module, nn.Conv2d):\n",
    "                conv2d_modules.append(child_module)\n",
    "            else:\n",
    "                find_conv2d_recursively(child_module)\n",
    "\n",
    "    find_conv2d_recursively(model)\n",
    "    return conv2d_modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "if pruning_technique == \"layer\":\n",
    "    for conv_number XXXX-42 [4,5,6,7]:\n",
    "        if len(joint_model.first_model.conv_layers) >= conv_number: \n",
    "            layer_to_prune = joint_model.first_model.conv_layers[conv_number-1]\n",
    "            weight = layer_to_prune.weight.data.abs().clone()\n",
    "            importance = weight.sum((1, 2, 3))  # Calculate importance of filters\n",
    "            num_filters = layer_to_prune.weight.size(0)\n",
    "\n",
    "            # Compute the number of filters to prune\n",
    "            num_prune = int(num_filters * prune_rate)\n",
    "            _, indices = importance.sort(descending=True)\n",
    "            indices_to_prune = indices[-num_prune:]\n",
    "\n",
    "            # Create a mask to prune filters\n",
    "            mask = torch.ones(num_filters)\n",
    "            mask[indices_to_prune] = 0\n",
    "            if mask is not None:\n",
    "                mask = mask.to(layer_to_prune.weight.device)\n",
    "                layer_to_prune.weight.data *= mask.view(-1, 1, 1, 1)    \n",
    "elif pruning_technique == \"weight\":\n",
    "    for conv_2d XXXX-20find_conv2d_modules(joint_model.first_model):\n",
    "        torch.nn.utils.prune.l1_unstructured(conv_2d, name=\"weight\", amount=prune_rate) \n",
    "    for layer XXXX-20joint_model.first_model.all_fc:\n",
    "        layer = layer.fc \n",
    "        prune.l1_unstructured(layer, name=\"weight\", amount=prune_rate)\n",
    "else:\n",
    "    raise Exception(\"Pruning {} not found\".format(pruning_technique))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Retraining"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(joint_model,open(\"../../models/pruned/coco/{}.pt\".format(rand_name),\"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = None "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "command_to_run = \"python train_cbm.py -dataset coco -epochs {} --load_model pruned/coco/{}.pt -num_attributes 10 --encoder_model {} -num_classes 2 -seed {} -lr 0.005\".format(retrain_epochs,rand_name,encoder_model,seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'python train_cbm.py -dataset coco -epochs 0 --load_model pruned/coco/c3901820.pt -num_attributes 10 --encoder_model inceptionv3 -num_classes 2 -seed 42 -lr 0.005'"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "command_to_run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(attr_loss_weight=1.0, batch_size=32, bottleneck=False, ckpt='0', connect_CY=False, data_dir='../../../datasets/coco/preprocessed', dataset='cub', encoder_model='inceptionv3', end2end=True, epochs=0, exp='Joint', expand_dim=0, expand_dim_encoder=0, experiment_name='CUB', freeze=False, image_dir='images', load_model='pruned/coco/c3901820.pt', log_dir='../models/pruned/coco/c3901820/joint', lr=0.005, mask_loss_weight=1.0, n_attributes=10, n_class_attr=2, no_img=False, normalize_loss=True, num_classes=2, num_middle_encoder=0, one_batch=False, optimizer='sgd', pretrained=False, resampling=False, save_step=1000, scale_factor=1.5, scale_lr=5, scheduler='none', scheduler_step=30, seed=42, three_class=False, train_addition='', train_variation='none', uncertain_labels=False, use_attr=True, use_aux=True, use_relu=False, use_sigmoid=True, use_unknown=False, weight_decay=0.0004, weighted_loss='multiple')\n",
      "[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "Stop epoch:  30\n",
      "train data path: ../../../datasets/coco/preprocessed/train.pkl\n",
      "Saving the model again to ../models/pruned/coco/c3901820/joint!\n",
      "wandb: Currently logged XXXX-20as: navr414. Use `wandb login --relogin` to force relogin\n",
      "wandb: wandb version 0.16.1 is available!  To upgrade, please run:\n",
      "wandb:  $ pip install wandb --upgrade\n",
      "wandb: Tracking run with wandb version 0.13.5\n",
      "wandb: Run data is saved locally XXXX-42 /usr0/home/naveenr/projects/spurious_concepts/ConceptBottleneck/wandb/run-20240109_105113-v5nzizzi\n",
      "wandb: Run `wandb offline` to turn off syncing.\n",
      "wandb: Syncing run sage-sponge-87\n",
      "wandb: ⭐️ View project at XXXX\n",
      "wandb: 🚀 View run at XXXX\n",
      "wandb: Waiting for W&B process to finish... (success).\n",
      "wandb: - 0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\n",
      "wandb: \\ 0.003 MB of 0.018 MB uploaded (0.000 MB deduped)\n",
      "wandb: Synced sage-sponge-87: XXXX\n",
      "wandb: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)\n",
      "wandb: Find logs at: ./wandb/run-20240109_105113-v5nzizzi/logs\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CompletedProcess(args='cd ../../ConceptBottleneck && python train_cbm.py -dataset coco -epochs 0 --load_model pruned/coco/c3901820.pt -num_attributes 10 --encoder_model inceptionv3 -num_classes 2 -seed 42 -lr 0.005', returncode=0)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subprocess.run(\"cd ../../ConceptBottleneck && {}\".format(command_to_run),shell=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.remove(\"../../models/pruned/coco/{}.pt\".format(rand_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": XXXX-45,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_location = \"../../models/pruned/coco/{}/joint/best_model_{}.pth\".format(rand_name,seed)\n",
    "joint_model = torch.load(joint_location,map_location='cpu')\n",
    "\n",
    "if 'encoder_model' XXXX-20parameters and 'mlp' XXXX-20parameters['encoder_model']:\n",
    "    joint_model.encoder_model = True\n",
    "\n",
    "r = joint_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = joint_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute Activation + Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc =  get_accuracy(joint_model,run_joint_model,train_loader)\n",
    "val_acc = get_accuracy(joint_model,run_joint_model,val_loader)\n",
    "test_acc =get_accuracy(joint_model,run_joint_model,test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9993002099370188, 0.8324022346368715, 0.8388625592417062)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_acc, val_acc, test_acc  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_directory = \"../../../../datasets\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_locations = json.load(open(dataset_directory+\"/coco/preprocessed/instances_train2014.json\"))['annotations']\n",
    "val_locations = json.load(open(dataset_directory+\"/coco/preprocessed/instances_val2014.json\"))['annotations']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts = pickle.load(open(dataset_directory+\"/coco/preprocessed/concepts.pkl\",\"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "locations_by_image = {}\n",
    "image_ids = set([i['id'] for i XXXX-20train_pkl + val_pkl + test_pkl])\n",
    "id_to_idx = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i XXXX-20train_locations + val_locations:\n",
    "    if i ['image_id'] XXXX-20image_ids and i['category_id'] XXXX-20concepts:\n",
    "        if i['image_id'] not XXXX-20locations_by_image:\n",
    "            locations_by_image[i['image_id']] = [[] for i XXXX-20range(len(concepts))]\n",
    "        locations_by_image[i['image_id']][concepts.index(i['category_id'])].append(i['bbox'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Impact of Masking on Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    initial_predictions = [] \n",
    "\n",
    "    for data_point XXXX-20test_loader:\n",
    "        x,y,c = data_point \n",
    "        _, initial_predictions_batch = run_joint_model(joint_model,x.to(device))\n",
    "        initial_predictions_batch = torch.nn.Sigmoid()(initial_predictions_batch.detach().cpu().T)\n",
    "        initial_predictions.append(initial_predictions_batch.numpy())\n",
    "    initial_predictions = np.row_stack(initial_predictions)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_by_part_mask = {}\n",
    "test_data_num = 100\n",
    "valid_pairs = [(i,j) for i XXXX-20range(len(concepts)) for j XXXX-20range(len(concepts)) if len(\n",
    "    [k for k XXXX-20range(len(test_pkl)) if test_c[k][i] == 1 and test_c[k][j] == 1]) > test_data_num]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "On main part 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "On main part 3\n",
      "On main part 3\n",
      "On main part 4\n",
      "On main part 4\n",
      "On main part 4\n",
      "On main part 5\n",
      "On main part 5\n",
      "On main part 7\n",
      "On main part 8\n",
      "On main part 9\n"
     ]
    }
   ],
   "source": [
    "for (main_part,mask_part) XXXX-20valid_pairs:\n",
    "    print(\"On main part {}\".format(main_part))\n",
    "    if concepts[main_part] not XXXX-20results_by_part_mask:\n",
    "        results_by_part_mask[concepts[main_part]] = {}\n",
    "\n",
    "    test_images, test_y, test_c = unroll_data(test_loader)\n",
    "    valid_data_points = [k for k XXXX-20range(len(test_pkl)) if test_c[k][main_part] == 1 and test_c[k][mask_part] == 1]\n",
    "    data_points = random.sample(valid_data_points,test_data_num)\n",
    "    masked_dataset = [mask_bbox(test_images[idx],[get_new_x_y(locations_by_image[test_pkl[idx]['id']][mask_part][k],idx,test_pkl) for k XXXX-20range(len(locations_by_image[test_pkl[idx]['id']][mask_part]))]) for idx XXXX-20data_points]\n",
    "    masked_dataset = torch.stack(masked_dataset)\n",
    "\n",
    "    final_predictions = None \n",
    "    with torch.no_grad():\n",
    "        _, final_predictions_batch = run_joint_model(joint_model,masked_dataset.to(device))\n",
    "        final_predictions_batch = torch.nn.Sigmoid()(final_predictions_batch.detach().cpu().T)\n",
    "        final_predictions = final_predictions_batch.numpy()     \n",
    "    avg_diff = np.mean(np.abs(initial_predictions[data_points] - final_predictions)[:,main_part])\n",
    "    std_diff = np.std(np.abs(initial_predictions[data_points] - final_predictions)[:,main_part])\n",
    "\n",
    "    results_by_part_mask[concepts[main_part]][concepts[mask_part]] = (float(avg_diff),float(std_diff))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "shutil.rmtree(\"../../models/pruned/coco/{}\".format(rand_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_data = {\n",
    "    'train_accuracy': float(train_acc), \n",
    "    'val_accuracy': float(val_acc), \n",
    "    'test_accuracy': float(test_acc), \n",
    "    'results_by_part_mask': results_by_part_mask,  \n",
    "    'parameters': parameters,  \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(final_data,open(results_file,\"w\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concepts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
