{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%load_ext tensorboard\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import numpy as np\n",
    "import os\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import yaml\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from matplotlib import rc\n",
    "from matplotlib import cm\n",
    "import seaborn as sns\n",
    "from importlib import reload\n",
    "from pathlib import Path\n",
    "import sklearn\n",
    "from tensorflow.keras.models import load_model\n",
    "from joblib import dump, load\n",
    "import pandas as pd\n",
    "import cub_experiments as cub\n",
    "import models\n",
    "from CUB200.cub_loader import load_data, find_class_imbalance\n",
    "import torch\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "################################################################################\n",
    "## Global Variables Defining Experiment Flow\n",
    "################################################################################\n",
    "\n",
    "LATEX_SYMBOL = \"$\"\n",
    "RESULTS_DIR = \"results/\"\n",
    "CUB_RESULTS_DIR = os.path.join(\n",
    "    RESULTS_DIR,\n",
    "    \"cub\"\n",
    ")\n",
    "SPLIT_USED = 0\n",
    "rc('text', usetex=(LATEX_SYMBOL == \"$\"))\n",
    "plt.style.use('seaborn-whitegrid')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Model Configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "normal_cub_configs = defaultdict(dict)\n",
    "for file in os.listdir(CUB_RESULTS_DIR):\n",
    "    if '_experiment_config.joblib' in file:\n",
    "        config = load(os.path.join(CUB_RESULTS_DIR, file))\n",
    "        fold = int(file[file.find(\"_fold_\") + len(\"_fold_\"):file.find(\"_experiment_config\")]) - 1\n",
    "        model_name = f\"{config['architecture']}{config.get('extra_name', '')}\"\n",
    "        normal_cub_configs[str(fold)][model_name] = config\n",
    "\n",
    "print(\"Normal Model names:\")\n",
    "for model_name, _ in normal_cub_configs[\"0\"].items(): \n",
    "    print(\"\\t\", model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# See Selected Attributes for CUB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_attributes = [\n",
    "    1,\n",
    "    4,\n",
    "    6,\n",
    "    7,\n",
    "    10,\n",
    "    14,\n",
    "    15,\n",
    "    20,\n",
    "    21,\n",
    "    23,\n",
    "    25,\n",
    "    29,\n",
    "    30,\n",
    "    35,\n",
    "    36,\n",
    "    38,\n",
    "    40,\n",
    "    44,\n",
    "    45,\n",
    "    50,\n",
    "    51,\n",
    "    53,\n",
    "    54,\n",
    "    56,\n",
    "    57,\n",
    "    59,\n",
    "    63,\n",
    "    64,\n",
    "    69,\n",
    "    70,\n",
    "    72,\n",
    "    75,\n",
    "    80,\n",
    "    84,\n",
    "    90,\n",
    "    91,\n",
    "    93,\n",
    "    99,\n",
    "    101,\n",
    "    106,\n",
    "    110,\n",
    "    111,\n",
    "    116,\n",
    "    117,\n",
    "    119,\n",
    "    125,\n",
    "    126,\n",
    "    131,\n",
    "    132,\n",
    "    134,\n",
    "    145,\n",
    "    149,\n",
    "    151,\n",
    "    152,\n",
    "    153,\n",
    "    157,\n",
    "    158,\n",
    "    163,\n",
    "    164,\n",
    "    168,\n",
    "    172,\n",
    "    178,\n",
    "    179,\n",
    "    181,\n",
    "    183,\n",
    "    187,\n",
    "    188,\n",
    "    193,\n",
    "    194,\n",
    "    196,\n",
    "    198,\n",
    "    202,\n",
    "    203,\n",
    "    208,\n",
    "    209,\n",
    "    211,\n",
    "    212,\n",
    "    213,\n",
    "    218,\n",
    "    220,\n",
    "    221,\n",
    "    225,\n",
    "    235,\n",
    "    236,\n",
    "    238,\n",
    "    239,\n",
    "    240,\n",
    "    242,\n",
    "    243,\n",
    "    244,\n",
    "    249,\n",
    "    253,\n",
    "    254,\n",
    "    259,\n",
    "    260,\n",
    "    262,\n",
    "    268,\n",
    "    274,\n",
    "    277,\n",
    "    283,\n",
    "    289,\n",
    "    292,\n",
    "    293,\n",
    "    294,\n",
    "    298,\n",
    "    299,\n",
    "    304,\n",
    "    305,\n",
    "    308,\n",
    "    309,\n",
    "    310,\n",
    "    311,\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "attributes = [\n",
    "    \"has_bill_shape::curved_(up_or_down)\",\n",
    "    \"has_bill_shape::dagger\",\n",
    "    \"has_bill_shape::hooked\",\n",
    "    \"has_bill_shape::needle\",\n",
    "    \"has_bill_shape::hooked_seabird\",\n",
    "    \"has_bill_shape::spatulate\",\n",
    "    \"has_bill_shape::all-purpose\",\n",
    "    \"has_bill_shape::cone\",\n",
    "    \"has_bill_shape::specialized\",\n",
    "    \"has_wing_color::blue\",\n",
    "    \"has_wing_color::brown\",\n",
    "    \"has_wing_color::iridescent\",\n",
    "    \"has_wing_color::purple\",\n",
    "    \"has_wing_color::rufous\",\n",
    "    \"has_wing_color::grey\",\n",
    "    \"has_wing_color::yellow\",\n",
    "    \"has_wing_color::olive\",\n",
    "    \"has_wing_color::green\",\n",
    "    \"has_wing_color::pink\",\n",
    "    \"has_wing_color::orange\",\n",
    "    \"has_wing_color::black\",\n",
    "    \"has_wing_color::white\",\n",
    "    \"has_wing_color::red\",\n",
    "    \"has_wing_color::buff\",\n",
    "    \"has_upperparts_color::blue\",\n",
    "    \"has_upperparts_color::brown\",\n",
    "    \"has_upperparts_color::iridescent\",\n",
    "    \"has_upperparts_color::purple\",\n",
    "    \"has_upperparts_color::rufous\",\n",
    "    \"has_upperparts_color::grey\",\n",
    "    \"has_upperparts_color::yellow\",\n",
    "    \"has_upperparts_color::olive\",\n",
    "    \"has_upperparts_color::green\",\n",
    "    \"has_upperparts_color::pink\",\n",
    "    \"has_upperparts_color::orange\",\n",
    "    \"has_upperparts_color::black\",\n",
    "    \"has_upperparts_color::white\",\n",
    "    \"has_upperparts_color::red\",\n",
    "    \"has_upperparts_color::buff\",\n",
    "    \"has_underparts_color::blue\",\n",
    "    \"has_underparts_color::brown\",\n",
    "    \"has_underparts_color::iridescent\",\n",
    "    \"has_underparts_color::purple\",\n",
    "    \"has_underparts_color::rufous\",\n",
    "    \"has_underparts_color::grey\",\n",
    "    \"has_underparts_color::yellow\",\n",
    "    \"has_underparts_color::olive\",\n",
    "    \"has_underparts_color::green\",\n",
    "    \"has_underparts_color::pink\",\n",
    "    \"has_underparts_color::orange\",\n",
    "    \"has_underparts_color::black\",\n",
    "    \"has_underparts_color::white\",\n",
    "    \"has_underparts_color::red\",\n",
    "    \"has_underparts_color::buff\",\n",
    "    \"has_breast_pattern::solid\",\n",
    "    \"has_breast_pattern::spotted\",\n",
    "    \"has_breast_pattern::striped\",\n",
    "    \"has_breast_pattern::multi-colored\",\n",
    "    \"has_back_color::blue\",\n",
    "    \"has_back_color::brown\",\n",
    "    \"has_back_color::iridescent\",\n",
    "    \"has_back_color::purple\",\n",
    "    \"has_back_color::rufous\",\n",
    "    \"has_back_color::grey\",\n",
    "    \"has_back_color::yellow\",\n",
    "    \"has_back_color::olive\",\n",
    "    \"has_back_color::green\",\n",
    "    \"has_back_color::pink\",\n",
    "    \"has_back_color::orange\",\n",
    "    \"has_back_color::black\",\n",
    "    \"has_back_color::white\",\n",
    "    \"has_back_color::red\",\n",
    "    \"has_back_color::buff\",\n",
    "    \"has_tail_shape::forked_tail\",\n",
    "    \"has_tail_shape::rounded_tail\",\n",
    "    \"has_tail_shape::notched_tail\",\n",
    "    \"has_tail_shape::fan-shaped_tail\",\n",
    "    \"has_tail_shape::pointed_tail\",\n",
    "    \"has_tail_shape::squared_tail\",\n",
    "    \"has_upper_tail_color::blue\",\n",
    "    \"has_upper_tail_color::brown\",\n",
    "    \"has_upper_tail_color::iridescent\",\n",
    "    \"has_upper_tail_color::purple\",\n",
    "    \"has_upper_tail_color::rufous\",\n",
    "    \"has_upper_tail_color::grey\",\n",
    "    \"has_upper_tail_color::yellow\",\n",
    "    \"has_upper_tail_color::olive\",\n",
    "    \"has_upper_tail_color::green\",\n",
    "    \"has_upper_tail_color::pink\",\n",
    "    \"has_upper_tail_color::orange\",\n",
    "    \"has_upper_tail_color::black\",\n",
    "    \"has_upper_tail_color::white\",\n",
    "    \"has_upper_tail_color::red\",\n",
    "    \"has_upper_tail_color::buff\",\n",
    "    \"has_head_pattern::spotted\",\n",
    "    \"has_head_pattern::malar\",\n",
    "    \"has_head_pattern::crested\",\n",
    "    \"has_head_pattern::masked\",\n",
    "    \"has_head_pattern::unique_pattern\",\n",
    "    \"has_head_pattern::eyebrow\",\n",
    "    \"has_head_pattern::eyering\",\n",
    "    \"has_head_pattern::plain\",\n",
    "    \"has_head_pattern::eyeline\",\n",
    "    \"has_head_pattern::striped\",\n",
    "    \"has_head_pattern::capped\",\n",
    "    \"has_breast_color::blue\",\n",
    "    \"has_breast_color::brown\",\n",
    "    \"has_breast_color::iridescent\",\n",
    "    \"has_breast_color::purple\",\n",
    "    \"has_breast_color::rufous\",\n",
    "    \"has_breast_color::grey\",\n",
    "    \"has_breast_color::yellow\",\n",
    "    \"has_breast_color::olive\",\n",
    "    \"has_breast_color::green\",\n",
    "    \"has_breast_color::pink\",\n",
    "    \"has_breast_color::orange\",\n",
    "    \"has_breast_color::black\",\n",
    "    \"has_breast_color::white\",\n",
    "    \"has_breast_color::red\",\n",
    "    \"has_breast_color::buff\",\n",
    "    \"has_throat_color::blue\",\n",
    "    \"has_throat_color::brown\",\n",
    "    \"has_throat_color::iridescent\",\n",
    "    \"has_throat_color::purple\",\n",
    "    \"has_throat_color::rufous\",\n",
    "    \"has_throat_color::grey\",\n",
    "    \"has_throat_color::yellow\",\n",
    "    \"has_throat_color::olive\",\n",
    "    \"has_throat_color::green\",\n",
    "    \"has_throat_color::pink\",\n",
    "    \"has_throat_color::orange\",\n",
    "    \"has_throat_color::black\",\n",
    "    \"has_throat_color::white\",\n",
    "    \"has_throat_color::red\",\n",
    "    \"has_throat_color::buff\",\n",
    "    \"has_eye_color::blue\",\n",
    "    \"has_eye_color::brown\",\n",
    "    \"has_eye_color::purple\",\n",
    "    \"has_eye_color::rufous\",\n",
    "    \"has_eye_color::grey\",\n",
    "    \"has_eye_color::yellow\",\n",
    "    \"has_eye_color::olive\",\n",
    "    \"has_eye_color::green\",\n",
    "    \"has_eye_color::pink\",\n",
    "    \"has_eye_color::orange\",\n",
    "    \"has_eye_color::black\",\n",
    "    \"has_eye_color::white\",\n",
    "    \"has_eye_color::red\",\n",
    "    \"has_eye_color::buff\",\n",
    "    \"has_bill_length::about_the_same_as_head\",\n",
    "    \"has_bill_length::longer_than_head\",\n",
    "    \"has_bill_length::shorter_than_head\",\n",
    "    \"has_forehead_color::blue\",\n",
    "    \"has_forehead_color::brown\",\n",
    "    \"has_forehead_color::iridescent\",\n",
    "    \"has_forehead_color::purple\",\n",
    "    \"has_forehead_color::rufous\",\n",
    "    \"has_forehead_color::grey\",\n",
    "    \"has_forehead_color::yellow\",\n",
    "    \"has_forehead_color::olive\",\n",
    "    \"has_forehead_color::green\",\n",
    "    \"has_forehead_color::pink\",\n",
    "    \"has_forehead_color::orange\",\n",
    "    \"has_forehead_color::black\",\n",
    "    \"has_forehead_color::white\",\n",
    "    \"has_forehead_color::red\",\n",
    "    \"has_forehead_color::buff\",\n",
    "    \"has_under_tail_color::blue\",\n",
    "    \"has_under_tail_color::brown\",\n",
    "    \"has_under_tail_color::iridescent\",\n",
    "    \"has_under_tail_color::purple\",\n",
    "    \"has_under_tail_color::rufous\",\n",
    "    \"has_under_tail_color::grey\",\n",
    "    \"has_under_tail_color::yellow\",\n",
    "    \"has_under_tail_color::olive\",\n",
    "    \"has_under_tail_color::green\",\n",
    "    \"has_under_tail_color::pink\",\n",
    "    \"has_under_tail_color::orange\",\n",
    "    \"has_under_tail_color::black\",\n",
    "    \"has_under_tail_color::white\",\n",
    "    \"has_under_tail_color::red\",\n",
    "    \"has_under_tail_color::buff\",\n",
    "    \"has_nape_color::blue\",\n",
    "    \"has_nape_color::brown\",\n",
    "    \"has_nape_color::iridescent\",\n",
    "    \"has_nape_color::purple\",\n",
    "    \"has_nape_color::rufous\",\n",
    "    \"has_nape_color::grey\",\n",
    "    \"has_nape_color::yellow\",\n",
    "    \"has_nape_color::olive\",\n",
    "    \"has_nape_color::green\",\n",
    "    \"has_nape_color::pink\",\n",
    "    \"has_nape_color::orange\",\n",
    "    \"has_nape_color::black\",\n",
    "    \"has_nape_color::white\",\n",
    "    \"has_nape_color::red\",\n",
    "    \"has_nape_color::buff\",\n",
    "    \"has_belly_color::blue\",\n",
    "    \"has_belly_color::brown\",\n",
    "    \"has_belly_color::iridescent\",\n",
    "    \"has_belly_color::purple\",\n",
    "    \"has_belly_color::rufous\",\n",
    "    \"has_belly_color::grey\",\n",
    "    \"has_belly_color::yellow\",\n",
    "    \"has_belly_color::olive\",\n",
    "    \"has_belly_color::green\",\n",
    "    \"has_belly_color::pink\",\n",
    "    \"has_belly_color::orange\",\n",
    "    \"has_belly_color::black\",\n",
    "    \"has_belly_color::white\",\n",
    "    \"has_belly_color::red\",\n",
    "    \"has_belly_color::buff\",\n",
    "    \"has_wing_shape::rounded-wings\",\n",
    "    \"has_wing_shape::pointed-wings\",\n",
    "    \"has_wing_shape::broad-wings\",\n",
    "    \"has_wing_shape::tapered-wings\",\n",
    "    \"has_wing_shape::long-wings\",\n",
    "    \"has_size::large_(16_-_32_in)\",\n",
    "    \"has_size::small_(5_-_9_in)\",\n",
    "    \"has_size::very_large_(32_-_72_in)\",\n",
    "    \"has_size::medium_(9_-_16_in)\",\n",
    "    \"has_size::very_small_(3_-_5_in)\",\n",
    "    \"has_shape::upright-perching_water-like\",\n",
    "    \"has_shape::chicken-like-marsh\",\n",
    "    \"has_shape::long-legged-like\",\n",
    "    \"has_shape::duck-like\",\n",
    "    \"has_shape::owl-like\",\n",
    "    \"has_shape::gull-like\",\n",
    "    \"has_shape::hummingbird-like\",\n",
    "    \"has_shape::pigeon-like\",\n",
    "    \"has_shape::tree-clinging-like\",\n",
    "    \"has_shape::hawk-like\",\n",
    "    \"has_shape::sandpiper-like\",\n",
    "    \"has_shape::upland-ground-like\",\n",
    "    \"has_shape::swallow-like\",\n",
    "    \"has_shape::perching-like\",\n",
    "    \"has_back_pattern::solid\",\n",
    "    \"has_back_pattern::spotted\",\n",
    "    \"has_back_pattern::striped\",\n",
    "    \"has_back_pattern::multi-colored\",\n",
    "    \"has_tail_pattern::solid\",\n",
    "    \"has_tail_pattern::spotted\",\n",
    "    \"has_tail_pattern::striped\",\n",
    "    \"has_tail_pattern::multi-colored\",\n",
    "    \"has_belly_pattern::solid\",\n",
    "    \"has_belly_pattern::spotted\",\n",
    "    \"has_belly_pattern::striped\",\n",
    "    \"has_belly_pattern::multi-colored\",\n",
    "    \"has_primary_color::blue\",\n",
    "    \"has_primary_color::brown\",\n",
    "    \"has_primary_color::iridescent\",\n",
    "    \"has_primary_color::purple\",\n",
    "    \"has_primary_color::rufous\",\n",
    "    \"has_primary_color::grey\",\n",
    "    \"has_primary_color::yellow\",\n",
    "    \"has_primary_color::olive\",\n",
    "    \"has_primary_color::green\",\n",
    "    \"has_primary_color::pink\",\n",
    "    \"has_primary_color::orange\",\n",
    "    \"has_primary_color::black\",\n",
    "    \"has_primary_color::white\",\n",
    "    \"has_primary_color::red\",\n",
    "    \"has_primary_color::buff\",\n",
    "    \"has_leg_color::blue\",\n",
    "    \"has_leg_color::brown\",\n",
    "    \"has_leg_color::iridescent\",\n",
    "    \"has_leg_color::purple\",\n",
    "    \"has_leg_color::rufous\",\n",
    "    \"has_leg_color::grey\",\n",
    "    \"has_leg_color::yellow\",\n",
    "    \"has_leg_color::olive\",\n",
    "    \"has_leg_color::green\",\n",
    "    \"has_leg_color::pink\",\n",
    "    \"has_leg_color::orange\",\n",
    "    \"has_leg_color::black\",\n",
    "    \"has_leg_color::white\",\n",
    "    \"has_leg_color::red\",\n",
    "    \"has_leg_color::buff\",\n",
    "    \"has_bill_color::blue\",\n",
    "    \"has_bill_color::brown\",\n",
    "    \"has_bill_color::iridescent\",\n",
    "    \"has_bill_color::purple\",\n",
    "    \"has_bill_color::rufous\",\n",
    "    \"has_bill_color::grey\",\n",
    "    \"has_bill_color::yellow\",\n",
    "    \"has_bill_color::olive\",\n",
    "    \"has_bill_color::green\",\n",
    "    \"has_bill_color::pink\",\n",
    "    \"has_bill_color::orange\",\n",
    "    \"has_bill_color::black\",\n",
    "    \"has_bill_color::white\",\n",
    "    \"has_bill_color::red\",\n",
    "    \"has_bill_color::buff\",\n",
    "    \"has_crown_color::blue\",\n",
    "    \"has_crown_color::brown\",\n",
    "    \"has_crown_color::iridescent\",\n",
    "    \"has_crown_color::purple\",\n",
    "    \"has_crown_color::rufous\",\n",
    "    \"has_crown_color::grey\",\n",
    "    \"has_crown_color::yellow\",\n",
    "    \"has_crown_color::olive\",\n",
    "    \"has_crown_color::green\",\n",
    "    \"has_crown_color::pink\",\n",
    "    \"has_crown_color::orange\",\n",
    "    \"has_crown_color::black\",\n",
    "    \"has_crown_color::white\",\n",
    "    \"has_crown_color::red\",\n",
    "    \"has_crown_color::buff\",\n",
    "    \"has_wing_pattern::solid\",\n",
    "    \"has_wing_pattern::spotted\",\n",
    "    \"has_wing_pattern::striped\",\n",
    "    \"has_wing_pattern::multi-colored\",\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_attribute_names = list(np.array(attributes)[selected_attributes])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reload(cub)\n",
    "\n",
    "train_dl = load_data(\n",
    "    pkl_paths=[os.path.join(cub.BASE_DIR, 'train.pkl')],\n",
    "    use_attr=True,\n",
    "    no_img=False,\n",
    "    batch_size=128,\n",
    "    uncertain_label=False,\n",
    "    n_class_attr=2,\n",
    "    image_dir='images',\n",
    "    resampling=False,\n",
    "    root_dir='CUB200/',\n",
    "    num_workers=4,\n",
    ")\n",
    "imbalance = find_class_imbalance(os.path.join(cub.BASE_DIR, 'train.pkl'), True)\n",
    "sample = next(iter(train_dl))\n",
    "\n",
    "n_concepts, n_tasks = sample[2].shape[-1], 200\n",
    "total_c = []\n",
    "for (_, _, c) in train_dl:\n",
    "    total_c.append(c.cpu().detach())\n",
    "total_c = np.concatenate(total_c, axis=0)\n",
    "concept_corr_matrix = np.corrcoef(total_c.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reload(cub)\n",
    "\n",
    "test_dl = load_data(\n",
    "    pkl_paths=[os.path.join(cub.BASE_DIR, 'test.pkl')],\n",
    "    use_attr=True,\n",
    "    no_img=False,\n",
    "    batch_size=128,\n",
    "    uncertain_label=False,\n",
    "    n_class_attr=2,\n",
    "    image_dir='images',\n",
    "    resampling=False,\n",
    "    root_dir='CUB200/',\n",
    "    num_workers=4,\n",
    ")\n",
    "\n",
    "# And split this up into arrays for ease of use\n",
    "x_test, c_test, y_test = [], [], []\n",
    "for (x, y, c) in test_dl:\n",
    "    x_test.append(x)\n",
    "    y_test.append(y)\n",
    "    c_test.append(c)\n",
    "x_test = np.concatenate(x_test, axis=0)\n",
    "print(\"x_test.shape =\", x_test.shape)\n",
    "c_test = np.concatenate(c_test, axis=0)\n",
    "print(\"c_test.shape =\", c_test.shape)\n",
    "y_test = np.concatenate(y_test, axis=0)\n",
    "print(\"y_test.shape =\", y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reload(cub)\n",
    "\n",
    "test_all_concept_ones_dl = load_data(\n",
    "    pkl_paths=[os.path.join(cub.BASE_DIR, 'test.pkl')],\n",
    "    use_attr=True,\n",
    "    no_img=False,\n",
    "    batch_size=128,\n",
    "    uncertain_label=False,\n",
    "    n_class_attr=2,\n",
    "    image_dir='images',\n",
    "    resampling=False,\n",
    "    root_dir='CUB200/',\n",
    "    num_workers=4,\n",
    "    concept_transform=lambda c: np.ones(len(c)),\n",
    ")\n",
    "\n",
    "# And split this up into arrays for ease of use\n",
    "x_test_all_ones, c_test_all_ones, y_test_all_ones = [], [], []\n",
    "for (x, y, c) in test_all_concept_ones_dl:\n",
    "    x_test_all_ones.append(x)\n",
    "    y_test_all_ones.append(y)\n",
    "    c_test_all_ones.append(c)\n",
    "x_test_all_ones = np.concatenate(x_test_all_ones, axis=0)\n",
    "print(\"x_test_all_ones.shape =\", x_test_all_ones.shape)\n",
    "c_test_all_ones = np.concatenate(c_test_all_ones, axis=0)\n",
    "print(\"c_test_all_ones.shape =\", c_test_all_ones.shape)\n",
    "y_test_all_ones = np.concatenate(y_test_all_ones, axis=0)\n",
    "print(\"y_test_all_ones.shape =\", y_test_all_ones.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reload(cub)\n",
    "\n",
    "test_all_concept_zeros_dl = load_data(\n",
    "    pkl_paths=[os.path.join(cub.BASE_DIR, 'test.pkl')],\n",
    "    use_attr=True,\n",
    "    no_img=False,\n",
    "    batch_size=128,\n",
    "    uncertain_label=False,\n",
    "    n_class_attr=2,\n",
    "    image_dir='images',\n",
    "    resampling=False,\n",
    "    root_dir='CUB200/',\n",
    "    num_workers=4,\n",
    "    concept_transform=lambda c: np.zeros(len(c)),\n",
    ")\n",
    "\n",
    "# And split this up into arrays for ease of use\n",
    "x_test_all_zeros, c_test_all_zeros, y_test_all_zeros = [], [], []\n",
    "for (x, y, c) in test_all_concept_zeros_dl:\n",
    "    x_test_all_zeros.append(x)\n",
    "    y_test_all_zeros.append(y)\n",
    "    c_test_all_zeros.append(c)\n",
    "x_test_all_zeros = np.concatenate(x_test_all_zeros, axis=0)\n",
    "print(\"x_test_all_zeros.shape =\", x_test_all_zeros.shape)\n",
    "c_test_all_zeros = np.concatenate(c_test_all_zeros, axis=0)\n",
    "print(\"c_test_all_zeros.shape =\", c_test_all_zeros.shape)\n",
    "y_test_all_zeros = np.concatenate(y_test_all_zeros, axis=0)\n",
    "print(\"y_test_all_zeros.shape =\", y_test_all_zeros.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_CONCEPTS, N_TASKS = 112, 200"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load CUB Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_cub_model(\n",
    "    config,\n",
    "    n_tasks=N_TASKS,\n",
    "    result_dir=CUB_RESULTS_DIR,\n",
    "    n_concepts=N_CONCEPTS,\n",
    "    split=0,\n",
    "    imbalance=None,\n",
    "    intervention_idxs=None,\n",
    "    adversarial_intervention=False,\n",
    "    train_dl=None,\n",
    "):\n",
    "    if split is not None:\n",
    "        full_run_name = f\"{config['architecture']}{config.get('extra_name', '')}_{config['c_extractor_arch']}_fold_{split + 1}\"\n",
    "    else:\n",
    "        full_run_name = f\"{config['architecture']}{config.get('extra_name', '')}_{config['c_extractor_arch']}\"\n",
    "    selected_concepts = np.arange(n_concepts)\n",
    "    if config.get(\"message_passing_layers\"):\n",
    "        edges = []\n",
    "        edge_weights = []\n",
    "        corr_thresh = config.get('corr_thresh')\n",
    "        sorted_selected = sorted(selected_concepts)\n",
    "        for i in range(n_concepts):\n",
    "            i_idx = sorted_selected[i]\n",
    "            for j in range(i + 1, n_concepts):\n",
    "                j_idx = sorted_selected[j]\n",
    "                if np.abs(concept_corr_matrix[i_idx, j_idx]) >= corr_thresh:\n",
    "                    edges.append(np.array([[i, j], [j, i]]))\n",
    "                    if config.get(\"weighted_edges\"):\n",
    "                        weight = np.abs(concept_corr_matrix[i_idx, j_idx])\n",
    "                    else:\n",
    "                        weight = 1\n",
    "                    edge_weights.extend([weight, weight])\n",
    "        concept_edge_list = torch.cuda.LongTensor(np.concatenate(edges, axis=-1))\n",
    "        concept_edge_weights = torch.cuda.FloatTensor(np.array(edge_weights))\n",
    "    else:\n",
    "        concept_edge_list = None\n",
    "        concept_edge_weights = None\n",
    "    if (\n",
    "        (intervention_idxs is not None) and\n",
    "        (train_dl is not None) and\n",
    "        (config['architecture'] == \"ConceptBottleneckModel\") and\n",
    "        (not config.get('sigmoidal_prob', True))\n",
    "    ):\n",
    "        # Then let's look at the empirical distribution of the logits in order to\n",
    "        # be able to intervene\n",
    "        model = models.construct_model(\n",
    "            n_concepts=n_concepts,\n",
    "            n_tasks=n_tasks,\n",
    "            config=config,\n",
    "            imbalance=imbalance,\n",
    "            concept_edge_list=concept_edge_list,\n",
    "            concept_edge_weights=concept_edge_weights,\n",
    "        )\n",
    "        trainer = pl.Trainer(\n",
    "            gpus=GPU,\n",
    "        )\n",
    "        batch_results = trainer.predict(model, train_dl)\n",
    "        out_embs = np.concatenate(\n",
    "            list(map(lambda x: x[1], batch_results)),\n",
    "            axis=0,\n",
    "        )\n",
    "        active_intervention_values = []\n",
    "        inactive_intervention_values = []\n",
    "        for idx in range(n_concepts):\n",
    "            active_intervention_values.append(np.percentile(out_embs[:, idx], 95))\n",
    "            inactive_intervention_values.append(np.percentile(out_embs[:, idx], 5))\n",
    "        print(\"For\", full_run_name, \"we found its intervention values to be:\")\n",
    "        print(\"\\tactive_intervention_values =\", active_intervention_values)\n",
    "        print(\"\\tinactive_intervention_values =\", inactive_intervention_values)\n",
    "    else:\n",
    "        active_intervention_values = inactive_intervention_values = None\n",
    "    model = models.construct_model(\n",
    "        n_concepts=n_concepts,\n",
    "        n_tasks=n_tasks,\n",
    "        config=config,\n",
    "        imbalance=imbalance,\n",
    "        concept_edge_list=concept_edge_list,\n",
    "        concept_edge_weights=concept_edge_weights,\n",
    "        intervention_idxs=intervention_idxs,\n",
    "        adversarial_intervention=adversarial_intervention,\n",
    "        active_intervention_values=active_intervention_values,\n",
    "        inactive_intervention_values=inactive_intervention_values,\n",
    "    )\n",
    "    model_saved_path = os.path.join(\n",
    "        result_dir or \".\",\n",
    "        f'{full_run_name}.pt'\n",
    "    )\n",
    "    model.load_state_dict(torch.load(model_saved_path))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reload(cub)\n",
    "import models_cub\n",
    "reload(models_cub)\n",
    "\n",
    "WHITELIST = [\n",
    "    'ConceptBottleneckModelFuzzyExtraCapacity_Logit',\n",
    "    'MixtureEmbModelSharedProb_AdaptiveDropout_NoProbConcat',\n",
    "]\n",
    "all_models = defaultdict(dict)\n",
    "intervention_models = defaultdict(dict)\n",
    "for split, runs in normal_cub_configs.items():\n",
    "    for model_name, config in runs.items(): \n",
    "        if model_name not in WHITELIST:\n",
    "            continue\n",
    "        try:\n",
    "            config[\"shared_prob_gen\"] = config.get(\"shared_prob_gen\", False)\n",
    "            config[\"per_concept_weight\"] = config.get(\"per_concept_weight\", False)\n",
    "            all_models[split][model_name] = load_cub_model(\n",
    "                config=config,\n",
    "                n_tasks=N_TASKS,\n",
    "                n_concepts=N_CONCEPTS,\n",
    "                result_dir=CUB_RESULTS_DIR,\n",
    "                split=int(split),\n",
    "                imbalance=imbalance,\n",
    "            )\n",
    "            intervention_models[split][model_name] = load_cub_model(\n",
    "                config=config,\n",
    "                n_tasks=N_TASKS,\n",
    "                n_concepts=N_CONCEPTS,\n",
    "                result_dir=CUB_RESULTS_DIR,\n",
    "                split=int(split),\n",
    "                imbalance=imbalance,\n",
    "                intervention_idxs=np.arange(N_CONCEPTS),\n",
    "            )\n",
    "        except Exception as e:\n",
    "            print(\"Could not load model\", model_name, \"for split\", split)\n",
    "            print(\"\\t\", e)\n",
    "            raise e"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extract Test Embeddings using all Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reload(models_cub)\n",
    "\n",
    "test_c_embs = {}\n",
    "test_c_embs_pos = {}\n",
    "test_c_embs_neg = {}\n",
    "test_c_embs_oracle = {}\n",
    "test_c_embs_pure = {}\n",
    "test_y_preds = {}\n",
    "test_c_sems = {}\n",
    "\n",
    "\n",
    "for model_name, model in all_models['0'].items():\n",
    "    print(model_name)\n",
    "    trainer = pl.Trainer(\n",
    "        gpus=1,\n",
    "    )\n",
    "    batch_results = trainer.predict(model, test_dl)\n",
    "    test_c_sems[model_name] = np.concatenate(\n",
    "        list(map(lambda x: x[0], batch_results)),\n",
    "        axis=0,\n",
    "    )\n",
    "    test_y_preds[model_name] = np.concatenate(\n",
    "        list(map(lambda x: x[2], batch_results)),\n",
    "        axis=0,\n",
    "    )\n",
    "    complete_c_embs = np.concatenate(\n",
    "        list(map(lambda x: x[1], batch_results)),\n",
    "        axis=0,\n",
    "    )\n",
    "    n_concepts = test_c_sems[model_name].shape[-1]\n",
    "    if (\"SplitEmbModel\" in model_name) or ('MaskedSplitEmbModel' in model_name):\n",
    "        \n",
    "        c_embs = []\n",
    "        c_embs_pure = []\n",
    "        emb_size = complete_c_embs.shape[-1] // n_concepts\n",
    "        for i in range(complete_c_embs.shape[-1] // emb_size):\n",
    "            c_embs.append(\n",
    "                np.expand_dims(\n",
    "                    complete_c_embs[:, i*emb_size:(i+1)*emb_size],\n",
    "                    axis=1,\n",
    "                )\n",
    "            )\n",
    "            c_embs_pure.append(\n",
    "                np.expand_dims(\n",
    "                    # Do not include the last dimension as that is the probability/semantics dimension\n",
    "                    complete_c_embs[:, i*emb_size:(i+1)*emb_size - 1],\n",
    "                    axis=1,\n",
    "                )\n",
    "            )\n",
    "    elif (\"MixtureEmbModel\" in model_name):\n",
    "        \n",
    "        c_embs = []\n",
    "        c_embs_pure = []\n",
    "        emb_size = complete_c_embs.shape[-1] // n_concepts\n",
    "        for i in range(complete_c_embs.shape[-1] // emb_size):\n",
    "            c_embs.append(\n",
    "                np.expand_dims(\n",
    "                    complete_c_embs[:, i*emb_size:(i+1)*emb_size],\n",
    "                    axis=1,\n",
    "                )\n",
    "            )\n",
    "            c_embs_pure.append(\n",
    "                c_embs[-1],\n",
    "            )\n",
    "            \n",
    "        # And compute also only positive\n",
    "        pos_batch_results = trainer.predict(\n",
    "            intervention_models['0'][model_name],\n",
    "            test_all_concept_ones_dl,\n",
    "        )\n",
    "        pos_c_embs_complete = np.concatenate(\n",
    "            list(map(lambda x: x[1], pos_batch_results)),\n",
    "            axis=0,\n",
    "        )\n",
    "        # and negative embeddings\n",
    "        neg_batch_results = trainer.predict(\n",
    "            intervention_models['0'][model_name],\n",
    "            test_all_concept_zeros_dl,\n",
    "        )\n",
    "        neg_c_embs_complete = np.concatenate(\n",
    "            list(map(lambda x: x[1], neg_batch_results)),\n",
    "            axis=0,\n",
    "        )\n",
    "        # also those embeddings that perfectly match their concepts\n",
    "        oracle_batch_results = trainer.predict(\n",
    "            intervention_models['0'][model_name],\n",
    "            test_dl,\n",
    "        )\n",
    "        oracle_c_embs_complete = np.concatenate(\n",
    "            list(map(lambda x: x[1], oracle_batch_results)),\n",
    "            axis=0,\n",
    "        )\n",
    "        \n",
    "        # put them together in the right form\n",
    "        pos_c_embs = []\n",
    "        neg_c_embs = []\n",
    "        oracle_c_embs = []\n",
    "        for i in range(complete_c_embs.shape[-1] // emb_size):\n",
    "            pos_c_embs.append(\n",
    "                np.expand_dims(\n",
    "                    pos_c_embs_complete[:, i*emb_size:(i+1)*emb_size],\n",
    "                    axis=1,\n",
    "                )\n",
    "            )\n",
    "            neg_c_embs.append(\n",
    "                np.expand_dims(\n",
    "                    neg_c_embs_complete[:, i*emb_size:(i+1)*emb_size],\n",
    "                    axis=1,\n",
    "                )\n",
    "            )\n",
    "            oracle_c_embs.append(\n",
    "                np.expand_dims(\n",
    "                    oracle_c_embs_complete[:, i*emb_size:(i+1)*emb_size],\n",
    "                    axis=1,\n",
    "                )\n",
    "            )\n",
    "        test_c_embs_pos[model_name] = np.concatenate(pos_c_embs, axis=1)\n",
    "        test_c_embs_neg[model_name] = np.concatenate(neg_c_embs, axis=1)\n",
    "        test_c_embs_oracle[model_name] = np.concatenate(oracle_c_embs, axis=1)\n",
    "        \n",
    "    elif \"ConceptBottleneckModel\" in model_name:\n",
    "        c_embs = []\n",
    "        c_embs_pure = []\n",
    "        for i in range(n_concepts):\n",
    "            pure_entries = []\n",
    "            entries = [complete_c_embs[:, i:i+1]]\n",
    "            if complete_c_embs.shape[-1] > n_concepts:\n",
    "                entries.append(complete_c_embs[:, n_concepts:])\n",
    "                pure_entries.append(complete_c_embs[:, n_concepts:])\n",
    "            c_embs.append(\n",
    "                np.expand_dims(\n",
    "                   np.concatenate(entries, axis=-1),\n",
    "                   axis=1,\n",
    "                )\n",
    "            )\n",
    "            if pure_entries:\n",
    "                c_embs_pure.append(\n",
    "                    np.expand_dims(\n",
    "                       np.concatenate(pure_entries, axis=-1),\n",
    "                       axis=1,\n",
    "                    )\n",
    "                )\n",
    "    c_embs = np.concatenate(c_embs, axis=1)\n",
    "    if len(c_embs_pure):\n",
    "        c_embs_pure = np.concatenate(c_embs_pure, axis=1)\n",
    "\n",
    "    test_c_embs[model_name] = c_embs\n",
    "    test_c_embs_pure[model_name] = c_embs_pure\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Functions for visualizing embedding clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import NearestNeighbors\n",
    "from scipy.cluster.hierarchy import dendrogram, linkage\n",
    "from scipy.cluster.hierarchy import fcluster\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "class UnNormalize(object):\n",
    "    def __init__(self, mean, std):\n",
    "        self.mean = mean\n",
    "        self.std = std\n",
    "\n",
    "    def __call__(self, tensor):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.\n",
    "        Returns:\n",
    "            Tensor: Normalized image.\n",
    "        \"\"\"\n",
    "        return tensor * np.expand_dims(np.expand_dims(self.std, axis=-1), axis=-1) + np.expand_dims(\n",
    "            np.expand_dims(self.mean, axis=-1),\n",
    "            axis=-1\n",
    "        )\n",
    "    \n",
    "def show_bird_image(image, ax):\n",
    "    ax.grid(False)\n",
    "    ax.axis(False)\n",
    "    ax.imshow(np.transpose(\n",
    "        (UnNormalize(mean=[0.5, 0.5, 0.5], std=[2, 2, 2])(image) * 255).astype(np.int32),\n",
    "        axes=[1, 2, 0]\n",
    "    ))\n",
    "\n",
    "def show_closest_activation_examples(\n",
    "    x_test,\n",
    "    test_dl,\n",
    "    test_c_embs,\n",
    "    concept_semantics=None,\n",
    "    num_examples=5,\n",
    "    shown_neighs=5,\n",
    "    scale=1.5,\n",
    "    selected_concepts=None,\n",
    "    seed=None,\n",
    "):\n",
    "    np.random.seed(seed)\n",
    "    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))\n",
    "    for selected_concept in selected_concepts:\n",
    "        if concept_semantics is not None:\n",
    "            print(\n",
    "                \"Selected concept at index\",\n",
    "                selected_concept,\n",
    "                \"with semantics\",\n",
    "                concept_semantics[selected_concept],\n",
    "            )\n",
    "\n",
    "        selected_inds = []\n",
    "        for i, (_, _, c_batch) in enumerate(test_dl):\n",
    "            for idx, c in enumerate(c_batch):\n",
    "                if c[selected_concept] == 1:\n",
    "                    selected_inds.append(i*c_batch.shape[0] + idx)\n",
    "        selected_inds = np.random.choice(selected_inds, size=num_examples, replace=False,)\n",
    "        fig, axs = plt.subplots(\n",
    "            num_examples,\n",
    "            shown_neighs + 2,\n",
    "            figsize=(scale*shown_neighs, scale*num_examples),\n",
    "        )\n",
    "        if concept_semantics is not None:\n",
    "            fig.suptitle(\n",
    "                f'Closest Embeddings for Concept {concept_semantics[selected_concept]}',\n",
    "                fontsize=15,\n",
    "            )\n",
    "        for i, example_idx in enumerate(selected_inds):\n",
    "            show_bird_image(x_test[example_idx, :, :, :], axs[i, 0])\n",
    "            if i == 0:\n",
    "                # Then add a title here\n",
    "                axs[i, 0].set_title(\"Sample\", fontsize=20)\n",
    "            # Let's add an empty image in between as a separator\n",
    "            axs[i, 1].grid(False)\n",
    "            axs[i, 1].axis(False)\n",
    "            nbrs = NearestNeighbors(n_neighbors=(shown_neighs + 1), algorithm='ball_tree').fit(\n",
    "                test_c_embs[:, selected_concept, :]\n",
    "            )\n",
    "            [distances], [nearest_indices] = nbrs.kneighbors(test_c_embs[example_idx:example_idx+1, selected_concept, :])\n",
    "            for j, sample_idx in enumerate(nearest_indices[1:], start=2):\n",
    "                show_bird_image(x_test[sample_idx, :, :, :], axs[i, j])\n",
    "                if (i == 0) and ((j - 2) == shown_neighs // 2):\n",
    "                    axs[i, j].set_title(\"Nearest Neighbors\", fontsize=20)\n",
    "        fig.tight_layout()\n",
    "        fig.subplots_adjust(\n",
    "            wspace=0,\n",
    "            hspace=0,\n",
    "        )\n",
    "        plt.show()\n",
    "\n",
    "def show_concept_clusters(\n",
    "    x_test,\n",
    "    test_c_embs,\n",
    "    test_c_sems,\n",
    "    max_d=50, #100\n",
    "    concept_semantics=None,\n",
    "    show_activated_only=True,\n",
    "    show_examples=True,\n",
    "    max_clusters=5,\n",
    "    shown_samples=5,\n",
    "    scale=1.5,\n",
    "    selected_concepts=None,\n",
    "    model_name=\"\",\n",
    "):\n",
    "    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))\n",
    "    for selected_concept in selected_concepts:\n",
    "        if concept_semantics is not None:\n",
    "            print(\n",
    "                \"Selected concept at index\",\n",
    "                selected_concept,\n",
    "                \"with semantics\",\n",
    "                concept_semantics[selected_concept],\n",
    "            )\n",
    "        if show_activated_only:\n",
    "            selected_inds = np.arange(0, test_c_embs.shape[0])[\n",
    "                test_c_sems[:, selected_concept] > 0.5,\n",
    "            ].astype(np.int32)\n",
    "        else:\n",
    "            selected_inds = np.arange(0, test_c_embs.shape[0]).astype(np.int32)\n",
    "\n",
    "        selected_test_embs = test_c_embs[\n",
    "            selected_inds,\n",
    "            :,\n",
    "            :\n",
    "        ]\n",
    "\n",
    "        if show_examples:\n",
    "            print(\"Examples of selected test samples:\")\n",
    "            fig = plt.figure(figsize=(14, 6))\n",
    "            for i, idx in enumerate(selected_inds[:8]):\n",
    "                fig.add_subplot(1, 8, i + 1)\n",
    "                show_bird_image(x_test[idx, :, :, :], plt)\n",
    "            plt.show()\n",
    "\n",
    "        # selected_test_embs = test_c_embs\n",
    "        Z = linkage(selected_test_embs[:, selected_concept, :], 'ward')\n",
    "        clusters = fcluster(Z, max_d, criterion='distance')\n",
    "        cluster_types = np.unique(clusters)\n",
    "        print(\"Found\", len(cluster_types), \"clusters from\", clusters.shape[0], \"samples\")\n",
    "        cluster_map = [\n",
    "            [] for _ in range(len(cluster_types))\n",
    "        ]\n",
    "        for i, cluster_type in enumerate(clusters):\n",
    "            cluster_map[cluster_type - 1].append(i)\n",
    "\n",
    "        fig, axs = plt.subplots(\n",
    "            min(len(cluster_map), max_clusters),\n",
    "            shown_samples,\n",
    "            figsize=(scale*shown_samples, scale*min(max_clusters, len(cluster_map))),\n",
    "        )\n",
    "        if concept_semantics is not None:\n",
    "            fig.suptitle(\n",
    "                f'{model_name} Sample Concept Clusters for Concept {concept_semantics[selected_concept]}',\n",
    "                fontsize=15,\n",
    "            )\n",
    "        for row in axs:\n",
    "            for ax in row:\n",
    "                ax.grid(False)\n",
    "                ax.axis(False)\n",
    "\n",
    "        for cluster_id, samples in enumerate(cluster_map):\n",
    "            if cluster_id >= max_clusters:\n",
    "                break\n",
    "            real_shown_samples = min(shown_samples, len(samples))\n",
    "            centroid = np.expand_dims(\n",
    "                np.mean(selected_test_embs[samples, selected_concept, :], axis=0),\n",
    "                axis=0,\n",
    "            )\n",
    "            nbrs = NearestNeighbors(n_neighbors=real_shown_samples, algorithm='ball_tree').fit(\n",
    "                selected_test_embs[samples, selected_concept, :]\n",
    "            )\n",
    "            [distances], [nearest_indices] = nbrs.kneighbors(centroid)\n",
    "            for i, sample_idx in enumerate(nearest_indices):\n",
    "                real_idx = selected_inds[samples[sample_idx]]\n",
    "                show_bird_image(x_test[real_idx, :, :, :], axs[cluster_id, i])\n",
    "        fig.tight_layout()\n",
    "        fig.subplots_adjust(\n",
    "            wspace=0,\n",
    "            hspace=0.1,\n",
    "        )\n",
    "        plt.show()\n",
    "\n",
    "def show_inter_concept_similarity(\n",
    "    x_test,\n",
    "    test_c_embs,\n",
    "    test_c_sems,\n",
    "    concept_semantics=None,\n",
    "    show_activated_only=True,\n",
    "    selected_concepts=None,\n",
    "    normalize=True,\n",
    "    n_closest=5,\n",
    "    metric='cosine',\n",
    "    to_console=True,\n",
    "):\n",
    "    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))\n",
    "\n",
    "    centroids = np.zeros((len(selected_concepts), test_c_embs.shape[-1]))\n",
    "    for i, concept_idx in enumerate(selected_concepts):\n",
    "        if show_activated_only:\n",
    "            selected_inds = np.arange(0, test_c_embs.shape[0])[\n",
    "                test_c_sems[:, concept_idx] > 0.5,\n",
    "            ].astype(np.int32)\n",
    "        else:\n",
    "            selected_inds = np.arange(0, test_c_embs.shape[0]).astype(np.int32)\n",
    "\n",
    "        selected_test_embs = test_c_embs[\n",
    "            selected_inds,\n",
    "            :,\n",
    "            :\n",
    "        ]\n",
    "        centroids[i, :] = np.mean(\n",
    "            selected_test_embs[:, concept_idx, :],\n",
    "            axis=0,\n",
    "        )\n",
    "    if normalize:\n",
    "        centroids = sklearn.preprocessing.normalize(centroids, axis=1)\n",
    "    nbrs = NearestNeighbors(\n",
    "        n_neighbors=n_closest + 1,\n",
    "        algorithm='auto',\n",
    "        metric=metric,\n",
    "    ).fit(centroids)\n",
    "    \n",
    "    result = []\n",
    "    for i, concept_idx in enumerate(selected_concepts):\n",
    "        [distances], [nearest_concepts] = nbrs.kneighbors(centroids[i:i+1, :])\n",
    "        concept_name = concept_idx\n",
    "        nearest_concepts_idx = nearest_concepts\n",
    "        if concept_semantics is not None:\n",
    "            concept_name = concept_semantics[concept_idx]\n",
    "            nearest_concepts = np.array(concept_semantics)[nearest_concepts]\n",
    "        \n",
    "        if to_console:\n",
    "            print(f\"Nearest concepts to concept {concept_name}:\")\n",
    "        partial_lst = []\n",
    "        for j, name, dist in zip(nearest_concepts_idx, nearest_concepts[1:], distances[1:]):\n",
    "            if to_console:\n",
    "                print(f\"\\t{name} (distance {dist})\")\n",
    "            partial_lst.append((j, dist))\n",
    "        result.append(partial_lst)\n",
    "        if to_console:\n",
    "            print()\n",
    "    return centroids, result\n",
    "\n",
    "def plot_concept_centroids(\n",
    "    x_test,\n",
    "    test_c_embs,\n",
    "    test_c_sems,\n",
    "    concept_semantics=None,\n",
    "    selected_concepts=None,\n",
    "    perplexity=50,\n",
    "    n_iter=1000,\n",
    "    figsize=(8, 6),\n",
    "    show_activated_only=True,\n",
    "    model_name=\"SplitEmb\",\n",
    "    annotation_size=5,\n",
    "    concept_colors=None,\n",
    "    dot_size=10,\n",
    "    half_emb=False,\n",
    "):\n",
    "    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))\n",
    "    if half_emb:\n",
    "        centroids = np.zeros((len(selected_concepts), test_c_embs.shape[-1]//2))\n",
    "    else:\n",
    "        centroids = np.zeros((len(selected_concepts), test_c_embs.shape[-1]))\n",
    "    for i, concept_idx in enumerate(selected_concepts):\n",
    "        if show_activated_only:\n",
    "            selected_inds = np.arange(0, test_c_embs.shape[0])[\n",
    "                test_c_sems[:, concept_idx] > 0.5,\n",
    "            ].astype(np.int32)\n",
    "        else:\n",
    "            selected_inds = np.arange(0, test_c_embs.shape[0]).astype(np.int32)\n",
    "        \n",
    "        if half_emb:\n",
    "            selected_test_embs = test_c_embs[\n",
    "                selected_inds,\n",
    "                :,\n",
    "                :test_c_embs.shape[-1]//2\n",
    "            ]\n",
    "        else:\n",
    "            selected_test_embs = test_c_embs[\n",
    "                selected_inds,\n",
    "                :,\n",
    "                :\n",
    "            ]\n",
    "        centroids[i, :] = np.mean(\n",
    "            selected_test_embs[:, concept_idx, :],\n",
    "            axis=0,\n",
    "        )\n",
    "#     centroids = sklearn.preprocessing.normalize(centroids, axis=1)\n",
    "    tsne = TSNE(\n",
    "        n_components=2,\n",
    "        verbose=1,\n",
    "        perplexity=perplexity,\n",
    "        n_iter=n_iter,\n",
    "        init='pca',\n",
    "        learning_rate='auto',\n",
    "    )\n",
    "    tsne_results = tsne.fit_transform(centroids)\n",
    "    fig, ax = plt.subplots(\n",
    "        1,\n",
    "        1,\n",
    "        figsize=figsize,\n",
    "    )\n",
    "    ax.set_title(\n",
    "        f\"{model_name} Cluster Centroids\",\n",
    "        fontsize=15,\n",
    "    )\n",
    "    if concept_colors is None:\n",
    "        colors = []\n",
    "        marker = []\n",
    "        concept_semantics = concept_semantics or [\n",
    "            f'concept_{idx}' for idx in selected_concepts\n",
    "        ]\n",
    "        for i, concept_idx in enumerate(selected_concepts):\n",
    "            concept_name = concept_semantics[concept_idx]\n",
    "            if \"color\" in concept_name:\n",
    "                marker.append(\"o\")\n",
    "                color = concept_name[concept_name.find(\"::\") + 2:]\n",
    "                colors.append(color)\n",
    "            else:\n",
    "                colors.append(\"black\")\n",
    "                marker.append(\"x\")\n",
    "    else:\n",
    "        colors = list(np.array(concept_colors)[selected_concepts])\n",
    "        markers = [\"o\" for _ in selected_concepts]\n",
    "    for i, color in enumerate(colors):\n",
    "        if color == \"buff\":\n",
    "            color = \"palegoldenrod\" \n",
    "        elif color == \"multi-colored\":\n",
    "            color = \"palegreen\"\n",
    "        elif color == \"white\":\n",
    "            color = \"cyan\"\n",
    "        colors[i] = color\n",
    "    ax.scatter(\n",
    "        tsne_results[:, 0],\n",
    "        tsne_results[:, 1],\n",
    "        c=colors,\n",
    "        s=dot_size,\n",
    "        # TODO!!!!!!!\n",
    "#         marker=markers,\n",
    "    )\n",
    "   \n",
    "    for i, concept_idx in enumerate(selected_concepts):\n",
    "        concept_name = concept_semantics[concept_idx]\n",
    "        ax.annotate(\n",
    "            concept_semantics[concept_idx],\n",
    "            (tsne_results[i, 0], tsne_results[i, 1]),\n",
    "            fontsize=annotation_size,\n",
    "        )\n",
    "    ax.grid(False)\n",
    "    ax.axis(False)\n",
    "    fig.legend(fontsize=10) #, loc='center right')\n",
    "    plt.show()\n",
    "            \n",
    "def plot_tsne_embeddings(\n",
    "    test_c_embs,\n",
    "    c_test,\n",
    "    color_activations=None,\n",
    "    color_activation_labels=None,\n",
    "    attributes=None,\n",
    "    perplexity=50,\n",
    "    n_iter=1000,\n",
    "    figsize=(8, 6),\n",
    "    selected_concepts=None,\n",
    "    y_test=None,\n",
    "    model_name=\"SplitEmb\",\n",
    "):\n",
    "    results = []\n",
    "    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))\n",
    "    for selected_concept in selected_concepts:\n",
    "        if attributes is not None:\n",
    "            print(\n",
    "                \"Selected concept at index\",\n",
    "                selected_concept,\n",
    "                \"with semantics\",\n",
    "                attributes[selected_concept],\n",
    "            )\n",
    "        tsne = TSNE(\n",
    "            n_components=2,\n",
    "            verbose=1,\n",
    "            perplexity=perplexity,\n",
    "            n_iter=n_iter,\n",
    "            init='pca',\n",
    "            learning_rate='auto',\n",
    "        )\n",
    "        tsne_results = tsne.fit_transform(test_c_embs[:, selected_concept, :])\n",
    "        results.append(tsne_results)\n",
    "        if y_test is not None:\n",
    "            fig, ax = plt.subplots(\n",
    "                1,\n",
    "                1,\n",
    "                figsize=figsize,\n",
    "            )\n",
    "            if attributes is not None:\n",
    "                ax.set_title(\n",
    "                    f\"TSNE Embeddings for {attributes[selected_concept]} (by class)\",\n",
    "                    fontsize=15,\n",
    "                )\n",
    "            ax.scatter(\n",
    "                tsne_results[:, 0],\n",
    "                tsne_results[:, 1],\n",
    "                c=y_test,\n",
    "                s=5,\n",
    "                cmap='ocean',\n",
    "            )\n",
    "            ax.grid(False)\n",
    "            ax.axis(False)\n",
    "            plt.show()\n",
    "            \n",
    "        if color_activations is not None:\n",
    "            activations = color_activations\n",
    "        else:\n",
    "            activations = [c_test[:, selected_concept]]\n",
    "        for i, activation in enumerate(activations):\n",
    "            if color_activation_labels is not None:\n",
    "                activation_label = color_activation_labels[i]\n",
    "            elif (color_activations is None):\n",
    "                if (attributes is not None):\n",
    "                    activation_label = attributes[selected_concept]\n",
    "                else:\n",
    "                    activation_label = f\"Concept {selected_concept + 1}\"\n",
    "            else:\n",
    "                activation_label = f\"Concept {i + 1}\"\n",
    "            \n",
    "            mask = activation == 1\n",
    "            neg_mask = np.logical_not(mask)\n",
    "\n",
    "            # And let's plot all of these\n",
    "            fig, ax = plt.subplots(\n",
    "                1,\n",
    "                1,\n",
    "                figsize=figsize,\n",
    "            )\n",
    "            if attributes is not None:\n",
    "                ax.set_title(\n",
    "                    f\"{model_name} TSNE Embeddings for {attributes[selected_concept]}\",\n",
    "                    fontsize=15,\n",
    "                )\n",
    "            ax.scatter(\n",
    "                tsne_results[mask, 0],\n",
    "                tsne_results[mask, 1],\n",
    "                color='red',\n",
    "                label=activation_label + \" active\",\n",
    "                s=5,\n",
    "            )\n",
    "\n",
    "            ax.scatter(\n",
    "                tsne_results[neg_mask, 0],\n",
    "                tsne_results[neg_mask, 1],\n",
    "                color='blue',\n",
    "                label=activation_label + \" not active\",\n",
    "                s=5,\n",
    "            )\n",
    "            ax.grid(False)\n",
    "            ax.axis(False)\n",
    "            fig.legend(fontsize=10) #, loc='center right')\n",
    "            plt.show()\n",
    "        \n",
    "    return results\n",
    "\n",
    "\n",
    "def plot_tsne_latent_space(\n",
    "    test_c_embs,\n",
    "    c_test,\n",
    "    attributes=None,\n",
    "    perplexity=50,\n",
    "    n_iter=1000,\n",
    "    figsize=(8, 6),\n",
    "    selected_concepts=None,\n",
    "    y_test=None,\n",
    "):\n",
    "    tsne = TSNE(\n",
    "        n_components=2,\n",
    "        verbose=1,\n",
    "        perplexity=perplexity,\n",
    "        n_iter=n_iter,\n",
    "        init='pca',\n",
    "        learning_rate='auto',\n",
    "    )\n",
    "    latent_space = test_c_embs.reshape(\n",
    "        test_c_embs.shape[0],\n",
    "        -1,\n",
    "    )\n",
    "    tsne_results = tsne.fit_transform(latent_space)\n",
    "    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))\n",
    "    if y_test is not None:\n",
    "        fig, ax = plt.subplots(\n",
    "            1,\n",
    "            1,\n",
    "            figsize=figsize,\n",
    "        )\n",
    "        if attributes is not None:\n",
    "            ax.set_title(\n",
    "                f\"SplitEmb COMPLETE Latent Space TSNE colored by class\",\n",
    "                fontsize=15,\n",
    "            )\n",
    "        ax.scatter(\n",
    "            tsne_results[:, 0],\n",
    "            tsne_results[:, 1],\n",
    "            c=y_test,\n",
    "            s=5,\n",
    "            cmap='ocean',\n",
    "        )\n",
    "\n",
    "        ax.grid(False)\n",
    "        ax.axis(False)\n",
    "        plt.show()\n",
    "        \n",
    "    for selected_concept in selected_concepts:\n",
    "        if attributes is not None:\n",
    "            print(\n",
    "                \"Selected concept at index\",\n",
    "                selected_concept,\n",
    "                \"with semantics\",\n",
    "                attributes[selected_concept],\n",
    "            )\n",
    "\n",
    "        mask = c_test[:, selected_concept] == 1\n",
    "        neg_mask = np.logical_not(mask)\n",
    "\n",
    "        # And let's plot all of these\n",
    "        fig, ax = plt.subplots(\n",
    "            1,\n",
    "            1,\n",
    "            figsize=figsize,\n",
    "        )\n",
    "        if attributes is not None:\n",
    "            ax.set_title(\n",
    "                f\"SplitEmb COMPLETE Latent Space TSNE colored by {attributes[selected_concept]}\",\n",
    "                fontsize=15,\n",
    "            )\n",
    "        ax.scatter(\n",
    "            tsne_results[mask, 0],\n",
    "            tsne_results[mask, 1],\n",
    "            color='red',\n",
    "            label=\"Concept activated\",\n",
    "            s=5,\n",
    "        )\n",
    "\n",
    "        ax.scatter(\n",
    "            tsne_results[neg_mask, 0],\n",
    "            tsne_results[neg_mask, 1],\n",
    "            color='blue',\n",
    "            label=\"Concept not present\",\n",
    "            s=5,\n",
    "        )\n",
    "        ax.grid(False)\n",
    "        ax.axis(False)\n",
    "        fig.legend(fontsize=10) #, loc='center right')\n",
    "        plt.show()\n",
    "    return tsne_results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Time to look at clusters in SplitEmbedding model with shared probability generators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "color_idxs = []\n",
    "colors = []\n",
    "colored_rename = []\n",
    "real_attributes = np.array(attributes)[selected_attributes]\n",
    "for i, name in enumerate(real_attributes):\n",
    "    if (\"color\" in name) and (\"multi-colored\" not in name):\n",
    "        color_idxs.append(i)\n",
    "        rename, color = name.split(\"::\")\n",
    "        if rename.startswith(\"has_\"):\n",
    "            rename = rename[len(\"has_\"):]\n",
    "        if rename.endswith(\"_color\"):\n",
    "            rename = rename[:-len(\"_color\")]\n",
    "        colors.append(color)\n",
    "        colored_rename.append(color + \"_\" + rename)\n",
    "    else:\n",
    "        colors.append(\"black\")\n",
    "        colored_rename.append(name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## For mixture of embeddings model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plot_tsne_embeddings(\n",
    "    c_test=c_test,\n",
    "    y_test=y_test,\n",
    "    test_c_embs=test_c_embs[f'MixtureEmbModelSharedProb_AdaptiveDropout_NoProbConcat'],\n",
    "    attributes=real_attributes,\n",
    "    perplexity=30,\n",
    "    n_iter=1500,\n",
    "    figsize=(8, 6),\n",
    "    selected_concepts=list(range(0, len(selected_attributes), 2)),\n",
    "    model_name=\"MixCEM\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plot_tsne_embeddings(\n",
    "    c_test=c_test,\n",
    "    y_test=y_test,\n",
    "    test_c_embs=test_c_embs_oracle[f'MixtureEmbModelSharedProb_AdaptiveDropout_NoProbConcat'],\n",
    "    attributes=real_attributes,\n",
    "    perplexity=30,\n",
    "    n_iter=1500,\n",
    "    figsize=(8, 6),\n",
    "    selected_concepts=list(range(0, len(selected_attributes), 2)),\n",
    "    model_name=\"MixCEM\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "show_closest_activation_examples(\n",
    "    x_test=x_test,\n",
    "    test_c_embs=test_c_embs[f'MixtureEmbModelSharedProb_AdaptiveDropout_NoProbConcat'],\n",
    "    concept_semantics=real_attributes,\n",
    "    selected_concepts=list(range(0, len(selected_attributes), 2)),\n",
    "    test_dl=test_dl,\n",
    "    num_examples=5,\n",
    "    shown_neighs=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Same experiments for Hybrid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plot_tsne_embeddings(\n",
    "    c_test=c_test,\n",
    "    y_test=y_test,\n",
    "    test_c_embs=test_c_embs[f'ConceptBottleneckModelFuzzyExtraCapacity_Logit'],\n",
    "    attributes=real_attributes,\n",
    "    perplexity=30,\n",
    "    n_iter=1500,\n",
    "    figsize=(8, 6),\n",
    "    selected_concepts=list(range(0, len(selected_attributes), 2)),\n",
    "    model_name=\"Hybrid\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plot_tsne_embeddings(\n",
    "    c_test=c_test,\n",
    "    y_test=y_test,\n",
    "    test_c_embs=test_c_embs_pure[f'ConceptBottleneckModelFuzzyExtraCapacity_Logit'],\n",
    "    attributes=real_attributes,\n",
    "    perplexity=30,\n",
    "    n_iter=1500,\n",
    "    figsize=(8, 6),\n",
    "    selected_concepts=list(range(0, len(selected_attributes), 2)),\n",
    "    model_name=\"Hybrid\",\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
