{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/home/njr61/rds/hpc-work/spurious-concepts/ConceptBottleneck')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "from captum.attr import visualization as viz\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import cv2\n",
    "from copy import copy \n",
    "import itertools\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ConceptBottleneck.CUB.models import ModelXtoC, ModelOracleCtoY\n",
    "from ConceptBottleneck.CUB.dataset import load_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.images import *\n",
    "from src.util import *\n",
    "from src.models import *\n",
    "from src.plot import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Impact of Bounding Box Size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Small Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_objects = 1\n",
    "noisy=False\n",
    "weight_decay = 0.0004\n",
    "encoder_model='small3'\n",
    "optimizer = 'sgd'\n",
    "seed = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "json_file = \"results/ablation/bounding_box.json\"\n",
    "results_dict = {\n",
    "    'num_objects': num_objects, \n",
    "    'noisy': noisy,\n",
    "    'weight_decay': weight_decay,\n",
    "    'encoder_model': encoder_model, \n",
    "    'optimizer': optimizer, \n",
    "    'seed': seed,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x14fa8c5ea6d0>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, train_pkl, val_pkl = get_data(num_objects, noisy)\n",
    "val_images, val_y, val_c = unroll_data(val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = get_synthetic_model(num_objects,encoder_model,noisy,weight_decay,optimizer,seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_epsilons = {}\n",
    "\n",
    "for epsilon XXXX-20range(0,51,10):\n",
    "    activation_values = []\n",
    "\n",
    "    for concept_num XXXX-20range(num_objects*2):\n",
    "        ret_image = get_maximal_activation(joint_model,run_joint_model,concept_num,\n",
    "                                        get_valid_image_function(concept_num,num_objects,epsilon=epsilon))\n",
    "        activation_values.append (\n",
    "            torch.nn.Sigmoid()(run_joint_model(joint_model,ret_image)[1])[concept_num][0].detach().numpy()\n",
    "        )\n",
    "\n",
    "    all_epsilons[epsilon] = activation_values\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dict['epsilon_adversarial'] = all_epsilons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i XXXX-20results_dict['epsilon_adversarial']:\n",
    "    results_dict['epsilon_adversarial'][i] = np.array(results_dict['epsilon_adversarial'][i]).tolist() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(results_dict,open(json_file,\"w\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": XXXX-45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'num_objects': 1,\n",
       " 'noisy': False,\n",
       " 'weight_decay': 0.0004,\n",
       " 'encoder_model': 'small3',\n",
       " 'optimizer': 'sgd',\n",
       " 'seed': 42,\n",
       " 'epsilon_adversarial': {0: [1.0, 1.0],\n",
       "  10: [1.0, 1.0],\n",
       "  20: [1.0, 0.07281245291233063],\n",
       "  30: [1.0, 0.0015784641727805138],\n",
       "  40: [1.0, 1.1123177500849124e-05],\n",
       "  50: [1.0, 5.825044269158752e-08]}}"
      ]
     },
     "execution_count": XXXX-45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Large Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_objects = 1\n",
    "noisy=False\n",
    "weight_decay = 0.0004\n",
    "encoder_model='small7'\n",
    "optimizer = 'sgd'\n",
    "seed = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "json_file = \"results/ablation/bounding_box_large.json\"\n",
    "results_dict = {\n",
    "    'num_objects': num_objects, \n",
    "    'noisy': noisy,\n",
    "    'weight_decay': weight_decay,\n",
    "    'encoder_model': encoder_model, \n",
    "    'optimizer': optimizer, \n",
    "    'seed': seed,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x1553ec05e6b0>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, train_pkl, val_pkl = get_data(num_objects, noisy)\n",
    "val_images, val_y, val_c = unroll_data(val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = get_synthetic_model(num_objects,encoder_model,noisy,weight_decay,optimizer,seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_epsilons = {}\n",
    "\n",
    "for epsilon XXXX-20range(0,51,10):\n",
    "    activation_values = []\n",
    "\n",
    "    for concept_num XXXX-20range(num_objects*2):\n",
    "        ret_image = get_maximal_activation(joint_model,run_joint_model,concept_num,\n",
    "                                        get_valid_image_function(concept_num,num_objects,epsilon=epsilon))\n",
    "        activation_values.append (\n",
    "            torch.nn.Sigmoid()(run_joint_model(joint_model,ret_image)[1])[concept_num][0].detach().numpy()\n",
    "        )\n",
    "\n",
    "    all_epsilons[epsilon] = activation_values\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dict['epsilon_adversarial'] = all_epsilons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i XXXX-20results_dict['epsilon_adversarial']:\n",
    "    results_dict['epsilon_adversarial'][i] = np.array(results_dict['epsilon_adversarial'][i]).tolist() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(results_dict,open(json_file,\"w\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'num_objects': 1,\n",
       " 'noisy': False,\n",
       " 'weight_decay': 0.0004,\n",
       " 'encoder_model': 'small7',\n",
       " 'optimizer': 'sgd',\n",
       " 'seed': 42,\n",
       " 'epsilon_adversarial': {0: [1.0, 1.0],\n",
       "  10: [1.0, 1.0],\n",
       "  20: [1.0, 1.0],\n",
       "  30: [1.0, 1.0],\n",
       "  40: [1.0, 1.0],\n",
       "  50: [1.0, 1.0]}}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pruning Ablation Study"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Weight Pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_objects = 1\n",
    "noisy=False\n",
    "weight_decay = 0.0004\n",
    "encoder_model='small7'\n",
    "optimizer = 'sgd'\n",
    "seed = 44\n",
    "\n",
    "results_json = \"results/ablation/pruning_{}.json\".format(seed)\n",
    "results = {\n",
    "    'layer': {},\n",
    "    'weight': {},\n",
    "    'num_objects': num_objects, \n",
    "    'noisy': noisy,\n",
    "    'weight_decay': weight_decay,\n",
    "    'encoder_model': encoder_model, \n",
    "    'optimizer': optimizer, \n",
    "    'seed': seed,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x1464509766f0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, train_pkl, val_pkl = get_data(num_objects, noisy)\n",
    "val_images, val_y, val_c = unroll_data(val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "no_color = torch.Tensor([0.25,0.25,0.25])\n",
    "full_color = torch.Tensor([-0.25,-0.25,0.25])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "blank_image = no_color.clone().view(3, 1, 1).expand((3,256,256))\n",
    "full_image = full_color.clone().view(3, 1, 1).expand_as(blank_image)\n",
    "\n",
    "half_left = no_color.view(3, 1, 1).expand_as(blank_image).clone()\n",
    "half_left[:,:,:128] = full_image[:,:,:128]\n",
    "\n",
    "half_right = no_color.clone().view(3, 1, 1).expand_as(blank_image).clone()\n",
    "half_right[:,:,128:] = full_image[:,:,128:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "for prune_rate XXXX-42 [0.25,0.5,0.75,0.9,0.95,0.99]:\n",
    "    joint_model = get_synthetic_model(num_objects,encoder_model,noisy,weight_decay,optimizer,seed)\n",
    "\n",
    "    for layer XXXX-20joint_model.first_model.children():\n",
    "        if isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.Linear):\n",
    "            torch.nn.utils.prune.l1_unstructured(layer, name=\"weight\", amount=prune_rate)\n",
    "\n",
    "    activation_values = []\n",
    "\n",
    "    for concept_num XXXX-20range(num_objects*2):\n",
    "        ret_image = get_maximal_activation(joint_model,run_joint_model,concept_num,\n",
    "                                        get_valid_image_function(concept_num,num_objects,epsilon=32))\n",
    "        activation_values.append (\n",
    "        torch.nn.Sigmoid()(run_joint_model(joint_model,ret_image)[1])[concept_num][0].detach().numpy()\n",
    "    )\n",
    "\n",
    "    results['weight'][prune_rate] = {\n",
    "        'activations': np.array(activation_values).tolist(), \n",
    "        'accuracies': np.array(get_concept_accuracy_by_concept(joint_model,run_joint_model,train_loader,sigmoid=True).detach().numpy()).tolist(), \n",
    "    }\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Layer Pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "for prune_rate XXXX-42 [0.25,0.5,0.75,0.9,0.95,0.99]:\n",
    "    joint_model = get_synthetic_model(num_objects,encoder_model,noisy,weight_decay,optimizer,seed)\n",
    "\n",
    "    for layer_to_prune XXXX-42 [joint_model.first_model.conv4, \n",
    "                        joint_model.first_model.conv5,\n",
    "                        joint_model.first_model.conv6,\n",
    "                        joint_model.first_model.conv7]:\n",
    "\n",
    "        weight = layer_to_prune.weight.data.abs().clone()\n",
    "        importance = weight.sum((1, 2, 3))  # Calculate importance of filters\n",
    "        num_filters = layer_to_prune.weight.size(0)\n",
    "\n",
    "        # Compute the number of filters to prune\n",
    "        num_prune = int(num_filters * prune_rate)\n",
    "        _, indices = importance.sort(descending=True)\n",
    "        indices_to_prune = indices[-num_prune:]\n",
    "\n",
    "        # Create a mask to prune filters\n",
    "        mask = torch.ones(num_filters)\n",
    "        mask[indices_to_prune] = 0\n",
    "        if mask is not None:\n",
    "            mask = mask.to(layer_to_prune.weight.device)\n",
    "            layer_to_prune.weight.data *= mask.view(-1, 1, 1, 1)\n",
    "\n",
    "\n",
    "    activation_values = []\n",
    "\n",
    "    for concept_num XXXX-20range(num_objects*2):\n",
    "        ret_image = get_maximal_activation(joint_model,run_joint_model,concept_num,\n",
    "                                        get_valid_image_function(concept_num,num_objects,epsilon=32))\n",
    "        activation_values.append (\n",
    "        torch.nn.Sigmoid()(run_joint_model(joint_model,ret_image)[1])[concept_num][0].detach().numpy()\n",
    "    )\n",
    "\n",
    "    results['layer'][prune_rate] = {\n",
    "        'activations': np.array(activation_values).tolist(), \n",
    "        'accuracies': np.array(get_concept_accuracy_by_concept(joint_model,run_joint_model,train_loader,sigmoid=True).detach().numpy()).tolist(), \n",
    "    }\n",
    "\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(results,open(results_json,'w'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Independent Spatial Locality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_objects = 1\n",
    "noisy=False\n",
    "weight_decay = 0.0004\n",
    "encoder_model='small7'\n",
    "optimizer = 'sgd'\n",
    "seed = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, train_pkl, val_pkl = get_data(num_objects, noisy)\n",
    "val_images, val_y, val_c = unroll_data(val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_activation_values = []\n",
    "\n",
    "for seed XXXX-42 [42,43,44]:\n",
    "\n",
    "    independent_model_encoder = get_independent_encoder(num_objects,encoder_model,noisy,weight_decay,optimizer,seed)\n",
    "    independent_model_decoder = get_independent_decoder(num_objects,encoder_model,noisy,weight_decay,optimizer,seed)\n",
    "\n",
    "    activation_values = []\n",
    "\n",
    "    for concept_num XXXX-20range(num_objects*2):\n",
    "        ret_image = get_maximal_activation([independent_model_encoder,independent_model_decoder],run_independent_model,concept_num,\n",
    "                                        get_valid_image_function(concept_num,num_objects,epsilon=32))\n",
    "        activation_values.append (\n",
    "        torch.nn.Sigmoid()(run_independent_model([independent_model_encoder,independent_model_decoder],ret_image)[1])[concept_num][0].detach().numpy()\n",
    "    )\n",
    "\n",
    "    all_activation_values.append(activation_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_activation_values = np.array(all_activation_values)\n",
    "all_activation_values = all_activation_values.tolist() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump({'spatial_locality': all_activation_values},open(\"results/ablation/independent_spatial.json\",\"w\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Independent Semantic Locality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'dsprites_20'\n",
    "noisy=False\n",
    "weight_decay = 0.0004\n",
    "encoder_model='small3'\n",
    "optimizer = 'sgd'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'../cem/cem/dsprites_20/preprocessed/'"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_name = dataset\n",
    "data_dir = \"../cem/cem/{}/preprocessed/\".format(dataset_name)\n",
    "data_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_path = os.path.join(data_dir, 'train.pkl')\n",
    "val_data_path = train_data_path.replace('train.pkl', 'val.pkl')\n",
    "extra_data_path = '../cem/cem/dsprites/preprocessed/extra.pkl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrained = True\n",
    "freeze = False\n",
    "use_aux = True\n",
    "expand_dim = 0\n",
    "three_class = False\n",
    "use_attr = True\n",
    "no_img = False\n",
    "batch_size = 64\n",
    "uncertain_labels = False\n",
    "image_dir = 'images'\n",
    "num_class_attr = 2\n",
    "resampling = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = load_data([train_data_path], use_attr, no_img, batch_size, uncertain_labels, image_dir=image_dir, \n",
    "                         n_class_attr=num_class_attr, resampling=resampling, path_transform=lambda path: \"../cem/cem/\"+path, is_training=False)\n",
    "val_loader = load_data([val_data_path], use_attr, no_img=False, batch_size=64, image_dir=image_dir, \n",
    "                        n_class_attr=num_class_attr, path_transform=lambda path: \"../cem/cem/\"+path)\n",
    "extra_loader = load_data([extra_data_path], use_attr, no_img=False, batch_size=64, image_dir=image_dir, \n",
    "                        n_class_attr=num_class_attr, path_transform=lambda path: \"../cem/cem/\"+path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pkl = pickle.load(open(train_data_path,\"rb\"))\n",
    "val_pkl = pickle.load(open(val_data_path,\"rb\"))\n",
    "extra_pkl = pickle.load(open(extra_data_path,\"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "extra_images, extra_y, extra_c = unroll_data(extra_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "attributes_as_matrix = np.array([i['attribute_label'] for i XXXX-20extra_pkl])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_list = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "for seed XXXX-42 [42,43,44]:\n",
    "    log_folder = \"results/dsprites_20/independent_model_small3\"\n",
    "    independent_location = \"ConceptBottleneck/{}/concept/best_model_{}.pth\".format(log_folder,seed)\n",
    "    independent_encoder = torch.load(independent_location,map_location=torch.device('cpu')).eval() \n",
    "    independent_location = \"ConceptBottleneck/{}/bottleneck/best_model_{}.pth\".format(log_folder,seed)\n",
    "    independent_decoder = torch.load(independent_location,map_location=torch.device('cpu')).eval() \n",
    "\n",
    "    _, extra_predictions = run_independent_model([independent_encoder,independent_decoder],extra_images)\n",
    "    extra_predictions = extra_predictions.T\n",
    "    extra_predictions = torch.nn.Sigmoid()(extra_predictions)\n",
    "\n",
    "    all_predictions = [list_to_string([int(i) for i XXXX-20np.round(j.detach().numpy())]) for j XXXX-20extra_predictions]\n",
    "    correct_answers = [list_to_string(i) for i XXXX-20attributes_as_matrix]\n",
    "\n",
    "    overall_accuracy = len([i for i XXXX-20range(len(correct_answers)) if correct_answers[i] == all_predictions[i]])/len(all_predictions)\n",
    "    accuracy_list.append(overall_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.08333333333333333, 0.06944444444444445, 0.08333333333333333]"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump({'all_accuracies': accuracy_list},open('results/ablation/independent_semantic.json','w'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cem",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
