{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%load_ext tensorboard\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import numpy as np\n",
    "import os\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import yaml\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from matplotlib import rc\n",
    "from matplotlib import cm\n",
    "import seaborn as sns\n",
    "from importlib import reload\n",
    "from pathlib import Path\n",
    "import sklearn\n",
    "from tensorflow.keras.models import load_model\n",
    "from joblib import dump, load\n",
    "import pandas as pd\n",
    "import cub_experiments as cub\n",
    "import models\n",
    "import torch\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning import seed_everything\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from pytorch_lightning import seed_everything\n",
    "import celeba_experiments as celeba\n",
    "from CUB200.cub_loader import load_data, find_class_imbalance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "################################################################################\n",
    "## Global Variables Defining Experiment Flow\n",
    "################################################################################\n",
    "\n",
    "GPU = 1\n",
    "NUM_WORKERS = 5\n",
    "LATEX_SYMBOL = \"$\"\n",
    "RESULTS_DIR = \"results/\"\n",
    "CUB_RESULTS_DIR = os.path.join(\n",
    "    \"results/cub_0.25_subsample\",\n",
    ")\n",
    "\n",
    "rc('text', usetex=(LATEX_SYMBOL == \"$\"))\n",
    "plt.style.use('seaborn-whitegrid')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Model Configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "cub_configs = defaultdict(dict)\n",
    "\n",
    "for file in os.listdir(CUB_RESULTS_DIR):\n",
    "    if '_experiment_config.joblib' in file:\n",
    "        config = load(os.path.join(CUB_RESULTS_DIR, file))\n",
    "        fold = int(file[file.find(\"_fold_\") + len(\"_fold_\"):file.find(\"_experiment_config\")]) - 1\n",
    "        model_name = f\"{config['architecture']}{config.get('extra_name', '')}\"\n",
    "        cub_configs[str(fold)][model_name] = config\n",
    "\n",
    "\n",
    "print(\"CUB Model names:\")\n",
    "for model_name, _ in cub_configs['0'].items(): \n",
    "    print(\"\\t\", model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_everything(42)\n",
    "N_CONCEPTS, N_TASKS = 112, 200\n",
    "og_config = cub_configs['0']['MixtureEmbModelSharedProb_AdaptiveDropout_NoProbConcat']\n",
    "if og_config['weight_loss']:\n",
    "    imbalance = find_class_imbalance(os.path.join(cub.BASE_DIR, 'train.pkl'), True)\n",
    "else:\n",
    "    imbalance = None\n",
    "sampling_percent = og_config.get(\"sampling_percent\", 1)\n",
    "if sampling_percent != 1:\n",
    "    # Do the subsampling\n",
    "    new_n_concepts = int(np.ceil(N_CONCEPTS * sampling_percent))\n",
    "    selected_concepts_file = os.path.join(\n",
    "        CUB_RESULTS_DIR,\n",
    "        f\"selected_concepts_sampling_{sampling_percent}.npy\",\n",
    "    )\n",
    "    selected_concepts = np.load(selected_concepts_file)\n",
    "    print(\"\\t\\tSelected concepts:\", selected_concepts)\n",
    "    def subsample_transform(sample):\n",
    "        if isinstance(sample, list):\n",
    "            sample = np.array(sample)\n",
    "        return sample[selected_concepts]\n",
    "\n",
    "    if og_config['weight_loss']:\n",
    "        imbalance = np.array(imbalance)[selected_concepts]\n",
    "\n",
    "    train_dl = load_data(\n",
    "        pkl_paths=[os.path.join(cub.BASE_DIR, 'train.pkl')],\n",
    "        use_attr=True,\n",
    "        no_img=False,\n",
    "        batch_size=og_config['batch_size'],\n",
    "        uncertain_label=False,\n",
    "        n_class_attr=2,\n",
    "        image_dir='images',\n",
    "        resampling=False,\n",
    "        root_dir=cub.CUB_DIR,\n",
    "        num_workers=NUM_WORKERS,\n",
    "        concept_transform=subsample_transform,\n",
    "    )\n",
    "    test_dl = load_data(\n",
    "        pkl_paths=[os.path.join(cub.BASE_DIR, 'test.pkl')],\n",
    "        use_attr=True,\n",
    "        no_img=False,\n",
    "        batch_size=og_config['batch_size'],\n",
    "        uncertain_label=False,\n",
    "        n_class_attr=2,\n",
    "        image_dir='images',\n",
    "        resampling=False,\n",
    "        root_dir=cub.CUB_DIR,\n",
    "        num_workers=NUM_WORKERS,\n",
    "        concept_transform=subsample_transform,\n",
    "    )\n",
    "    \n",
    "    train_complete_dl = load_data(\n",
    "        pkl_paths=[os.path.join(cub.BASE_DIR, 'train.pkl')],\n",
    "        use_attr=True,\n",
    "        no_img=False,\n",
    "        batch_size=og_config['batch_size'],\n",
    "        uncertain_label=False,\n",
    "        n_class_attr=2,\n",
    "        image_dir='images',\n",
    "        resampling=False,\n",
    "        root_dir=cub.CUB_DIR,\n",
    "        num_workers=NUM_WORKERS,\n",
    "    )\n",
    "    test_complete_dl = load_data(\n",
    "        pkl_paths=[os.path.join(cub.BASE_DIR, 'test.pkl')],\n",
    "        use_attr=True,\n",
    "        no_img=False,\n",
    "        batch_size=og_config['batch_size'],\n",
    "        uncertain_label=False,\n",
    "        n_class_attr=2,\n",
    "        image_dir='images',\n",
    "        resampling=False,\n",
    "        root_dir=cub.CUB_DIR,\n",
    "        num_workers=NUM_WORKERS,\n",
    "    )\n",
    "\n",
    "    # And set the right number of concepts to be used\n",
    "    N_CONCEPTS = new_n_concepts\n",
    "else:\n",
    "    selected_concepts = list(range(N_CONCEPTS))\n",
    "    train_dl = load_data(\n",
    "        pkl_paths=[os.path.join(cub.BASE_DIR, 'train.pkl')],\n",
    "        use_attr=True,\n",
    "        no_img=False,\n",
    "        batch_size=og_config['batch_size'],\n",
    "        uncertain_label=False,\n",
    "        n_class_attr=2,\n",
    "        image_dir='images',\n",
    "        resampling=False,\n",
    "        root_dir=cub.CUB_DIR,\n",
    "        num_workers=NUM_WORKERS,\n",
    "    )\n",
    "    test_dl = load_data(\n",
    "        pkl_paths=[os.path.join(cub.BASE_DIR, 'test.pkl')],\n",
    "        use_attr=True,\n",
    "        no_img=False,\n",
    "        batch_size=og_config['batch_size'],\n",
    "        uncertain_label=False,\n",
    "        n_class_attr=2,\n",
    "        image_dir='images',\n",
    "        resampling=False,\n",
    "        root_dir=cub.CUB_DIR,\n",
    "        num_workers=NUM_WORKERS,\n",
    "    )\n",
    "\n",
    "total_c = []\n",
    "for (_, _, c) in train_dl:\n",
    "    total_c.append(c.cpu().detach())\n",
    "total_c = np.concatenate(total_c, axis=0)\n",
    "concept_corr_matrix = np.corrcoef(total_c.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# And split this up into arrays for ease of use\n",
    "x_train, c_train, y_train = [], [], []\n",
    "for (x, y, c) in train_dl:\n",
    "    x_train.append(x)\n",
    "    y_train.append(y)\n",
    "    c_train.append(c)\n",
    "x_train = np.concatenate(x_train, axis=0)\n",
    "print(\"x_train.shape =\", x_train.shape)\n",
    "c_train = np.concatenate(c_train, axis=0)\n",
    "print(\"c_train.shape =\", c_train.shape)\n",
    "y_train = np.concatenate(y_train, axis=0)\n",
    "print(\"y_train.shape =\", y_train.shape)\n",
    "\n",
    "# And split this up into arrays for ease of use\n",
    "x_train_complete, c_train_complete, y_train_complete = [], [], []\n",
    "for (x, y, c) in train_complete_dl:\n",
    "    x_train_complete.append(x)\n",
    "    y_train_complete.append(y)\n",
    "    c_train_complete.append(c)\n",
    "x_train_complete = np.concatenate(x_train_complete, axis=0)\n",
    "print(\"x_train_complete.shape =\", x_train_complete.shape)\n",
    "c_train_complete = np.concatenate(c_train_complete, axis=0)\n",
    "print(\"c_train_complete.shape =\", c_train_complete.shape)\n",
    "y_train_complete = np.concatenate(y_train_complete, axis=0)\n",
    "print(\"y_train_complete.shape =\", y_train_complete.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# And split this up into arrays for ease of use\n",
    "x_test, c_test, y_test = [], [], []\n",
    "for (x, y, c) in test_dl:\n",
    "    x_test.append(x)\n",
    "    y_test.append(y)\n",
    "    c_test.append(c)\n",
    "x_test = np.concatenate(x_test, axis=0)\n",
    "print(\"x_test.shape =\", x_test.shape)\n",
    "c_test = np.concatenate(c_test, axis=0)\n",
    "print(\"c_test.shape =\", c_test.shape)\n",
    "y_test = np.concatenate(y_test, axis=0)\n",
    "print(\"y_test.shape =\", y_test.shape)\n",
    "\n",
    "# And split this up into arrays for ease of use\n",
    "x_test_complete, c_test_complete, y_test_complete = [], [], []\n",
    "for (x, y, c) in test_complete_dl:\n",
    "    x_test_complete.append(x)\n",
    "    y_test_complete.append(y)\n",
    "    c_test_complete.append(c)\n",
    "x_test_complete = np.concatenate(x_test_complete, axis=0)\n",
    "print(\"x_test_complete.shape =\", x_test_complete.shape)\n",
    "c_test_complete = np.concatenate(c_test_complete, axis=0)\n",
    "print(\"c_test_complete.shape =\", c_test_complete.shape)\n",
    "y_test_complete = np.concatenate(y_test_complete, axis=0)\n",
    "print(\"y_test_complete.shape =\", y_test_complete.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_trained_model(\n",
    "    config,\n",
    "    n_tasks=N_TASKS,\n",
    "    result_dir=CUB_RESULTS_DIR,\n",
    "    n_concepts=N_CONCEPTS,\n",
    "    split=0,\n",
    "    imbalance=None,\n",
    "    intervention_idxs=None,\n",
    "    adversarial_intervention=False,\n",
    "    train_dl=None,\n",
    "):\n",
    "    if split is not None:\n",
    "        full_run_name = f\"{config['architecture']}{config.get('extra_name', '')}_{config['c_extractor_arch']}_fold_{split + 1}\"\n",
    "    else:\n",
    "        full_run_name = f\"{config['architecture']}{config.get('extra_name', '')}_{config['c_extractor_arch']}\"\n",
    "    selected_concepts = np.arange(n_concepts)\n",
    "    if config.get(\"message_passing_layers\"):\n",
    "        edges = []\n",
    "        edge_weights = []\n",
    "        corr_thresh = config.get('corr_thresh')\n",
    "        sorted_selected = sorted(selected_concepts)\n",
    "        for i in range(n_concepts):\n",
    "            i_idx = sorted_selected[i]\n",
    "            for j in range(i + 1, n_concepts):\n",
    "                j_idx = sorted_selected[j]\n",
    "                if np.abs(concept_corr_matrix[i_idx, j_idx]) >= corr_thresh:\n",
    "                    edges.append(np.array([[i, j], [j, i]]))\n",
    "                    if config.get(\"weighted_edges\"):\n",
    "                        weight = np.abs(concept_corr_matrix[i_idx, j_idx])\n",
    "                    else:\n",
    "                        weight = 1\n",
    "                    edge_weights.extend([weight, weight])\n",
    "        concept_edge_list = torch.cuda.LongTensor(np.concatenate(edges, axis=-1))\n",
    "        concept_edge_weights = torch.cuda.FloatTensor(np.array(edge_weights))\n",
    "    else:\n",
    "        concept_edge_list = None\n",
    "        concept_edge_weights = None\n",
    "    if (\n",
    "        (intervention_idxs is not None) and\n",
    "        (train_dl is not None) and\n",
    "        (config['architecture'] == \"ConceptBottleneckModel\") and\n",
    "        (not config.get('sigmoidal_prob', True))\n",
    "    ):\n",
    "        # Then let's look at the empirical distribution of the logits in order to\n",
    "        # be able to intervene\n",
    "        model = models.construct_model(\n",
    "            n_concepts=n_concepts,\n",
    "            n_tasks=n_tasks,\n",
    "            config=config,\n",
    "            imbalance=imbalance,\n",
    "            concept_edge_list=concept_edge_list,\n",
    "            concept_edge_weights=concept_edge_weights,\n",
    "        )\n",
    "        trainer = pl.Trainer(\n",
    "            gpus=GPU,\n",
    "        )\n",
    "        batch_results = trainer.predict(model, train_dl)\n",
    "        out_embs = np.concatenate(\n",
    "            list(map(lambda x: x[1], batch_results)),\n",
    "            axis=0,\n",
    "        )\n",
    "        active_intervention_values = []\n",
    "        inactive_intervention_values = []\n",
    "        for idx in range(n_concepts):\n",
    "            active_intervention_values.append(np.percentile(out_embs[:, idx], 95))\n",
    "            inactive_intervention_values.append(np.percentile(out_embs[:, idx], 5))\n",
    "        print(\"For\", full_run_name, \"we found its intervention values to be:\")\n",
    "        print(\"\\tactive_intervention_values =\", active_intervention_values)\n",
    "        print(\"\\tinactive_intervention_values =\", inactive_intervention_values)\n",
    "    else:\n",
    "        active_intervention_values = inactive_intervention_values = None\n",
    "    model = models.construct_model(\n",
    "        n_concepts=n_concepts,\n",
    "        n_tasks=n_tasks,\n",
    "        config=config,\n",
    "        imbalance=imbalance,\n",
    "        concept_edge_list=concept_edge_list,\n",
    "        concept_edge_weights=concept_edge_weights,\n",
    "        intervention_idxs=intervention_idxs,\n",
    "        adversarial_intervention=adversarial_intervention,\n",
    "        active_intervention_values=active_intervention_values,\n",
    "        inactive_intervention_values=inactive_intervention_values,\n",
    "    )\n",
    "    model_saved_path = os.path.join(\n",
    "        result_dir or \".\",\n",
    "        f'{full_run_name}.pt'\n",
    "    )\n",
    "    model.load_state_dict(torch.load(model_saved_path))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reload(cub)\n",
    "import models_cub\n",
    "reload(models_cub)\n",
    "\n",
    "WHITELIST = [\n",
    "    'ConceptBottleneckModelFuzzyExtraCapacity_Logit',\n",
    "    'MixtureEmbModelSharedProb_AdaptiveDropout_NoProbConcat',\n",
    "]\n",
    "all_models = defaultdict(dict)\n",
    "for split, runs in cub_configs.items():\n",
    "    for model_name, config in runs.items(): \n",
    "        if model_name not in WHITELIST:\n",
    "            continue\n",
    "        try:\n",
    "            config[\"shared_prob_gen\"] = config.get(\"shared_prob_gen\", False)\n",
    "            config[\"per_concept_weight\"] = config.get(\"per_concept_weight\", False)\n",
    "            all_models[split][model_name] = load_trained_model(\n",
    "                config=config,\n",
    "                n_tasks=N_TASKS,\n",
    "                n_concepts=N_CONCEPTS,\n",
    "                result_dir=CUB_RESULTS_DIR,\n",
    "                split=int(split),\n",
    "                imbalance=imbalance,\n",
    "            )\n",
    "        except Exception as e:\n",
    "            print(\"Could not load model\", model_name, \"for split\", split)\n",
    "            print(\"\\t\", e)\n",
    "            raise e"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Now time to train a simple classifier to predict the hidden concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_bin_accuracy(y_pred, y_true):\n",
    "    y_probs = y_pred.cpu().detach()\n",
    "    y_pred = y_probs > 0.5\n",
    "    y_true = y_true.reshape(-1).cpu().detach()\n",
    "    y_accuracy = sklearn.metrics.accuracy_score(y_true, y_pred)\n",
    "    try:\n",
    "        y_auc = sklearn.metrics.roc_auc_score(y_true, y_probs)\n",
    "    except:\n",
    "        y_auc = 0\n",
    "    try:\n",
    "        y_f1 = sklearn.metrics.f1_score(y_true, y_pred)\n",
    "    except:\n",
    "        y_f1 = 0\n",
    "    return (y_accuracy, y_auc, y_f1)\n",
    "\n",
    "def compute_accuracy(\n",
    "    y_pred,\n",
    "    y_true,\n",
    "):\n",
    "    if (len(y_pred.shape) < 2) or (y_pred.shape[-1] == 1):\n",
    "        return compute_bin_accuracy(\n",
    "            y_pred,\n",
    "            y_true,\n",
    "        )\n",
    "    y_probs = torch.nn.Softmax(dim=-1)(y_pred).cpu().detach()\n",
    "    used_classes = np.unique(y_true.reshape(-1).cpu().detach())\n",
    "    y_probs = y_probs[:, sorted(list(used_classes))]\n",
    "    y_pred = y_pred.argmax(dim=-1).cpu().detach()\n",
    "    y_true = y_true.reshape(-1).cpu().detach()\n",
    "    y_accuracy = sklearn.metrics.accuracy_score(y_true, y_pred)\n",
    "    try:\n",
    "        y_auc = sklearn.metrics.roc_auc_score(y_true, y_probs, multi_class='ovo')\n",
    "    except:\n",
    "        y_auc = 0.0\n",
    "    y_f1 = 0.0\n",
    "    return (y_accuracy, y_auc, y_f1)\n",
    "\n",
    "\n",
    "class SimpleMLP(pl.LightningModule):\n",
    "    def __init__(\n",
    "        self,\n",
    "        n_features,\n",
    "        n_tasks,\n",
    "        layer_sizes=[],\n",
    "        momentum=0.9,\n",
    "        learning_rate=0.01,\n",
    "        weight_decay=4e-05,\n",
    "        optimizer=\"sgd\",\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.n_tasks = n_tasks\n",
    "        layers = []\n",
    "        prev_size = n_features\n",
    "        for size in (layer_sizes or []):\n",
    "            layers.append(torch.nn.Linear(prev_size, size))\n",
    "            layers.append(torch.nn.LeakyReLU())\n",
    "            prev_size = size\n",
    "        layers.append(torch.nn.Linear(prev_size, n_tasks))\n",
    "        self.model = torch.nn.Sequential(*layers)\n",
    "        self.loss_task = torch.nn.CrossEntropyLoss() if n_tasks > 1 else torch.nn.BCEWithLogitsLoss()\n",
    "        self.momentum = momentum\n",
    "        self.learning_rate = learning_rate\n",
    "        self.optimizer_name = optimizer\n",
    "        self.weight_decay = weight_decay\n",
    "\n",
    "    def _unpack_batch(self, batch):\n",
    "        x, y = batch\n",
    "        return x, y\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.model(x)\n",
    "    \n",
    "    def predict_step(self, batch, batch_idx, dataloader_idx=0):\n",
    "        x, y = self._unpack_batch(batch)\n",
    "        return self(x)\n",
    "\n",
    "    def _run_step(self, batch, batch_idx, train=False):\n",
    "        x, y = self._unpack_batch(batch)\n",
    "        y_logits = self(x)\n",
    "        loss = self.loss_task(\n",
    "            y_logits if y_logits.shape[-1] > 1 else y_logits.reshape(-1),\n",
    "            y,\n",
    "        )\n",
    "        # compute accuracy\n",
    "        (y_accuracy, y_auc, y_f1) = compute_accuracy(\n",
    "            y_logits,\n",
    "            y,\n",
    "        )\n",
    "        result = {\n",
    "            \"y_accuracy\": y_accuracy,\n",
    "            \"y_auc\": y_auc,\n",
    "            \"y_f1\": y_f1,\n",
    "            \"loss\": loss.detach(),\n",
    "        }\n",
    "        return loss, result\n",
    "\n",
    "    def training_step(self, batch, batch_no):\n",
    "        loss, result = self._run_step(batch, batch_no, train=True)\n",
    "        for name, val in result.items():\n",
    "            self.log(name, val, prog_bar=(\"accuracy\" in name))\n",
    "        return {\n",
    "            \"loss\": loss,\n",
    "            \"log\": {\n",
    "                \"y_accuracy\": result['y_accuracy'],\n",
    "                \"y_auc\": result['y_auc'],\n",
    "                \"y_f1\": result['y_f1'],\n",
    "                \"loss\": result['loss'],\n",
    "            },\n",
    "        }\n",
    "\n",
    "    def validation_step(self, batch, batch_no):\n",
    "        loss, result = self._run_step(batch, batch_no, train=False)\n",
    "        for name, val in result.items():\n",
    "            self.log(\"val_\" + name, val, prog_bar=(\"accuracy\" in name))\n",
    "        return {\n",
    "            \"val_\" + key: val\n",
    "            for key, val in result.items()\n",
    "        }\n",
    "\n",
    "    def test_step(self, batch, batch_no):\n",
    "        loss, result = self._run_step(batch, batch_no, train=False)\n",
    "        for name, val in result.items():\n",
    "            self.log(\"test_\" + name, val, prog_bar=True)\n",
    "        return result['loss']\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        if self.optimizer_name.lower() == \"adam\":\n",
    "            optimizer = torch.optim.Adam(\n",
    "                self.parameters(),\n",
    "                lr=self.learning_rate,\n",
    "                weight_decay=self.weight_decay,\n",
    "            )\n",
    "        else:\n",
    "            optimizer = torch.optim.SGD(\n",
    "                filter(lambda p: p.requires_grad, self.parameters()),\n",
    "                lr=self.learning_rate,\n",
    "                momentum=self.momentum,\n",
    "                weight_decay=self.weight_decay,\n",
    "            )\n",
    "        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)\n",
    "        return {\n",
    "            \"optimizer\": optimizer,\n",
    "            \"lr_scheduler\": lr_scheduler,\n",
    "            \"monitor\": \"loss\",\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_mlp(train_dl, test_dl, n_tasks=1, layer_sizes=[], max_epochs=100, gpu=1, verbose=True):\n",
    "    (x, y) = next(iter(train_dl))\n",
    "    n_features = x.shape[-1]\n",
    "    print(\"n_features =\", n_features)\n",
    "    model = SimpleMLP(\n",
    "        n_features=n_features,\n",
    "        n_tasks=n_tasks,\n",
    "        layer_sizes=layer_sizes,\n",
    "    )\n",
    "    trainer = pl.Trainer(\n",
    "        gpus=gpu,\n",
    "        max_epochs=max_epochs,\n",
    "        check_val_every_n_epoch=5\n",
    "    )\n",
    "    # Else it is time to train it\n",
    "    trainer.fit(model, train_dl)\n",
    "    # freeze model and compute test accuracy\n",
    "    model.freeze()\n",
    "    [test_results] = trainer.test(model, test_dl)\n",
    "    y_accuracy = test_results[\"test_y_accuracy\"]\n",
    "    return model, test_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_bottleneck(\n",
    "    x_test_complete,\n",
    "    y_test_complete,\n",
    "    c_test_complete,\n",
    "    model_name,\n",
    "    selected_concepts=selected_concepts,\n",
    "    max_epochs=50,\n",
    "    layer_sizes=[],\n",
    "):\n",
    "    folds = sorted(list(map(int, all_models.keys())))\n",
    "    trainer = pl.Trainer(\n",
    "        gpus=1,\n",
    "    )\n",
    "    model_test_results = []\n",
    "    model_mean_other_concept_accs = []\n",
    "    maj_mean_other_concept_accs = []\n",
    "    for fold in folds:\n",
    "        print(\"Starting with fold\", fold + 1)\n",
    "        if str(fold) not in all_models or (\n",
    "            model_name not in all_models[str(fold)]\n",
    "        ):\n",
    "            print(\"Skipping\", model_name, \"for fold\", fold + 1, \"as we could not find it\")\n",
    "            continue\n",
    "        model = all_models[str(fold)][model_name]\n",
    "        test_batch_results = trainer.predict(\n",
    "            model,\n",
    "            torch.utils.data.DataLoader(\n",
    "                torch.utils.data.TensorDataset(\n",
    "                    torch.cuda.FloatTensor(x_test_complete),\n",
    "                    torch.cuda.FloatTensor(y_test_complete),\n",
    "                    torch.cuda.FloatTensor(c_test_complete),\n",
    "                ),\n",
    "                batch_size=1,\n",
    "            ),\n",
    "        )\n",
    "        train_batch_results = trainer.predict(\n",
    "            model,\n",
    "            torch.utils.data.DataLoader(\n",
    "                torch.utils.data.TensorDataset(\n",
    "                    torch.cuda.FloatTensor(x_train_complete),\n",
    "                    torch.cuda.FloatTensor(y_train_complete),\n",
    "                    torch.cuda.FloatTensor(c_train_complete),\n",
    "                ),\n",
    "                batch_size=1,\n",
    "            ),\n",
    "        )\n",
    "        test_complete_embs = np.concatenate(\n",
    "            list(map(lambda x: x[1], test_batch_results)),\n",
    "            axis=0,\n",
    "        )\n",
    "\n",
    "        train_complete_embs = np.concatenate(\n",
    "            list(map(lambda x: x[1], train_batch_results)),\n",
    "            axis=0,\n",
    "        )\n",
    "\n",
    "        mean_acc = 0\n",
    "        maj_class_acc = 0 \n",
    "        for concept_idx in range(c_train_complete.shape[-1]):\n",
    "            if concept_idx in selected_concepts:\n",
    "                print(\"Skipping concept\", concept_idx, \"as it was used during training\")\n",
    "                continue\n",
    "            print(\"Training with concept index\", concept_idx, \"for fold\", fold)\n",
    "            current_train_ds = torch.utils.data.TensorDataset(\n",
    "                torch.cuda.FloatTensor(\n",
    "                    train_complete_embs\n",
    "                ),\n",
    "                torch.cuda.FloatTensor(c_train_complete[:, concept_idx]),\n",
    "            )\n",
    "            current_train_dl = torch.utils.data.DataLoader(\n",
    "                current_train_ds,\n",
    "                batch_size=512,\n",
    "            )\n",
    "\n",
    "            current_test_ds = torch.utils.data.TensorDataset(\n",
    "                torch.cuda.FloatTensor(\n",
    "                    test_complete_embs\n",
    "                ),\n",
    "                torch.cuda.FloatTensor(c_test_complete[:, concept_idx]),\n",
    "            )\n",
    "            current_test_dl = torch.utils.data.DataLoader(\n",
    "                current_test_ds,\n",
    "                batch_size=512,\n",
    "            )\n",
    "\n",
    "\n",
    "            _, test_results = train_mlp(\n",
    "                train_dl=current_train_dl,\n",
    "                test_dl=current_test_dl,\n",
    "                layer_sizes=layer_sizes,\n",
    "                max_epochs=max_epochs,\n",
    "                verbose=False,\n",
    "            )\n",
    "            print(\"For concept\", concept_idx, \"we obtained accuracy\", test_results['test_y_accuracy'])\n",
    "            mean_acc += test_results['test_y_accuracy']/(c_train_complete.shape[-1] - len(selected_concepts))\n",
    "            maj_class_acc += max(\n",
    "                1 - np.mean(c_test_complete[:, concept_idx]),\n",
    "                np.mean(c_test_complete[:, concept_idx]),\n",
    "            )/(c_train_complete.shape[-1] - len(selected_concepts))\n",
    "        model_mean_other_concept_accs.append(mean_acc)\n",
    "        maj_mean_other_concept_accs.append(maj_class_acc)\n",
    "        print(\"Resulting mean accuracy is:\", mean_acc)\n",
    "        print(\"Majority class accuracy is:\", maj_class_acc)\n",
    "    return model_mean_other_concept_accs, maj_mean_other_concept_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now the same but with a single linear layer\n",
    "import logging\n",
    "logging.getLogger(\"lightning\").setLevel(logging.ERROR)\n",
    "import logging\n",
    "logging.getLogger(\"lightning\").addHandler(logging.NullHandler())\n",
    "logging.getLogger(\"lightning\").propagate = False\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "single_mixed_mean_accs, maj_mean_accs = check_bottleneck(\n",
    "    x_test_complete=x_test_complete,\n",
    "    y_test_complete=y_test_complete,\n",
    "    c_test_complete=c_test_complete,\n",
    "    model_name='MixtureEmbModelSharedProb_AdaptiveDropout_NoProbConcat',\n",
    "    selected_concepts=selected_concepts,\n",
    "    max_epochs=50,\n",
    "    layer_sizes=[],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir results/cub_representation_use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import joblib\n",
    "\n",
    "joblib.dump(\n",
    "    single_mixed_mean_accs,\n",
    "    os.path.join(\"results/cub_representation_use\", f'single_mixed_mean_accs.joblib'),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result: 94.32602682047421 ± 0.8794648020417654\n"
     ]
    }
   ],
   "source": [
    "print(\"Result:\", 100*np.mean(single_mixed_mean_accs), \"±\", 200*np.std(single_mixed_mean_accs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Time to do the same with the Hybrid representation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "logging.getLogger(\"lightning\").setLevel(logging.ERROR)\n",
    "import logging\n",
    "logging.getLogger(\"lightning\").addHandler(logging.NullHandler())\n",
    "logging.getLogger(\"lightning\").propagate = False\n",
    "import warnings\n",
    "\n",
    "single_hybrid_mean_accs, maj_mean_accs = check_bottleneck(\n",
    "    x_test_complete=x_test_complete,\n",
    "    y_test_complete=y_test_complete,\n",
    "    c_test_complete=c_test_complete,\n",
    "    model_name='ConceptBottleneckModelFuzzyExtraCapacity_Logit',\n",
    "    selected_concepts=selected_concepts,\n",
    "    max_epochs=50,\n",
    "    layer_sizes=[],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import joblib\n",
    "\n",
    "joblib.dump(\n",
    "    single_hybrid_mean_accs,\n",
    "    os.path.join(\"results/cub_representation_use\", f'single_hybrid_mean_accs.joblib'),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result: 91.82535575495827 ± 0.5091445744528944\n"
     ]
    }
   ],
   "source": [
    "print(\"Result:\", 100*np.mean(single_hybrid_mean_accs), \"±\", 200*np.std(single_hybrid_mean_accs))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
