{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp methods.uncertain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ipynb_path import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from counterfactual.import_essentials import *\n",
    "from counterfactual.utils import cat_normalize\n",
    "from counterfactual.training_module import CounterfactualTrainingModule\n",
    "from torch import Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# exporti\n",
    "class TargetedLossFunction:\n",
    "    \"\"\"Loss function for generating CEs in a particular target class.\"\"\"\n",
    "    def loss(self, outputs: Tensor, original_labels: Tensor) -> Tensor:\n",
    "        batch_size = outputs.size(0)\n",
    "        targets = 1. - original_labels\n",
    "        assert targets.size() == (batch_size, )\n",
    "        return -F.binary_cross_entropy(outputs[:, 1], targets)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# exporti\n",
    "def _get_prediction_and_grad(\n",
    "    pred_fn: Callable[[Tensor], Tensor],\n",
    "    loss_function,\n",
    "    examples: Tensor, original_labels: Tensor\n",
    ") ->Tuple[Tensor, Tensor, Tensor]: \n",
    "    examples = examples.clone().detach()\n",
    "    assert examples.grad is None\n",
    "    examples.requires_grad = True\n",
    "    output = pred_fn(examples)\n",
    "    loss = loss_function.loss(output, original_labels)\n",
    "    loss.backward()\n",
    "    confidence, _ = torch.max(output, dim=1)\n",
    "    return output, confidence, examples.grad.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# exporti\n",
    "def _uncertaincf(\n",
    "    originals: Tensor,\n",
    "    pred_fn: Callable[[Tensor], Tensor],\n",
    "    loss_function,\n",
    "    n_steps: int,\n",
    "    n_changes: int, # n_changes = 10\n",
    "    confidence_threshold: float,\n",
    "    cat_arrays: List[List[str]],\n",
    "    cat_idx: int\n",
    "):\n",
    "    batch_size = originals.size(0)\n",
    "\n",
    "    examples = originals.clone().detach().view(batch_size, -1)\n",
    "\n",
    "    original_labels = pred_fn(originals).argmax(dim=1).detach()\n",
    "    assert original_labels.shape == (batch_size, )\n",
    "    # assert_shape(original_labels, (batch_size,))\n",
    "\n",
    "    input_flat_size = originals.view(originals.size(0), -1).size(1)\n",
    "    _perturbations = torch.full((input_flat_size,), 1.0 / n_changes)\n",
    "\n",
    "    batch_perturbations = (\n",
    "        _perturbations.unsqueeze(0).repeat(batch_size, 1)\n",
    "    )\n",
    "\n",
    "    altered_pixels = torch.zeros(size=examples.shape, device=examples.device, dtype=torch.int)\n",
    "\n",
    "    for i in range(n_steps):\n",
    "        prediction, confidence, grad = _get_prediction_and_grad(\n",
    "            pred_fn, loss_function, examples, original_labels\n",
    "        )\n",
    "\n",
    "        have_changed_class = torch.argmax(prediction, -1) != original_labels\n",
    "        if torch.sum(have_changed_class) == batch_size:\n",
    "            break\n",
    "\n",
    "        # If we have already changed a pixel n_changes times, set the gradient to zero so we\n",
    "        # don't change it again.\n",
    "        grad[altered_pixels >= n_changes] = 0.0\n",
    "\n",
    "        # We want to change the pixel with the largest gradient, which is the most sensitive.\n",
    "        max_mask = grad.abs() == grad.abs().max(dim=1, keepdim=True)[0]\n",
    "        have_changed_class_mask = have_changed_class.view(batch_size, 1).repeat(\n",
    "            1, max_mask.size(1)\n",
    "        )\n",
    "        confidence_mask = (\n",
    "            (confidence < confidence_threshold)\n",
    "            .view(batch_size, 1)\n",
    "            .repeat(1, max_mask.size(1))\n",
    "        )\n",
    "        # Change the pixel with the largest gradient, if it is part of an example which either\n",
    "        # hasn't changed class, or the class prediction is not >=95%.\n",
    "        to_change_mask = (\n",
    "            max_mask & (~have_changed_class_mask | confidence_mask) & (grad != 0.0)\n",
    "        )\n",
    "        grad_sign = grad[to_change_mask].sign()\n",
    "        perturbation_size = batch_perturbations[to_change_mask]\n",
    "        examples[to_change_mask] += perturbation_size * grad_sign\n",
    "        altered_pixels[to_change_mask] += 1\n",
    "\n",
    "        examples = torch.clamp(examples, 0.0, 1.0)\n",
    "        examples = cat_normalize(examples, cat_arrays, cat_idx, hard=False)\n",
    "    return examples.view(originals.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class UncertainCF(CounterfactualTrainingModule):\n",
    "    def __init__(self, config, model: pl.LightningModule, n_steps: int = 500, n_changes: int = 10, confidence_threshold: float = 0.99):\n",
    "        \"\"\"\n",
    "        config: basic configs\n",
    "        model: the black-box model to be explained\n",
    "        \"\"\"\n",
    "        super().__init__(config)\n",
    "        self.n_steps = n_steps\n",
    "        self.n_changes = n_changes\n",
    "        self.confidence_threshold = confidence_threshold\n",
    "        self.model = model\n",
    "        self.model.freeze()\n",
    "        self.prepare_data()\n",
    "\n",
    "    def predict(self, x):\n",
    "        return self.model.predict(x)\n",
    "\n",
    "    def generate_cf(self, x):\n",
    "        def pred_fn(x):\n",
    "            output = self.model(x)\n",
    "            return torch.cat([1. - output, output], dim=1)\n",
    "        \n",
    "        cat_idx = len(self.continous_cols)\n",
    "        return _uncertaincf(\n",
    "            x, pred_fn, TargetedLossFunction(), \n",
    "            self.n_steps, self.n_changes, self.confidence_threshold,\n",
    "            cat_arrays=self.cat_array, cat_idx=cat_idx\n",
    "        )\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.10 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
