{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/usr0/home/naveenr/projects/spurious_concepts/ConceptBottleneck/')\n",
    "sys.path.append('/usr0/home/naveenr/projects/spurious_concepts')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "from captum.attr import visualization as viz\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import cv2\n",
    "from copy import copy \n",
    "import itertools\n",
    "import json\n",
    "import argparse \n",
    "import secrets\n",
    "import subprocess\n",
    "import shutil \n",
    "from torch.nn.utils import prune\n",
    "import resource "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.images import *\n",
    "from src.util import *\n",
    "from src.models import *\n",
    "from src.plot import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_num_threads(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'seed': 43, 'encoder_model': 'inceptionv3', 'debugging': False, 'dataset_name': 'CUB', 'num_concept_combinations': 200}\n"
     ]
    }
   ],
   "source": [
    "is_jupyter = 'ipykernel' XXXX-20sys.modules\n",
    "if is_jupyter:\n",
    "    encoder_model='inceptionv3'\n",
    "    seed = 43\n",
    "    dataset_name = 'CUB'\n",
    "    num_concept_combinations = 200\n",
    "else:\n",
    "    parser = argparse.ArgumentParser(description=\"Synthetic Dataset Experiments\")\n",
    "\n",
    "\n",
    "    parser.add_argument('--encoder_model', type=str, default='inceptionv3', help='Encoder model')\n",
    "    parser.add_argument('--seed', type=int, default=42, help='Random seed')\n",
    "    parser.add_argument('--dataset_name', type=str, default='CUB', help='Number of concept combinations')\n",
    "    parser.add_argument('--num_concept_combinations', type=int, default=100, help='Random seed')\n",
    "\n",
    "    args = parser.parse_args()\n",
    "    encoder_model = args.encoder_model \n",
    "    seed = args.seed \n",
    "    dataset_name = args.dataset_name \n",
    "    num_concept_combinations = args.num_concept_combinations\n",
    "\n",
    "parameters = {\n",
    "    'seed': seed, \n",
    "    'encoder_model': encoder_model ,\n",
    "    'debugging': False,\n",
    "    'dataset_name': dataset_name,\n",
    "    'num_concept_combinations': num_concept_combinations,\n",
    "}\n",
    "print(parameters)\n",
    "torch.cuda.set_per_process_memory_fraction(0.5)\n",
    "resource.setrlimit(resource.RLIMIT_AS, (20 * 1024 * 1024 * 1024, -1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, test_loader, train_pkl, val_pkl, test_pkl = get_data(1,encoder_model=encoder_model,dataset_name=dataset_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_random_combinations(K):\n",
    "    all_combos = list(set([str(i['attribute_label']) for i XXXX-20train_pkl]))\n",
    "    random.seed(seed)    \n",
    "    # Generate all possible combinations\n",
    "    random.shuffle(all_combos)\n",
    "    all_combos = [eval(i) for i XXXX-20all_combos]\n",
    "\n",
    "    return all_combos[:K]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_combinations = generate_random_combinations(num_concept_combinations)\n",
    "formatted_combinations = []\n",
    "for r XXXX-20random_combinations:\n",
    "    formatted_combinations.append(str(int(\"\".join([str(i) for i XXXX-20r]),2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name == \"CUB\":\n",
    "    epochs = 100\n",
    "    command_to_run = \"python train_cbm.py -dataset CUB --encoder_model inceptionv3 --pretrained -epochs {} -num_attributes 112 -num_classes 200 -seed {} --attr_loss_weight 0.01 --optimizer adam --scheduler_step 100 -lr 0.005 --concept_restriction {}\".format(epochs,seed,\" \".join(formatted_combinations))\n",
    "elif dataset_name == \"coco\":\n",
    "    epochs = 25\n",
    "    command_to_run = \"python train_cbm.py -dataset coco --encoder_model inceptionv3 --pretrained -epochs {} -num_attributes 10 -num_classes 2 -seed {} --attr_loss_weight 0.1 --optimizer adam --scheduler_step 100 -lr 0.005 --concept_restriction {}\".format(epochs,seed,\" \".join(formatted_combinations))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files to delete are []\n",
      "Namespace(adversarial_epsilon=0.01, adversarial_weight=0.25, attr_loss_weight=0.01, batch_size=32, bottleneck=False, ckpt='0', concept_restriction=[489434840942257862547355994169602, 20604902816464749788475737524264, 82417754025159411000739688890512, 20760269268464950908855439884288, 30904563943319113684455092652065, 731454243580909305755854349289601, 10460603154798514420549683513376, 2678707308143309534777359169882152, 407245753770509458926955160703040, 324518554867355428701590198976512, 649037726286876266790630683855368, 20604902816464173336528117057576, 669642011342243446517486512818216, 515082171979654095527797410766882, 489352823749794443838504306753796, 824016296030621995351173006106880, 406936308578464778494410333636609, 171212139289391367582460917385472, 669642010124481654523558334056520, 813951844719847969791594442262530, 669638296303146277227239087294504, 2760953143537429435283906260111362, 345118496100528533257654219325448, 20604902816464173336563013666849, 162260514920371258740860000405508, 649040207163522669616120931174416, 2305851805344497664, 190545365465799077102828663377952, 494467373811032338702970650626306, 1390774303232134339132778907566096, 818983453251511115459263129989378, 514748558092817067674346417366306], connect_CY=False, data_dir='../../../datasets/CUB/preprocessed', dataset='cub', encoder_model='inceptionv3', end2end=True, epochs=10, exp='Joint', expand_dim=0, expand_dim_encoder=0, experiment_name='CUB', freeze=False, image_dir='images', load_model='none', log_dir='../models/CUB/99c03dc5/joint', lr=0.005, mask_loss_weight=1.0, n_attributes=112, n_class_attr=2, no_img=False, normalize_loss=True, num_classes=200, num_middle_encoder=0, one_batch=False, optimizer='adam', pretrained=True, resampling=False, save_step=1000, scale_factor=1.5, scale_lr=5, scheduler='none', scheduler_step=100, seed=42, three_class=False, train_addition='', train_variation='none', uncertain_labels=False, use_attr=True, use_aux=True, use_relu=False, use_residual=False, use_sigmoid=True, use_unknown=False, weight_decay=0.0004, weighted_loss='multiple')\n",
      "[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "Stop epoch:  100\n",
      "train data path: ../../../datasets/CUB/preprocessed/train.pkl\n",
      "Concepts to binary [[0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0], [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0], [0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]]\n",
      "Dataset length is 769\n",
      "On epoch 0\n",
      "New model best model at epoch 0\n",
      "Epoch [0]:\tTrain loss: 4.2440\tTrain accuracy: 1.9531\tTrain concept accuracy: 47.6969\tTrain concept auc: 0.4893\tVal loss: 3.0294\tVal acc: 0.6678\tVal concept acc: 49.9322\tVal concept auc: 0.5091\tBest val epoch: 0\n",
      "Current lr: [0.005]\n",
      "On epoch 1\n",
      "New model best model at epoch 1\n",
      "Epoch [1]:\tTrain loss: 3.9752\tTrain accuracy: 4.9479\tTrain concept accuracy: 50.5697\tTrain concept auc: 0.5143\tVal loss: 2.9615\tVal acc: 1.0851\tVal concept acc: 53.9553\tVal concept auc: 0.5413\tBest val epoch: 1\n",
      "On epoch 2\n",
      "New model best model at epoch 2\n",
      "Epoch [2]:\tTrain loss: 3.7350\tTrain accuracy: 6.2500\tTrain concept accuracy: 53.9074\tTrain concept auc: 0.5464\tVal loss: 2.9396\tVal acc: 1.3356\tVal concept acc: 57.3389\tVal concept auc: 0.5717\tBest val epoch: 2\n",
      "On epoch 3\n",
      "New model best model at epoch 3\n",
      "Epoch [3]:\tTrain loss: 3.5229\tTrain accuracy: 12.3698\tTrain concept accuracy: 56.5162\tTrain concept auc: 0.5790\tVal loss: 2.9319\tVal acc: 2.0868\tVal concept acc: 60.3394\tVal concept auc: 0.6061\tBest val epoch: 3\n",
      "On epoch 4\n",
      "New model best model at epoch 4\n",
      "Epoch [4]:\tTrain loss: 3.3508\tTrain accuracy: 18.3594\tTrain concept accuracy: 59.0169\tTrain concept auc: 0.6137\tVal loss: 2.9413\tVal acc: 3.0885\tVal concept acc: 62.0118\tVal concept auc: 0.6341\tBest val epoch: 4\n",
      "On epoch 5\n",
      "New model best model at epoch 5\n",
      "Epoch [5]:\tTrain loss: 3.1930\tTrain accuracy: 26.9531\tTrain concept accuracy: 59.8993\tTrain concept auc: 0.6398\tVal loss: 2.9618\tVal acc: 4.1736\tVal concept acc: 62.9584\tVal concept auc: 0.6561\tBest val epoch: 5\n",
      "On epoch 6\n",
      "New model best model at epoch 6\n",
      "Epoch [6]:\tTrain loss: 3.0555\tTrain accuracy: 41.5365\tTrain concept accuracy: 61.5572\tTrain concept auc: 0.6653\tVal loss: 2.9896\tVal acc: 6.1770\tVal concept acc: 63.7044\tVal concept auc: 0.6715\tBest val epoch: 6\n",
      "On epoch 7\n",
      "New model best model at epoch 7\n",
      "Epoch [7]:\tTrain loss: 2.9361\tTrain accuracy: 52.3438\tTrain concept accuracy: 61.9141\tTrain concept auc: 0.6799\tVal loss: 3.0189\tVal acc: 7.0952\tVal concept acc: 64.2186\tVal concept auc: 0.6852\tBest val epoch: 7\n",
      "On epoch 8\n",
      "New model best model at epoch 8\n",
      "Epoch [8]:\tTrain loss: 2.8325\tTrain accuracy: 60.2865\tTrain concept accuracy: 62.7825\tTrain concept auc: 0.7003\tVal loss: 3.0552\tVal acc: 7.9299\tVal concept acc: 64.3237\tVal concept auc: 0.6956\tBest val epoch: 8\n",
      "On epoch 9\n",
      "New model best model at epoch 9\n",
      "Epoch [9]:\tTrain loss: 2.7313\tTrain accuracy: 71.7448\tTrain concept accuracy: 63.4138\tTrain concept auc: 0.7131\tVal loss: 3.0831\tVal acc: 9.2654\tVal concept acc: 64.6293\tVal concept auc: 0.7031\tBest val epoch: 9\n",
      "Saving the model again to ../models/CUB/99c03dc5/joint!\n",
      "wandb: Currently logged XXXX-20as: navr414. Use `wandb login --relogin` to force relogin\n",
      "wandb: wandb version 0.19.1 is available!  To upgrade, please run:\n",
      "wandb:  $ pip install wandb --upgrade\n",
      "wandb: Tracking run with wandb version 0.13.5\n",
      "wandb: Run data is saved locally XXXX-42 /usr0/home/naveenr/projects/spurious_concepts/ConceptBottleneck/wandb/run-20241226_182703-1v8ivhpy\n",
      "wandb: Run `wandb offline` to turn off syncing.\n",
      "wandb: Syncing run playful-yogurt-68\n",
      "wandb: ⭐️ View project at XXXX\n",
      "wandb: 🚀 View run at XXXX\n",
      "/usr0/home/naveenr/miniconda3/envs/concepts_spurious/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions XXXX-20buffer. For large datasets this may lead to large memory footprint.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/usr0/home/naveenr/miniconda3/envs/concepts_spurious/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:163: UserWarning: The epoch parameter XXXX-42 `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: XXXX.\n",
      "  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)\n",
      "/usr0/home/naveenr/miniconda3/envs/concepts_spurious/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:382: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.\n",
      "  \"please use `get_last_lr()`.\", UserWarning)\n",
      "wandb: Waiting for W&B process to finish... (success).\n",
      "wandb: - 0.005 MB of 0.005 MB uploaded (0.000 MB deduped)\n",
      "wandb: \\ 0.005 MB of 0.021 MB uploaded (0.000 MB deduped)\n",
      "wandb: | 0.039 MB of 0.039 MB uploaded (0.000 MB deduped)\n",
      "wandb: \n",
      "wandb: Run history:\n",
      "wandb: lr ▁▁▁▁▁▁▁▁▁▁\n",
      "wandb: \n",
      "wandb: Run summary:\n",
      "wandb: lr 0.005\n",
      "wandb: \n",
      "wandb: Synced playful-yogurt-68: XXXX\n",
      "wandb: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)\n",
      "wandb: Find logs at: ./wandb/run-20241226_182703-1v8ivhpy/logs\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CompletedProcess(args='cd ../../ConceptBottleneck && python train_cbm.py -dataset CUB --encoder_model inceptionv3 --pretrained -epochs 10 -num_attributes 112 -num_classes 200 -seed 42 --attr_loss_weight 0.01 --optimizer adam --scheduler_step 100 -lr 0.005 --concept_restriction 489434840942257862547355994169602 20604902816464749788475737524264 82417754025159411000739688890512 20760269268464950908855439884288 30904563943319113684455092652065 731454243580909305755854349289601 10460603154798514420549683513376 2678707308143309534777359169882152 407245753770509458926955160703040 324518554867355428701590198976512 649037726286876266790630683855368 20604902816464173336528117057576 669642011342243446517486512818216 515082171979654095527797410766882 489352823749794443838504306753796 824016296030621995351173006106880 406936308578464778494410333636609 171212139289391367582460917385472 669642010124481654523558334056520 813951844719847969791594442262530 669638296303146277227239087294504 2760953143537429435283906260111362 345118496100528533257654219325448 20604902816464173336563013666849 162260514920371258740860000405508 649040207163522669616120931174416 2305851805344497664 190545365465799077102828663377952 494467373811032338702970650626306 1390774303232134339132778907566096 818983453251511115459263129989378 514748558092817067674346417366306', returncode=0)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subprocess.run(\"cd ../../ConceptBottleneck && {}\".format(command_to_run),shell=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_most_recent_file(directory):\n",
    "    files = [os.path.join(directory, f) for f XXXX-20os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]\n",
    "    if not files:\n",
    "        return None\n",
    "\n",
    "    most_recent_file = max(files, key=os.path.getmtime)\n",
    "    \n",
    "    return most_recent_file\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "most_recent_data = get_most_recent_file(\"../../models/model_data/\")\n",
    "rand_name = most_recent_data.split(\"/\")[-1].replace(\".json\",\"\")\n",
    "results_file = \"../../results/correlation/{}.json\".format(rand_name)\n",
    "delete_same_dict(parameters,\"../../results/correlation\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "42"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "rand_name = get_name_matching_parameters({'seed': seed, 'dataset': dataset_name})#, 'concept_restrictions': str([int(i) for i XXXX-20formatted_combinations])})\n",
    "temp = [json.load(open(\"../../models/model_data/{}.json\".format(i))) for i XXXX-20rand_name]\n",
    "rand_name = [rand_name[i] for i XXXX-20range(len(rand_name)) if 'concept_restriction' XXXX-20temp[i] and len(temp[i]['concept_restriction']) == num_concept_combinations]\n",
    "rand_name = rand_name[-1]\n",
    "results_file = \"../../results/correlation/{}.json\".format(rand_name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_location = \"../../models/{}/{}/joint/best_model_{}.pth\".format(dataset_name,rand_name,seed)\n",
    "joint_model = torch.load(joint_location,map_location='cpu')\n",
    "\n",
    "if 'encoder_model' XXXX-20parameters and 'mlp' XXXX-20parameters['encoder_model']:\n",
    "    joint_model.encoder_model = True\n",
    "\n",
    "r = joint_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = joint_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "concept_acc = get_concept_accuracy_by_concept(joint_model,run_joint_model,test_loader)\n",
    "locality_intervention = 1-torch.mean(concept_acc).detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_data = {\n",
    "    'parameters': parameters,  \n",
    "    'locality_intervention': locality_intervention,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'parameters': {'seed': 43,\n",
       "  'encoder_model': 'inceptionv3',\n",
       "  'debugging': False,\n",
       "  'dataset_name': 'CUB',\n",
       "  'num_concept_combinations': 200},\n",
       " 'locality_intervention': 0.12926393747329712}"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(final_data,open(results_file,\"w\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concepts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
