{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MIRI UCI dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import torch\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "from src.imputer_wrapper import impute_now\n",
    "import argparse\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from uci_datasets import Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generating Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "p = 0.6\n",
    "seed = 1\n",
    "\n",
    "dset = \"wine\"\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "methods = ['miri', 'hyperimpute', 'knewimp'] \n",
    "print(\"methods: \", methods)\n",
    "dataset = Dataset(dset) # in the paper, we also standardize the data before the iomputation. We do not do it here for simplicity of the demo\n",
    "Xdata = torch.tensor(dataset.x).float().to(device)\n",
    "n = Xdata.shape[0]\n",
    "d = Xdata.shape[1]\n",
    "def sample_ref(n):\n",
    "    X = Xdata[:n, :].detach().clone()\n",
    "    Xstar = X.detach().clone()\n",
    "    \n",
    "    M = torch.distributions.bernoulli.Bernoulli(torch.ones(n, d)*p).sample().to(device)\n",
    "    X[M==0] = torch.randn(n, d).to(device)[M==0]\n",
    "    \n",
    "    return X.cpu(), M.cpu(), Xstar.cpu()\n",
    "\n",
    "X0, M, Xstar = sample_ref(n)\n",
    "print(\"Data size\", \"n: \", X0.shape[0], \"d: \", X0.shape[1], \"p: \", p)\n",
    "print(\"Missing rate: \", 1 - M.mean().item())\n",
    "print(\"------------ The tests begins now ------------ \\n \")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running Imputation Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for method in methods:\n",
    "    X_tilde, mmd_list, mi_list = impute_now(X0, M, Xstar, method, max_rounds=10, batchsize=50)\n",
    "    torch.save([X_tilde, mmd_list, mi_list], f\"res/{dset}_{seed}_{method}.pt\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Printing Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "\n",
    "results = []\n",
    "for method in methods:\n",
    "    X_tilde, mmd_list, mi_list = torch.load(f\"res/{dset}_{seed}_{method}.pt\")\n",
    "    results.append({\n",
    "        \"method\": method,\n",
    "        \"mmd (the smaller the better)\": float(mmd_list[-1]),\n",
    "        \"mi (the smaller the better)\": float(mi_list[-1])\n",
    "    })\n",
    "\n",
    "df_results = pd.DataFrame(results)\n",
    "print(f\"Imputation Results Summary of {dset} dataset with seed {seed} and p={p}\")\n",
    "print(df_results)\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rectflow",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
