{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from counterfactuals.datasets import (\n",
    "    LawDataset,\n",
    "    MoonsDataset,\n",
    "    HelocDataset,\n",
    "    AuditDataset,\n",
    ")\n",
    "\n",
    "from counterfactuals.generative_models.kde import KDE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset = MoonsDataset(\"../data/moons.csv\")\n",
    "dataset = AuditDataset(\"../data/audit.csv\")\n",
    "# dataset = LawDataset(\"../data/law.csv\")\n",
    "# dataset = HelocDataset(\"../data/heloc.csv\")\n",
    "# dataset = PolishBankDataset(\"../data/polish_bankruptcy.csv\")\n",
    "train_dataloader = dataset.train_dataloader(batch_size=128, shuffle=True, noise_lvl=0)\n",
    "test_dataloader = dataset.test_dataloader(batch_size=128, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kde = KDE(bandwidth=0.1)\n",
    "kde.fit(train_dataloader)\n",
    "log_prob = []\n",
    "for x, y in train_dataloader:\n",
    "    log_prob.append(kde.log_prob(x, y))\n",
    "print(\"KDE Train\")\n",
    "med_kde_train = np.median(torch.hstack(log_prob).numpy()).round(2)\n",
    "mean_kde_train = np.mean(torch.hstack(log_prob).numpy()).round(2)\n",
    "\n",
    "log_prob = []\n",
    "for x, y in test_dataloader:\n",
    "    log_prob.append(kde.log_prob(x, y))\n",
    "print(\"KDE Test\")\n",
    "med_kde_test = np.median(torch.hstack(log_prob).numpy()).round(2)\n",
    "mean_kde_test = np.mean(torch.hstack(log_prob).numpy()).round(2)\n",
    "\n",
    "print(\"KDE Train\", med_kde_train, mean_kde_train)\n",
    "print(\"KDE Test\", med_kde_test, mean_kde_test)\n",
    "\n",
    "\n",
    "# flow = torch.load(\"../models/gen_model_FLOW_orig_MoonsDataset.pt\")\n",
    "# log_prob = []\n",
    "# with torch.no_grad():\n",
    "#     for x, y in train_dataloader:\n",
    "#         y = y.view(-1, 1)\n",
    "#         log_prob.append(flow.log_prob(x, y))\n",
    "# med_flow_train = np.median(torch.hstack(log_prob).numpy()).round(2)\n",
    "# mean_flow_train = np.mean(torch.hstack(log_prob).numpy()).round(2)\n",
    "\n",
    "# log_prob = []\n",
    "# with torch.no_grad():\n",
    "#     for x, y in test_dataloader:\n",
    "#         y = y.view(-1, 1)\n",
    "#         log_prob.append(flow.log_prob(x, y))\n",
    "# med_flow_test = np.median(torch.hstack(log_prob).numpy()).round(2)\n",
    "# mean_flow_test = np.mean(torch.hstack(log_prob).numpy()).round(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in [\n",
    "    MoonsDataset(\"../data/moons.csv\"),\n",
    "    LawDataset(\"../data/law.csv\"),\n",
    "    HelocDataset(\"../data/heloc.csv\"),\n",
    "    AuditDataset(\"../data/audit.csv\"),\n",
    "]:\n",
    "    train_dataloader = dataset.train_dataloader(\n",
    "        batch_size=128, shuffle=True, noise_lvl=0\n",
    "    )\n",
    "    test_dataloader = dataset.test_dataloader(batch_size=128, shuffle=False)\n",
    "    kde = KDE(bandwidth=0.1)\n",
    "    kde.fit(train_dataloader)\n",
    "    log_prob = []\n",
    "    for x, y in train_dataloader:\n",
    "        log_prob.append(kde.log_prob(x, y))\n",
    "    med_kde_train = np.median(torch.hstack(log_prob).numpy()).round(2)\n",
    "    mean_kde_train = np.mean(torch.hstack(log_prob).numpy()).round(2)\n",
    "\n",
    "    log_prob = []\n",
    "    for x, y in test_dataloader:\n",
    "        log_prob.append(kde.log_prob(x, y))\n",
    "    med_kde_test = np.median(torch.hstack(log_prob).numpy()).round(2)\n",
    "    mean_kde_test = np.mean(torch.hstack(log_prob).numpy()).round(2)\n",
    "\n",
    "    flow = torch.load(\n",
    "        f\"../models/gen_model_FLOW_orig_{str(dataset).split(' ')[0].split('.')[-1]}.pt\"\n",
    "    )\n",
    "    print(\n",
    "        f\"../models/gen_model_FLOW_orig_{str(dataset).split(' ')[0].split('.')[-1]}.pt\"\n",
    "    )\n",
    "    with torch.no_grad():\n",
    "        log_prob = []\n",
    "        with torch.no_grad():\n",
    "            for x, y in train_dataloader:\n",
    "                y = y.view(-1, 1)\n",
    "                log_prob.append(flow.log_prob(x, y).squeeze())\n",
    "        med_flow_train = np.median(torch.hstack(log_prob).numpy()).round(2)\n",
    "        mean_flow_train = np.mean(torch.hstack(log_prob).numpy()).round(2)\n",
    "\n",
    "        log_prob = []\n",
    "        with torch.no_grad():\n",
    "            for x, y in test_dataloader:\n",
    "                y = y.view(-1, 1)\n",
    "                log_prob.append(flow.log_prob(x, y).squeeze())\n",
    "        med_flow_test = np.median(torch.hstack(log_prob).numpy()).round(2)\n",
    "        mean_flow_test = np.mean(torch.hstack(log_prob).numpy()).round(2)\n",
    "\n",
    "    print(str(dataset))\n",
    "    print(\"mean table\")\n",
    "    print(\n",
    "        f\"{mean_kde_train:.2f}, {mean_kde_test:.2f}, {mean_flow_train:.2f}, {mean_flow_test:.2f}\"\n",
    "    )\n",
    "    print(\"median table\")\n",
    "    print(\n",
    "        f\"{med_kde_train:.2f}, {med_kde_test:.2f}, {med_flow_train:.2f}, {med_flow_test:.2f}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
