{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Concept Bottleneck Models: Dot Training Example\n",
    "\n",
    "This very short notebook will showcase how to set up a Concept Embedding Model\n",
    "(CEM) using our library and train it on the Dot dataset proposed in our CEM\n",
    "NeurIPS 2022 paper.\n",
    "\n",
    "Our example is composed by four different steps:\n",
    "1. Loading the dataset of interest in a format that can be \"digested\" by our models.\n",
    "2. Instantiating a CEM with the embedding size and encoder/decoder architectures we want to use.\n",
    "3. Training the CEM on the Dot dataset.\n",
    "4. Evaluating the CEM's task accuracy, concept AUC, and concept alignment score (CAS)."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Load Data\n",
    "\n",
    "As a first step, we will show you how one can generate a dataset from scratch\n",
    "that is compatible with how our training pipeline is set.\n",
    "\n",
    "In practice, you can train any CEM (or CBM variant) using our library as long as\n",
    "your dataset is structured such that:\n",
    "1. It is contained within a Pytorch DataLoader object.\n",
    "2. Every sample contains is a tuple with three elements in it: the sample $\\mathbf{x} \\in \\mathbb{R}^n$, the task label $y \\in \\{0, \\cdots, L -1\\}$, and a vector of $k$ binary concept annotations $\\mathbf{c} \\in \\{0, 1\\}^k$ (in that order).\n",
    "\n",
    "Below, we show how we do this for the Dot dataset. For details on the actual\n",
    "dataset, please refer to our paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "# We first create a simple helper function to sample random labeled instances\n",
    "# from the Dot dataset:\n",
    "def generate_dot_data(size):\n",
    "    # sample from normal distribution\n",
    "    emb_size = 2\n",
    "    # Generate the latent vectors\n",
    "    v1 = np.random.randn(size, emb_size) * 2\n",
    "    v2 = np.ones(emb_size)\n",
    "    v3 = np.random.randn(size, emb_size) * 2\n",
    "    v4 = -np.ones(emb_size)\n",
    "    # Generate the sample\n",
    "    x = np.hstack([v1+v3, v1-v3])\n",
    "    \n",
    "    # Now the concept vector\n",
    "    c = np.stack([\n",
    "        np.dot(v1, v2).ravel() > 0,\n",
    "        np.dot(v3, v4).ravel() > 0,\n",
    "    ]).T\n",
    "    # And finally the label\n",
    "    y = ((v1*v3).sum(axis=-1) > 0).astype(np.int64)\n",
    "\n",
    "    # We NEED to put all of these into torch Tensors (THIS IS VERY IMPORTANT)\n",
    "    x = torch.FloatTensor(x)\n",
    "    c = torch.FloatTensor(c)\n",
    "    y = torch.Tensor(y)\n",
    "    return x, y, c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# We then use our helper function to generate DataLoaders with the correct\n",
    "# number of samples in them. We use a separate function for this to avoid\n",
    "# repeating code to generate the different folds of our dataset:\n",
    "def data_generator(\n",
    "    dataset_size,\n",
    "    batch_size,\n",
    "    seed=None,\n",
    "):\n",
    "    # For the sake of determinism, let's always first seed everything\n",
    "    # so that things can be recreated\n",
    "    seed_everything(seed)\n",
    "    x, y, c = generate_dot_data(dataset_size)\n",
    "    data = torch.utils.data.TensorDataset(x, y, c)\n",
    "    dl = torch.utils.data.DataLoader(\n",
    "        data,\n",
    "        batch_size=batch_size,\n",
    "    )\n",
    "    return dl\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Finally, we generate our training, testing, and validation folds with\n",
    "# different random seeds\n",
    "train_dl = data_generator(\n",
    "    dataset_size=int(3000 * 0.7),\n",
    "    batch_size=256,\n",
    "    seed=42,\n",
    ")\n",
    "test_dl = data_generator(\n",
    "    dataset_size=int(3000 * 0.2),\n",
    "    batch_size=256,\n",
    "    seed=43,\n",
    ")\n",
    "val_dl = data_generator(\n",
    "    dataset_size=int(3000 * 0.1),\n",
    "    batch_size=256,\n",
    "    seed=44,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Create CEM Model\n",
    "\n",
    "Now that we have our dataset in the correct `DataLoader` format, we can\n",
    "proceed to construct our CEM object. For this, we will simply import\n",
    "our `ConceptEmbeddingModel` object from the `cem` library. We can then instantiate\n",
    "a CEM by indicating:\n",
    "1. The number of concepts `n_concepts` in the dataset we will train it on (e.g., 2 for the Dot dataset).\n",
    "2. The number of output tasks/labels `n_tasks` in the dataset of interest (e.g., 1 for the binary task in the Dot dataset).\n",
    "3. The size `emb_size` of each concept embedding.\n",
    "3. The weight `concept_loss_weight` to use for the concept prediction loss during training of the CEM (e.g., in our paper we set this value to 1 for the Dot dataset).\n",
    "4. The `learning_rate` and `optimizer` to use during training (e.g., \"adam\" or \"sgd\").\n",
    "5. The probability `training_intervention_prob` to perform a random intervention at training time via RandInt (we recommend setting this to 0.25).\n",
    "5. The model architecture `c_extractor_arch` to use for the latent code generator (i.e., the model that generates a latent representation to learn embeddings from the input samples).\n",
    "6. The model `c2y_model` to use as a label predictor **after** all concept embeddings have been generated by a CEM.\n",
    "\n",
    "The only non-trivial parameters to set for this instantiation are the model\n",
    "architectures for the latent code generator (passed via the `c_extractor_arch`\n",
    "argument) and for the label predictor (passed via) the `c2y_model` argument.\n",
    "\n",
    "\n",
    "The first of these arguments, namely the latent code generator `c_extractor_arch`,\n",
    "must be provided as a simple Python function that takes as an input a named\n",
    "argument `output_dim` and generates a model that maps inputs from your task\n",
    "of interest to a latent code with shape `output_dim`. For our Dot example,\n",
    "we will do this via a simple MLP (although in practice you can do use an\n",
    "arbitrarily complex model):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def latent_code_generator_model(output_dim):\n",
    "    if output_dim is None:\n",
    "        output_dim = 128\n",
    "    return torch.nn.Sequential(*[\n",
    "        # 4 because Dot has inputs with 4 features in them\n",
    "        torch.nn.Linear(4, 128),\n",
    "        torch.nn.LeakyReLU(),\n",
    "        torch.nn.Linear(128, 128),\n",
    "        torch.nn.LeakyReLU(),\n",
    "        torch.nn.Linear(128, output_dim),\n",
    "    ])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The second of these arguments, namely the label predictor `c2y_model`, must\n",
    "be any valid Pytorch model that takes as an input as many activations as the\n",
    "CEM's bottleneck (i,e., `n_concepts` * `emb_size`) and generates `n_tasks`\n",
    "outputs, one for each output label in our dataset's downstream task. If not\n",
    "provided, or if set to `None`, then by default we will simply attach a linear\n",
    "mapping after the CEM's bottleneck to obtain the output label prediction.\n",
    "In practice, this is how a CEM is usually constructed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We simply import our CEM class (the same can be done with CBMs to easily train\n",
    "# any of their variants)\n",
    "from cem.models.cem import ConceptEmbeddingModel\n",
    "\n",
    "# And generate the actual model\n",
    "cem_model = ConceptEmbeddingModel(\n",
    "  n_concepts=2, # Number of training-time concepts. Dot has 2\n",
    "  n_tasks=1, # Number of output labels. Dot is binary so it has 1.\n",
    "  emb_size=128,  # We will use an embedding size of 128\n",
    "  concept_loss_weight=1,  # The weight assigned to the concept prediction loss relative to the task predictive loss.\n",
    "  learning_rate=1e-3,  # The learning rate to use during training.\n",
    "  optimizer=\"adam\",  # The optimizer to use during training.\n",
    "  training_intervention_prob=0.25, # RandInt probability. We recommend setting this to 0.25.\n",
    "  c_extractor_arch=latent_code_generator_model, # Here we provide our generating function for the latent code generator model.\n",
    "  c2y_model=None,  # We will let the API simply add a linear layer from the concept bottleneck to the downstream task labels\n",
    ")\n",
    "print(cem_model)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Train the CEM\n",
    "\n",
    "Now that we have both the dataset and the model defined, we can train our CEM\n",
    "using Pytorch Lightning's wrappers for ease. This should be very simple via\n",
    "Pytorch Lightning's `Trainer` once the data has been generated:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl\n",
    "\n",
    "trainer = pl.Trainer(\n",
    "    accelerator=\"gpu\",  # Change to \"cpu\" if you are not running on a GPU!\n",
    "    devices=\"auto\", \n",
    "    max_epochs=500,  # The number of epochs we will train our model for\n",
    "    check_val_every_n_epoch=5,  # And how often we will check for validation metrics\n",
    "    logger=False,  # No logs to be dumped for this trainer\n",
    ")\n",
    "\n",
    "# train_dl and val_dl are datasets previously built...\n",
    "trainer.fit(cem_model, train_dl, val_dl)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For more details on all the things you may add/configure to the Trainer for more\n",
    "control, please refer to the [official documentation](https://lightning.ai/docs/pytorch/stable/common/trainer.html)."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 4: Evaluate Model\n",
    "\n",
    "Once the CEM has been trained, you can evaluate it with test data to generate\n",
    "the learnt embeddings, the predicted concepts, and the predicted task labels!\n",
    "\n",
    "A CEM or CBM model can be called with any input sample of shape `(batch_size, ...)`\n",
    "using Pytorch's functional API:\n",
    "```python\n",
    "(c_pred, c_embs, y_pred) = cem_model(x)\n",
    "```\n",
    "Where:\n",
    "1. `c_pred` is a $(\\text{batch\\_size}, k)$-dimensional vector where the i-th dimension indicates the probability that the i-th concept is on.\n",
    "2. `c_embs` is a $(\\text{batch\\_size}, k \\cdot \\text{emb\\_size})$-dimensional tensor representing the CEM's bottleneck. This corresponds to all concept embeddings concatenated in the same order as given in the training annotations (so reshaping it to $(\\text{batch\\_size}, k, \\text{emb\\_size})$ will allow you to access each concept's embedding directly).\n",
    "3. `y_pred` is a $(\\text{batch\\_size}, L)$-dimensional vector where the i-th dimension is proportional to the probability that the i-th label is predicted for the current sample (the model outputs logits by default). If the downstream task is binary, then the CEM will output a $(\\text{batch\\_size})$-dimensional vector where each entry is the logit of the probability of the downstream class being $1$.\n",
    "\n",
    "This allows us to compute some metrics of interest. Below, we will use\n",
    "PytorchLightning's API to be able to run inference in batches in a GPU to\n",
    "obtain all test activations.\n",
    "\n",
    "Before doing this, we will turn our test dataset into numpy arrays as they\n",
    "will be easily easier to work with if we want to compute custom metrics:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Before anything, however, let's get the underlying numpy arrays of our\n",
    "# test dataset as they will be easier to work with\n",
    "x_test, y_test, c_test = [], [], []\n",
    "for (x, y, c) in test_dl:\n",
    "    x_test.append(x)\n",
    "    y_test.append(y)\n",
    "    c_test.append(c)\n",
    "x_test = np.concatenate(x_test, axis=0)\n",
    "y_test = np.concatenate(y_test, axis=0)\n",
    "c_test = np.concatenate(c_test, axis=0)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to generate the concept, label, and embedding predictions for\n",
    "the test set using our trained CEM:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We will use a Trainer object to run inference in batches over our test\n",
    "# dataset\n",
    "trainer = pl.Trainer(\n",
    "    accelerator=\"gpu\",\n",
    "    devices=\"auto\",\n",
    "    logger=False, # No logs to be dumped for this trainer\n",
    ")\n",
    "batch_results = trainer.predict(cem_model, test_dl)\n",
    "\n",
    "# Then we combine all results into numpy arrays by joining over the batch\n",
    "# dimension\n",
    "c_pred = np.concatenate(\n",
    "    list(map(lambda x: x[0].detach().cpu().numpy(), batch_results)),\n",
    "    axis=0,\n",
    ")\n",
    "c_embs = np.concatenate(\n",
    "    list(map(lambda x: x[1].detach().cpu().numpy(), batch_results)),\n",
    "    axis=0,\n",
    ")\n",
    "# Reshape them so that we have embeddings (batch_size, k, emb_size)\n",
    "c_embs = np.reshape(c_embs, (c_test.shape[0], c_test.shape[1], -1))\n",
    "\n",
    "y_pred = np.concatenate(\n",
    "    list(map(lambda x: x[2].detach().cpu().numpy(), batch_results)),\n",
    "    axis=0,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And compute all the metrics of interest:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########\n",
    "## Compute test task accuracy\n",
    "##########\n",
    "\n",
    "from scipy.special import expit\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "# Which allows us to compute the task accuracy (we explicitly perform a\n",
    "# sigmoidal operation as CEMs always return logits)\n",
    "task_accuracy = accuracy_score(y_test, expit(y_pred) >=0.5)\n",
    "print(f\"Our CEM's test task accuracy is {task_accuracy*100:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########\n",
    "## Compute test concept AUC\n",
    "##########\n",
    "\n",
    "from scipy.special import expit\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "# Which allows us to compute the task accuracy (we explicitly perform a\n",
    "# sigmoidal operation as CEMs always return logits)\n",
    "concept_auc = roc_auc_score(c_test, c_pred)\n",
    "print(f\"Our CEM's test concept AUC is {concept_auc*100:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########\n",
    "## Compute test concept alignment score\n",
    "##########\n",
    "\n",
    "from cem.metrics.cas import concept_alignment_score\n",
    "\n",
    "cas, _ = concept_alignment_score(\n",
    "    c_vec=c_embs,\n",
    "    c_test=c_test,\n",
    "    y_test=y_test,\n",
    "    step=5,\n",
    "    progress_bar=False,\n",
    ")\n",
    "print(f\"Our CEM's concept alignment score (CAS) is {cas*100:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
