{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "from ipynb_path import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from counterfactual.import_essentials import *\n",
    "from counterfactual.utils import *\n",
    "from counterfactual.train import *\n",
    "from counterfactual.training_module import *\n",
    "from counterfactual.net import *\n",
    "# from counterfactual.evaluate import *\n",
    "\n",
    "from torch.nn.parameter import Parameter\n",
    "from pytorch_lightning.metrics.functional.classification import *\n",
    "\n",
    "pl_logger = logging.getLogger('lightning')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pl version: 1.1.0\n",
      "torch version: 1.7.1\n"
     ]
    }
   ],
   "source": [
    "print(f\"pl version: {pl.__version__}\")\n",
    "print(f\"torch version: {torch.__version__}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export \n",
    "\n",
    "class Clamp(torch.autograd.Function):\n",
    "    \"\"\"\n",
    "    Clamp parameter to [0, 1]\n",
    "    code from: https://discuss.pytorch.org/t/regarding-clamped-learnable-parameter/58474/4\n",
    "    \"\"\"\n",
    "    @staticmethod\n",
    "    def forward(ctx, input):\n",
    "        return input.clamp(min=0, max=1)\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        return grad_output.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class ExplainerBase(nn.Module):\n",
    "    def __init__(self, x: torch.tensor, model: pl.LightningModule):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        self.model.freeze()\n",
    "        self.x = x\n",
    "        self.clamp = Clamp()\n",
    "\n",
    "    def forward(self):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def compute_regularization_loss(self):\n",
    "        cat_idx = len(self.model.continous_cols)\n",
    "        regularization_loss = 0.\n",
    "        for i in range(self.n_cfs):\n",
    "            for col in self.model.cat_arrays:\n",
    "                cat_idx_end = cat_idx + len(col)\n",
    "                regularization_loss += torch.pow((torch.sum(self.cf[i][cat_idx: cat_idx_end]) - 1.0), 2)\n",
    "        return regularization_loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.Adam([self.cf], lr=0.001)\n",
    "\n",
    "    def generate_cf(self, n_iters):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Vanilla CF\n",
    "\n",
    "Wachter, S., Mittelstadt, B., & Russell, C. (2017). Counterfactual Explanations Without Opening the Black Box: Automated Decisions and the GDPR. SSRN Electronic Journal. https://doi.org/10.2139/ssrn.3063289"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class VanillaCF(ExplainerBase):\n",
    "    def __init__(self, x: torch.tensor, model: BaselineModel):\n",
    "        \"\"\"vanilla version of counterfactual generation\n",
    "            - link: https://doi.org/10.2139/ssrn.3063289\n",
    "\n",
    "        Args:\n",
    "            x (torch.tensor): input instance\n",
    "            model (BaselineModel): black-box model\n",
    "        \"\"\"\n",
    "        super().__init__(x, model)\n",
    "        self.cf = nn.Parameter(self.x.clone(), requires_grad=True)\n",
    "\n",
    "    def forward(self):\n",
    "        cf = self.cf * 1.0\n",
    "        return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), False)\n",
    "        # return cf\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.RMSprop([self.cf], lr=0.001)\n",
    "\n",
    "    def compute_regularization_loss(self):\n",
    "        cat_idx = len(self.model.continous_cols)\n",
    "        regularization_loss = 0.\n",
    "        for col in self.model.cat_arrays:\n",
    "            cat_idx_end = cat_idx + len(col)\n",
    "            regularization_loss += torch.pow((torch.sum(self.cf[cat_idx: cat_idx_end]) - 1.0), 2)\n",
    "        return regularization_loss\n",
    "\n",
    "    def _loss_functions(self, x, c):\n",
    "        # target\n",
    "        y_pred = self.model.predict(x)\n",
    "        y_prime = torch.ones(y_pred.shape) - y_pred\n",
    "\n",
    "        c_y = self.model(c)\n",
    "        l_1 = F.binary_cross_entropy(c_y, y_prime.float())\n",
    "        l_2 = F.mse_loss(x, c)\n",
    "        return l_1, l_2\n",
    "\n",
    "    def _loss_compute(self, l_1, l_2):\n",
    "        return 1.0 * l_1 + 0.5 * l_2\n",
    "\n",
    "    def generate_cf(self, n_iters, debug: bool = False):\n",
    "        optim = self.configure_optimizers()\n",
    "        for i in range(n_iters):\n",
    "            c = self()\n",
    "            l_1, l_2 = self._loss_functions(self.x, c)\n",
    "            loss = self._loss_compute(l_1, l_2)\n",
    "            optim.zero_grad()\n",
    "            loss.backward()\n",
    "            optim.step()\n",
    "\n",
    "            if debug and i % 100 == 0:\n",
    "                print(f\"iter: {i}, loss: {loss.item()}\")\n",
    "\n",
    "            # contrain to [0,1]\n",
    "            self.clamp.apply(self.cf)\n",
    "\n",
    "        cf = self.cf * 1.0\n",
    "        self.clamp.apply(self.cf)\n",
    "        return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Diverse CF\n",
    "\n",
    "Mothilal, R. K., Sharma, A., & Tan, C. (2020). Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations. Proceedings of the 2020 Conference on Fairness, Accountability, and Transparency, 607–617. https://doi.org/10.1145/3351095.3372850\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class DiverseCF(ExplainerBase):\n",
    "    def __init__(self, x: torch.tensor, model: CounterfactualTrainingModule):\n",
    "        \"\"\"diverse counterfactual explanation\n",
    "            - link: https://doi.org/10.1145/3351095.3372850\n",
    "\n",
    "        Args:\n",
    "            x (torch.tensor): input instance\n",
    "            model (CounterfactualTrainingModule): black-box model\n",
    "        \"\"\"\n",
    "        self.n_cfs = 5\n",
    "        super().__init__(x, model)\n",
    "        # self.cf = nn.Parameter(self.x.repeat(self.n_cfs, 1), requires_grad=True)\n",
    "        self.cf = nn.Parameter(torch.rand(self.n_cfs, self.x.size(1)), requires_grad=True)\n",
    "\n",
    "    def forward(self):\n",
    "        cf = self.cf * 1.0\n",
    "        return torch.clamp(cf, 0, 1)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.RMSprop([self.cf], lr=0.001)\n",
    "\n",
    "    def _compute_dist(self, x1, x2):\n",
    "        return torch.sum(torch.abs(x1 - x2), dim = 0)\n",
    "\n",
    "    def _compute_proximity_loss(self):\n",
    "        \"\"\"Compute the second part (distance from x1) of the loss function.\"\"\"\n",
    "        proximity_loss = 0.0\n",
    "        for i in range(self.n_cfs):\n",
    "            proximity_loss += self.compute_dist(self.cf[i], self.x1)\n",
    "        return proximity_loss/(torch.mul(len(self.minx[0]), self.total_CFs))\n",
    "\n",
    "    def _dpp_style(self, cf):\n",
    "        det_entries = torch.ones(self.n_cfs, self.n_cfs)\n",
    "        for i in range(self.n_cfs):\n",
    "            for j in range(self.n_cfs):\n",
    "                det_entries[i, j] = self._compute_dist(cf[i], cf[j])\n",
    "\n",
    "        # implement inverse distance\n",
    "        det_entries = 1.0 / (1.0 + det_entries)\n",
    "        det_entries += torch.eye(self.n_cfs) * 0.0001\n",
    "        return torch.det(det_entries)\n",
    "\n",
    "    def _compute_diverse_loss(self, c):\n",
    "        return self._dpp_style(c)\n",
    "\n",
    "    def _compute_regularization_loss(self):\n",
    "        cat_idx = len(self.model.continous_cols)\n",
    "        regularization_loss = 0.\n",
    "        for i in range(self.n_cfs):\n",
    "            for col in self.model.cat_arrays:\n",
    "                cat_idx_end = cat_idx + len(col)\n",
    "                regularization_loss += torch.pow((torch.sum(self.cf[i][cat_idx: cat_idx_end]) - 1.0), 2)\n",
    "        return regularization_loss\n",
    "\n",
    "    def _loss_functions(self, x, c):\n",
    "        # target\n",
    "        y_pred = self.model.predict(x)\n",
    "        y_prime = torch.ones(y_pred.shape) - y_pred\n",
    "\n",
    "        c_y = self.model(c)\n",
    "        # yloss\n",
    "        l_1 = hinge_loss(input=c_y, target=y_prime.float())\n",
    "        # proximity loss\n",
    "        l_2 = l1_mean(x, c)\n",
    "        # diverse loss\n",
    "        l_3 = self._compute_diverse_loss(c)\n",
    "        # categorical penalty\n",
    "        l_4 = self._compute_regularization_loss()\n",
    "        return l_1, l_2, l_3, l_4\n",
    "\n",
    "    def _compute_loss(self, *loss_f):\n",
    "        return sum(loss_f)\n",
    "\n",
    "    def generate_cf(self, n_iters, debug: bool = False):\n",
    "        optim = self.configure_optimizers()\n",
    "        for i in range(n_iters):\n",
    "            c = self()\n",
    "\n",
    "            l_1, l_2, l_3, l_4 = self._loss_functions(self.x, c)\n",
    "            loss = self._compute_loss(l_1, l_2, l_3, l_4)\n",
    "            optim.zero_grad()\n",
    "            loss.backward()\n",
    "            optim.step()\n",
    "\n",
    "            if  debug and i % 100 == 0:\n",
    "                print(f\"iter: {i}, loss: {loss.item()}\")\n",
    "\n",
    "            # contrain to [0,1]\n",
    "            self.clamp.apply(self.cf)\n",
    "\n",
    "        cf = self.cf * 1.0\n",
    "        cf = torch.clamp(cf, 0, 1)\n",
    "        # return cf[0]\n",
    "        return cat_normalize(cf[0].view(1, -1), self.model.cat_arrays, len(self.model.continous_cols), True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ProtoCF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export net\n",
    "\n",
    "class AE(DataModule):\n",
    "    def __init__(self, configs, encoded_size=5):\n",
    "        super().__init__(configs)\n",
    "        input_dim = configs['encoder_dims'][0]\n",
    "        self.encoder_model = MultilayerPerception([input_dim, 20, 16, 14, 12, encoded_size])\n",
    "        self.decoder_model = MultilayerPerception([encoded_size, 12, 14, 16, 20, input_dim])\n",
    "\n",
    "    def forward(self, x):\n",
    "        z = self.encoded(x)\n",
    "        x_prime = self.decoder_model(z)\n",
    "        return x_prime\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)\n",
    "\n",
    "    def encoded(self, x):\n",
    "        return self.encoder_model(x)\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        # batch\n",
    "        x, _ = batch\n",
    "        # prediction\n",
    "        x_prime = self(x)\n",
    "\n",
    "        loss = F.mse_loss(x_prime, x, reduction='mean')\n",
    "\n",
    "        self.log('train/loss', loss)\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        # batch\n",
    "        x, _ = batch\n",
    "        # prediction\n",
    "        x_prime = self(x)\n",
    "\n",
    "        loss = F.mse_loss(x_prime, x, reduction='mean')\n",
    "\n",
    "        self.log('val/val_loss', loss)\n",
    "\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class ProtoCF(ExplainerBase):\n",
    "    def __init__(self, x: torch.tensor, model: pl.LightningModule, train_loader: DataLoader, ae: AE):\n",
    "        \"\"\"vanilla version of counterfactual generation\n",
    "            - link: https://doi.org/10.2139/ssrn.3063289\n",
    "\n",
    "        Args:\n",
    "            x (torch.tensor): input instance\n",
    "            model (pl.LightningModule): black-box model\n",
    "        \"\"\"\n",
    "        super().__init__(x, model)\n",
    "        self.cf = nn.Parameter(self.x.clone(), requires_grad=True)\n",
    "        self.sampled_data, _ = next(iter(train_loader))\n",
    "        self.sampled_label = self.model.predict(self.sampled_data)\n",
    "        self.ae = ae\n",
    "        self.ae.freeze()\n",
    "\n",
    "    def forward(self):\n",
    "        cf = self.cf * 1.0\n",
    "        # return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), False)\n",
    "        return cf\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.RMSprop([self.cf], lr=0.001)\n",
    "\n",
    "    def compute_regularization_loss(self):\n",
    "        cat_idx = len(self.model.continous_cols)\n",
    "        regularization_loss = 0.\n",
    "        for col in self.model.cat_arrays:\n",
    "            cat_idx_end = cat_idx + len(col)\n",
    "            regularization_loss += torch.pow((torch.sum(self.cf[cat_idx: cat_idx_end]) - 1.0), 2)\n",
    "        return regularization_loss\n",
    "\n",
    "    def proto(self, data):\n",
    "        return self.ae.encoded(data).mean(axis=0).view(1, -1)\n",
    "\n",
    "    def _loss_functions(self, x, c):\n",
    "        # target\n",
    "        y_pred = self.model.predict(x)\n",
    "        y = torch.ones(y_pred.shape) - y_pred\n",
    "\n",
    "        data = self.sampled_data[self.sampled_label == y]\n",
    "\n",
    "        l_1 = F.binary_cross_entropy(self.model(c), y)\n",
    "        l_2 = 0.1 * F.l1_loss(x, c) + F.mse_loss(x, c)\n",
    "        l_3 = F.mse_loss(self.ae.encoded(c), self.proto(data))\n",
    "\n",
    "        return l_1, l_2, l_3\n",
    "\n",
    "    def _loss_compute(self, l_1, l_2, l_3):\n",
    "        return l_1 + l_2 + l_3 #+ self.compute_regularization_loss()\n",
    "\n",
    "    def generate_cf(self, n_iters, debug: bool = False):\n",
    "        optim = self.configure_optimizers()\n",
    "        for i in range(n_iters):\n",
    "            c = self()\n",
    "\n",
    "            l_1, l_2, l_3 = self._loss_functions(self.x, c)\n",
    "            loss = self._loss_compute(l_1, l_2, l_3)\n",
    "            optim.zero_grad()\n",
    "            loss.backward()\n",
    "            optim.step()\n",
    "\n",
    "            if debug and i % 100 == 0:\n",
    "                print(f\"iter: {i}, loss: {loss.item()}\")\n",
    "\n",
    "            # contrain to [0,1]\n",
    "            self.clamp.apply(self.cf)\n",
    "\n",
    "        cf = self.cf * 1.0\n",
    "        self.clamp.apply(self.cf)\n",
    "        # return cf\n",
    "        return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VAE-CF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export net\n",
    "class VAE(pl.LightningModule):\n",
    "    def __init__(self, input_dims, encoded_size=5):\n",
    "        super().__init__()\n",
    "        self.encoder_mean = MultilayerPerception([input_dims + 1, 20, 16, 14, 12, encoded_size])\n",
    "        self.encoder_var = MultilayerPerception([input_dims + 1, 20, 16, 14, 12, encoded_size])\n",
    "        self.decoder_mean = MultilayerPerception([encoded_size + 1, 12, 14, 16, 20, input_dims])\n",
    "\n",
    "    def encoder(self, x):\n",
    "        mean = self.encoder_mean(x)\n",
    "        logvar = 0.5+ self.encoder_var(x)\n",
    "        return mean, logvar\n",
    "\n",
    "    def decoder(self, z):\n",
    "        mean = self.decoder_mean(z)\n",
    "        return mean\n",
    "\n",
    "    def sample_latent_code(self, mean, logvar):\n",
    "        eps = torch.randn_like(logvar)\n",
    "        return mean + torch.sqrt(logvar) * eps\n",
    "\n",
    "    def normal_likelihood(self, x, mean, logvar, raxis=1):\n",
    "        return torch.sum( -.5 * ((x - mean)*(1./logvar)*(x-mean) + torch.log(logvar) ), axis=1)\n",
    "\n",
    "    def forward(self, x, c):\n",
    "        \"\"\"\n",
    "        x: input instance\n",
    "        c: target y\n",
    "        \"\"\"\n",
    "        c = c.view(c.shape[0], 1).float()\n",
    "        # c = torch.tensor(c).float()\n",
    "        res = {}\n",
    "        mc_samples = 50\n",
    "        em, ev = self.encoder(torch.cat((x, c), 1))\n",
    "        res['em'] = em\n",
    "        res['ev'] = ev\n",
    "        res['z'] = []\n",
    "        res['x_pred'] = []\n",
    "        res['mc_samples'] = mc_samples\n",
    "        for i in range(mc_samples):\n",
    "            z = self.sample_latent_code(em, ev)\n",
    "            x_pred = self.decoder(torch.cat((z, c), 1))\n",
    "            res['z'].append(z)\n",
    "            res['x_pred'].append(x_pred)\n",
    "        return res\n",
    "\n",
    "    def compute_elbo(self, x, c, model):\n",
    "        c= c.clone().detach().float()\n",
    "        c=c.view(c.shape[0], 1)\n",
    "        em, ev = self.encoder(torch.cat((x,c),1))\n",
    "        kl_divergence = 0.5*torch.mean(em**2 + ev - torch.log(ev) - 1, axis=1)\n",
    "\n",
    "        z = self.sample_latent_code(em, ev)\n",
    "        dm= self.decoder( torch.cat((z,c),1) )\n",
    "        log_px_z = torch.tensor(0.0)\n",
    "\n",
    "        x_pred= dm\n",
    "        return torch.mean(log_px_z), torch.mean(kl_divergence), x, x_pred, model.predict(x_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class VAE_CF(CounterfactualTrainingModule):\n",
    "    def __init__(self, config: Dict, model: pl.LightningModule):\n",
    "        \"\"\"\n",
    "        config: basic configs\n",
    "        model: the black-box model to be explained\n",
    "        \"\"\"\n",
    "        super().__init__(config)\n",
    "        self.model = model\n",
    "        self.model.freeze()\n",
    "        self.vae = VAE(input_dims=self.enc_dims[0])\n",
    "        # validity_reg set to 42.0\n",
    "        # according to https://interpret.ml/DiCE/notebooks/DiCE_getting_started_feasible.html#Generate-counterfactuals-using-a-VAE-model\n",
    "        self.validity_reg = config['validity_reg'] if 'validity_reg' in config.keys() else 1.0\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        \"\"\"lazy implementation since this method is actually not needed\"\"\"\n",
    "        recon_err, kl_err, x_true, x_pred, cf_label = self.vae.compute_elbo(x, 1 - self.model.predict(x), self.model)\n",
    "        # return y, c\n",
    "        return cf_label, x_pred\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)\n",
    "\n",
    "    def predict(self, x):\n",
    "        return self.model.predict(x)\n",
    "\n",
    "    def compute_loss(self, out, x, y):\n",
    "        em = out['em']\n",
    "        ev = out['ev']\n",
    "        z = out['z']\n",
    "        dm = out['x_pred']\n",
    "        mc_samples = out['mc_samples']\n",
    "        #KL Divergence\n",
    "        kl_divergence = 0.5*torch.mean(em**2 + ev - torch.log(ev) - 1, axis=1)\n",
    "\n",
    "        #Reconstruction Term\n",
    "        #Proximity: L1 Loss\n",
    "        x_pred = dm[0]\n",
    "        cat_idx = len(self.continous_cols)\n",
    "        # recon_err = - \\\n",
    "        #     torch.sum(torch.abs(x[:, cat_idx:-1] -\n",
    "        #                         x_pred[:, cat_idx:-1]), axis=1)\n",
    "        recon_err = - torch.sum(torch.abs(x - x_pred), axis=1)\n",
    "\n",
    "        # Sum to 1 over the categorical indexes of a feature\n",
    "        for col in self.cat_arrays:\n",
    "            cat_end_idx = cat_idx + len(col)\n",
    "            temp = - \\\n",
    "                torch.abs(1.0 - x_pred[:, cat_idx: cat_end_idx].sum(axis=1))\n",
    "            recon_err += temp\n",
    "\n",
    "        #Validity\n",
    "        c_y = self.model(x_pred)\n",
    "        validity_loss = torch.zeros(1, device=self.device)\n",
    "        validity_loss += hinge_loss(input=c_y, target=y.float())\n",
    "\n",
    "        for i in range(1, mc_samples):\n",
    "            x_pred = dm[i]\n",
    "\n",
    "            # recon_err += - \\\n",
    "            #     torch.sum(torch.abs(x[:, cat_idx:-1] -\n",
    "            #                         x_pred[:, cat_idx:-1]), axis=1)\n",
    "            recon_err += - torch.sum(torch.abs(x - x_pred), axis=1)\n",
    "\n",
    "            # Sum to 1 over the categorical indexes of a feature\n",
    "            for col in self.cat_arrays:\n",
    "                cat_end_idx = cat_idx + len(col)\n",
    "                temp = - \\\n",
    "                    torch.abs(1.0 - x_pred[:, cat_idx: cat_end_idx].sum(axis=1))\n",
    "                recon_err += temp\n",
    "\n",
    "            #Validity\n",
    "            c_y = self.model(x_pred)\n",
    "            validity_loss += hinge_loss(c_y, y.float())\n",
    "\n",
    "        recon_err = recon_err / mc_samples\n",
    "        validity_loss = -1 * self.validity_reg * validity_loss / mc_samples\n",
    "\n",
    "        return -torch.mean(recon_err - kl_divergence) - validity_loss\n",
    "\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        # batch\n",
    "        x, _ = batch\n",
    "        # prediction\n",
    "        y_hat = self.model.predict(x)\n",
    "        # target\n",
    "        y = 1.0 - y_hat\n",
    "\n",
    "        out = self.vae(x, y)\n",
    "        loss = self.compute_loss(out, x, y)\n",
    "\n",
    "        self.log('train/loss', loss)\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        # batch\n",
    "        x, _ = batch\n",
    "        # prediction\n",
    "        y_hat = self.model.predict(x)\n",
    "        # target\n",
    "        y = 1.0 - y_hat\n",
    "\n",
    "        out = self.vae(x, y)\n",
    "        loss = self.compute_loss(out, x, y)\n",
    "\n",
    "        _, _, _, x_pred, cf_label = self.vae.compute_elbo(x, y, self.model)\n",
    "\n",
    "        cf_proximity = torch.abs(x - x_pred).sum(dim=1).mean()\n",
    "        cf_accuracy = accuracy(cf_label, y)\n",
    "\n",
    "        self.log('val/val_loss', loss)\n",
    "        self.log('val/proximity', cf_proximity)\n",
    "        self.log('val/cf_accuracy', cf_accuracy)\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def validation_epoch_end(self, val_outs):\n",
    "        return\n",
    "\n",
    "    def generate_cf(self, x):\n",
    "        self.vae.freeze()\n",
    "        y_hat = self.model.predict(x)\n",
    "        recon_err, kl_err, x_true, x_pred, cf_label = self.vae.compute_elbo(x, 1.-y_hat, self.model)\n",
    "        return self.model.cat_normalize(x_pred, hard=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# C-CHVAE\n",
    "\n",
    "Pawelczyk, Martin, Klaus Broelemann and Gjergji Kasneci. “Learning Model-Agnostic Counterfactual Explanations for Tabular Data.” Proceedings of The Web Conference 2020 (2020)\n",
    "- https://arxiv.org/pdf/1910.09398.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export net\n",
    "class CHVAE(pl.LightningModule):\n",
    "    \"\"\"\n",
    "    https://github.com/carla-recourse/CARLA/blob/main/carla/recourse_methods/autoencoder/models.py\n",
    "    \"\"\"\n",
    "    def __init__(self, input_dims, encoded_size=5):\n",
    "        super().__init__()\n",
    "        encoder = MultilayerPerception([input_dims, 20, 16, 14, 12])\n",
    "        decoder = MultilayerPerception([encoded_size, 12, 14, 16, 20])\n",
    "\n",
    "        self._mu_enc = nn.Sequential(encoder, nn.Linear(12, encoded_size))\n",
    "        self._log_var_enc = nn.Sequential(encoder, nn.Linear(12, encoded_size))\n",
    "        self.mu_dec = nn.Sequential(\n",
    "            decoder, nn.Linear(20, input_dims), nn.BatchNorm1d(input_dims), nn.Sigmoid(),\n",
    "        )\n",
    "        self.log_var_dec  = nn.Sequential(\n",
    "            decoder, nn.Linear(20, input_dims), nn.BatchNorm1d(input_dims), nn.Sigmoid(),\n",
    "        )\n",
    "\n",
    "    def encode(self, x):\n",
    "        return self._mu_enc(x), self._log_var_enc(x)\n",
    "\n",
    "    def decode(self, z):\n",
    "        return self.mu_dec(z), self.log_var_dec(z)\n",
    "\n",
    "    def __reparametrization_trick(self, mu, log_var):\n",
    "        std = torch.exp(0.5 * log_var)\n",
    "        epsilon = torch.randn_like(std)  # the Gaussian random noise\n",
    "        return mu + std * epsilon\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu_z, log_var_z = self.encode(x)\n",
    "        z_rep = self.__reparametrization_trick(mu_z, log_var_z)\n",
    "        mu_x, log_var_x = self.decode(z_rep)\n",
    "\n",
    "        return mu_x, log_var_x, z_rep, mu_z, log_var_z\n",
    "\n",
    "    def regenerate(self, z):\n",
    "        mu_x, log_var_x = self.decode(z)\n",
    "        return mu_x\n",
    "\n",
    "    def compute_loss(self, mse_loss, mu, logvar):\n",
    "        MSE = mse_loss\n",
    "        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
    "        return MSE + KLD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class CCHVAE(CounterfactualTrainingModule):\n",
    "    \"\"\"\n",
    "    Refer to https://github.com/carla-recourse/CARLA/blob/main/carla/recourse_methods/catalog/cchvae/model.py\n",
    "    \"\"\"\n",
    "    def __init__(self, config: Dict, model: pl.LightningModule):\n",
    "        \"\"\"\n",
    "        config: basic configs\n",
    "        model: the black-box model to be explained\n",
    "        \"\"\"\n",
    "        super().__init__(config)\n",
    "        self.model = model\n",
    "        self.model.freeze()\n",
    "        self.vae = CHVAE(input_dims=self.enc_dims[0])\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        \"\"\"lazy implementation as it is not needed\"\"\"\n",
    "        MU_X_eval, LOG_VAR_X_eval, Z_ENC_eval, MU_Z_eval, LOG_VAR_Z_eval = self.vae(x)\n",
    "        return MU_X_eval, LOG_VAR_X_eval\n",
    "\n",
    "    def predict(self, x):\n",
    "        return self.model.predict(x)\n",
    "\n",
    "    def _hyper_sphere_coordindates(\n",
    "        self, x, high: int, low: int, n_search_samples: int\n",
    "    ) -> Tuple[np.ndarray, np.ndarray]:\n",
    "        \"\"\"\n",
    "        :param n_search_samples: int > 0\n",
    "        :param x: input point array\n",
    "        :param high: float>= 0, h>l; upper bound\n",
    "        :param low: float>= 0, l<h; lower bound\n",
    "        :return: candidate counterfactuals & distances\n",
    "        \"\"\"\n",
    "        delta_instance = torch.randn(n_search_samples, x.size(1))\n",
    "        dist = (\n",
    "            torch.rand(n_search_samples) * (high - low) + low\n",
    "        )  # length range [l, h)\n",
    "        norm_p = torch.norm(delta_instance, p=1, dim=1)\n",
    "        d_norm = torch.divide(dist, norm_p).reshape(-1, 1)  # rescale/normalize factor\n",
    "        delta_instance = torch.multiply(delta_instance, d_norm)\n",
    "        candidate_counterfactuals = x + delta_instance\n",
    "        return candidate_counterfactuals, dist\n",
    "\n",
    "    def generate_cf(self, x):\n",
    "        # params\n",
    "        n_search_samples = 300; count = 0; max_iter = 1000; step=0.1\n",
    "        low = 0; high = step\n",
    "\n",
    "        self.vae.freeze()\n",
    "        y_hat = self.model.predict(x)\n",
    "\n",
    "        # vectorize z\n",
    "        z = self.vae.encode(x)[0]\n",
    "        z_rep = torch.repeat_interleave(\n",
    "            z.reshape(1, -1), n_search_samples, dim=0\n",
    "        )\n",
    "\n",
    "        candidate_dist = []\n",
    "        x_ce: Union[np.ndarray, torch.Tensor] = torch.tensor([])\n",
    "\n",
    "        while count <= max_iter:\n",
    "            count = count + 1\n",
    "\n",
    "            # STEP 1 -- SAMPLE POINTS on hyper sphere around instance\n",
    "            latent_neighbourhood, _ = self._hyper_sphere_coordindates(z_rep, high, low, n_search_samples)\n",
    "            x_ce = self.vae.decode(latent_neighbourhood)[0]\n",
    "\n",
    "            x_ce = self.model.cat_normalize(x_ce, hard=True)\n",
    "            x_ce = x_ce.clip(0, 1)\n",
    "\n",
    "            # STEP 2 -- COMPUTE l1 norms\n",
    "            distances = torch.abs((x_ce - x)).sum(dim=1)\n",
    "\n",
    "            # counterfactual labels\n",
    "            y_candidate = self.model.predict(x_ce)\n",
    "            indeces = torch.where(y_candidate != y_hat)\n",
    "            candidate_counterfactuals = x_ce[indeces]\n",
    "            candidate_dist = distances[indeces]\n",
    "\n",
    "            if len(candidate_dist) == 0:\n",
    "                # no candidate found & push search range outside\n",
    "                low = high\n",
    "                high = low + step\n",
    "            elif len(candidate_dist) > 0:\n",
    "                # certain candidates generated\n",
    "                min_index = np.argmin(candidate_dist)\n",
    "                # return candidate_counterfactuals[min_index]\n",
    "                return candidate_counterfactuals[0]\n",
    "        return x_ce[0]\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        MU_X_eval, LOG_VAR_X_eval, Z_ENC_eval, MU_Z_eval, LOG_VAR_Z_eval = self.vae(x)\n",
    "\n",
    "        reconstruction = MU_X_eval\n",
    "        mse_loss = F.mse_loss(reconstruction, x)\n",
    "        loss = self.vae.compute_loss(mse_loss, MU_Z_eval, LOG_VAR_Z_eval)\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        MU_X_eval, LOG_VAR_X_eval, Z_ENC_eval, MU_Z_eval, LOG_VAR_Z_eval = self.vae(x)\n",
    "\n",
    "        reconstruction = self.cat_normalize(MU_X_eval)\n",
    "        mse_loss = F.mse_loss(reconstruction, x)\n",
    "\n",
    "        self.log('val/val_loss', mse_loss)\n",
    "        return mse_loss\n",
    "\n",
    "    def validation_epoch_end(self, val_outs):\n",
    "        return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2, 3]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor([1, 2, 3]).reshape(1, -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CounteRGAN\n",
    "\n",
    "Nemirovsky, D., Thiebaut, N., Xu, Y., & Gupta, A. (2020). Countergan: generating realistic counterfactuals with residual generative adversarial nets. arXiv preprint arXiv:2009.05199.\n",
    "\n",
    "- https://arxiv.org/pdf/2009.05199.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class CounteRGANTrainingModule(DataModule):\n",
    "    def __init__(self, config: Dict, model, target_class: int):\n",
    "        super().__init__(config)\n",
    "        self.model = model\n",
    "        self.model.freeze()\n",
    "\n",
    "        if target_class in [0., 1.]:\n",
    "            self.target_class = target_class\n",
    "        else:\n",
    "            raise ValueError(f'`target_class` should be either `0` or `1`.')\n",
    "        self.init_rgan()\n",
    "\n",
    "    def init_rgan(self):\n",
    "        gen_dims = self.enc_dims + self.exp_dims + [self.enc_dims[0]]\n",
    "        self.generator = MultilayerPerception(gen_dims)\n",
    "        self.discriminator = nn.Sequential(\n",
    "            MultilayerPerception([gen_dims[0], 128]),\n",
    "            nn.Linear(128, 1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        pass\n",
    "\n",
    "    def discriminate(self, x):\n",
    "        y_hat = self.discriminator(x)\n",
    "        return torch.squeeze(y_hat, dim=-1)\n",
    "\n",
    "    def forward(self, x, hard=False, imutable=True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
    "        \"\"\"forward pass of CounteRGAN\n",
    "\n",
    "        Args:\n",
    "            x (torch.Tensor): input\n",
    "            hard (bool, optional): categorical features in counterfactual is one-hot-encoding or not.\n",
    "                Defaults to False.\n",
    "            imutable (bool, optional): whether to use immutable features or not. Defaults to True.\n",
    "\n",
    "        Returns:\n",
    "            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: outputs `cf`, `real_fake_y`, `cf_y`\n",
    "        \"\"\"\n",
    "        cf = self.generator(x)\n",
    "        cf = x + cf\n",
    "        cf = self.cat_normalize(cf, hard=hard)\n",
    "        if imutable:\n",
    "            cf[:, self.imutable_idx_list] = x[:, self.imutable_idx_list] * 1.0\n",
    "        real_fake_y = self.discriminate(cf)\n",
    "        cf_y = self.model(cf)\n",
    "\n",
    "        return cf, real_fake_y, cf_y\n",
    "\n",
    "    def generate_cf(self, x):\n",
    "        cf, _, _ = self.forward(x, hard=True)\n",
    "        return cf\n",
    "\n",
    "    def discriminator_step(self, batch):\n",
    "        x_real, _ = batch\n",
    "        real_disc_y = self.discriminate(x_real)\n",
    "        x_fake, fake_disc_y, _ = self(x_real)\n",
    "        y_hat = torch.cat((real_disc_y, fake_disc_y))\n",
    "\n",
    "        x = torch.cat((x_real, x_fake))\n",
    "        y = torch.cat((torch.ones(len(x_real)), torch.zeros(len(x_fake))))\n",
    "\n",
    "        # # shuffle\n",
    "        # p = np.random.permutation(len(y))\n",
    "        # x, y = x[p], y[p]\n",
    "\n",
    "        # train model\n",
    "        # y_hat = self.discriminator(x)\n",
    "        loss = F.binary_cross_entropy(y_hat, y)\n",
    "        return loss\n",
    "\n",
    "    def generator_step(self, batch):\n",
    "        x, y = batch\n",
    "        cf, y_disc, cf_y = self.forward(x)\n",
    "        # cf loss\n",
    "        # y_prime = 1. - self.model.predict(x)\n",
    "        y_prime = self.target_class + torch.zeros_like(cf_y)\n",
    "        loss_cf = F.binary_cross_entropy(cf_y, y_prime)\n",
    "        # gan loss\n",
    "        y_fake = torch.ones(len(cf))\n",
    "        loss_gan = F.binary_cross_entropy(y_disc, y_fake)\n",
    "        # regularization loss\n",
    "        reg_loss = 0. * F.l1_loss(x, cf) + 1e-6 * F.mse_loss(x, cf)\n",
    "\n",
    "        return loss_gan + loss_cf + reg_loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        opt_1 = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=0.0005)\n",
    "        opt_2 = torch.optim.RMSprop([p for p in self.parameters() if p.requires_grad], lr=2e-4)\n",
    "        return (opt_1, opt_2)\n",
    "\n",
    "    def training_step(self, batch, batch_idx, optimizer_idx):\n",
    "        self.model.freeze()\n",
    "        # pl_logger.info([p for p in self.model.parameters() if not p.requires_grad])\n",
    "        if batch_idx % 6 in [0, 1]:\n",
    "            if optimizer_idx == 0:\n",
    "                use_grad(self.discriminator, requires_grad=True)\n",
    "                return self.discriminator_step(batch)\n",
    "        else:\n",
    "            if optimizer_idx == 1:\n",
    "                use_grad(self.discriminator, requires_grad=False)\n",
    "                return self.generator_step(batch)\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        loss = self.generator_step(batch)\n",
    "        self.log('val/val_loss', loss)\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class CounteRGAN:\n",
    "    def __init__(\n",
    "        self,\n",
    "        rgan_0: CounteRGANTrainingModule,\n",
    "        rgan_1:CounteRGANTrainingModule,\n",
    "    ):\n",
    "        # copy attributes\n",
    "        self.__dict__ = rgan_0.__dict__.copy()\n",
    "        self.rgan_0 = rgan_0\n",
    "        self.rgan_1 = rgan_1\n",
    "        self.model = self.rgan_0.model\n",
    "        self.rgan_0.eval()\n",
    "        self.rgan_1.eval()\n",
    "        use_grad(self.rgan_0, self.rgan_1, requires_grad=False)\n",
    "\n",
    "    def predict(self, x):\n",
    "        return self.model.predict(x)\n",
    "\n",
    "    def generate_cf(self, x):\n",
    "        cf_0 = self.rgan_0.generate_cf(x)\n",
    "        cf_1 = self.rgan_1.generate_cf(x)\n",
    "        y_target = 1 - self.model.predict(x)\n",
    "        return torch.where(\n",
    "            torch.round(y_target).byte(), cf_1, cf_0\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m_configs = {\n",
    "    'data_dir': 'data/s_adult.csv',\n",
    "    'lr':0.01, \n",
    "    'batch_size': 2048,\n",
    "    'lambda_1': 1.,\n",
    "    'lambda_2': .01,\n",
    "    'lambda_3': 1.,\n",
    "    'threshold': 1., \n",
    "    'continous_cols': ['age', 'hours_per_week'],\n",
    "    'discret_cols': ['workclass', 'education', 'marital_status', 'occupation', 'race', 'gender'],\n",
    "    'encoder_dims': [29, 50, 10],\n",
    "    'decoder_dims': [10, 10],\n",
    "    'explainer_dims': [10, 50],\n",
    "    'loss_1': 'mse',\n",
    "    'loss_2': 'mse',\n",
    "    'loss_3': 'mse'\n",
    "}\n",
    "# trainer configs\n",
    "t_configs = {\n",
    "    'max_epochs': 100,\n",
    "#     'deterministic': True,\n",
    "#     'gradient_clip_val': 0.5,\n",
    "    'num_sanity_val_steps': 0,\n",
    "#     'callbacks': [early_stopping],\n",
    "    'accelerator': 'ddp',\n",
    "    'gpus': 1,\n",
    "#     debug\n",
    "#     'weights_summary': 'full',\n",
    "#     'fast_dev_run': True,\n",
    "    'track_grad_norm':2\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quickly init a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: False\n",
      "TPU available: None, using: 0 TPU cores\n",
      "/opt/conda/envs/pytorch/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: GPU available but not used. Set the --gpus flag when calling the script.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "x_cont: (32561, 2), x_cat: (32561, 27)\n",
      "(32561, 29)\n",
      "\n",
      "  | Name  | Type       | Params | In sizes | Out sizes\n",
      "------------------------------------------------------------\n",
      "0 | model | Sequential | 2.3 K  | [1, 29]  | [1, 1]   \n",
      "------------------------------------------------------------\n",
      "2.3 K     Trainable params\n",
      "0         Non-trainable params\n",
      "2.3 K     Total params\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training: 0it [00:00, ?it/s]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/pytorch/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    }
   ],
   "source": [
    "model = load_model('../saved_weights/adult/baseline/epoch=55-step=10695.ckpt', 56)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[8.7551e-01, 8.3667e-01, 8.5154e-02, 6.7861e-01, 4.8256e-01, 1.4477e-01,\n",
       "         3.7129e-01, 8.2560e-01, 6.8110e-01, 1.0998e-01, 6.4115e-01, 4.2497e-01,\n",
       "         3.1698e-01, 6.1735e-01, 5.8713e-01, 8.0798e-01, 7.3314e-06, 5.0367e-01,\n",
       "         7.4309e-01, 2.9842e-01, 9.4241e-01, 3.8378e-01, 6.7887e-01, 3.9197e-01,\n",
       "         6.5418e-01, 6.4981e-01, 2.8277e-01, 1.5848e-01, 2.7166e-01]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(1, 29)\n",
    "x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VanillaCF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 5.1 s, sys: 0 ns, total: 5.1 s\n",
      "Wall time: 3.42 s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[0.3636, 1.0516, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000,\n",
       "         0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000,\n",
       "         1.0000, 0.0000]], grad_fn=<CopySlices>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "cf = VanillaCF(x, model)\n",
    "cf.generate_cf(1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DiverseCF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "../counterfactual/utils.py:222: UserWarning: Using a target size (torch.Size([5, 29])) that is different to the input size (torch.Size([1, 29])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  return F.l1_loss(x, c, reduction='mean') / x.abs().mean() # MAD\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 16.7 s, sys: 10.8 ms, total: 16.7 s\n",
      "Wall time: 9.01 s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0.8001, 0.5417, 0.2331, 0.7017, 0.0000, 0.1092, 0.0928, 0.0000, 0.0618,\n",
       "        0.0000, 0.5129, 0.4254, 0.4727, 0.6172, 0.5870, 0.5621, 0.3842, 0.5036,\n",
       "        0.7432, 0.2988, 0.9423, 0.3837, 0.6788, 0.5285, 0.6543, 0.6498, 0.3424,\n",
       "        0.3928, 0.2716], grad_fn=<SelectBackward>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "cf = DiverseCF(x, model)\n",
    "cf.generate_cf(1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ProtoCF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "GPU available: False, used: False\n",
      "TPU available: None, using: 0 TPU cores\n",
      "TPU available: None, using: 0 TPU cores\n",
      "hyper parameters: \"batch_size\":     128\n",
      "\"continous_cols\": ['age', 'hours_per_week']\n",
      "\"data_dir\":       ../data/s_adult.csv\n",
      "\"decoder_dims\":   [10, 10]\n",
      "\"discret_cols\":   ['workclass', 'education', 'marital_status', 'occupation', 'race', 'gender']\n",
      "\"encoder_dims\":   [29, 50, 10]\n",
      "\"explainer_dims\": [10, 50]\n",
      "\"lambda_1\":       1.0\n",
      "\"lambda_2\":       0.01\n",
      "\"lambda_3\":       1.0\n",
      "\"loss_1\":         mse\n",
      "\"loss_2\":         mse\n",
      "\"loss_3\":         mse\n",
      "\"lr\":             0.01\n",
      "\"threshold\":      1.0\n",
      "hyper parameters: \"batch_size\":     128\n",
      "\"continous_cols\": ['age', 'hours_per_week']\n",
      "\"data_dir\":       ../data/s_adult.csv\n",
      "\"decoder_dims\":   [10, 10]\n",
      "\"discret_cols\":   ['workclass', 'education', 'marital_status', 'occupation', 'race', 'gender']\n",
      "\"encoder_dims\":   [29, 50, 10]\n",
      "\"explainer_dims\": [10, 50]\n",
      "\"lambda_1\":       1.0\n",
      "\"lambda_2\":       0.01\n",
      "\"lambda_3\":       1.0\n",
      "\"loss_1\":         mse\n",
      "\"loss_2\":         mse\n",
      "\"loss_3\":         mse\n",
      "\"lr\":             0.01\n",
      "\"threshold\":      1.0\n",
      "x_cont: (32561, 2), x_cat: (32561, 27)\n",
      "x_cont: (32561, 2), x_cat: (32561, 27)\n",
      "(32561, 29)\n",
      "(32561, 29)\n",
      "\n",
      "  | Name          | Type                 | Params | In sizes | Out sizes\n",
      "------------------------------------------------------------------------------\n",
      "0 | encoder_model | MultilayerPerception | 1.6 K  | [1, 29]  | [1, 5]   \n",
      "1 | decoder_model | MultilayerPerception | 1.6 K  | [1, 5]   | [1, 29]  \n",
      "------------------------------------------------------------------------------\n",
      "3.2 K     Trainable params\n",
      "0         Non-trainable params\n",
      "3.2 K     Total params\n",
      "\n",
      "  | Name          | Type                 | Params | In sizes | Out sizes\n",
      "------------------------------------------------------------------------------\n",
      "0 | encoder_model | MultilayerPerception | 1.6 K  | [1, 29]  | [1, 5]   \n",
      "1 | decoder_model | MultilayerPerception | 1.6 K  | [1, 5]   | [1, 29]  \n",
      "------------------------------------------------------------------------------\n",
      "3.2 K     Trainable params\n",
      "0         Non-trainable params\n",
      "3.2 K     Total params\n",
      "C:\\Users\\Hangzhi Guo\\AppData\\Roaming\\Python\\Python38\\site-packages\\pytorch_lightning\\utilities\\distributed.py:49: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd00ff3b3d4648b9bd9f0acd33132410",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "result = train(AE(m_configs), t_configs)\n",
    "ae = result['module']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ProtoCF initialized.\n",
      "iter: 0, loss: 6.1465535163879395\n",
      "l_1: 0.0, l_2: 0.0, l_3: 0.14655336737632751\n",
      "iter: 100, loss: 6.0259175300598145\n",
      "l_1: 0.015172960236668587, l_2: 0.0018211350543424487, l_3: 0.00983387790620327\n",
      "iter: 200, loss: 6.016486167907715\n",
      "l_1: 0.01386270672082901, l_2: 0.002581034554168582, l_3: 0.001332893269136548\n",
      "iter: 300, loss: 6.016288757324219\n",
      "l_1: 0.014074395410716534, l_2: 0.002700776094570756, l_3: 0.0008641568128950894\n",
      "iter: 400, loss: 6.016404628753662\n",
      "l_1: 0.01418527215719223, l_2: 0.0027024508453905582, l_3: 0.0008683655178174376\n",
      "iter: 500, loss: 6.016188144683838\n",
      "l_1: 0.013960369862616062, l_2: 0.0027024406008422375, l_3: 0.0008763322839513421\n",
      "iter: 600, loss: 6.016372203826904\n",
      "l_1: 0.014170968905091286, l_2: 0.002703545382246375, l_3: 0.000849399424623698\n",
      "iter: 700, loss: 6.016396999359131\n",
      "l_1: 0.014149093069136143, l_2: 0.002701388904824853, l_3: 0.0008973746444098651\n",
      "iter: 800, loss: 6.016327381134033\n",
      "l_1: 0.014108755625784397, l_2: 0.002702898345887661, l_3: 0.0008669801172800362\n",
      "iter: 900, loss: 6.016377925872803\n",
      "l_1: 0.01419689692556858, l_2: 0.002710389206185937, l_3: 0.000826049770694226\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[0.9893, 0.9910, 0.8431, 0.9772, 0.4053, 0.0504, 0.6392, 0.8146, 0.4124,\n",
       "         0.5149, 0.6525, 0.9059, 0.5502, 0.0373, 0.3220, 0.4081, 0.0209, 0.7489,\n",
       "         0.6025, 0.2613, 0.9205, 0.7139, 0.0665, 0.4821, 0.3095, 0.4590, 0.4964,\n",
       "         0.9253, 0.7672]], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cf = ProtoCF(x=x, model=model, train_loader=ae.train_dataloader(), ae=ae)\n",
    "cf.generate_cf(1000, debug=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VAE-CF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n",
      "Using environment variable NODE_RANK for node rank (0).\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "hyper parameters: \"batch_size\":     2048\n",
      "\"continous_cols\": ['age', 'hours_per_week']\n",
      "\"data_dir\":       data/s_adult.csv\n",
      "\"decoder_dims\":   [10, 10]\n",
      "\"discret_cols\":   ['workclass', 'education', 'marital_status', 'occupation', 'race', 'gender']\n",
      "\"encoder_dims\":   [29, 50, 10]\n",
      "\"explainer_dims\": [10, 50]\n",
      "\"lambda_1\":       1.0\n",
      "\"lambda_2\":       0.01\n",
      "\"lambda_3\":       1.0\n",
      "\"loss_1\":         mse\n",
      "\"loss_2\":         mse\n",
      "\"loss_3\":         mse\n",
      "\"lr\":             0.01\n",
      "\"threshold\":      1.0\n",
      "x_cont: (32561, 2), x_cat: (32561, 27)\n",
      "(32561, 29)\n",
      "\n",
      "  | Name  | Type          | Params | In sizes | Out sizes\n",
      "---------------------------------------------------------------\n",
      "0 | model | BaselineModel | 2.3 K  | [1, 29]  | [1]      \n",
      "1 | vae   | VAE           | 4.8 K  | ?        | ?        \n",
      "---------------------------------------------------------------\n",
      "4.8 K     Trainable params\n",
      "2.3 K     Non-trainable params\n",
      "7.1 K     Total params\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0:   0%|          | 0/16 [00:00<?, ?it/s] "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/pytorch/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/opt/conda/envs/pytorch/lib/python3.7/site-packages/ipykernel_launcher.py:31: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fd0bc471830>\n",
      "Traceback (most recent call last):\n",
      "  File \"/opt/conda/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py\", line 1203, in __del__\n",
      "Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fd0bc471830>\n",
      "Traceback (most recent call last):\n",
      "  File \"/opt/conda/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py\", line 1203, in __del__\n",
      "        self._shutdown_workers()"
     ]
    }
   ],
   "source": [
    "cf = VAE_CF(m_configs, model=model)\n",
    "result = train(\n",
    "    cf, \n",
    "    t_configs,\n",
    "    logger=pl_loggers.TestTubeLogger(Path('../log/'), name=\"adult/vae\")\n",
    ")\n",
    "\n",
    "x = torch.rand(100, 29)\n",
    "result['module'].generate_cf(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Misc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>28.869472</td>\n",
       "      <td>75.537317</td>\n",
       "      <td>13.732009</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>56.541628</td>\n",
       "      <td>51.057476</td>\n",
       "      <td>14.176149</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>54.259902</td>\n",
       "      <td>46.058342</td>\n",
       "      <td>12.923560</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>43.165512</td>\n",
       "      <td>56.313580</td>\n",
       "      <td>12.536208</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>24.003729</td>\n",
       "      <td>26.398063</td>\n",
       "      <td>10.360779</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          x1         x2         x3    y\n",
       "0  28.869472  75.537317  13.732009  0.0\n",
       "1  56.541628  51.057476  14.176149  0.0\n",
       "2  54.259902  46.058342  12.923560  1.0\n",
       "3  43.165512  56.313580  12.536208  1.0\n",
       "4  24.003729  26.398063  10.360779  1.0"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dummy = pd.read_csv('../data/dummy_data.csv')\n",
    "dummy[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter: 0, loss: 1.7241566181182861\n",
      "iter: 100, loss: 0.14336611330509186\n",
      "iter: 200, loss: 0.10569803416728973\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-47-bf5136acfdd3>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[0mcf\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mVanillaCF\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0mr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate_cf\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m10000\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-46-0304445ddf65>\u001b[0m in \u001b[0;36mgenerate_cf\u001b[1;34m(self, n_iters)\u001b[0m\n\u001b[0;32m     35\u001b[0m             \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_loss_compute\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0ml_1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0ml_2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     36\u001b[0m             \u001b[0moptim\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 37\u001b[1;33m             \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     38\u001b[0m             \u001b[0moptim\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     39\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Miniconda3\\lib\\site-packages\\torch\\tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[0;32m    183\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[1;33m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    184\u001b[0m         \"\"\"\n\u001b[1;32m--> 185\u001b[1;33m         \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    186\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    187\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Miniconda3\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[0;32m    123\u001b[0m         \u001b[0mretain_graph\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    124\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 125\u001b[1;33m     Variable._execution_engine.run_backward(\n\u001b[0m\u001b[0;32m    126\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    127\u001b[0m         allow_unreachable=True)  # allow_unreachable flag\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "input_instance = dummy[7500:]\n",
    "x = model.transform(input_instance)\n",
    "\n",
    "cf = VanillaCF(x=x, model=model)\n",
    "r = cf.generate_cf(10000).detach()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-151-91622a393fcc>:2: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  r = torch.tensor(r)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(0), 0)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.freeze()\n",
    "r = torch.tensor(r)\n",
    "model.check_cont_robustness(x, r, model.predict(r))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def proximity(x, c):\n",
    "    return torch.abs(x - c).sum(dim=1).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cf_accuracy(c_y, y_hat):\n",
    "    return accuracy(c_y > .5, y_hat < .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.6029)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "proximity(x, r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y, c = model(x)\n",
    "c_y, _ = model(r)\n",
    "cf_accuracy(c_y, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.2822, grad_fn=<MeanBackward0>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y, c = model(x)\n",
    "proximity(x, c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.9548)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c_y, _ = model(c)\n",
    "cf_accuracy(c_y, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model configs\n",
    "m_configs = {\n",
    "    'data_dir': '../data/dummy_data.csv',\n",
    "    'lr':3e-4, \n",
    "    'batch_size': 128,\n",
    "    'lambda_1': 1.,\n",
    "    'lambda_2': 0.5,\n",
    "    'lambda_3': 1.,\n",
    "    'threshold': 1, \n",
    "    'continous_cols': ['x1', 'x2', 'x3',],\n",
    "    'discret_cols': [], \n",
    "    'encoder_dims': [3, 100, 10],\n",
    "    'decoder_dims': [10, 10],\n",
    "    'explainer_dims': [10, 10]\n",
    "}\n",
    "# trainer configs\n",
    "t_configs = {\n",
    "    'max_epochs': 100,\n",
    "#     'checkpoint_callback': checkpoint_callback,\n",
    "#     'callbacks': [early_stopping]\n",
    "#     'gpus': 1,\n",
    "#     debug\n",
    "#     'weights_summary': 'full',\n",
    "#     'fast_dev_run': True,\n",
    "    'track_grad_norm':2\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "GPU available: False, used: False\n",
      "TPU available: None, using: 0 TPU cores\n",
      "TPU available: None, using: 0 TPU cores\n",
      "hyper parameters: \"batch_size\":     128\n",
      "\"continous_cols\": ['x1', 'x2', 'x3']\n",
      "\"data_dir\":       ../data/dummy_data.csv\n",
      "\"decoder_dims\":   [10, 10]\n",
      "\"discret_cols\":   []\n",
      "\"encoder_dims\":   [3, 100, 10]\n",
      "\"explainer_dims\": [10, 10]\n",
      "\"lambda_1\":       1.0\n",
      "\"lambda_2\":       0.5\n",
      "\"lambda_3\":       1.0\n",
      "\"lr\":             0.0003\n",
      "\"threshold\":      1\n",
      "hyper parameters: \"batch_size\":     128\n",
      "\"continous_cols\": ['x1', 'x2', 'x3']\n",
      "\"data_dir\":       ../data/dummy_data.csv\n",
      "\"decoder_dims\":   [10, 10]\n",
      "\"discret_cols\":   []\n",
      "\"encoder_dims\":   [3, 100, 10]\n",
      "\"explainer_dims\": [10, 10]\n",
      "\"lambda_1\":       1.0\n",
      "\"lambda_2\":       0.5\n",
      "\"lambda_3\":       1.0\n",
      "\"lr\":             0.0003\n",
      "\"threshold\":      1\n",
      "x_cont: (10000, 3), x_cat: (10000, 0)\n",
      "x_cont: (10000, 3), x_cat: (10000, 0)\n",
      "(10000, 3)\n",
      "(10000, 3)\n",
      "\n",
      "  | Name          | Type                 | Params | In sizes | Out sizes\n",
      "------------------------------------------------------------------------------\n",
      "0 | encoder_model | MultilayerPerception | 1.6 K  | [1, 3]   | [1, 10]  \n",
      "1 | predictor     | Sequential           | 141    | [1, 10]  | [1, 1]   \n",
      "2 | explainer     | Sequential           | 163    | [1, 10]  | [1, 3]   \n",
      "------------------------------------------------------------------------------\n",
      "1.9 K     Trainable params\n",
      "0         Non-trainable params\n",
      "1.9 K     Total params\n",
      "\n",
      "  | Name          | Type                 | Params | In sizes | Out sizes\n",
      "------------------------------------------------------------------------------\n",
      "0 | encoder_model | MultilayerPerception | 1.6 K  | [1, 3]   | [1, 10]  \n",
      "1 | predictor     | Sequential           | 141    | [1, 10]  | [1, 1]   \n",
      "2 | explainer     | Sequential           | 163    | [1, 10]  | [1, 3]   \n",
      "------------------------------------------------------------------------------\n",
      "1.9 K     Trainable params\n",
      "0         Non-trainable params\n",
      "1.9 K     Total params\n",
      "C:\\Users\\Hangzhi Guo\\AppData\\Roaming\\Python\\Python38\\site-packages\\pytorch_lightning\\utilities\\distributed.py:49: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b82cdf646f7f4563acb91190736e98e4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Hangzhi Guo\\AppData\\Roaming\\Python\\Python38\\site-packages\\pytorch_lightning\\utilities\\distributed.py:49: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    }
   ],
   "source": [
    "result = train(CounterfactualModel(m_configs), t_configs,logger_name = \"debug\")\n",
    "model = result['module']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "GPU available: False, used: False\n",
      "TPU available: None, using: 0 TPU cores\n",
      "TPU available: None, using: 0 TPU cores\n",
      "x_cont: (10000, 3), x_cat: (10000, 0)\n",
      "x_cont: (10000, 3), x_cat: (10000, 0)\n",
      "(10000, 3)\n",
      "(10000, 3)\n",
      "\n",
      "  | Name          | Type                 | Params | In sizes | Out sizes\n",
      "------------------------------------------------------------------------------\n",
      "0 | encoder_model | MultilayerPerception | 1.6 K  | [1, 3]   | [1, 10]  \n",
      "1 | predictor     | Sequential           | 141    | [1, 10]  | [1, 1]   \n",
      "2 | explainer     | Sequential           | 163    | [1, 10]  | [1, 3]   \n",
      "------------------------------------------------------------------------------\n",
      "1.9 K     Trainable params\n",
      "0         Non-trainable params\n",
      "1.9 K     Total params\n",
      "\n",
      "  | Name          | Type                 | Params | In sizes | Out sizes\n",
      "------------------------------------------------------------------------------\n",
      "0 | encoder_model | MultilayerPerception | 1.6 K  | [1, 3]   | [1, 10]  \n",
      "1 | predictor     | Sequential           | 141    | [1, 10]  | [1, 1]   \n",
      "2 | explainer     | Sequential           | 163    | [1, 10]  | [1, 3]   \n",
      "------------------------------------------------------------------------------\n",
      "1.9 K     Trainable params\n",
      "0         Non-trainable params\n",
      "1.9 K     Total params\n",
      "C:\\Users\\Hangzhi Guo\\AppData\\Roaming\\Python\\Python38\\site-packages\\pytorch_lightning\\utilities\\distributed.py:49: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "536176cd82bf4349a0435c98071fface",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = CounterfactualModel(m_configs)\n",
    "trainer = pl.Trainer(max_epochs=63, resume_from_checkpoint=\"../log/debug/version_0/checkpoints/epoch=62-step=3716.ckpt\")\n",
    "trainer.fit(model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.5 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "1aebd4a71fcec916e49c5e2d294321100f54b5d9bbeca9249da939e91b36ccc5"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
