{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp training_module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "from ipynb_path import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from counterfactual.import_essentials import *\n",
    "from counterfactual.utils import *\n",
    "from pytorch_lightning.metrics.functional.classification import *\n",
    "from sklearn.preprocessing import StandardScaler,MinMaxScaler, OneHotEncoder\n",
    "from pytorch_lightning.callbacks import EarlyStopping\n",
    "\n",
    "pl_logger = logging.getLogger('lightning')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pl version: 1.1.0\n",
      "torch version: 1.8.2\n"
     ]
    }
   ],
   "source": [
    "print(f\"pl version: {pl.__version__}\")\n",
    "print(f\"torch version: {torch.__version__}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 747 ms, sys: 59.7 ms, total: 807 ms\n",
      "Wall time: 845 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "dummy_data = pd.read_csv('../data/dummy_data.csv')\n",
    "adult_data = load_adult_income_dataset('../data/adult.data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Base Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[39., 40.],\n",
       "       [50., 13.],\n",
       "       [38., 40.],\n",
       "       ...,\n",
       "       [58., 40.],\n",
       "       [22., 20.],\n",
       "       [52., 40.]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "scalar = StandardScaler()\n",
    "cont = scalar.fit_transform(adult_data[['age', 'hours_per_week']])\n",
    "scalar.inverse_transform(cont)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([['Government', 'Bachelors', 'Single', 'White-Collar', 'White',\n",
       "        'Male'],\n",
       "       ['Self-Employed', 'Bachelors', 'Married', 'White-Collar', 'White',\n",
       "        'Male'],\n",
       "       ['Private', 'HS-grad', 'Divorced', 'Blue-Collar', 'White', 'Male'],\n",
       "       ...,\n",
       "       ['Private', 'HS-grad', 'Widowed', 'White-Collar', 'White',\n",
       "        'Female'],\n",
       "       ['Private', 'HS-grad', 'Single', 'White-Collar', 'White', 'Male'],\n",
       "       ['Self-Employed', 'HS-grad', 'Married', 'White-Collar', 'White',\n",
       "        'Female']], dtype=object)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "enc = OneHotEncoder(sparse=False)\n",
    "discrete_cols = ['workclass','education', 'marital_status', \n",
    "            'occupation','race', 'gender']\n",
    "cat  = enc.fit_transform(adult_data[discrete_cols])\n",
    "enc.inverse_transform(cat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "imutable_cols = ['race', 'gender']\n",
    "imutable_idx_list = []\n",
    "cat_idx = 2\n",
    "for i, (col_name, cols) in enumerate(zip(discrete_cols, enc.categories_)):\n",
    "    cat_end_idx = cat_idx + len(cols)\n",
    "    if col_name in imutable_cols:\n",
    "        imutable_idx_list += list(range(cat_idx, cat_end_idx))\n",
    "    cat_idx = cat_end_idx\n",
    "\n",
    "assert imutable_idx_list == list(range(25,29))\n",
    "np.concatenate((cont, cat), axis=-1)[:, imutable_idx_list].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class DataModule(pl.LightningModule):\n",
    "    \"\"\"\n",
    "    config[Dict]: containing configurations\n",
    "    data_dir[str]: the location of the dataframe (assuming pandas dataframe)\n",
    "    \"\"\"\n",
    "    def __init__(self, config: Dict):\n",
    "        super().__init__()\n",
    "        self.save_hyperparameters(config)\n",
    "\n",
    "        # read data\n",
    "        self.data = pd.read_csv(Path(config['data_dir']))\n",
    "        self.continous_cols = config['continous_cols']\n",
    "        self.discret_cols = config['discret_cols']\n",
    "        self.imutable_cols = config['imutable_cols'] if 'imutable_cols' in config else []\n",
    "        self.check_cols()\n",
    "\n",
    "        # set configs\n",
    "        self.lr = config['lr']\n",
    "        self.batch_size = config['batch_size']\n",
    "        self.lambda_1 = config['lambda_1'] if 'lambda_1' in config.keys() else 1\n",
    "        self.lambda_2 = config['lambda_2'] if 'lambda_2' in config.keys() else 1\n",
    "        self.lambda_3 = config['lambda_3'] if 'lambda_3' in config.keys() else 1\n",
    "        self.threshold = config['threshold'] if 'threshold' in config.keys() else 0.5\n",
    "        self.smooth_y = config['smooth_y'] if 'smooth_y' in config.keys() else True\n",
    "\n",
    "        # loss functions\n",
    "        self.loss_func_1 = get_loss_functions(config['loss_1']) if 'loss_1' in config.keys() else get_loss_functions(\"cross_entropy\")\n",
    "        self.loss_func_2 = get_loss_functions(config['loss_2']) if 'loss_2' in config.keys() else get_loss_functions(\"l1_mean\")\n",
    "        self.loss_func_3 = get_loss_functions(config['loss_3']) if 'loss_3' in config.keys() else get_loss_functions(\"cross_entropy\")\n",
    "\n",
    "        # self.optimizer_names = [optimizers(optim_name) for optim_name in config['optimizer_names']]\n",
    "\n",
    "        # set model configs\n",
    "        self.enc_dims = config['encoder_dims'] if 'encoder_dims' in config.keys() else []\n",
    "        self.dec_dims = config['decoder_dims'] if 'decoder_dims' in config.keys() else []\n",
    "        self.exp_dims = config['explainer_dims'] if 'explainer_dims' in config.keys() else []\n",
    "\n",
    "        # log graph\n",
    "        self.example_input_array = torch.randn((1, self.enc_dims[0]))\n",
    "\n",
    "    def check_cols(self):\n",
    "        self.data = self.data.astype({col: np.float for col in self.continous_cols})\n",
    "        # check imutable cols\n",
    "        cols = self.continous_cols + self.discret_cols\n",
    "        for col in self.imutable_cols:\n",
    "            assert col in cols\n",
    "\n",
    "    def training_epoch_end(self, outs):\n",
    "        if self.current_epoch == 0:\n",
    "            self.logger.log_hyperparams(self.hparams)\n",
    "\n",
    "    def transform(self, x, return_tensor=True):\n",
    "        assert isinstance(x, pd.DataFrame)\n",
    "        x_cont = self.normalizer.transform(x[self.continous_cols]) if self.continous_cols else np.array([[] for _ in range(len(x))])\n",
    "        x_cat = self.encoder.transform(x[self.discret_cols]) if self.discret_cols else np.array([[] for _ in range(len(x))])\n",
    "        x = np.concatenate((x_cont, x_cat), axis=1)\n",
    "        return torch.from_numpy(x).float() if return_tensor else x\n",
    "\n",
    "    def inverse_transform(self, x, return_tensor=True):\n",
    "        \"\"\"x should be a transformed tensor\"\"\"\n",
    "        cat_idx = len(self.continous_cols)\n",
    "        # inverse transform\n",
    "        x_cont_inv = self.normalizer.inverse_transform(x[:, :cat_idx].cpu())\n",
    "        x_cat_inv = self.encoder.inverse_transform(x[:, cat_idx:].cpu()) if self.discret_cols else np.array([[] for _ in range(len(x))])\n",
    "        x = np.concatenate((x_cont_inv, x_cat_inv), axis=1)\n",
    "        return torch.from_numpy(x).float() if return_tensor else x\n",
    "\n",
    "    def predict(self, x):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def check_cont_robustness(self, x, c, c_y):\n",
    "        cat_idx = len(self.continous_cols)\n",
    "        # inverse transform\n",
    "        x_cont_inv = self.normalizer.inverse_transform(x[:, :cat_idx].cpu())\n",
    "        c_cont_inv = self.normalizer.inverse_transform(c[:, :cat_idx].cpu())\n",
    "        # calculate the diff between x and c\n",
    "        cont_diff = np.abs(x_cont_inv - c_cont_inv) < self.threshold\n",
    "        # total nums of differences\n",
    "        total_diffs = np.sum(cont_diff.any(axis=1))\n",
    "        # new continous cf\n",
    "        c_cont_hat = np.where(cont_diff, x_cont_inv, c_cont_inv)\n",
    "        c[:, :cat_idx] = torch.from_numpy(self.normalizer.transform(c_cont_hat))\n",
    "        c_y_hat = self.predict(c)\n",
    "        return ((c_y_hat > .5) != (c_y > .5)).sum(), total_diffs\n",
    "\n",
    "    def cat_normalize(self, c, hard=False):\n",
    "        # categorical feature starting index\n",
    "        cat_idx = len(self.continous_cols)\n",
    "        return cat_normalize(c, self.cat_arrays, cat_idx, hard=hard)\n",
    "\n",
    "    def prepare_data(self):\n",
    "        def split_x_and_y(data):\n",
    "            X = data[data.columns[:-1]]\n",
    "            y = data[data.columns[-1]]\n",
    "            return X, y\n",
    "\n",
    "        def find_imutable_idx_list(\n",
    "            cat_idx: int,\n",
    "            imutable_col_names: List[str],\n",
    "            discrete_col_names: List[str],\n",
    "            cat_arrays: List[List[str]]\n",
    "        ) -> List[int]:\n",
    "            imutable_idx_list = []\n",
    "            for i, (col_name, cols) in enumerate(zip(discrete_col_names, cat_arrays)):\n",
    "                cat_end_idx = cat_idx + len(cols)\n",
    "                if col_name in imutable_col_names:\n",
    "                    imutable_idx_list += list(range(cat_idx, cat_end_idx))\n",
    "                cat_idx = cat_end_idx\n",
    "            return imutable_idx_list\n",
    "\n",
    "\n",
    "        X, y = split_x_and_y(self.data)\n",
    "\n",
    "        # preprocessing\n",
    "        self.normalizer = MinMaxScaler()\n",
    "        self.encoder = OneHotEncoder(sparse=False)\n",
    "        X_cont = self.normalizer.fit_transform(X[self.continous_cols]) if self.continous_cols else np.array([[] for _ in range(len(X))])\n",
    "        X_cat = self.encoder.fit_transform(X[self.discret_cols]) if self.discret_cols else np.array([[] for _ in range(len(X))])\n",
    "        X = np.concatenate((X_cont, X_cat), axis=1)\n",
    "        self.cat_arrays = self.encoder.categories_ if self.discret_cols else []\n",
    "        # imutable\n",
    "        self.imutable_idx_list = find_imutable_idx_list(\n",
    "            cat_idx=len(self.continous_cols), imutable_col_names=self.imutable_cols, discrete_col_names=self.discret_cols, cat_arrays=self.cat_arrays\n",
    "        )\n",
    "        pl_logger.info(f\"x_cont: {X_cont.shape}, x_cat: {X_cat.shape}\")\n",
    "        pl_logger.info(X.shape)\n",
    "        assert X.shape[-1] == self.enc_dims[0], f'The input dimension X (shape: {X.shape[-1]})  != encoder_dims[0]: {self.enc_dims}'\n",
    "\n",
    "        # prepare train & test\n",
    "        train_X, test_X, train_y, test_y = train_test_split(X, y.to_numpy(), shuffle=False)\n",
    "        self.train_dataset = NumpyDataset(train_X, train_y)\n",
    "        self.val_dataset = NumpyDataset(test_X, test_y)\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        return DataLoader(self.train_dataset, batch_size=self.batch_size,\n",
    "                          pin_memory=True, shuffle=True, num_workers=0)\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        return DataLoader(self.val_dataset, batch_size=self.batch_size,\n",
    "                          pin_memory=True, shuffle=True, num_workers=0)\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        return DataLoader(self.val_dataset, batch_size=self.batch_size,\n",
    "                          pin_memory=True, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baseline Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1, 1])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([0, 1, 1, 0])\n",
    "x[x == 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0800, 0.8633, 0.8413, 0.1227])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.where(x==1, torch.rand(4) * 0.15 + 0.8, torch.rand(4) * 0.15 + 0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export utils\n",
    "def uniform(shape: tuple, r1: float, r2: float, device=None):\n",
    "    assert r1 < r2\n",
    "    return (r2 - r1) * torch.rand(*shape, device=device) + r1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class BaselineTrainingModule(DataModule):\n",
    "    def __init__(self, config: Dict):\n",
    "        super().__init__(config)\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def forward(self, *x):\n",
    "        return self.model_forward(x)\n",
    "\n",
    "    def predict(self, x):\n",
    "        \"\"\"x has not been preprocessed\"\"\"\n",
    "        # x = self.transform(x)\n",
    "        self.freeze()\n",
    "        # pl_logger.info(f\"x: {x}\")\n",
    "        y_hat = self(x)\n",
    "        return torch.round(y_hat)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        # batch\n",
    "        *x, y = batch\n",
    "        # x = x.view(x.size(0), -1)\n",
    "        # fwd\n",
    "        y_hat = self(*x)\n",
    "        # loss\n",
    "        y = torch.where(y == 1,\n",
    "                        uniform(y.size(), 0.8, 0.95, device=self.device),\n",
    "                        uniform(y.size(), 0.05, 0.2, device=self.device))\n",
    "        loss = F.binary_cross_entropy(y_hat, y)\n",
    "        # Logging to TensorBoard by default\n",
    "        self.log('train/train_loss_1', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)\n",
    "        # log = {\"train_loss\": loss}\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        # batch\n",
    "        *x, y = batch\n",
    "        # fwd\n",
    "        y_hat = self(*x)\n",
    "        # loss\n",
    "        loss = F.binary_cross_entropy(y_hat, y)\n",
    "        score = accuracy(y_hat > .5, y)\n",
    "        return {'score': score, 'val_loss': loss}\n",
    "\n",
    "    def validation_epoch_end(self, val_outs):\n",
    "        avg_loss = torch.stack([output['val_loss'] for output in val_outs]).mean()\n",
    "        avg_score = torch.stack([output['score'] for output in val_outs]).mean()\n",
    "        self.log('val/val_loss', avg_loss)\n",
    "        self.log('val/pred_accuracy', avg_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Counterfactual Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export utils\n",
    "def hinge_loss(input, target):\n",
    "    \"\"\"\n",
    "    reference:\n",
    "    - https://github.com/interpretml/DiCE/blob/a772c8d4fcd88d1cab7f2e02b0bcc045dc0e2eab/dice_ml/explainer_interfaces/dice_pytorch.py#L196-L202\n",
    "    - https://en.wikipedia.org/wiki/Hinge_loss\n",
    "    \"\"\"\n",
    "    input = torch.log((abs(input - 1e-6) / (1 - abs(input - 1e-6))))\n",
    "    all_ones = torch.ones_like(target)\n",
    "    target = 2 * target - all_ones\n",
    "    loss = all_ones - torch.mul(target, input)\n",
    "    loss = F.relu(loss)\n",
    "    return torch.norm(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x: tensor([[0.2904],\n",
      "        [0.3001],\n",
      "        [0.0508],\n",
      "        [0.3926]]) \n",
      "hinge: 8.592495918273926 \n"
     ]
    }
   ],
   "source": [
    "# x = torch.tensor([[0.6, 0.7, 0.1, 0.8]])\n",
    "x = torch.rand(4, 1)\n",
    "target = torch.tensor([[1, 1, 0, 1]])\n",
    "\n",
    "print(f\"x: {x} \")\n",
    "print(f\"hinge: {hinge_loss(x, target)} \")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Counterfactual Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.9465, -0.0229,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
       "          1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
       "          0.0000,  0.0000,  0.0000,  0.0000]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits = torch.randn(1, 20)\n",
    "logits[:, 2:] = F.gumbel_softmax(logits[:, 2:], hard=True)\n",
    "logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export utils\n",
    "def cat_normalize(c, cat_arrays, cat_idx, hard=False):\n",
    "    # categorical feature starting index\n",
    "    for col in cat_arrays:\n",
    "        cat_end_idx = cat_idx + len(col)\n",
    "        if hard:\n",
    "            c[:, cat_idx: cat_end_idx] = F.gumbel_softmax(c[:, cat_idx: cat_end_idx].clone(), hard=hard)\n",
    "        else:\n",
    "            c[:, cat_idx: cat_end_idx] = F.softmax(c[:, cat_idx: cat_end_idx].clone(), dim=-1)\n",
    "        cat_idx = cat_end_idx\n",
    "    return c\n",
    "\n",
    "def l1_mean(x, c):\n",
    "    return F.l1_loss(x, c, reduction='mean') / x.abs().mean() # MAD\n",
    "\n",
    "_loss_functions = {\n",
    "    'cross_entropy': F.binary_cross_entropy,\n",
    "    'l1': F.l1_loss,\n",
    "    'l1_mean': l1_mean,\n",
    "    'mse': F.mse_loss\n",
    "}\n",
    "\n",
    "def get_loss_functions(f_name: str):\n",
    "    assert f_name in _loss_functions.keys(), f'function name \"{f_name}\" is not in the loss function list {_loss_functions.keys()}'\n",
    "    return _loss_functions[f_name]\n",
    "\n",
    "_optimizers = {\n",
    "    'adam': torch.optim.Adam\n",
    "}\n",
    "\n",
    "def get_optimizers(o_name: str):\n",
    "    assert o_name in _optimizers.keys(), f'optimizer name \"{o_name}\" is not in the optimizer list {_optimizers.keys()}'\n",
    "    return _optimizers[o_name]\n",
    "\n",
    "def smooth_y(y, device=None):\n",
    "    return torch.where(y == 1,\n",
    "                       uniform(y.size(), 0.8, 0.95, device=y.device),\n",
    "                       uniform(y.size(), 0.05, 0.2, device=y.device))\n",
    "\n",
    "def use_grad(*models, requires_grad: bool):\n",
    "    for model in models:\n",
    "        assert isinstance(model, nn.Module), f\"{model} is not a `nn.Module` \"\n",
    "        for p in model.parameters():\n",
    "            p.requires_grad = requires_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class CounterfactualTrainingModule(DataModule):\n",
    "    def __init__(self, config: Dict):\n",
    "        super().__init__(config)\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def forward(self, x, hard=False, imutable=True):\n",
    "        \"\"\"hard: categorical features in counterfactual is one-hot-encoding or not\"\"\"\n",
    "        y, cf = self.model_forward(x)\n",
    "        cf = self.cat_normalize(cf, hard=hard)\n",
    "        if imutable:\n",
    "            cf[:, self.imutable_idx_list] = x[:, self.imutable_idx_list] * 1.0\n",
    "        return y, cf\n",
    "\n",
    "    def predict(self, x):\n",
    "        \"\"\"x has not been preprocessed\"\"\"\n",
    "        # x = self.transform(x)\n",
    "        # self.freeze()\n",
    "        # pl_logger.info(f\"x: {x}\")\n",
    "        y_hat, c = self.model_forward(x)\n",
    "        return torch.round(y_hat)\n",
    "\n",
    "    def generate_cf(self, x, clamp=False, imutable=True):\n",
    "        self.freeze()\n",
    "        y, cf = self.model_forward(x)\n",
    "        if imutable:\n",
    "            cf[:, self.imutable_idx_list] = x[:, self.imutable_idx_list] * 1.0\n",
    "        if clamp:\n",
    "            cf = torch.clamp(cf, 0., 1.)\n",
    "        return self.cat_normalize(cf, hard=True)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)\n",
    "        return optimizer\n",
    "\n",
    "    def _loss_functions(self, x, c, y, y_hat, y_prime=None, y_prime_mode='predicted', is_val=False):\n",
    "        \"\"\"\n",
    "        x: input value\n",
    "        c: conterfactual example\n",
    "        y: ground truth\n",
    "        y_hat: predicted result\n",
    "        y_prime_mode: 'label' or 'predicted'\n",
    "        \"\"\"\n",
    "        # flip zero/one\n",
    "        if y_prime == None:\n",
    "            if y_prime_mode == 'label':\n",
    "                y_prime = torch.ones(y.shape) - y\n",
    "            elif y_prime_mode == 'predicted':\n",
    "                y_prime = (y_hat < .5).clone().detach().float()\n",
    "\n",
    "        c_y, _ = self(c)\n",
    "        # loss functions\n",
    "        if self.smooth_y and not is_val:\n",
    "            y = smooth_y(y)\n",
    "            y_prime = smooth_y(y_prime)\n",
    "        # l_1 = F.binary_cross_entropy(y_hat, y)\n",
    "        # l_2 = F.l1_loss(c, x, reduction='mean') / x.abs().mean() # MAD\n",
    "        # l_3 = F.binary_cross_entropy(c_y, y_prime)\n",
    "        l_1 = self.loss_func_1(y_hat, y)\n",
    "        l_2 = self.loss_func_2(x, c)\n",
    "        l_3 = self.loss_func_3(c_y, y_prime)\n",
    "\n",
    "        return l_1, l_2, l_3\n",
    "\n",
    "    def _loss_compute(self, l_1, l_2, l_3):\n",
    "        return self.lambda_1 * l_1 + self.lambda_2 * l_2 + self.lambda_3 * l_3\n",
    "\n",
    "    def _logging_gradient(self):\n",
    "        enc_grads = []\n",
    "        pred_grads = []\n",
    "        exp_grads = []\n",
    "        for n, p in self.model.named_parameters():\n",
    "            if (p.requires_grad) and ('bias' not in n):\n",
    "                _grad = p.grad\n",
    "                if ('encoder' in n) and (_grad is not None):\n",
    "                    enc_grads.append(_grad)\n",
    "                elif ('predictor' in n) and (_grad is not None):\n",
    "                    pred_grads.append(_grad)\n",
    "                elif ('explainer' in n) and (_grad is not None):\n",
    "                    exp_grads.append(_grad)\n",
    "\n",
    "        logger = self.logger.experiment\n",
    "        if len(enc_grads) > 0:\n",
    "            logger.add_histogram('gradient/encoder', torch.tensor(enc_grads), self.global_step, bins='auto')\n",
    "        if len(pred_grads) > 0:\n",
    "            logger.add_histogram('gradient/predictor', torch.tensor(pred_grads), self.global_step, bins='auto')\n",
    "        if len(exp_grads) > 0:\n",
    "            logger.add_histogram('gradient/explainer', torch.tensor(exp_grads), self.global_step, bins='auto')\n",
    "\n",
    "    def _logging_loss(self, l_1, l_2, l_3, stage: str, on_step: bool = False):\n",
    "        self.log(f'{stage}/{stage}_loss_1', l_1, on_step=on_step, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)\n",
    "        self.log(f'{stage}/{stage}_loss_2', l_2, on_step=on_step, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)\n",
    "        self.log(f'{stage}/{stage}_loss_3', l_3, on_step=on_step, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)\n",
    "\n",
    "    def _logging_cf_results(self, x, c, y, y_hat, c_y):\n",
    "        \"\"\"\n",
    "        params:\n",
    "            x: input value\n",
    "            c: conterfactual example\n",
    "            y: ground truth\n",
    "            y_hat: predicted result\n",
    "            c_y: the prediction of counterfactual example\n",
    "        \"\"\"\n",
    "        cat_idx = len(self.continous_cols)\n",
    "        log = None\n",
    "        if self.current_epoch % 10 == 0:# and self.current_epoch != 0:\n",
    "            x = x.cpu()\n",
    "            c = c.cpu()\n",
    "            x_0_cont = self.normalizer.inverse_transform(x[0, :cat_idx].reshape(1, -1))\n",
    "            c_0_cont = self.normalizer.inverse_transform(c[0, :cat_idx].reshape(1, -1))\n",
    "            x_0_cat = self.encoder.inverse_transform(x[0, cat_idx:].unsqueeze(dim=0)) if self.discret_cols else []\n",
    "            c_0_cat = self.encoder.inverse_transform(c[0, cat_idx:].unsqueeze(dim=0)) if self.discret_cols else []\n",
    "\n",
    "            x_log = f\"x_cont: {x_0_cont}, x_cat: {x_0_cat}, y_hat: {y_hat[0]}\"\n",
    "            c_log = f\"c_cont: {c_0_cont}, c_cat: {c_0_cat}, y_ctf: {c_y[0]}\"\n",
    "            label_log = f\"label: {y[0]}\"\n",
    "\n",
    "            log = f\"\"\"\n",
    "            {\"==\" * 25}\n",
    "            {label_log}\n",
    "            {x_log}\n",
    "            {c_log}\n",
    "            {\"==\" * 25}\n",
    "            \"\"\"\n",
    "        return log\n",
    "\n",
    "    def transformed_cf_results(self, x, y):\n",
    "        cat_idx = len(self.continous_cols)\n",
    "        # y_hat, c = self(x, hard=True)\n",
    "        c = self.generate_cf(x, clamp=True)\n",
    "\n",
    "        log = \"\"\n",
    "        x = x.cpu()\n",
    "        c = c.cpu()\n",
    "\n",
    "        sparsity = 0\n",
    "        distance = 1000\n",
    "        best_log = \"\"\n",
    "\n",
    "        for i in range(len(x)):\n",
    "            x_0_cont = self.normalizer.inverse_transform(x[i, :cat_idx].reshape(1, -1))\n",
    "            c_0_cont = self.normalizer.inverse_transform(c[i, :cat_idx].reshape(1, -1))\n",
    "            x_0_cat = self.encoder.inverse_transform(x[i, cat_idx:].unsqueeze(dim=0)) if self.discret_cols else []\n",
    "            c_0_cat = self.encoder.inverse_transform(c[i, cat_idx:].unsqueeze(dim=0)) if self.discret_cols else []\n",
    "\n",
    "            x_log = f\"x_cont: {np.round(x_0_cont)}, x_cat: {x_0_cat}\"\n",
    "            c_log = f\"c_cont: {np.round(c_0_cont)}, c_cat: {c_0_cat}\"\n",
    "            original_c = f\"c: {c[i, :]}\"\n",
    "\n",
    "            cont_diff = np.abs(x_0_cont - c_0_cont) < 10.0\n",
    "            # total nums of differences\n",
    "            total_diffs = np.sum(cont_diff)\n",
    "\n",
    "            log += f\"\"\"\n",
    "            {\"==\" * 25}\n",
    "            {x_log}\n",
    "            {c_log}\n",
    "            {original_c}\n",
    "            {\"==\" * 25}\n",
    "            \"\"\"\n",
    "            # if total_diffs > sparsity:\n",
    "            #     best_log = f\"\"\"{\"==\" * 25}\\n{x_log}\\n{c_log}\\n{original_c}\\n{\"==\" * 25}\"\"\"\n",
    "            if sum(abs(x[i, :] - c[i, :])) < distance:\n",
    "                best_log = f\"\"\"{\"==\" * 25}\\n{x_log}\\n{c_log}\\n{original_c}\\n{\"==\" * 25}\"\"\"\n",
    "        return log, best_log\n",
    "\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        # batch\n",
    "        x, y = batch\n",
    "        # fwd\n",
    "        y_hat, c = self(x)\n",
    "        # pl_logger.info(f\"y_hat: {y_hat.requires_grad}, c: {c.requires_grad}\")\n",
    "        # loss\n",
    "        l_1, l_2, l_3 = self._loss_functions(x, c, y, y_hat)\n",
    "        # pl_logger.info(f\"l_1: {l_1.requires_grad}, l_2: {l_2.requires_grad}\")\n",
    "        # logging train loss\n",
    "        self._logging_loss(l_1, l_2, l_3, stage='train', on_step=True)\n",
    "\n",
    "        return self._loss_compute(l_1, l_2, l_3)\n",
    "\n",
    "#     def on_before_zero_grad(self, optimizer):\n",
    "#         self._logging_gradient()\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        # batch\n",
    "        x, y = batch\n",
    "        # fwd\n",
    "        y_hat, c = self(x, hard=True)\n",
    "        c_y, _ = self(c)\n",
    "        # loss\n",
    "        l_1, l_2, l_3 = self._loss_functions(x, c, y, y_hat, is_val=True)\n",
    "        loss = l_1 + self.lambda_3 * l_3 + self.lambda_2 * l_2\n",
    "        # logging val loss\n",
    "        self._logging_loss(l_1, l_2, l_3, stage='val')\n",
    "\n",
    "        # metrics\n",
    "        cat_idx = len(self.continous_cols)\n",
    "\n",
    "        pred_acc = accuracy(y_hat > .5, y)\n",
    "        cf_proximity = torch.abs(x - c).sum(dim=1).mean()\n",
    "        cf_acc = accuracy(c_y > .5, y_hat < .5)\n",
    "\n",
    "        # print counterfactual results\n",
    "        # log = self._logging_cf_results(x, c, y, y_hat, c_y)\n",
    "        log = None\n",
    "\n",
    "        # logging robustness on manipulating small\n",
    "        diffs, total_num = self.check_cont_robustness(x, c, c_y)\n",
    "\n",
    "        return {\n",
    "                'pred_acc': pred_acc,\n",
    "                'cf_proximity': cf_proximity,\n",
    "                'cf_acc': cf_acc,\n",
    "                'val_loss': loss,\n",
    "                'log': log,\n",
    "                'diffs': diffs,\n",
    "                'total_num': total_num\n",
    "               }\n",
    "\n",
    "    def validation_epoch_end(self, val_outs):\n",
    "        loss, pred_accuracy, cf_proximity, cf_accuracy, diffs, total_diff_num = (0. for _ in range(6))\n",
    "        logs = []\n",
    "\n",
    "        for out in val_outs:\n",
    "            loss += out['val_loss']\n",
    "            pred_accuracy += out['pred_acc']\n",
    "            cf_proximity += out['cf_proximity']\n",
    "            cf_accuracy += out['cf_acc']\n",
    "            diffs += out['diffs']\n",
    "            total_diff_num += out['total_num']\n",
    "            if out['log'] is not None:\n",
    "                logs.append(out['log'])\n",
    "\n",
    "        size = len(val_outs)\n",
    "        if total_diff_num == 0:\n",
    "            total_diff_num = 1\n",
    "#         avg_loss = torch.stack([output['val_loss'] for output in val_outs]).mean()\n",
    "#         avg_pred_accuracy = torch.stack([output['pred_acc'] for output in val_outs]).mean()\n",
    "#         avg_cf_proximity = torch.stack([output['cf_proximity'] for output in val_outs]).mean()\n",
    "#         avg_cf_accuracy = torch.stack([output['cf_acc'] for output in val_outs]).mean()\n",
    "#         avg_robust_accuracy = torch.stack([output['robustness'] for output in val_outs]).mean()\n",
    "#         logs = [output['log'] for output in val_outs if output['log'] is not None]\n",
    "\n",
    "        self.log('val/val_loss', loss / size, sync_dist=True)\n",
    "        self.log('val/pred_accuracy', pred_accuracy / size, sync_dist=True)\n",
    "        self.log('val/cf_proximity', cf_proximity / size, sync_dist=True)\n",
    "        self.log('val/cf_accuracy', cf_accuracy / size, sync_dist=True)\n",
    "        self.log('val/robustness', (1 - diffs / total_diff_num), sync_dist=True)\n",
    "        self.log('val/total_diff_num', total_diff_num, sync_dist=True)\n",
    "        self.logger.experiment.add_text('results','\\n\\n'.join(logs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Counterfactual Model with Loss Wrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class LossWrapper(pl.LightningModule):\n",
    "    def __init__(self, loss_num=3):\n",
    "        super().__init__()\n",
    "        self.loss_num = loss_num\n",
    "        self.log_vars = nn.Parameter(torch.zeros((loss_num)))\n",
    "\n",
    "    def forward(self, *loss_f):\n",
    "        assert self.loss_num == len(loss_f)\n",
    "\n",
    "        loss = 0.\n",
    "        for i, l in enumerate(loss_f):\n",
    "            w = torch.exp(-self.log_vars[i])\n",
    "            loss += torch.sum(w * l ** 2 + self.log_vars[0], -1)\n",
    "\n",
    "        return loss.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class CounterfactualTrainingModuleLossWrapper(CounterfactualTrainingModule):\n",
    "    def __init__(self, config, loss_wrapper=None):\n",
    "        super().__init__(config)\n",
    "        self.loss_wrapper = LossWrapper() if loss_wrapper is None else loss_wrapper\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        # batch\n",
    "        x, y = batch\n",
    "        # fwd\n",
    "        y_hat, c = self(x)\n",
    "        # loss\n",
    "        l_1, l_2, l_3 = self._loss_functions(x, c, y, y_hat)\n",
    "\n",
    "        # Logging to TensorBoard by default\n",
    "        self.log('train/train_loss_1', l_1, on_step=True, on_epoch=True, prog_bar=False, logger=True)\n",
    "        self.log('train/train_loss_2', l_2, on_step=True, on_epoch=True, prog_bar=False, logger=True)\n",
    "        self.log('train/train_loss_3', l_3, on_step=True, on_epoch=True, prog_bar=False, logger=True)\n",
    "        # log = {\"train_loss\": loss}\n",
    "\n",
    "        return self.loss_wrapper(l_1, l_2, l_3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Counterfactual Model with 2 Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class CounterfactualTrainingModule2Optimizers(CounterfactualTrainingModule):\n",
    "    def configure_optimizers(self):\n",
    "        opt_1 = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)\n",
    "        opt_2 = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)\n",
    "        return (opt_1, opt_2)\n",
    "\n",
    "    def training_step(self, batch, batch_idx, optimizer_idx):\n",
    "        # batch\n",
    "        x, y = batch\n",
    "        # fwd\n",
    "        y_hat, c = self(x)\n",
    "        # loss\n",
    "        l_1, l_2, l_3 = self._loss_functions(x, c, y, y_hat)\n",
    "\n",
    "        result = 0\n",
    "        if optimizer_idx == 0:\n",
    "            # use_grad(self, requires_grad=True)\n",
    "            result = self.predictor_step(l_1, l_3)\n",
    "\n",
    "        if optimizer_idx == 1:\n",
    "            # freeze_modules = [self.encoder_model, self.predictor, self.pred_linear]\n",
    "            # use_grad(*freeze_modules, requires_grad=False)\n",
    "            result = self.explainer_step(l_2, l_3)\n",
    "\n",
    "        # Logging to TensorBoard by default\n",
    "        self._logging_loss(l_1, l_2, l_3, stage='train', on_step=True)\n",
    "        return result\n",
    "\n",
    "    def predictor_step(self, l_1, l_3):\n",
    "        p_loss = self.lambda_1 * l_1 #+ self.lambda_3 * l_3\n",
    "        self.log('train/p_loss', p_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)\n",
    "        return p_loss\n",
    "\n",
    "    def explainer_step(self, l_2, l_3):\n",
    "        e_loss = self.lambda_2 * l_2 + self.lambda_3 * l_3\n",
    "        self.log('train/e_loss', e_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)\n",
    "        return e_loss\n",
    "\n",
    "\n",
    "class CounterfactualTrainingModulePosthoc(CounterfactualTrainingModule):\n",
    "    def configure_optimizers(self):\n",
    "        opt = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)\n",
    "        return opt\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        # batch\n",
    "        x, y = batch\n",
    "        # fwd\n",
    "        y_hat, cf = self(x)\n",
    "        # loss\n",
    "        l_1, l_2, l_3 = self._loss_functions(x, cf, y, y_hat)\n",
    "\n",
    "        if self.current_epoch < self.trainer.max_epochs // 2:\n",
    "            use_grad(self, requires_grad=True)\n",
    "            result = self.predictor_step(l_1, l_3)\n",
    "        else:\n",
    "            freeze_modules = [self.encoder_model, self.predictor, self.pred_linear]\n",
    "            use_grad(*freeze_modules, requires_grad=False)\n",
    "            result = self.explainer_step(l_2, l_3)\n",
    "\n",
    "        # Logging to TensorBoard by default\n",
    "        self._logging_loss(l_1, l_2, l_3, stage='train', on_step=True)\n",
    "        return result\n",
    "\n",
    "    def predictor_step(self, l_1, l_3):\n",
    "        p_loss = self.lambda_1 * l_1 #+ self.lambda_3 * l_3\n",
    "        self.log('train/p_loss', p_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)\n",
    "        return p_loss\n",
    "\n",
    "    def explainer_step(self, l_2, l_3):\n",
    "        e_loss = self.lambda_2 * l_2 + self.lambda_3 * l_3\n",
    "        self.log('train/e_loss', e_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)\n",
    "        return e_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Counterfactual Model with 2 Optimizer and Loss Wrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CounterfactualModel2OptimizerWithLossWrapper(CounterfactualTrainingModule):\n",
    "    def __init__(self, config, data: tuple()):\n",
    "        super().__init__(config, data)\n",
    "        self.loss_wrapper_1 = LossWrapper()\n",
    "        self.loss_wrapper_2 = LossWrapper()\n",
    "        \n",
    "    def configure_optimizers(self):\n",
    "        opt_1 = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)\n",
    "        opt_2 = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)\n",
    "        return (opt_1, opt_2)\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        # batch\n",
    "        x, y = batch\n",
    "        # fwd\n",
    "        y_hat, c = self(x)\n",
    "        # loss\n",
    "        l_1, l_2, l_3 = self._loss_functions(x, c, y, y_hat)\n",
    "        # optimizers\n",
    "        opt_1, opt_2 = self.optimizers()\n",
    "        \n",
    "        # updata l_1 + l_3\n",
    "        loss_1 = self.loss_wrapper_2(l_1, l_3)\n",
    "        self.manual_backward(loss_1, opt_1)\n",
    "        opt_1.step()\n",
    "        opt_1.zero_grad()\n",
    "        \n",
    "        # update l_2 + l_3\n",
    "        loss_2 = self.loss_wrapper_2(l_2, l_3)\n",
    "        self.manual_backward(loss_2, opt_2)\n",
    "        opt_2.step()\n",
    "        opt_2.zero_grad()\n",
    "        \n",
    "        # Logging to TensorBoard by default\n",
    "        self.log('plt/train_loss_1', l_1, on_step=True, on_epoch=True, prog_bar=False, logger=True)\n",
    "        self.log('plt/train_loss_2', l_2, on_step=True, on_epoch=True, prog_bar=False, logger=True)\n",
    "        self.log('plt/train_loss_3', l_3, on_step=True, on_epoch=True, prog_bar=False, logger=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export net\n",
    "\n",
    "class _LinearBlock(nn.Module):\n",
    "    \"\"\"ICML version\"\"\"\n",
    "    def __init__(self, input_dim, out_dim, dropout=0.3):\n",
    "        super().__init__()\n",
    "        self.block = nn.Sequential(\n",
    "            nn.Linear(input_dim, out_dim),\n",
    "            nn.BatchNorm1d(num_features=out_dim),\n",
    "            nn.LeakyReLU(),\n",
    "            nn.Dropout(dropout),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.block(x)\n",
    "\n",
    "\n",
    "class LinearBlock(nn.Module):\n",
    "    def __init__(self, input_dim, out_dim, dropout=0.3):\n",
    "        super().__init__()\n",
    "        self.block = nn.Sequential(\n",
    "            nn.Linear(input_dim, out_dim),\n",
    "            # nn.BatchNorm1d(num_features=out_dim),\n",
    "            nn.LeakyReLU(),\n",
    "            nn.Dropout(dropout),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.block(x)\n",
    "\n",
    "class _MultilayerPerception(nn.Module):\n",
    "    \"\"\"ICML version\"\"\"\n",
    "    def __init__(self, dims=[3, 100, 10]):\n",
    "        super().__init__()\n",
    "        layers  = []\n",
    "        num_blocks = len(dims)\n",
    "        for i in range(1, num_blocks):\n",
    "            layers += [\n",
    "                _LinearBlock(dims[i-1], dims[i])\n",
    "            ]\n",
    "        self.model = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.model(x)\n",
    "\n",
    "class MultilayerPerception(nn.Module):\n",
    "    def __init__(self, dims=[3, 100, 10]):\n",
    "        super().__init__()\n",
    "        layers  = []\n",
    "        num_blocks = len(dims)\n",
    "        for i in range(1, num_blocks):\n",
    "            layers += [\n",
    "                LinearBlock(dims[i-1], dims[i])\n",
    "            ]\n",
    "        self.model = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.model(x)\n",
    "\n",
    "class BaselineModel(BaselineTrainingModule):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        assert self.enc_dims[-1] == self.dec_dims[0]\n",
    "        self.model = nn.Sequential(\n",
    "            _MultilayerPerception(self.enc_dims),\n",
    "            _MultilayerPerception(self.dec_dims),\n",
    "            nn.Linear(self.dec_dims[-1], 1)\n",
    "        )\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        # x = ([],)\n",
    "        x, = x\n",
    "        y_hat = torch.sigmoid(self.model(x))\n",
    "        return torch.squeeze(y_hat, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export net\n",
    "\n",
    "class ConvBlock(nn.Module):\n",
    "    def __init__(self, input_dim, out_dim, dropout=0.3):\n",
    "        super().__init__()\n",
    "        self.block = nn.Sequential(\n",
    "            nn.Conv1d(input_dim, out_dim, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm1d(num_features=out_dim),\n",
    "            nn.LeakyReLU(),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.block(x)\n",
    "\n",
    "class MultilayerConv(nn.Module):\n",
    "    def __init__(self, dims=[3, 100, 10]):\n",
    "        super().__init__()\n",
    "        layers  = []\n",
    "        num_blocks = len(dims)\n",
    "        for i in range(1, num_blocks):\n",
    "            layers += [\n",
    "                ConvBlock(dims[i-1], dims[i])\n",
    "            ]\n",
    "        self.model = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export net\n",
    "\n",
    "class CounterfactualModel(CounterfactualTrainingModule):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        assert self.enc_dims[-1] == self.dec_dims[0]\n",
    "        assert self.enc_dims[-1] == self.exp_dims[0]\n",
    "\n",
    "        self.encoder_model = MultilayerPerception(self.enc_dims)\n",
    "        self.predictor = nn.Sequential(\n",
    "            MultilayerPerception(self.dec_dims),\n",
    "            nn.Linear(self.dec_dims[-1], 1)\n",
    "        )\n",
    "        self.explainer = nn.Sequential(\n",
    "            MultilayerPerception(self.exp_dims),\n",
    "            nn.Linear(self.exp_dims[-1], self.enc_dims[0])\n",
    "        )\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        x = self.encoder_model(x)\n",
    "        # predicted y_hat\n",
    "        y_hat = torch.sigmoid(self.predictor(x))\n",
    "        # counterfactual example\n",
    "        c = self.explainer(x)\n",
    "        return torch.squeeze(y_hat, -1), c\n",
    "\n",
    "class CounterfactualModel2Optimizers(CounterfactualTrainingModule2Optimizers):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        assert self.enc_dims[-1] == self.dec_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})\"\n",
    "        assert self.enc_dims[-1] == self.exp_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})\"\n",
    "\n",
    "        self.encoder_model = MultilayerPerception(self.enc_dims)\n",
    "        # predictor\n",
    "        self.predictor = MultilayerPerception(self.dec_dims)\n",
    "        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)\n",
    "        # explainer\n",
    "        exp_dims = [x for x in self.exp_dims]\n",
    "        exp_dims[0] = self.exp_dims[0] + self.dec_dims[-1]\n",
    "\n",
    "        self.explainer = nn.Sequential(\n",
    "            MultilayerPerception(exp_dims),\n",
    "            nn.Linear(self.exp_dims[-1], self.enc_dims[0])\n",
    "        )\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        x = self.encoder_model(x)\n",
    "        # predicted y_hat\n",
    "        pred = self.predictor(x)\n",
    "        y_hat = torch.sigmoid(self.pred_linear(pred))\n",
    "        # counterfactual example\n",
    "        x = torch.cat((x, pred), -1)\n",
    "        c = self.explainer(x)\n",
    "        return torch.squeeze(y_hat, -1), c\n",
    "\n",
    "class CounterfactualModel2OptsNoPass(CounterfactualTrainingModule2Optimizers):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        assert self.enc_dims[-1] == self.dec_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})\"\n",
    "        assert self.enc_dims[-1] == self.exp_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})\"\n",
    "\n",
    "        self.encoder_model = MultilayerPerception(self.enc_dims)\n",
    "        # predictor\n",
    "        self.predictor = MultilayerPerception(self.dec_dims)\n",
    "        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)\n",
    "        # explainer\n",
    "        self.explainer = nn.Sequential(\n",
    "            MultilayerPerception(self.exp_dims),\n",
    "            nn.Linear(self.exp_dims[-1], self.enc_dims[0])\n",
    "        )\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        x = self.encoder_model(x)\n",
    "        # predicted y_hat\n",
    "        pred = self.predictor(x)\n",
    "        y_hat = torch.sigmoid(self.pred_linear(pred))\n",
    "        # counterfactual example\n",
    "        c = self.explainer(x)\n",
    "        return torch.squeeze(y_hat, -1), c\n",
    "\n",
    "\n",
    "class CounterfactualModelSeparate(CounterfactualTrainingModule2Optimizers):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        assert self.enc_dims[-1] == self.dec_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})\"\n",
    "        assert self.enc_dims[-1] == self.exp_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})\"\n",
    "\n",
    "        self.encoder_model = MultilayerPerception(self.enc_dims)\n",
    "        # predictor\n",
    "        self.predictor = MultilayerPerception(self.dec_dims)\n",
    "        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)\n",
    "        # explainer\n",
    "        exp_dims = self.enc_dims + self.exp_dims[1:]\n",
    "        self.explainer = nn.Sequential(\n",
    "            MultilayerPerception(exp_dims),\n",
    "            nn.Linear(self.exp_dims[-1], self.enc_dims[0])\n",
    "        )\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        p = self.encoder_model(x)\n",
    "        # predicted y_hat\n",
    "        pred = self.predictor(p)\n",
    "        y_hat = torch.sigmoid(self.pred_linear(pred))\n",
    "        # counterfactual example\n",
    "        cf = self.explainer(x)\n",
    "        return torch.squeeze(y_hat, -1), cf\n",
    "\n",
    "\n",
    "class CounterfactualModelPosthoc(CounterfactualTrainingModulePosthoc):\n",
    "    \"\"\"Train in a post-hoc fashion, i.e., train predictive model first, then train explainer.\"\"\"\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        assert self.enc_dims[-1] == self.dec_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})\"\n",
    "        assert self.enc_dims[-1] == self.exp_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})\"\n",
    "\n",
    "        self.encoder_model = MultilayerPerception(self.enc_dims)\n",
    "        # predictor\n",
    "        self.predictor = MultilayerPerception(self.dec_dims)\n",
    "        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)\n",
    "        # explainer\n",
    "        exp_dims = [x for x in self.exp_dims]\n",
    "        exp_dims[0] = self.exp_dims[0] + self.dec_dims[-1]\n",
    "\n",
    "        self.explainer = nn.Sequential(\n",
    "            MultilayerPerception(exp_dims),\n",
    "            nn.Linear(self.exp_dims[-1], self.enc_dims[0])\n",
    "        )\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        x = self.encoder_model(x)\n",
    "        # predicted y_hat\n",
    "        pred = self.predictor(x)\n",
    "        y_hat = torch.sigmoid(self.pred_linear(pred))\n",
    "        # counterfactual example\n",
    "        x = torch.cat((x, pred), -1)\n",
    "        c = self.explainer(x)\n",
    "        return torch.squeeze(y_hat, -1), c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export net\n",
    "class ConvCounterNet(CounterfactualTrainingModule2Optimizers):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        assert self.enc_dims[-1] == self.dec_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})\"\n",
    "        assert self.enc_dims[-1] == self.exp_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})\"\n",
    "\n",
    "        self.encoder_model = MultilayerConv(self.enc_dims)\n",
    "        # predictor\n",
    "        self.predictor = MultilayerConv(self.dec_dims)\n",
    "        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)\n",
    "        # explainer\n",
    "        exp_dims = [x for x in self.exp_dims]\n",
    "        exp_dims[0] = self.exp_dims[0] + self.dec_dims[-1]\n",
    "\n",
    "        self.explainer = nn.Sequential(\n",
    "            MultilayerPerception(exp_dims),\n",
    "            nn.Linear(self.exp_dims[-1], self.enc_dims[0])\n",
    "        )\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        x = x.unsqueeze(dim=-1)\n",
    "        x = self.encoder_model(x)\n",
    "        # predicted y_hat\n",
    "        pred = self.predictor(x)\n",
    "        y_hat = torch.sigmoid(self.pred_linear(pred.squeeze(-1)))\n",
    "        # counterfactual example\n",
    "        x = torch.cat((x, pred), 1).squeeze(-1)\n",
    "        c = self.explainer(x)\n",
    "        return torch.squeeze(y_hat, -1), c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 29, 8])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.matmul(torch.rand((100, 29, 1)), torch.rand((1, 8)))\n",
    "nn.TransformerEncoderLayer(d_model=8, nhead=4)(a).size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.rand((100, 29))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export net\n",
    "class Embed(nn.Module):\n",
    "    def __init__(self, emb_dims: int):\n",
    "        super().__init__()\n",
    "        self.weight = nn.Parameter(torch.empty((1, emb_dims)))\n",
    "        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x @ self.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 29, 8])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emb = Embed(8)\n",
    "emb(torch.rand((100, 29, 1))).size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export net\n",
    "class TransCounterNet(CounterfactualTrainingModule2Optimizers):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        assert self.enc_dims[-1] == self.dec_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})\"\n",
    "        assert self.enc_dims[-1] == self.exp_dims[0], f\"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})\"\n",
    "\n",
    "        self.emb = Embed(emb_dims=8)\n",
    "        self.encoder_model = nn.TransformerEncoderLayer(d_model=8, nhead=4)\n",
    "\n",
    "        self.dec_dims = [8] + self.dec_dims\n",
    "        # predictor\n",
    "        self.predictor = MultilayerPerception(self.dec_dims)\n",
    "        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)\n",
    "        # explainer\n",
    "        exp_dims = list(self.exp_dims)\n",
    "        exp_dims[0] = self.dec_dims[0] + self.dec_dims[-1]\n",
    "\n",
    "        self.explainer = nn.Sequential(\n",
    "            MultilayerPerception(exp_dims),\n",
    "            nn.Linear(self.exp_dims[-1], self.enc_dims[0])\n",
    "        )\n",
    "\n",
    "    def model_forward(self, x):\n",
    "        # append special token (-1)\n",
    "        x = torch.cat((x, (torch.zeros((x.size(0), 1))-1)), dim=-1)\n",
    "        x = x.unsqueeze(dim=-1)\n",
    "        x = self.emb(x)\n",
    "        x = self.encoder_model(x)\n",
    "        x = x[:, -1, :]\n",
    "        # predicted y_hat\n",
    "        pred = self.predictor(x)\n",
    "        y_hat = torch.sigmoid(self.pred_linear(pred))\n",
    "        # counterfactual example\n",
    "        x = torch.cat((x, pred), -1)\n",
    "        c = self.explainer(x)\n",
    "        return torch.squeeze(y_hat, -1), c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10, 10, 1])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.Conv1d(3, 10, kernel_size=3, padding=1)(torch.rand(10, 3, 1)).size()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.11 ('cf')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "4ded4f5ecb61699c2399dc500d4672bb7b883865a2d3678f8ff433138842e475"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
