{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MST\n",
    "MST is a differentially private synthesizer that relies on [Private-PGM](https://github.com/ryan112358/private-pgm) to (privately) find the likeliest distribution over data given a set of measured marginals. Details on the method, and how it won a NIST competition, can be found in this paper (https://arxiv.org/abs/2108.04978).\n",
    "\n",
    "#### Why \"MST\"\n",
    "The acronym “MST” stands for “Maximum-Spanning-Tree” as the method produces differentially private synthetic data by relying on a “Maximum-Spanning-Tree\" of mutual information.\n",
    "\n",
    "MST finds the maximum spanning tree on a graph where nodes are data attributes and edge weights correspond to approximate mutual information between any two attributes. We say approximate here, because the “maximum spanning tree” is built using the exponential mechanism, which helps select edge weights with high levels of mutual information in a differentially private manner. The marginals are measured using the Gaussian mechanism.\n",
    "\n",
    "### Specifying Domain\n",
    "MST is easy to use, although it does require the data owner to specify the domain of each attribute a priori.\n",
    "\n",
    "Here, we walk through a basic example of MST, and how to **properly specify the domain** using either a JSON file or a python dictionary.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "import os\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "from snsynth.mst import MSTSynthesizer\n",
    "\n",
    "git_root_dir = subprocess.check_output(\"git rev-parse --show-toplevel\".split(\" \")).decode(\"utf-8\").strip()\n",
    "\n",
    "csv_path = os.path.join(git_root_dir, os.path.join(\"datasets\", \"PUMS.csv\"))\n",
    "\n",
    "df = pd.read_csv(csv_path)\n",
    "df = df.drop([\"income\"], axis=1)\n",
    "df = df.sample(frac=1, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating a correct `Domain` json file\n",
    "Here, we specify a domains dictionary, where we can list out names and filepaths for each of our datasets domains"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "Domains = {\n",
    "    \"pums\": \"pums-domain.json\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our PUMS csv data here looks like this, which can be easily loaded into a pandas dataframe."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "age,sex,educ,race,married\n",
    "59,1,9,1,1\n",
    "31,0,1,3,0\n",
    "36,1,11,1,1\n",
    "54,1,11,1,1\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From this data, the domain ```pums-domain.json``` file here looks like this:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```json\n",
    "{\n",
    "    \"age\": 95,\n",
    "    \"sex\": 2,\n",
    "    \"educ\": 17,\n",
    "    \"race\": 7,\n",
    "    \"married\": 2\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each column in our data **has to be included in our domain file**, and we must further specify the maximum value for each attribute, *m*, in their domain.\n",
    "\n",
    "MST will then impose a [0-*m*] range on each attribute when synthesizing.\n",
    "\n",
    "Note that MST does not work with continuous data, only categorical and low dimensional ordinal data. It is up to the data owner to properly (and privately) bin continous data for use with MST, if they so desire. Here, we have simply dropped the ```income``` column."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Synthesizing with MST\n",
    "Once the domain file is specified, synthesizing with MST is as easy as with any other smartnoise synthesizer.\n",
    "\n",
    "Specify an epsilon, a delta (if you like), and point to the domain file you created!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Domain(age: 95, sex: 2, educ: 17, race: 7, married: 2)\n",
      "Index(['age', 'sex', 'educ', 'race', 'married'], dtype='object')\n"
     ]
    }
   ],
   "source": [
    "mst_synth = MSTSynthesizer(domains_dict=Domains, \n",
    "                           domain='pums',\n",
    "                           epsilon=1.0,\n",
    "                           delta=1e-9)\n",
    "\n",
    "mst_synth.fit(df)\n",
    "\n",
    "sample_size = len(df)\n",
    "synth_data = mst_synth.sample(sample_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>sex</th>\n",
       "      <th>educ</th>\n",
       "      <th>race</th>\n",
       "      <th>married</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>1000.000000</td>\n",
       "      <td>1000.000000</td>\n",
       "      <td>1000.000000</td>\n",
       "      <td>1000.000000</td>\n",
       "      <td>1000.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>44.797000</td>\n",
       "      <td>0.514000</td>\n",
       "      <td>9.888000</td>\n",
       "      <td>1.954000</td>\n",
       "      <td>0.549000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>17.745385</td>\n",
       "      <td>0.500054</td>\n",
       "      <td>3.415424</td>\n",
       "      <td>1.155517</td>\n",
       "      <td>0.497842</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>18.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>31.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>9.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>42.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>11.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>55.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>13.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>93.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>16.000000</td>\n",
       "      <td>6.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               age          sex         educ         race      married\n",
       "count  1000.000000  1000.000000  1000.000000  1000.000000  1000.000000\n",
       "mean     44.797000     0.514000     9.888000     1.954000     0.549000\n",
       "std      17.745385     0.500054     3.415424     1.155517     0.497842\n",
       "min      18.000000     0.000000     1.000000     1.000000     0.000000\n",
       "25%      31.000000     0.000000     9.000000     1.000000     0.000000\n",
       "50%      42.000000     1.000000    11.000000     1.000000     1.000000\n",
       "75%      55.000000     1.000000    13.000000     3.000000     1.000000\n",
       "max      93.000000     1.000000    16.000000     6.000000     1.000000"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>sex</th>\n",
       "      <th>educ</th>\n",
       "      <th>race</th>\n",
       "      <th>married</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>1000.000000</td>\n",
       "      <td>1000.000000</td>\n",
       "      <td>1000.000000</td>\n",
       "      <td>1000.000000</td>\n",
       "      <td>1000.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>47.211000</td>\n",
       "      <td>0.544000</td>\n",
       "      <td>9.629000</td>\n",
       "      <td>1.908000</td>\n",
       "      <td>0.569000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>27.810104</td>\n",
       "      <td>0.498309</td>\n",
       "      <td>3.827271</td>\n",
       "      <td>1.075496</td>\n",
       "      <td>0.495464</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>23.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>9.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>47.500000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>11.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>71.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>13.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>94.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>16.000000</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               age          sex         educ         race      married\n",
       "count  1000.000000  1000.000000  1000.000000  1000.000000  1000.000000\n",
       "mean     47.211000     0.544000     9.629000     1.908000     0.569000\n",
       "std      27.810104     0.498309     3.827271     1.075496     0.495464\n",
       "min       0.000000     0.000000     0.000000     0.000000     0.000000\n",
       "25%      23.000000     0.000000     9.000000     1.000000     0.000000\n",
       "50%      47.500000     1.000000    11.000000     1.000000     1.000000\n",
       "75%      71.000000     1.000000    13.000000     3.000000     1.000000\n",
       "max      94.000000     1.000000    16.000000     5.000000     1.000000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_data.describe()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "e5d2dd739caba34ad17a129f3339262558015f1aa8f0a8a1538578d368c9f328"
  },
  "kernelspec": {
   "display_name": "Python 3.7.11 ('smartnoise-tests')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
