{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/home/njr61/rds/hpc-work/spurious-concepts/ConceptBottleneck')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "from captum.attr import visualization as viz\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import cv2\n",
    "from copy import copy \n",
    "import itertools\n",
    "from matplotlib.patches import Circle\n",
    "import json\n",
    "from collections import Counter \n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.metrics import accuracy_score\n",
    "import argparse\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ConceptBottleneck.CUB.models import ModelXtoC, ModelOracleCtoY\n",
    "from ConceptBottleneck.CUB.dataset import load_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.images import *\n",
    "from src.util import *\n",
    "from src.models import *\n",
    "from src.plot import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up dataset + model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# parser = argparse.ArgumentParser(description=\"Your script description here\")\n",
    "\n",
    "# # Add command-line arguments\n",
    "# parser.add_argument('--seed', type=int, default=42, help='Random seed')\n",
    "\n",
    "# # Parse the command-line arguments\n",
    "# args = parser.parse_args()\n",
    "\n",
    "# seed = args.seed\n",
    "\n",
    "seed = 42\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'dsprites_20'\n",
    "noisy=False\n",
    "weight_decay = 0.0004\n",
    "encoder_model='small3'\n",
    "optimizer = 'sgd'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'../cem/cem/dsprites_20/preprocessed/'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_name = dataset\n",
    "data_dir = \"../cem/cem/{}/preprocessed/\".format(dataset_name)\n",
    "data_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_path = os.path.join(data_dir, 'train.pkl')\n",
    "val_data_path = train_data_path.replace('train.pkl', 'val.pkl')\n",
    "extra_data_path = '../cem/cem/dsprites/preprocessed/extra.pkl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrained = True\n",
    "freeze = False\n",
    "use_aux = True\n",
    "expand_dim = 0\n",
    "three_class = False\n",
    "use_attr = True\n",
    "no_img = False\n",
    "batch_size = 64\n",
    "uncertain_labels = False\n",
    "image_dir = 'images'\n",
    "num_class_attr = 2\n",
    "resampling = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = load_data([train_data_path], use_attr, no_img, batch_size, uncertain_labels, image_dir=image_dir, \n",
    "                         n_class_attr=num_class_attr, resampling=resampling, path_transform=lambda path: \"../cem/cem/\"+path, is_training=False)\n",
    "val_loader = load_data([val_data_path], use_attr, no_img=False, batch_size=64, image_dir=image_dir, \n",
    "                        n_class_attr=num_class_attr, path_transform=lambda path: \"../cem/cem/\"+path)\n",
    "extra_loader = load_data([extra_data_path], use_attr, no_img=False, batch_size=64, image_dir=image_dir, \n",
    "                        n_class_attr=num_class_attr, path_transform=lambda path: \"../cem/cem/\"+path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ConceptBottleneck/results/dsprites_20/joint_model_small3/joint/best_model_42.pth\n"
     ]
    }
   ],
   "source": [
    "log_folder = get_log_folder(dataset_name,weight_decay,encoder_model,optimizer)\n",
    "joint_location = \"ConceptBottleneck/{}/best_model_{}.pth\".format(log_folder,seed)\n",
    "print(joint_location)\n",
    "joint_model = torch.load(joint_location,map_location=torch.device('cpu'))\n",
    "r = joint_model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pkl = pickle.load(open(train_data_path,\"rb\"))\n",
    "val_pkl = pickle.load(open(val_data_path,\"rb\"))\n",
    "extra_pkl = pickle.load(open(extra_data_path,\"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAFTElEQVR4nO3dO27DMBBAQTHQ/a/MdK+SEamQ9clMrYKuHhZcg2POORcAWJbl5+oDAHAfogBARAGAiAIAEQUAIgoARBQAiCgAkHXvh2OMM88BwMn2/FfZpABARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIOvVB4C7mHMe+n6McdJJ4DomBQAiCgBEFACIKAAQUQAgto94haObQ8A2kwIAEQUAIgoARBQAiItmLuWCGO7FpABARAGAiAIAEQUAIgoAxPYRX/O2TaOt3+PhHZ7OpABARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEC8vMbXbL1K9h9eY1sWL7LxHCYFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIA8cgOl/r0+MzbHt+BpzApABBRACCiAEBEAYCIAgCxfQR/+LQhBW9kUgAgogBARAGAiAIAEQUAYvuI17I1BMeZFACIKAAQUQAgogBAXDRzSy6J4RomBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFADIuvfDOeeZ5wDgBkwKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgDkF++9JRko0DwIAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAFOUlEQVR4nO3cMWrEQBAAQa3R/788Dgyd+Awy+FjduSoUCkZRM7DaNTNzAMBxHB+7BwDgPkQBgIgCABEFACIKAEQUAIgoABBRACDn1RfXWs+cA4Anu/Kvsk0BgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoA5Nw9AK9rZnaPkLXW7hHgLdgUAIgoABBRACCiAEBEAYA4ffRP3enk0F949D1OJMHv2RQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAOXcPwB5rrW/PZmbDJMCd2BQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABDXXPC2frq249EVH8AXmwIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAMTdR9yS+4lgD5sCABEFACIKAEQUAIgoABCnj4gTP4BNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgA5r744M8+cA4AbsCkAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJBPs10cFhYMqcoAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAFEElEQVR4nO3bIW7EQBBFQXfk+1+5wx42WGs2ShUe8NlTg5nd3QsAruv6OT0AgO8hCgBEFACIKAAQUQAgogBARAGAiAIAuZ8+nJk3dwDwsid/lV0KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgC5Tw+At+zu6QkfNTOnJ/APuBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAMh9egC8ZWZOT4A/x6UAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQ++nD3X1zBwBfwKUAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEB+AYv+EA56inLrAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAFNElEQVR4nO3dO27DMBBAQTHQ/a+86V4tO1Doz0xNwNs9LEDRa2bmAIDjOH52DwDA6xAFACIKAEQUAIgoABBRACCiAEBEAYCcVw+ute6cA4CbXflW2aYAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFAHLuHgB4DTNz+exa68ZJ2MmmAEBEAYCIAgARBQAiCgDE7SPgYY/cVNrFDann2BQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACDePoIv8w7vFrGPTQGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoA5Nw9AMBfrLV2j/BRbAoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIhnLuDL7HgWYmb+/Td5jk0BgIgCABEFACIKAEQUAIjbR8Dt/BHO+7ApABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAnFcPzsydcwDwAmwKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgDkF70BGRaG61npAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAFQUlEQVR4nO3cS2rDQBBAQU3Q/a/c2b2VDTZEHztVSyPQ7B7NtLVmZjYA2Lbt5+oDAHAfogBARAGAiAIAEQUAIgoARBQAiCgAkP3VB9daR54DgIO98l9lkwIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACD71QfgGjNz+jvXWqe/E3iPSQGAiAIAEQUAIgoAxEXzTV1xEQxgUgAgogBARAGAiAIAEQUAYvuI0zzaqPLpC7gXkwIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAg+9UH4LG11sPfZ+bkkwD/iUkBgIgCABEFACIKAEQUAIjtIy71bJvq2fYVcCyTAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxLeP+HO+WwSfy6QAQEQBgIgCABEFACIKAMT20ZezCQS8w6QAQEQBgIgCABEFAOKi+cO4OAaOZFIAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgOyvPjgzR54DgBswKQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAkF86SRwdFLDvCAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "num_images_show = 5\n",
    "for i XXXX-20range(num_images_show):\n",
    "    img_path = '../cem/cem/'+extra_pkl[i]['img_path']\n",
    "    image = Image.open(img_path)\n",
    "    image_array = np.array(image)\n",
    "    plt.figure()\n",
    "    plt.imshow(image_array)\n",
    "    plt.axis('off') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "extra_images, extra_y, extra_c = unroll_data(extra_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy = get_accuracy(joint_model,run_joint_model,train_loader), get_accuracy(joint_model,run_joint_model,val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "concept_accuracy = get_concept_accuracy_by_concept(joint_model,run_joint_model,train_loader,sigmoid=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Model over all Concept Combinations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, extra_predictions = run_joint_model(joint_model,extra_images)\n",
    "extra_predictions = extra_predictions.T\n",
    "extra_predictions = torch.nn.Sigmoid()(extra_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "attributes_as_matrix = np.array([i['attribute_label'] for i XXXX-20extra_pkl])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "concept_names = [\n",
    "    \"is_white\",\n",
    "    \"is_square\",\n",
    "    \"is_ellipse\",\n",
    "    \"is_heart\",\n",
    "    \"is_scale_0.5\",\n",
    "    \"is_scale_0.6\",\n",
    "    \"is_scale_0.7\",\n",
    "    \"is_scale_0.8\",\n",
    "    \"is_scale_0.9\",\n",
    "    \"is_scale_1\",\n",
    "    \"is_orientation_0\",\n",
    "    \"is_orientation_90\",\n",
    "    \"is_orientation_180\",\n",
    "    \"is_orientation_270\",\n",
    "    \"is_x_0\",\n",
    "    \"is_x_16\",\n",
    "    \"is_y_0\",\n",
    "    \"is_y_16\",\n",
    "] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_results(extra_predictions,attributes_as_matrix):\n",
    "    results = {}\n",
    "    num_concepts = extra_predictions.shape[1]\n",
    "    for i XXXX-20range(num_concepts):\n",
    "        data_points_with_concept = np.where(attributes_as_matrix[:, i] == 1)[0]\n",
    "        data_points_without_concept = np.where(attributes_as_matrix[:, i] == 0)[0]\n",
    "\n",
    "        min_with_concept = torch.min(extra_predictions[data_points_with_concept,i]).item()\n",
    "        max_with_concept = torch.max(extra_predictions[data_points_with_concept,i]).item()\n",
    "        argmin_with_concept = data_points_with_concept[torch.argmin(extra_predictions[data_points_with_concept,i]).item()].item()\n",
    "        argmax_with_concept = data_points_with_concept[torch.argmax(extra_predictions[data_points_with_concept,i]).item()].item()\n",
    "        freq_with_concept_adversarial = (len([p for p XXXX-20data_points_with_concept if extra_predictions[p, i] < 0.75]),len(data_points_with_concept))\n",
    "\n",
    "        if len(data_points_without_concept) == 0:\n",
    "            min_without_concept = -1\n",
    "            max_without_concept = -1\n",
    "\n",
    "            argmin_without_concept = -1\n",
    "            argmax_without_concept = -1\n",
    "        else:\n",
    "            min_without_concept = torch.min(extra_predictions[data_points_without_concept,i]).item()\n",
    "            max_without_concept = torch.max(extra_predictions[data_points_without_concept,i]).item()\n",
    "\n",
    "            argmin_without_concept = data_points_without_concept[torch.argmin(extra_predictions[data_points_without_concept,i]).item()].item()\n",
    "            argmax_without_concept = data_points_without_concept[torch.argmax(extra_predictions[data_points_without_concept,i]).item()].item()\n",
    "\n",
    "        results[concept_names[i]] = {\n",
    "            'min_with_concept': min_with_concept, \n",
    "            'max_with_concept': max_with_concept, \n",
    "            'min_without_concept': min_without_concept, \n",
    "            'max_without_concept': max_without_concept,\n",
    "            'freq_with_concept_adversarial': freq_with_concept_adversarial,\n",
    "            'argmin_with_concept': argmin_with_concept, \n",
    "            'argmax_with_concept': argmax_with_concept, \n",
    "            'argmin_without_concept': argmin_without_concept, \n",
    "            'argmax_without_concept': argmax_without_concept,\n",
    "        }\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_results_dataset(dataset_name):\n",
    "    log_folder = get_log_folder(dataset_name,weight_decay,encoder_model,optimizer)\n",
    "    joint_location = \"ConceptBottleneck/{}/best_model_{}.pth\".format(log_folder,seed)\n",
    "    model = torch.load(joint_location,map_location=torch.device('cpu'))\n",
    "    model.eval()\n",
    "\n",
    "    _, extra_predictions_dataset = run_joint_model(model,extra_images)\n",
    "    extra_predictions_dataset = extra_predictions_dataset.T\n",
    "    extra_predictions_dataset = torch.nn.Sigmoid()(extra_predictions_dataset)\n",
    "\n",
    "    return get_results(extra_predictions_dataset,attributes_as_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsprites_20_results = {\n",
    "    'accuracy': accuracy[0].item(), \n",
    "    'concept_accuracy': concept_accuracy.detach().numpy().tolist() \n",
    "}\n",
    "\n",
    "json.dump(dsprites_20_results,open('results/dsprites/dsprites_20_accuracy_{}.json'.format(seed),'w'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = get_results(extra_predictions,attributes_as_matrix)\n",
    "json.dump(results,open('results/dsprites/dsprites_20_results_{}.json'.format(seed),'w'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_5 = get_results_dataset('dsprites_5')\n",
    "results_10 = get_results_dataset('dsprites')\n",
    "results_15 = get_results_dataset('dsprites_15')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(results_5,open('results/dsprites/dsprites_5_results_{}.json'.format(seed),'w'))\n",
    "json.dump(results_10,open('results/dsprites/dsprites_10_results_{}.json'.format(seed),'w'))\n",
    "json.dump(results_15,open('results/dsprites/dsprites_15_results_{}.json'.format(seed),'w'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "present_concept_combo = list(set([list_to_string(i['attribute_label']) for i XXXX-20train_pkl]))\n",
    "all_predictions = [list_to_string([int(i) for i XXXX-20np.round(j.detach().numpy())]) for j XXXX-20extra_predictions]\n",
    "correct_answers = [list_to_string(i) for i XXXX-20attributes_as_matrix]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions_20 = {\n",
    "    'our_predictions': all_predictions, \n",
    "    'ground_truths': correct_answers, \n",
    "    'train_concepts': present_concept_combo, \n",
    "    'predictions_raw': extra_predictions.detach().numpy().tolist()\n",
    "}\n",
    "\n",
    "json.dump(predictions_20,open('results/dsprites/dsprites_20_predictions_{}.json'.format(seed),'w'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "images_as_numpy = extra_images.detach().numpy() \n",
    "images_reshaped = images_as_numpy.reshape(288, -1)\n",
    "pca = PCA(n_components=2)\n",
    "images_pca = pca.fit_transform(images_reshaped)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": XXXX-45,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_similarity = np.zeros((len(extra_images),len(extra_images)))\n",
    "for i XXXX-20range(len(extra_images)):\n",
    "    for j XXXX-20range(len(extra_images)):\n",
    "        image_similarity[i][j] = torch.norm(extra_images[i]-extra_images[j]).item() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "difficulty_of_prediction = []\n",
    "\n",
    "for i XXXX-20range(len(concept_names)):\n",
    "    if i == 0:\n",
    "        difficulty_of_prediction.append(1)\n",
    "        continue\n",
    "\n",
    "    X = images_pca \n",
    "    y = [int(a[i]) for a XXXX-20correct_answers]\n",
    "\n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)\n",
    "\n",
    "    svm_classifier = SVC(kernel='linear')  # You can choose different kernels as needed\n",
    "    svm_classifier.fit(X_train, y_train)\n",
    "\n",
    "    # Make predictions on the test data\n",
    "    y_pred = svm_classifier.predict(X_test)\n",
    "\n",
    "    # Calculate the accuracy of the model\n",
    "    accuracy = accuracy_score(y_test, y_pred)\n",
    "    difficulty_of_prediction.append(accuracy)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump({\n",
    "    'pca_images': images_pca.tolist(), \n",
    "    'image_similarities': image_similarity.tolist(), \n",
    "    'svm_accuracies': difficulty_of_prediction\n",
    "},open('results/dsprites/pca_images_{}.json'.format(seed),'w'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary Statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "def adversarial_rate(r):\n",
    "    return np.sum([r[i]['freq_with_concept_adversarial'][0] for i XXXX-20r])/np.sum([r[i]['freq_with_concept_adversarial'][1] for i XXXX-20r])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'dsprites_5': 0.4195601851851852,\n",
       " 'dsprites_10': 0.36747685185185186,\n",
       " 'dsprites_15': 0.3472222222222222,\n",
       " 'dsprites_20': 0.3420138888888889}"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adversarial_rates = {\n",
    "    'dsprites_5': adversarial_rate(results_5), \n",
    "    'dsprites_10': adversarial_rate(results_10), \n",
    "    'dsprites_15': adversarial_rate(results_15), \n",
    "    'dsprites_20': adversarial_rate(results), \n",
    "}\n",
    "adversarial_rates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "overall_accuracy = len([i for i XXXX-20range(len(correct_answers)) if correct_answers[i] == all_predictions[i]])/len(all_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.28125"
      ]
     },
     "execution_count": 101,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "frequency_train_predicted = len([i for i XXXX-20all_predictions if i XXXX-20present_concept_combo])/len(all_predictions)\n",
    "frequency_train_predicted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'is_white': 0.0,\n",
       " 'is_square': 0.65625,\n",
       " 'is_ellipse': 0.5729166666666666,\n",
       " 'is_heart': 0.3541666666666667,\n",
       " 'is_scale_0.5': 0.9166666666666666,\n",
       " 'is_scale_0.6': 0.8125,\n",
       " 'is_scale_0.7': 0.6875,\n",
       " 'is_scale_0.8': 0.8541666666666666,\n",
       " 'is_scale_0.9': 0.9375,\n",
       " 'is_scale_1': 0.25,\n",
       " 'is_orientation_0': 0.8055555555555556,\n",
       " 'is_orientation_90': 0.5555555555555556,\n",
       " 'is_orientation_180': 0.7083333333333334,\n",
       " 'is_orientation_270': 0.8472222222222222,\n",
       " 'is_x_0': 0.027777777777777776,\n",
       " 'is_x_16': 0.013888888888888888,\n",
       " 'is_y_0': 0.020833333333333332,\n",
       " 'is_y_16': 0.041666666666666664}"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adversarial_rates_concept = {}\n",
    "for i XXXX-20results:\n",
    "    adversarial_rates_concept[i] = results[i]['freq_with_concept_adversarial'][0]/results[i]['freq_with_concept_adversarial'][1]\n",
    "adversarial_rates_concept"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary_stats = {\n",
    "    'adversarial_rates': adversarial_rates, \n",
    "    'frequency_train_predicted': frequency_train_predicted, \n",
    "    'overall_accuracy': overall_accuracy, \n",
    "    'adversarial_rates_concept': adversarial_rates_concept, \n",
    "}\n",
    "\n",
    "json.dump(summary_stats, open('results/dsprites/summary_stats_{}.json'.format(seed),'w'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cem",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
