{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from counterfactuals.datasets import MoonsDataset\n",
    "from counterfactuals.losses import MulticlassDiscLoss\n",
    "from counterfactuals.cf_methods import PUMAL\n",
    "from counterfactuals.generative_models import MaskedAutoregressiveFlow\n",
    "from counterfactuals.discriminative_models import MultilayerPerceptron\n",
    "from counterfactuals.metrics import CFMetrics\n",
    "\n",
    "from counterfactuals.plot_utils import (\n",
    "    plot_generative_model_distribution,\n",
    "    plot_classifier_decision_region,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = MoonsDataset(\"../data/moons.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "disc_model = MultilayerPerceptron(dataset.X_test.shape[1], [512, 512], 2)\n",
    "# disc_model.fit(\n",
    "#     dataset.train_dataloader(batch_size=128, shuffle=True),\n",
    "#     dataset.test_dataloader(batch_size=128, shuffle=False),\n",
    "#     epochs=5000,\n",
    "#     patience=100,\n",
    "#     lr=1e-3,\n",
    "#     checkpoint_path=\"moons_mlp.pt\",\n",
    "# )\n",
    "disc_model.load(\"globe-ce-moons/moons_mlp.pt\")\n",
    "disc_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = disc_model.predict(dataset.X_test).detach().numpy().flatten()\n",
    "print(\"Test accuracy:\", (y_pred == np.argmax(dataset.y_test, axis=1)).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.y_train = dataset.y_transformer.transform(\n",
    "    disc_model.predict(dataset.X_train).detach().numpy().reshape(-1, 1)\n",
    ")\n",
    "dataset.y_test = dataset.y_transformer.transform(\n",
    "    disc_model.predict(dataset.X_test).detach().numpy().reshape(-1, 1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_model = MaskedAutoregressiveFlow(\n",
    "    features=dataset.X_train.shape[1],\n",
    "    hidden_features=16,\n",
    "    num_blocks_per_layer=2,\n",
    "    num_layers=5,\n",
    "    context_features=2,\n",
    "    batch_norm_within_layers=True,\n",
    "    batch_norm_between_layers=True,\n",
    "    use_random_permutations=True,\n",
    ")\n",
    "train_dataloader = dataset.train_dataloader(\n",
    "    batch_size=256, shuffle=True, noise_lvl=0.03\n",
    ")\n",
    "test_dataloader = dataset.test_dataloader(batch_size=256, shuffle=False)\n",
    "\n",
    "# gen_model.fit(\n",
    "#     train_dataloader,\n",
    "#     train_dataloader,\n",
    "#     learning_rate=1e-3,\n",
    "#     patience=100,\n",
    "#     num_epochs=500,\n",
    "#     checkpoint_path=\"moons_flow1.pth\",\n",
    "# )\n",
    "gen_model.load(\"moons_flow1.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_class = 0\n",
    "target_class = 1\n",
    "X_test_origin = dataset.X_test[np.argmax(dataset.y_test, axis=1) == source_class]\n",
    "y_test_origin = dataset.y_test[np.argmax(dataset.y_test, axis=1) == source_class]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset.actionable_features = [0, 1, 2, 3, 4]\n",
    "# dataset.not_actionable_features = [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]\n",
    "cf_method = PUMAL(\n",
    "    X=X_test_origin,\n",
    "    cf_method_type=\"GCE\",\n",
    "    K=2,\n",
    "    gen_model=gen_model,\n",
    "    disc_model=disc_model,\n",
    "    disc_model_criterion=MulticlassDiscLoss(eps=0.01),\n",
    "    not_actionable_features=None,\n",
    "    neptune_run=None,\n",
    ")\n",
    "\n",
    "train_dataloader_for_log_prob = dataset.train_dataloader(batch_size=4096, shuffle=False)\n",
    "log_prob_threshold = torch.quantile(\n",
    "    gen_model.predict_log_prob(train_dataloader_for_log_prob),\n",
    "    0.1,\n",
    ")\n",
    "\n",
    "cf_dataloader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.TensorDataset(\n",
    "        torch.tensor(X_test_origin).float(),\n",
    "        torch.tensor(y_test_origin).float(),\n",
    "    ),\n",
    "    batch_size=4096,\n",
    "    shuffle=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta, Xs, ys_orig, ys_target = cf_method.explain_dataloader(\n",
    "    dataloader=cf_dataloader,\n",
    "    target_class=target_class,\n",
    "    epochs=20000,\n",
    "    lr=0.01,\n",
    "    patience=500,\n",
    "    alpha_dist=1e-1,\n",
    "    alpha_plaus=10**2,\n",
    "    alpha_class=10**5,\n",
    "    alpha_s=10**3,\n",
    "    alpha_k=10**1,\n",
    "    log_prob_threshold=log_prob_threshold,\n",
    ")\n",
    "\n",
    "M, S, D = delta.get_matrices()\n",
    "print(S.sum(axis=0))\n",
    "Xs_cfs = Xs + delta().detach().numpy()\n",
    "\n",
    "values, indexes = S.max(dim=1)\n",
    "\n",
    "total = len(values)\n",
    "i_correct = indexes[values == 1]\n",
    "print(f\"Correct: {len(i_correct)}/{total}\")\n",
    "print(len(set(i_correct.tolist())))\n",
    "\n",
    "metrics = CFMetrics(\n",
    "    X_cf=Xs_cfs,\n",
    "    y_target=ys_target,\n",
    "    X_train=dataset.X_train,\n",
    "    y_train=dataset.y_train,\n",
    "    X_test=X_test_origin,\n",
    "    y_test=y_test_origin,\n",
    "    disc_model=disc_model,\n",
    "    gen_model=gen_model,\n",
    "    continuous_features=list(range(dataset.X_train.shape[1])),\n",
    "    categorical_features=dataset.categorical_features,\n",
    "    prob_plausibility_threshold=log_prob_threshold,\n",
    ")\n",
    "metrics.calc_all_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15, 5))\n",
    "\n",
    "groups = S.argmax(dim=1)\n",
    "\n",
    "for i in range(D.shape[0]):\n",
    "    plt.subplot(1, 3, i + 1)\n",
    "    plt.bar(range(2), D[i].detach().numpy())\n",
    "    mean_magn = M.squeeze()[groups == i].mean(axis=0)\n",
    "    std_magn = M.squeeze()[groups == i].std(axis=0)\n",
    "    n_vectors = (S.argmax(axis=1) == i).sum()\n",
    "    plt.title(\n",
    "        f\"CF {i}, # of cfs: {n_vectors},  Magnitude: {mean_magn:.2f} +- {std_magn:.2f}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, ax = plt.subplots(figsize=(10, 10))\n",
    "fig, ax = plt.subplots(1, 1)\n",
    "\n",
    "\n",
    "# Add arrows between each Xs and Xs_cfs\n",
    "group_colors = [\n",
    "    \"red\",\n",
    "    \"blue\",\n",
    "    \"green\",\n",
    "    \"orange\",\n",
    "    \"purple\",\n",
    "    \"brown\",\n",
    "    \"pink\",\n",
    "    \"gray\",\n",
    "    \"olive\",\n",
    "    \"cyan\",\n",
    "]\n",
    "group_cf_colors = [\n",
    "    \"orange\",\n",
    "    \"purple\",\n",
    "    \"green\",\n",
    "    \"orange\",\n",
    "    \"purple\",\n",
    "    \"brown\",\n",
    "    \"pink\",\n",
    "    \"gray\",\n",
    "    \"olive\",\n",
    "    \"cyan\",\n",
    "]\n",
    "for group_i in range(S.shape[1]):\n",
    "    xs_group = Xs[S.argmax(dim=1) == group_i]\n",
    "    xs_cfs_group = Xs_cfs[S.argmax(dim=1) == group_i]\n",
    "    ax.scatter(\n",
    "        xs_cfs_group[:, 0],\n",
    "        xs_cfs_group[:, 1],\n",
    "        c=\"orange\",\n",
    "        cmap=matplotlib.colormaps[\"tab10\"],\n",
    "        s=40,\n",
    "        alpha=0.6,\n",
    "    )\n",
    "    ax.scatter(\n",
    "        xs_group[:, 0],\n",
    "        xs_group[:, 1],\n",
    "        c=group_colors[group_i],\n",
    "        cmap=matplotlib.colormaps[\"tab10\"],\n",
    "        s=40,\n",
    "        alpha=0.6,\n",
    "    )\n",
    "    for i in range(len(xs_group)):\n",
    "        ax.arrow(\n",
    "            xs_group[i, 0],\n",
    "            xs_group[i, 1],\n",
    "            xs_cfs_group[i, 0] - xs_group[i, 0],\n",
    "            xs_cfs_group[i, 1] - xs_group[i, 1],\n",
    "            head_width=0.00,\n",
    "            head_length=-0.05,\n",
    "            fc=\"grey\",\n",
    "            ec=\"grey\",\n",
    "            # fc=group_colors[group_i],\n",
    "            # ec=group_colors[group_i],\n",
    "            alpha=0.5,\n",
    "        )\n",
    "# for i in range(len(Xs)):\n",
    "#     ax.arrow(\n",
    "#         Xs[i, 0],\n",
    "#         Xs[i, 1],\n",
    "#         Xs_cfs[i, 0] - Xs[i, 0],\n",
    "#         Xs_cfs[i, 1] - Xs[i, 1],\n",
    "#         head_width=0.02,\n",
    "#         head_length=0.00,\n",
    "#         fc=\"gray\",\n",
    "#         ec=\"gray\",\n",
    "#         alpha=0.5,\n",
    "# )\n",
    "\n",
    "plot_generative_model_distribution(ax, gen_model, log_prob_threshold, 2)\n",
    "plot_classifier_decision_region(ax, disc_model)\n",
    "# plot_observations(ax, Xs, ys_orig, group_colors)\n",
    "# plot_counterfactuals(ax, Xs_cfs)\n",
    "# plot_arrows(ax, Xs, Xs_cfs)\n",
    "# remove boundaries\n",
    "ax.get_xaxis().set_visible(False)\n",
    "ax.get_yaxis().set_visible(False)\n",
    "# remove frame\n",
    "ax.spines[\"top\"].set_visible(False)\n",
    "ax.spines[\"right\"].set_visible(False)\n",
    "ax.spines[\"bottom\"].set_visible(False)\n",
    "ax.spines[\"left\"].set_visible(False)\n",
    "plt.tight_layout()\n",
    "# plt.savefig(\"teaser_groupwise.pdf\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "metadata": {},
   "outputs": [],
   "source": [
    "cf_method = PUMAL(\n",
    "    X=X_test_origin,\n",
    "    cf_method_type=\"GCE\",\n",
    "    K=1,\n",
    "    gen_model=gen_model,\n",
    "    disc_model=disc_model,\n",
    "    disc_model_criterion=MulticlassDiscLoss(eps=0.01),\n",
    "    not_actionable_features=None,\n",
    "    neptune_run=None,\n",
    ")\n",
    "\n",
    "train_dataloader_for_log_prob = dataset.train_dataloader(batch_size=4096, shuffle=False)\n",
    "log_prob_threshold = torch.quantile(\n",
    "    gen_model.predict_log_prob(train_dataloader_for_log_prob),\n",
    "    0.1,\n",
    ")\n",
    "\n",
    "cf_dataloader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.TensorDataset(\n",
    "        torch.tensor(X_test_origin).float(),\n",
    "        torch.tensor(y_test_origin).float(),\n",
    "    ),\n",
    "    batch_size=4096,\n",
    "    shuffle=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta, Xs, ys_orig, ys_target = cf_method.explain_dataloader(\n",
    "    dataloader=cf_dataloader,\n",
    "    target_class=target_class,\n",
    "    epochs=100,\n",
    "    lr=0.001,\n",
    "    patience=500,\n",
    "    alpha_dist=1e-1,\n",
    "    alpha_plaus=10**3,\n",
    "    alpha_class=10**8,\n",
    "    alpha_s=0,\n",
    "    alpha_k=0,\n",
    "    log_prob_threshold=log_prob_threshold,\n",
    ")\n",
    "\n",
    "M, S, D = delta.get_matrices()\n",
    "print(S.sum(axis=0))\n",
    "Xs_cfs = Xs + delta().detach().numpy()\n",
    "\n",
    "values, indexes = S.max(dim=1)\n",
    "\n",
    "total = len(values)\n",
    "i_correct = indexes[values == 1]\n",
    "print(f\"Correct: {len(i_correct)}/{total}\")\n",
    "print(len(set(i_correct.tolist())))\n",
    "\n",
    "metrics = CFMetrics(\n",
    "    X_cf=Xs_cfs,\n",
    "    y_target=ys_target,\n",
    "    X_train=dataset.X_train,\n",
    "    y_train=dataset.y_train,\n",
    "    X_test=X_test_origin,\n",
    "    y_test=y_test_origin,\n",
    "    disc_model=disc_model,\n",
    "    gen_model=gen_model,\n",
    "    continuous_features=list(range(dataset.X_train.shape[1])),\n",
    "    categorical_features=dataset.categorical_features,\n",
    "    prob_plausibility_threshold=log_prob_threshold,\n",
    ")\n",
    "metrics.calc_all_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, ax = plt.subplots(figsize=(10, 10))\n",
    "fig, ax = plt.subplots(1, 1)\n",
    "# Add arrows between each Xs and Xs_cfs\n",
    "group_colors = [\n",
    "    \"blue\",\n",
    "    \"red\",\n",
    "    \"green\",\n",
    "    \"orange\",\n",
    "    \"purple\",\n",
    "    \"brown\",\n",
    "    \"pink\",\n",
    "    \"gray\",\n",
    "    \"olive\",\n",
    "    \"cyan\",\n",
    "]\n",
    "group_cf_colors = [\n",
    "    \"orange\",\n",
    "    \"purple\",\n",
    "    \"green\",\n",
    "    \"orange\",\n",
    "    \"purple\",\n",
    "    \"brown\",\n",
    "    \"pink\",\n",
    "    \"gray\",\n",
    "    \"olive\",\n",
    "    \"cyan\",\n",
    "]\n",
    "for group_i in range(S.shape[1]):\n",
    "    xs_group = Xs[S.argmax(dim=1) == group_i]\n",
    "    xs_cfs_group = Xs_cfs[S.argmax(dim=1) == group_i]\n",
    "    ax.scatter(\n",
    "        xs_cfs_group[:, 0],\n",
    "        xs_cfs_group[:, 1],\n",
    "        c=\"orange\",\n",
    "        cmap=matplotlib.colormaps[\"tab10\"],\n",
    "        s=40,\n",
    "        alpha=0.6,\n",
    "    )\n",
    "    ax.scatter(\n",
    "        xs_group[:, 0],\n",
    "        xs_group[:, 1],\n",
    "        c=group_colors[group_i],\n",
    "        cmap=matplotlib.colormaps[\"tab10\"],\n",
    "        s=40,\n",
    "        alpha=0.6,\n",
    "    )\n",
    "    for i in range(len(xs_group)):\n",
    "        ax.arrow(\n",
    "            xs_group[i, 0],\n",
    "            xs_group[i, 1],\n",
    "            xs_cfs_group[i, 0] - xs_group[i, 0],\n",
    "            xs_cfs_group[i, 1] - xs_group[i, 1],\n",
    "            head_width=0.00,\n",
    "            head_length=-0.05,\n",
    "            fc=\"grey\",\n",
    "            ec=\"grey\",\n",
    "            alpha=0.5,\n",
    "        )\n",
    "\n",
    "plot_generative_model_distribution(ax, gen_model, log_prob_threshold, 2)\n",
    "plot_classifier_decision_region(ax, disc_model)\n",
    "# remove boundaries\n",
    "ax.get_xaxis().set_visible(False)\n",
    "ax.get_yaxis().set_visible(False)\n",
    "# remove frame\n",
    "ax.spines[\"top\"].set_visible(False)\n",
    "ax.spines[\"right\"].set_visible(False)\n",
    "ax.spines[\"bottom\"].set_visible(False)\n",
    "ax.spines[\"left\"].set_visible(False)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"teaser_global.pdf\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LOCAL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 224,
   "metadata": {},
   "outputs": [],
   "source": [
    "cf_method = PUMAL(\n",
    "    X=X_test_origin,\n",
    "    cf_method_type=\"GCE\",\n",
    "    K=None,\n",
    "    gen_model=gen_model,\n",
    "    disc_model=disc_model,\n",
    "    disc_model_criterion=MulticlassDiscLoss(eps=0.01),\n",
    "    not_actionable_features=None,\n",
    "    neptune_run=None,\n",
    ")\n",
    "\n",
    "train_dataloader_for_log_prob = dataset.train_dataloader(batch_size=4096, shuffle=False)\n",
    "log_prob_threshold = torch.quantile(\n",
    "    gen_model.predict_log_prob(train_dataloader_for_log_prob),\n",
    "    0.1,\n",
    ")\n",
    "\n",
    "cf_dataloader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.TensorDataset(\n",
    "        torch.tensor(X_test_origin).float(),\n",
    "        torch.tensor(y_test_origin).float(),\n",
    "    ),\n",
    "    batch_size=4096,\n",
    "    shuffle=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta, Xs, ys_orig, ys_target = cf_method.explain_dataloader(\n",
    "    dataloader=cf_dataloader,\n",
    "    target_class=target_class,\n",
    "    epochs=20000,\n",
    "    lr=0.01,\n",
    "    patience=500,\n",
    "    alpha_dist=1e1,\n",
    "    alpha_plaus=10**3,\n",
    "    alpha_class=10**8,\n",
    "    alpha_s=0,\n",
    "    alpha_d=0,\n",
    "    alpha_k=0,\n",
    "    log_prob_threshold=log_prob_threshold,\n",
    ")\n",
    "\n",
    "M, S, D = delta.get_matrices()\n",
    "print(S.sum(axis=0))\n",
    "Xs_cfs = Xs + delta().detach().numpy()\n",
    "\n",
    "values, indexes = S.max(dim=1)\n",
    "\n",
    "total = len(values)\n",
    "i_correct = indexes[values == 1]\n",
    "print(f\"Correct: {len(i_correct)}/{total}\")\n",
    "print(len(set(i_correct.tolist())))\n",
    "\n",
    "metrics = CFMetrics(\n",
    "    X_cf=Xs_cfs,\n",
    "    y_target=ys_target,\n",
    "    X_train=dataset.X_train,\n",
    "    y_train=dataset.y_train,\n",
    "    X_test=X_test_origin,\n",
    "    y_test=y_test_origin,\n",
    "    disc_model=disc_model,\n",
    "    gen_model=gen_model,\n",
    "    continuous_features=list(range(dataset.X_train.shape[1])),\n",
    "    categorical_features=dataset.categorical_features,\n",
    "    prob_plausibility_threshold=log_prob_threshold,\n",
    ")\n",
    "metrics.calc_all_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, ax = plt.subplots(figsize=(10, 10))\n",
    "fig, ax = plt.subplots(1, 1)\n",
    "# Add arrows between each Xs and Xs_cfs\n",
    "group_colors = [\n",
    "    \"blue\",\n",
    "    \"red\",\n",
    "    \"green\",\n",
    "    \"orange\",\n",
    "    \"purple\",\n",
    "    \"brown\",\n",
    "    \"pink\",\n",
    "    \"gray\",\n",
    "    \"olive\",\n",
    "    \"cyan\",\n",
    "]\n",
    "group_cf_colors = [\n",
    "    \"orange\",\n",
    "    \"purple\",\n",
    "    \"green\",\n",
    "    \"orange\",\n",
    "    \"purple\",\n",
    "    \"brown\",\n",
    "    \"pink\",\n",
    "    \"gray\",\n",
    "    \"olive\",\n",
    "    \"cyan\",\n",
    "]\n",
    "for group_i in range(S.shape[1]):\n",
    "    xs_group = Xs[S.argmax(dim=1) == group_i]\n",
    "    xs_cfs_group = Xs_cfs[S.argmax(dim=1) == group_i]\n",
    "    ax.scatter(\n",
    "        xs_cfs_group[:, 0],\n",
    "        xs_cfs_group[:, 1],\n",
    "        c=\"orange\",\n",
    "        cmap=matplotlib.colormaps[\"tab10\"],\n",
    "        s=40,\n",
    "        alpha=0.6,\n",
    "    )\n",
    "    ax.scatter(\n",
    "        xs_group[:, 0],\n",
    "        xs_group[:, 1],\n",
    "        c=group_colors[0],\n",
    "        cmap=matplotlib.colormaps[\"tab10\"],\n",
    "        s=40,\n",
    "        alpha=0.6,\n",
    "    )\n",
    "    for i in range(len(xs_group)):\n",
    "        ax.arrow(\n",
    "            xs_group[i, 0],\n",
    "            xs_group[i, 1],\n",
    "            xs_cfs_group[i, 0] - xs_group[i, 0],\n",
    "            xs_cfs_group[i, 1] - xs_group[i, 1],\n",
    "            head_width=0.00,\n",
    "            head_length=-0.05,\n",
    "            fc=\"grey\",\n",
    "            ec=\"grey\",\n",
    "            alpha=0.5,\n",
    "        )\n",
    "\n",
    "plot_generative_model_distribution(ax, gen_model, log_prob_threshold, 2)\n",
    "plot_classifier_decision_region(ax, disc_model)\n",
    "# remove boundaries\n",
    "ax.get_xaxis().set_visible(False)\n",
    "ax.get_yaxis().set_visible(False)\n",
    "# remove frame\n",
    "ax.spines[\"top\"].set_visible(False)\n",
    "ax.spines[\"right\"].set_visible(False)\n",
    "ax.spines[\"bottom\"].set_visible(False)\n",
    "ax.spines[\"left\"].set_visible(False)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"teaser_local.pdf\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
