{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/usr0/home/naveenr/projects/spurious_concepts/ConceptBottleneck/')\n",
    "sys.path.append('/usr0/home/naveenr/projects/spurious_concepts')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "from captum.attr import visualization as viz\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import cv2\n",
    "from copy import copy \n",
    "import itertools\n",
    "import json\n",
    "import argparse \n",
    "import secrets\n",
    "import subprocess\n",
    "import shutil \n",
    "from torch.nn.utils import prune\n",
    "import resource "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.images import *\n",
    "from src.util import *\n",
    "from src.models import *\n",
    "from src.plot import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_num_threads(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'seed': 42, 'encoder_model': 'small7', 'debugging': False, 'dataset_name': 'synthetic_object/synthetic_4', 'num_concept_combinations': 8}\n"
     ]
    }
   ],
   "source": [
    "is_jupyter = 'ipykernel' XXXX-20sys.modules\n",
    "if is_jupyter:\n",
    "    encoder_model='small7'\n",
    "    seed = 42\n",
    "    num_concept_combinations = 8\n",
    "\n",
    "    num_objects = 4\n",
    "else:\n",
    "    parser = argparse.ArgumentParser(description=\"Synthetic Dataset Experiments\")\n",
    "\n",
    "\n",
    "    parser.add_argument('--encoder_model', type=str, default='small3', help='Encoder model')\n",
    "    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n",
    "    parser.add_argument('--num_concept_combinations', type=int, default=1, help='Number of concept combinations')\n",
    "    parser.add_argument('--num_objects', type=int, default=4, help='Number of objects/which synthetic dataset')\n",
    "\n",
    "    args = parser.parse_args()\n",
    "    encoder_model = args.encoder_model \n",
    "    seed = args.seed \n",
    "    num_concept_combinations = args.num_concept_combinations \n",
    "    num_objects = args.num_objects\n",
    "\n",
    "dataset_name = \"synthetic_object/synthetic_{}\".format(num_objects)\n",
    "\n",
    "parameters = {\n",
    "    'seed': seed, \n",
    "    'encoder_model': encoder_model ,\n",
    "    'debugging': False,\n",
    "    'dataset_name': dataset_name,\n",
    "    'num_concept_combinations': num_concept_combinations\n",
    "}\n",
    "print(parameters)\n",
    "torch.cuda.set_per_process_memory_fraction(0.5)\n",
    "resource.setrlimit(resource.RLIMIT_AS, (20 * 1024 * 1024 * 1024, -1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, test_loader, train_pkl, val_pkl, test_pkl = get_data(num_objects,encoder_model=encoder_model,dataset_name=dataset_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_images, test_y, test_c = unroll_data(test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_random_combinations(L, K):\n",
    "    random.seed(seed)    \n",
    "    # Generate all possible combinations\n",
    "    all_combinations = list(itertools.product([0, 1], repeat=L))\n",
    "    random.shuffle(all_combinations)\n",
    "\n",
    "    return all_combinations[:K]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_combinations = generate_random_combinations(num_objects,num_concept_combinations)\n",
    "random_full_combinations = []\n",
    "for c XXXX-20random_combinations:\n",
    "    random_full_combinations.append([])\n",
    "    for d XXXX-20c:\n",
    "        random_full_combinations[-1].append(d)\n",
    "        random_full_combinations[-1].append(1-d)\n",
    "formatted_combinations = []\n",
    "for r XXXX-20random_full_combinations:\n",
    "    formatted_combinations.append(str(int(\"\".join([str(i) for i XXXX-20r]),2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "command_to_run = \"python train_cbm.py -dataset {} -epochs {} -num_attributes {} --encoder_model {} -num_classes 2 -seed {} --concept_restriction {}\".format(dataset_name,50,num_objects*2,encoder_model,seed,\" \".join(formatted_combinations))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files to delete are ['0f361857']\n",
      "Namespace(adversarial_epsilon=0.01, adversarial_weight=0.25, attr_loss_weight=1.0, batch_size=32, bottleneck=False, ckpt='0', concept_restriction=[106, 150, 102, 105, 169, 153, 165, 149], connect_CY=False, data_dir='../../../datasets/synthetic_object/synthetic_4/preprocessed', dataset='cub', encoder_model='small7', end2end=True, epochs=50, exp='Joint', expand_dim=0, expand_dim_encoder=0, experiment_name='CUB', freeze=False, image_dir='images', load_model='none', log_dir='../models/synthetic_object/synthetic_4/c2ffe863/joint', lr=0.05, mask_loss_weight=1.0, n_attributes=8, n_class_attr=2, no_img=False, normalize_loss=True, num_classes=2, num_middle_encoder=0, one_batch=False, optimizer='sgd', pretrained=False, resampling=False, save_step=1000, scale_factor=1.5, scale_lr=5, scheduler='none', scheduler_step=30, seed=42, three_class=False, train_addition='', train_variation='none', uncertain_labels=False, use_attr=True, use_aux=True, use_relu=False, use_residual=False, use_sigmoid=True, use_unknown=False, weight_decay=0.0004, weighted_loss='multiple')\n",
      "[1, 1, 1, 1, 1, 1, 1, 1]\n",
      "Stop epoch:  60\n",
      "train data path: ../../../datasets/synthetic_object/synthetic_4/preprocessed/train.pkl\n",
      "Concepts to binary [[0, 1, 1, 0, 1, 0, 1, 0], [1, 0, 0, 1, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 1, 0], [0, 1, 1, 0, 1, 0, 0, 1], [1, 0, 1, 0, 1, 0, 0, 1], [1, 0, 0, 1, 1, 0, 0, 1], [1, 0, 1, 0, 0, 1, 0, 1], [1, 0, 0, 1, 0, 1, 0, 1]]\n",
      "Dataset length is 530\n",
      "On epoch 0\n",
      "New model best model at epoch 0\n",
      "Epoch [0]:\tTrain loss: 0.9532\tTrain accuracy: 75.1953\tTrain concept accuracy: 59.8633\tTrain concept auc: 0.6584\tVal loss: 0.6872\tVal acc: 66.9922\tVal concept acc: 50.8789\tVal concept auc: 0.5598\tBest val epoch: 0\n",
      "Current lr: [0.05]\n",
      "On epoch 1\n",
      "Epoch [1]:\tTrain loss: 0.9302\tTrain accuracy: 74.6094\tTrain concept accuracy: 60.3027\tTrain concept auc: 0.6887\tVal loss: 0.6960\tVal acc: 66.9922\tVal concept acc: 50.8789\tVal concept auc: 0.5677\tBest val epoch: 0\n",
      "On epoch 2\n",
      "Epoch [2]:\tTrain loss: 0.9173\tTrain accuracy: 75.0000\tTrain concept accuracy: 60.3760\tTrain concept auc: 0.6926\tVal loss: 0.7081\tVal acc: 66.9922\tVal concept acc: 50.4395\tVal concept auc: 0.5675\tBest val epoch: 0\n",
      "On epoch 3\n",
      "Epoch [3]:\tTrain loss: 0.9159\tTrain accuracy: 75.1953\tTrain concept accuracy: 59.7900\tTrain concept auc: 0.6861\tVal loss: 0.7124\tVal acc: 66.9922\tVal concept acc: 53.2471\tVal concept auc: 0.5655\tBest val epoch: 0\n",
      "On epoch 4\n",
      "Epoch [4]:\tTrain loss: 0.9126\tTrain accuracy: 75.5859\tTrain concept accuracy: 60.9863\tTrain concept auc: 0.6923\tVal loss: 0.7140\tVal acc: 66.9922\tVal concept acc: 50.8789\tVal concept auc: 0.5666\tBest val epoch: 0\n",
      "On epoch 5\n",
      "Epoch [5]:\tTrain loss: 0.9126\tTrain accuracy: 74.4141\tTrain concept accuracy: 60.5957\tTrain concept auc: 0.6949\tVal loss: 0.7090\tVal acc: 66.9922\tVal concept acc: 50.8789\tVal concept auc: 0.5659\tBest val epoch: 0\n",
      "On epoch 6\n",
      "Epoch [6]:\tTrain loss: 0.9127\tTrain accuracy: 74.4141\tTrain concept accuracy: 61.4014\tTrain concept auc: 0.6930\tVal loss: 0.7104\tVal acc: 66.9922\tVal concept acc: 56.3477\tVal concept auc: 0.5744\tBest val epoch: 0\n",
      "On epoch 7\n",
      "Epoch [7]:\tTrain loss: 0.9107\tTrain accuracy: 74.4141\tTrain concept accuracy: 65.6738\tTrain concept auc: 0.7088\tVal loss: 0.7087\tVal acc: 66.9922\tVal concept acc: 58.7402\tVal concept auc: 0.5800\tBest val epoch: 0\n",
      "On epoch 8\n",
      "Epoch [8]:\tTrain loss: 0.9061\tTrain accuracy: 75.0000\tTrain concept accuracy: 67.4072\tTrain concept auc: 0.7244\tVal loss: 0.7089\tVal acc: 66.9922\tVal concept acc: 59.6191\tVal concept auc: 0.6021\tBest val epoch: 0\n",
      "On epoch 9\n",
      "Epoch [9]:\tTrain loss: 0.8367\tTrain accuracy: 74.8047\tTrain concept accuracy: 69.1162\tTrain concept auc: 0.7611\tVal loss: 0.4848\tVal acc: 66.9922\tVal concept acc: 83.4961\tVal concept auc: 0.9036\tBest val epoch: 0\n",
      "On epoch 10\n",
      "Epoch [10]:\tTrain loss: 0.1221\tTrain accuracy: 74.4141\tTrain concept accuracy: 99.6094\tTrain concept auc: 0.9997\tVal loss: 2.7445\tVal acc: 66.9922\tVal concept acc: 91.1377\tVal concept auc: 0.9462\tBest val epoch: 0\n",
      "Current lr: [0.05]\n",
      "On epoch 11\n",
      "New model best model at epoch 11\n",
      "Epoch [11]:\tTrain loss: 0.0604\tTrain accuracy: 75.3906\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 4.5068\tVal acc: 73.6328\tVal concept acc: 91.1377\tVal concept auc: 0.9429\tBest val epoch: 11\n",
      "On epoch 12\n",
      "Epoch [12]:\tTrain loss: 0.0554\tTrain accuracy: 86.5234\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 4.7906\tVal acc: 73.6328\tVal concept acc: 91.1377\tVal concept auc: 0.9429\tBest val epoch: 11\n",
      "On epoch 13\n",
      "Epoch [13]:\tTrain loss: 0.0500\tTrain accuracy: 86.9141\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 4.7459\tVal acc: 73.6328\tVal concept acc: 91.1377\tVal concept auc: 0.9429\tBest val epoch: 11\n",
      "On epoch 14\n",
      "New model best model at epoch 14\n",
      "Epoch [14]:\tTrain loss: 0.0469\tTrain accuracy: 91.7969\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 4.6414\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9429\tBest val epoch: 14\n",
      "On epoch 15\n",
      "Epoch [15]:\tTrain loss: 0.0438\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 4.5289\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9429\tBest val epoch: 14\n",
      "On epoch 16\n",
      "Epoch [16]:\tTrain loss: 0.0413\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 4.4173\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9429\tBest val epoch: 14\n",
      "On epoch 17\n",
      "Epoch [17]:\tTrain loss: 0.0395\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 4.3073\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9509\tBest val epoch: 14\n",
      "On epoch 18\n",
      "Epoch [18]:\tTrain loss: 0.0375\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 4.2006\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9509\tBest val epoch: 14\n",
      "On epoch 19\n",
      "Epoch [19]:\tTrain loss: 0.0356\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 4.0967\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9509\tBest val epoch: 14\n",
      "On epoch 20\n",
      "Epoch [20]:\tTrain loss: 0.0339\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.9953\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9509\tBest val epoch: 14\n",
      "Current lr: [0.05]\n",
      "On epoch 21\n",
      "Epoch [21]:\tTrain loss: 0.0323\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.8967\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9509\tBest val epoch: 14\n",
      "On epoch 22\n",
      "Epoch [22]:\tTrain loss: 0.0314\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.8004\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9509\tBest val epoch: 14\n",
      "On epoch 23\n",
      "Epoch [23]:\tTrain loss: 0.0296\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.7065\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9509\tBest val epoch: 14\n",
      "On epoch XXXX-45\n",
      "Epoch [XXXX-45]:\tTrain loss: 0.0284\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.6149\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9509\tBest val epoch: 14\n",
      "On epoch 25\n",
      "Epoch [25]:\tTrain loss: 0.0270\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.5260\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9509\tBest val epoch: 14\n",
      "On epoch 26\n",
      "Epoch [26]:\tTrain loss: 0.0264\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.4389\tVal acc: 80.0781\tVal concept acc: 91.1377\tVal concept auc: 0.9509\tBest val epoch: 14\n",
      "On epoch 27\n",
      "New model best model at epoch 27\n",
      "Epoch [27]:\tTrain loss: 0.0256\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.3542\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9509\tBest val epoch: 27\n",
      "On epoch 28\n",
      "Epoch [28]:\tTrain loss: 0.0247\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.2716\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9502\tBest val epoch: 27\n",
      "On epoch 29\n",
      "Epoch [29]:\tTrain loss: 0.0236\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.1912\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9502\tBest val epoch: 27\n",
      "On epoch 30\n",
      "Epoch [30]:\tTrain loss: 0.0228\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.1128\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9502\tBest val epoch: 27\n",
      "Current lr: [0.05]\n",
      "On epoch 31\n",
      "Epoch [31]:\tTrain loss: 0.0224\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 3.0362\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9502\tBest val epoch: 27\n",
      "On epoch 32\n",
      "Epoch [32]:\tTrain loss: 0.0216\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.9616\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9502\tBest val epoch: 27\n",
      "On epoch 33\n",
      "Epoch [33]:\tTrain loss: 0.0210\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.8890\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9502\tBest val epoch: 27\n",
      "On epoch 34\n",
      "Epoch [34]:\tTrain loss: 0.0206\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.8180\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 35\n",
      "Epoch [35]:\tTrain loss: 0.0200\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.7490\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 36\n",
      "Epoch [36]:\tTrain loss: 0.0195\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.6816\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 37\n",
      "Epoch [37]:\tTrain loss: 0.0190\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.6160\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 38\n",
      "Epoch [38]:\tTrain loss: 0.0186\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.5520\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 39\n",
      "Epoch [39]:\tTrain loss: 0.0180\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.4895\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 40\n",
      "Epoch [40]:\tTrain loss: 0.0177\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.4287\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "Current lr: [0.05]\n",
      "On epoch 41\n",
      "Epoch [41]:\tTrain loss: 0.0171\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.3694\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 42\n",
      "Epoch [42]:\tTrain loss: 0.0168\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.3116\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 43\n",
      "Epoch [43]:\tTrain loss: 0.0166\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.2553\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 44\n",
      "Epoch [44]:\tTrain loss: 0.0162\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.2003\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 45\n",
      "Epoch [45]:\tTrain loss: 0.0160\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.1467\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 46\n",
      "Epoch [46]:\tTrain loss: 0.0154\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.0944\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 47\n",
      "Epoch [47]:\tTrain loss: 0.0152\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 2.0435\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 48\n",
      "Epoch [48]:\tTrain loss: 0.0148\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 1.9938\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "On epoch 49\n",
      "Epoch [49]:\tTrain loss: 0.0144\tTrain accuracy: 100.0000\tTrain concept accuracy: 100.0000\tTrain concept auc: 1.0000\tVal loss: 1.9455\tVal acc: 87.6953\tVal concept acc: 91.1377\tVal concept auc: 0.9495\tBest val epoch: 27\n",
      "Saving the model again to ../models/synthetic_object/synthetic_4/c2ffe863/joint!\n",
      "wandb: Currently logged XXXX-20as: navr414. Use `wandb login --relogin` to force relogin\n",
      "wandb: wandb version 0.18.6 is available!  To upgrade, please run:\n",
      "wandb:  $ pip install wandb --upgrade\n",
      "wandb: Tracking run with wandb version 0.13.5\n",
      "wandb: Run data is saved locally XXXX-42 /usr0/home/naveenr/projects/spurious_concepts/ConceptBottleneck/wandb/run-20241112_193139-31i132do\n",
      "wandb: Run `wandb offline` to turn off syncing.\n",
      "wandb: Syncing run toasty-fire-307\n",
      "wandb: ⭐️ View project at XXXX\n",
      "wandb: 🚀 View run at XXXX\n",
      "/usr0/home/naveenr/miniconda3/envs/concepts_spurious/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions XXXX-20buffer. For large datasets this may lead to large memory footprint.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/usr0/home/naveenr/miniconda3/envs/concepts_spurious/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:163: UserWarning: The epoch parameter XXXX-42 `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: XXXX.\n",
      "  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)\n",
      "/usr0/home/naveenr/miniconda3/envs/concepts_spurious/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:382: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.\n",
      "  \"please use `get_last_lr()`.\", UserWarning)\n",
      "wandb: Waiting for W&B process to finish... (success).\n",
      "wandb: \n",
      "wandb: Run history:\n",
      "wandb: lr ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
      "wandb: \n",
      "wandb: Run summary:\n",
      "wandb: lr 0.05\n",
      "wandb: \n",
      "wandb: Synced toasty-fire-307: XXXX\n",
      "wandb: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)\n",
      "wandb: Find logs at: ./wandb/run-20241112_193139-31i132do/logs\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CompletedProcess(args='cd ../../ConceptBottleneck && python train_cbm.py -dataset synthetic_object/synthetic_4 -epochs 50 -num_attributes 8 --encoder_model small7 -num_classes 2 -seed 42 --concept_restriction 106 150 102 105 169 153 165 149', returncode=0)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subprocess.run(\"cd ../../ConceptBottleneck && {}\".format(command_to_run),shell=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_most_recent_file(directory):\n",
    "    files = [os.path.join(directory, f) for f XXXX-20os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]\n",
    "    if not files:\n",
    "        return None\n",
    "\n",
    "    most_recent_file = max(files, key=os.path.getmtime)\n",
    "    \n",
    "    return most_recent_file\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "most_recent_data = get_most_recent_file(\"../../models/model_data/\")\n",
    "rand_name = most_recent_data.split(\"/\")[-1].replace(\".json\",\"\")\n",
    "results_file = \"../../results/correlation/{}.json\".format(rand_name)\n",
    "delete_same_dict(parameters,\"../../results/correlation\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_location = \"../../models/synthetic_object/synthetic_{}/{}/joint/best_model_{}.pth\".format(num_objects,rand_name,seed)\n",
    "joint_model = torch.load(joint_location,map_location='cpu')\n",
    "\n",
    "if 'encoder_model' XXXX-20parameters and 'mlp' XXXX-20parameters['encoder_model']:\n",
    "    joint_model.encoder_model = True\n",
    "\n",
    "r = joint_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": XXXX-45,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = joint_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc =  get_accuracy(joint_model,run_joint_model,train_loader)\n",
    "val_acc = get_accuracy(joint_model,run_joint_model,val_loader)\n",
    "test_acc =get_accuracy(joint_model,run_joint_model,test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "activation_values = []\n",
    "run_model_function = run_joint_model\n",
    "for concept_num XXXX-20range(num_objects*2):\n",
    "    val_for_concept = 0\n",
    "    trials = 5\n",
    "\n",
    "    for _ XXXX-20range(trials):\n",
    "        data_point = random.randint(0,len(test_images)-1)\n",
    "        input_image = deepcopy(test_images[data_point:data_point+1])\n",
    "        current_concept_val = test_c[data_point][concept_num]\n",
    "\n",
    "        ret_image = get_maximal_activation(joint_model,run_model_function,concept_num,\n",
    "                                        get_valid_image_function(concept_num,num_objects,epsilon=32),fixed_image=input_image,current_concept_val=current_concept_val).to(device)\n",
    "        predicted_concept = torch.nn.Sigmoid()(run_model_function(joint_model,ret_image)[1].detach().cpu())[concept_num][0].detach().numpy()\n",
    "        \n",
    "        val_for_concept += abs(predicted_concept-current_concept_val.detach().numpy())/trials \n",
    "    activation_values.append(val_for_concept)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_distro = 0\n",
    "correct_in_distro = 0 \n",
    "\n",
    "out_distro = 0\n",
    "correct_out_distro = 0\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "with torch.no_grad():  # Use torch.no_grad() to disable gradient computation\n",
    "\n",
    "    for data XXXX-20test_loader:\n",
    "        x, y, c = data\n",
    "        y_pred, c_pred = run_joint_model(joint_model, x.to(device))\n",
    "        c_pred = torch.stack([i.detach() for i XXXX-20c_pred])\n",
    "        c_pred = torch.nn.Sigmoid()(c_pred)\n",
    "\n",
    "        c_pred = c_pred.numpy().T\n",
    "        y_pred = logits_to_index(y_pred.detach())\n",
    "\n",
    "        c = torch.stack([i.detach() for i XXXX-20c]).numpy().T\n",
    "\n",
    "        in_distribution = []\n",
    "\n",
    "        for i XXXX-20range(len(c)):\n",
    "            binary_c = c[i]\n",
    "            combo = str(int(\"\".join([str(i) for i XXXX-20binary_c]),2))\n",
    "\n",
    "            if combo XXXX-20formatted_combinations:\n",
    "                in_distribution.append(True)\n",
    "            else:\n",
    "                in_distribution.append(False)\n",
    "        \n",
    "        in_distro += in_distribution.count(True) * len(c[0])\n",
    "        out_distro += in_distribution.count(False) * len(c[0])\n",
    "\n",
    "        in_distribution = np.array(in_distribution)\n",
    "\n",
    "        correct_in_distro += np.sum(np.clip(np.round(c_pred[in_distribution]),0,1) == c[in_distribution]).item() \n",
    "        correct_out_distro += np.sum(np.clip(np.round(c_pred[~in_distribution]),0,1) == c[~in_distribution]).item() \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "concept_accuracies = []\n",
    "\n",
    "# Try and flip each concept\n",
    "for concept_num XXXX-20range(num_objects*2):\n",
    "# Set this concept_num to 1 (which sets the corresponding thing to 0)\n",
    "    total_flipped = 0\n",
    "    total_points = 0\n",
    "\n",
    "    with torch.no_grad():  # Use torch.no_grad() to disable gradient computation\n",
    "\n",
    "        for data XXXX-20test_loader:\n",
    "            x, y, c = data\n",
    "            y_pred, c_pred = run_joint_model(joint_model, x.to(device))\n",
    "            c_pred = torch.stack([i.detach() for i XXXX-20c_pred]).numpy().T\n",
    "            y_pred = logits_to_index(y_pred.detach())\n",
    "\n",
    "            c = torch.stack([i.detach() for i XXXX-20c]).numpy().T\n",
    "\n",
    "            in_distribution = []\n",
    "\n",
    "            for i XXXX-20range(len(c)):\n",
    "                # Just look for errors where binary_c = 1 XXXX-20prediction\n",
    "\n",
    "                binary_c = c[i]\n",
    "\n",
    "                if binary_c[concept_num] == 1:\n",
    "                    in_distribution.append(True)\n",
    "                else:\n",
    "                    in_distribution.append(False)\n",
    "            \n",
    "            in_distribution = np.array(in_distribution)\n",
    "            total_points += np.sum(in_distribution) \n",
    "            total_flipped += np.sum(np.clip(np.round(c_pred[in_distribution,concept_num]),0,1) == c[in_distribution,concept_num]) \n",
    "            \n",
    "    concept_accuracies.append(total_flipped/total_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_data = {\n",
    "    'train_accuracy': train_acc, \n",
    "    'val_accuracy': val_acc, \n",
    "    'test_accuracy': test_acc, \n",
    "    'in_distro': correct_in_distro/in_distro, \n",
    "    'num_in_distro': in_distro, \n",
    "    'out_distro': correct_out_distro/out_distro, \n",
    "    'num_out_distro': out_distro, \n",
    "    'concept_accuracies': concept_accuracies,\n",
    "    'combinations': formatted_combinations,\n",
    "    'parameters': parameters,  \n",
    "    'adversarial_activations': activation_values\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train_accuracy': 0.8935546875,\n",
       " 'val_accuracy': 0.876953125,\n",
       " 'test_accuracy': 0.8515625,\n",
       " 'in_distro': 0.982976653696498,\n",
       " 'num_in_distro': 2056,\n",
       " 'out_distro': 0.8117647058823529,\n",
       " 'num_out_distro': 2040,\n",
       " 'concept_accuracies': [0.8537549407114624,\n",
       "  0.8803088803088803,\n",
       "  0.9964664310954063,\n",
       "  0.6899563318777293,\n",
       "  0.9958333333333333,\n",
       "  0.9852941176470589,\n",
       "  1.0,\n",
       "  0.9814814814814815],\n",
       " 'combinations': ['106', '150', '102', '105', '169', '153', '165', '149'],\n",
       " 'parameters': {'seed': 42,\n",
       "  'encoder_model': 'small7',\n",
       "  'debugging': False,\n",
       "  'dataset_name': 'synthetic_object/synthetic_4',\n",
       "  'num_concept_combinations': 8},\n",
       " 'adversarial_activations': [1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  0.9999999999740239,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0]}"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(final_data,open(results_file,\"w\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concepts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
