{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from counterfactuals.datasets import MoonsDataset\n",
    "from counterfactuals.losses import MulticlassDiscLoss\n",
    "from counterfactuals.cf_methods import PUMAL\n",
    "from counterfactuals.generative_models import MaskedAutoregressiveFlow\n",
    "from counterfactuals.discriminative_models import MultilayerPerceptron\n",
    "from counterfactuals.metrics import CFMetrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = MoonsDataset(\"../data/moons.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "disc_model = MultilayerPerceptron(\n",
    "    dataset.X_test.shape[1], [512, 512], dataset.y_test.shape[1]\n",
    ")\n",
    "disc_model.fit(\n",
    "    dataset.train_dataloader(batch_size=128, shuffle=True),\n",
    "    dataset.test_dataloader(batch_size=128, shuffle=False),\n",
    "    epochs=5000,\n",
    "    patience=100,\n",
    "    lr=1e-3,\n",
    ")\n",
    "disc_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = disc_model.predict(dataset.X_test).detach().numpy().flatten()\n",
    "print(\"Test accuracy:\", (y_pred == np.argmax(dataset.y_test, axis=1)).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.y_train = dataset.y_transformer.transform(\n",
    "    disc_model.predict(dataset.X_train).detach().numpy().reshape(-1, 1)\n",
    ")\n",
    "dataset.y_test = dataset.y_transformer.transform(\n",
    "    disc_model.predict(dataset.X_test).detach().numpy().reshape(-1, 1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_model = MaskedAutoregressiveFlow(\n",
    "    features=dataset.X_train.shape[1],\n",
    "    hidden_features=16,\n",
    "    num_blocks_per_layer=2,\n",
    "    num_layers=5,\n",
    "    context_features=2,\n",
    "    batch_norm_within_layers=True,\n",
    "    batch_norm_between_layers=True,\n",
    "    use_random_permutations=True,\n",
    ")\n",
    "train_dataloader = dataset.train_dataloader(\n",
    "    batch_size=256, shuffle=True, noise_lvl=0.03\n",
    ")\n",
    "test_dataloader = dataset.test_dataloader(batch_size=256, shuffle=False)\n",
    "\n",
    "gen_model.fit(\n",
    "    train_dataloader,\n",
    "    train_dataloader,\n",
    "    learning_rate=1e-3,\n",
    "    patience=100,\n",
    "    num_epochs=500,\n",
    "    checkpoint_path=\"moons_flow1.pth\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_class = 0\n",
    "target_class = 1\n",
    "X_test_origin = dataset.X_test[np.argmax(dataset.y_test, axis=1) == source_class]\n",
    "y_test_origin = dataset.y_test[np.argmax(dataset.y_test, axis=1) == source_class]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "cf_method = PUMAL(\n",
    "    X=X_test_origin,\n",
    "    cf_method_type=\"GCE\",\n",
    "    K=6,\n",
    "    gen_model=gen_model,\n",
    "    disc_model=disc_model,\n",
    "    disc_model_criterion=MulticlassDiscLoss(eps=0.01),\n",
    "    not_actionable_features=None,\n",
    "    neptune_run=None,\n",
    ")\n",
    "\n",
    "train_dataloader_for_log_prob = dataset.train_dataloader(batch_size=4096, shuffle=False)\n",
    "log_prob_threshold = torch.quantile(\n",
    "    gen_model.predict_log_prob(train_dataloader_for_log_prob),\n",
    "    0.25,\n",
    ")\n",
    "\n",
    "cf_dataloader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.TensorDataset(\n",
    "        torch.tensor(X_test_origin).float(),\n",
    "        torch.tensor(y_test_origin).float(),\n",
    "    ),\n",
    "    batch_size=4096,\n",
    "    shuffle=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta, Xs, ys_orig, ys_target = cf_method.explain_dataloader(\n",
    "    dataloader=cf_dataloader,\n",
    "    target_class=target_class,\n",
    "    epochs=20000,\n",
    "    lr=0.01,\n",
    "    patience=500,\n",
    "    alpha_dist=1e-1,\n",
    "    alpha_plaus=10**4,\n",
    "    alpha_class=10**5,\n",
    "    alpha_s=10**4,\n",
    "    alpha_k=10**3,\n",
    "    alpha_d=10**2,\n",
    "    log_prob_threshold=log_prob_threshold,\n",
    ")\n",
    "Xs_cfs = Xs + delta().detach().numpy()\n",
    "\n",
    "metrics = CFMetrics(\n",
    "    X_cf=Xs_cfs,\n",
    "    y_target=ys_target,\n",
    "    X_train=dataset.X_train,\n",
    "    y_train=dataset.y_train,\n",
    "    X_test=X_test_origin,\n",
    "    y_test=y_test_origin,\n",
    "    disc_model=disc_model,\n",
    "    gen_model=gen_model,\n",
    "    continuous_features=list(range(dataset.X_train.shape[1])),\n",
    "    categorical_features=dataset.categorical_features,\n",
    "    prob_plausibility_threshold=log_prob_threshold,\n",
    ")\n",
    "metrics.calc_all_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
