{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8869f6c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import os\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "import pickle\n",
    "import pandas as pd\n",
    "import hashlib\n",
    "import time\n",
    "from matplotlib.pyplot import figure "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f76909a7",
   "metadata": {},
   "source": [
    "# Deep learning experiments for dp-joins"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61e1966c",
   "metadata": {},
   "source": [
    "## First we will load the usual model without privacy and without joins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7aea1ea3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\")\n",
    "\n",
    "# Hyper-parameters \n",
    "input_size = 784 # 28x28\n",
    "hidden_size = 500 \n",
    "num_classes = 10\n",
    "\n",
    "# MNIST dataset. FLIPPING TEST AND TRAIN\n",
    "train_dataset= torchvision.datasets.EMNIST(root='./data', \n",
    "                                           train=True,\n",
    "                                           split = \"digits\",\n",
    "                                           transform=transforms.ToTensor(),  \n",
    "                                           download=True)\n",
    "\n",
    "test_dataset = torchvision.datasets.EMNIST(root='./data', \n",
    "                                          train=False, \n",
    "                                           split = \"digits\",\n",
    "                                          transform=transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b2c01b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "628f004e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data loader\n",
    "batch_size = 128\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, \n",
    "                                           batch_size=batch_size, \n",
    "                                           shuffle=True)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, \n",
    "                                          batch_size=batch_size, \n",
    "                                          shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4338ae0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "examples = iter(test_loader)\n",
    "example_data, example_targets = examples.next()\n",
    "\n",
    "for i in range(6):\n",
    "    plt.subplot(2,3,i+1)\n",
    "    plt.imshow(example_data[i][0], cmap='gray')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae59788b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class netter(nn.Module):\n",
    "    \n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        super().__init__()\n",
    "        self.bnorm0 = nn.BatchNorm1d(input_size)\n",
    "        \n",
    "        self.lin1 = nn.Linear(input_size, hidden_size)\n",
    "        self.bnorm1 = nn.BatchNorm1d(hidden_size)\n",
    "        \n",
    "        self.lin2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.bnorm2 = nn.BatchNorm1d(hidden_size)\n",
    "        \n",
    "        self.lin2a = nn.Linear(hidden_size, hidden_size)\n",
    "        self.bnorm2a = nn.BatchNorm1d(hidden_size)\n",
    "        \n",
    "        self.lin3 = nn.Linear(hidden_size, hidden_size//2)\n",
    "        self.bnorm3 = nn.BatchNorm1d(hidden_size//2)\n",
    "        \n",
    "        self.lin4 = nn.Linear(hidden_size//2, hidden_size//4)\n",
    "        self.bnorm4 = nn.BatchNorm1d(hidden_size//4)\n",
    "        \n",
    "        self.lin5 = nn.Linear(hidden_size//4, hidden_size//8)\n",
    "        self.bnorm5 = nn.BatchNorm1d(hidden_size//8)\n",
    "        \n",
    "        self.lin6 = nn.Linear(hidden_size//8, 10)\n",
    "        self.bnorm6 = nn.BatchNorm1d(10)         \n",
    "        \n",
    "        self.relu = nn.ReLU()\n",
    "        self.drop = nn.Dropout(0.2)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.bnorm0(x)\n",
    "        \n",
    "        x = self.lin1(x)\n",
    "        x = self.bnorm1(x)\n",
    "        x = self.relu(x)\n",
    "        \n",
    "        x = self.drop(x)\n",
    "        y = self.lin2(x)\n",
    "        y = self.bnorm2(x)\n",
    "        y = self.relu(x)\n",
    "        \n",
    "        y = self.drop(y)\n",
    "        x = x + y \n",
    "        y = self.lin2a(x)\n",
    "        y = self.bnorm2a(y)\n",
    "        y = self.relu(y)\n",
    "        \n",
    "        y = self.drop(y)\n",
    "        x = x + y\n",
    "        x = self.lin3(x)\n",
    "        x = self.bnorm3(x)        \n",
    "        x = self.relu(x)\n",
    "                \n",
    "        x = self.drop(x)\n",
    "        x = self.lin4(x)                                 \n",
    "        x = self.bnorm4(x)\n",
    "        x = self.relu(x)\n",
    "        \n",
    "        \n",
    "        x = self.drop(x)\n",
    "        x = self.lin5(x)\n",
    "        x = self.bnorm5(x)        \n",
    "        x = self.relu(x)\n",
    "        \n",
    "        x = self.drop(x)\n",
    "        x = self.lin6(x)\n",
    "        x = self.bnorm6(x)        \n",
    "        x = self.relu(x)\n",
    "                \n",
    "        # No need for softmax. Since we use cross entropy loss\n",
    "        return x\n",
    "    \n",
    "model = netter(input_size, hidden_size)       "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3a6be26",
   "metadata": {},
   "source": [
    "# Creating Zhao et al. sketch and run experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72a66b4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "h1 = hashlib.sha256\n",
    "model_results_dict = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed5a14fe",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "epoch_log = 50\n",
    "total_epochs = 251\n",
    "\n",
    "for run in range(0,5):\n",
    "    for gamma in (1e-7,1e-6,1e-5):\n",
    "        for beta in (1e-3,1e-2,1e-1, 3e-1):\n",
    "            for rho in (50, 10, 1, 0.5, 0.1, 0.022, 0.005)[::-1]:\n",
    "                \n",
    "                num_has = int(np.log(2/beta))\n",
    "                d = int(1/gamma)\n",
    "                std = (np.log(2/beta) / rho)**0.5\n",
    "\n",
    "                salt = {}\n",
    "                for i in range(num_has):\n",
    "                    salt[i] = np.random.randint(low=65,high=90,size=20,dtype=\"int32\").view(f\"U{20}\")[0]\n",
    "\n",
    "                vectors_dict = {}\n",
    "                signs_dict = {}\n",
    "                n_ab = 240_000\n",
    "                sketch = {}\n",
    "                labels = []            \n",
    "                t = time.time()\n",
    "\n",
    "                # The sender creates the sketch \n",
    "                \n",
    "                for k in range(num_has):\n",
    "                    for i in range(n_ab):\n",
    "                        labels.append(train_dataset.targets[i])\n",
    "                        s = salt[k] + str(i) + str(train_dataset.targets[i].item()) \n",
    "                        hot = int(h1(s.encode(\"utf-8\")).hexdigest(),16) \n",
    "                        buck = (hot // 2) % d\n",
    "                        sign = 2 * (hot % 2) - 1\n",
    "                        if (k,buck) in sketch:\n",
    "                            sketch[k,buck] += sign\n",
    "                        else:\n",
    "                            sketch[k,buck] = sign\n",
    "\n",
    "                print(f\"Sender done with clean sketch ! Time it took {time.time()-t}\")\n",
    "                            \n",
    "                # The sketch is ready. Now add Gaussian noise as given by Zhao et al.\n",
    "\n",
    "                t = time.time()\n",
    "                for k in range(num_has):\n",
    "                    gaussian_noise = std*np.random.randn(d)\n",
    "                    for i in range(d):\n",
    "                        if (k,i) in sketch:\n",
    "                            sketch[k,i] += gaussian_noise[i] \n",
    "                        else:\n",
    "                            sketch[k,i] = gaussian_noise[i] \n",
    "\n",
    "                # The following two lines clip the sketch values to be between -1 and 1. This helps with our \n",
    "                # problem as argued in our theorems\n",
    "                for key in sketch.keys():\n",
    "                    sketch[key] = min(1,max(-1,sketch[key]))\n",
    "                    \n",
    "                print(f\"Sender done! Time it took {time.time()-t}\")\n",
    "                \n",
    "                ########################################\n",
    "                # Now the receiver party does the work #\n",
    "                ########################################\n",
    "\n",
    "                results = {}\n",
    "                out = []\n",
    "                rep_dict = {}\n",
    "                val_range = 10\n",
    "                t= time.time()\n",
    "                \n",
    "                # The sum over i should be over the \"ids\" that party R is privy to. \n",
    "                \n",
    "                for k in range(num_has):    \n",
    "                    for i in range(n_ab):\n",
    "                        temp = []\n",
    "                        for j in range(val_range):                                                \n",
    "                            s = salt[k] + str(i) + str(j) \n",
    "                            hot = int(h1(s.encode(\"utf-8\")).hexdigest(),16) \n",
    "                            buck = (hot // 2) % d\n",
    "                            sign = 2 * (hot % 2) - 1\n",
    "                            if (i,j) in results:\n",
    "                                results[i,j].append(sign*sketch[k,buck])\n",
    "                            else:\n",
    "                                results[i,j] = [sign*sketch[k,buck]]\n",
    "\n",
    "                results_min = {}\n",
    "                out = []\n",
    "\n",
    "                # For the count sketch, the median (implemented below) over the hash functions gives the estimate.  \n",
    "                \n",
    "                for i in range(n_ab):\n",
    "                    temp = []\n",
    "                    for j in range(val_range):\n",
    "                        temp.append(np.median(results[i,j]))\n",
    "                    out.append(temp)\n",
    "                labels_raw = torch.tensor(out)\n",
    "                print(f\"Receiver done! Time it took {time.time()-t}\")\n",
    "\n",
    "                # The following two lines log the number of labels that are (in) correctly reconstructed\n",
    "                false = torch.sum(torch.argmax(labels_raw,dim=1) - torch.tensor(labels[:n_ab]) != 0)\n",
    "                true = torch.sum(torch.argmax(labels_raw,dim=1) - torch.tensor(labels[:n_ab]) == 0)\n",
    "\n",
    "                labels_raw = torch.tensor(out).cuda()  \n",
    "                image_data = train_dataset.data.cuda()\n",
    "                batch_size = 2048\n",
    "                model_raw = netter(input_size, hidden_size)\n",
    "                model_raw = model_raw.cuda()\n",
    "                loss_raw = nn.CrossEntropyLoss()\n",
    "                num_epochs = total_epochs\n",
    "                \n",
    "                optimizer_raw = torch.optim.AdamW(model_raw.parameters(), lr = 1e-5)\n",
    "                t0 = time.time()\n",
    "                \n",
    "                for epoch in range(num_epochs):\n",
    "                    for i in range(labels_raw.shape[0] // batch_size):\n",
    "                        images = image_data[i*batch_size:(i+1)*batch_size,:,:].float()\n",
    "                        labels_for_model = labels_raw[i*batch_size:(i+1)*batch_size,:]                    \n",
    "                        images = images.reshape(-1,28*28)\n",
    "                        outputs = model_raw(images)                            \n",
    "                        labels_final = labels_for_model\n",
    "\n",
    "                        out_norm = torch.nn.functional.softmax(outputs,dim=1)\n",
    "                        out_clip = torch.where(out_norm > 0.0001, out_norm, 0.0001) # Bound the output                                 \n",
    "                        out_log = torch.log(out_norm)\n",
    "\n",
    "                        l = - torch.mean( torch.sum(labels_final * out_log,dim=1) )                        \n",
    "                        optimizer_raw.zero_grad()\n",
    "                        l.backward()\n",
    "                        optimizer_raw.step()\n",
    "\n",
    "                        if epoch % epoch_log == 0 and i == 0:\n",
    "                            with torch.no_grad():\n",
    "                                nc = 0\n",
    "                                ns = 0\n",
    "                                for i, (images,labels_test) in enumerate(test_loader):\n",
    "                                    images = images.reshape(-1,28*28).cuda()        \n",
    "                                    outputs = model_raw(images)\n",
    "                                    _, predicted = torch.max(outputs,1)\n",
    "                                    ns += outputs.shape[0]\n",
    "                                    nc += (predicted==labels_test.cuda()).sum().item()\n",
    "                            acc= nc/ns\n",
    "                            print(f\"For {epoch}, {gamma}, {beta}, {rho} the test accuracy is: {acc}\")                            \n",
    "                            print(f\"The time it took for {epoch_log} epochs: {time.time() - t0} \\n\")\n",
    "                            t0 = time.time()\n",
    "                            model_results_dict[gamma,beta,rho,run,epoch] = [acc,true,false]\n",
    "\n",
    "                    f = open(\"perfect_join_zhao\",\"wb\")\n",
    "                    pickle.dump(model_results_dict,f)\n",
    "                    f.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adfd1f5a",
   "metadata": {},
   "source": [
    "# Analyse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bafd706",
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_epoch = 100\n",
    "res = {}\n",
    "n_ab = 240_000\n",
    "v = 10\n",
    "res_avg = {}\n",
    "\n",
    "# Get the results for (gamma, beta, rho) for each run. \n",
    "for run in range(0,5):\n",
    "    for gamma in (1e-7,1e-6,1e-5):\n",
    "        for beta in (1e-3,1e-2,1e-1,3e-1):\n",
    "            for rho in (50, 10, 1, 0.5, 0.1, 0.022, 0.005)[::-1]:\n",
    "                arr = []\n",
    "                for i in model_results_dict.keys():\n",
    "                    if i[0] == gamma and i[1] == beta and i[2] == rho and i[3] == run and i[4] == stop_epoch:\n",
    "                        arr.append(model_results_dict[i][0])\n",
    "                if len(arr) > 0:\n",
    "                    res_avg[gamma, beta, rho,run] = np.max(arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae14c199",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_ab = 240_000\n",
    "v = 10\n",
    "res_avg_run = {}\n",
    "# Now average the results for (gamma, beta, rho) across all runs. \n",
    "for gamma in (1e-7,1e-6,1e-5):\n",
    "    for beta in (1e-3,1e-2,1e-1,3e-1):\n",
    "        for rho in (50, 10, 1, 0.5, 0.1, 0.022, 0.005)[::-1]:\n",
    "            arr = []\n",
    "            for i in res_avg.keys():\n",
    "                if i[0] == gamma and i[1] == beta and i[2] == rho:\n",
    "                    arr.append(res_avg[i])\n",
    "            if len(arr) > 0:\n",
    "                res_avg_run[gamma, beta, rho] = np.mean(arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9137c761",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_avg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1b7b514",
   "metadata": {},
   "source": [
    "### Now load our results "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49a721c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = open(\"perfect_join_ours\",\"rb\")\n",
    "our_res = pickle.load(f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8218c7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "res2 = {}\n",
    "n_ab = 240_000\n",
    "v = 10\n",
    "stop_epoch = 100\n",
    "# Get the results for (epsilon, hash_dimension) for each run. \n",
    "for run in range(10):\n",
    "    for eps in [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]:\n",
    "        for d in [n_ab*10*v, n_ab*1*v, int(n_ab*0.1*v)]:\n",
    "            arr = []\n",
    "            for i in our_res.keys():\n",
    "                if i[1] == eps and i[0] == d and i[3] == run and i[2] == stop_epoch:\n",
    "                    arr.append(our_res[i][0])\n",
    "            if len(arr) > 0:\n",
    "                res2[eps, d, run] = arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f9c255a",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_avg2 = {}\n",
    "# Get average the results for (epsilon, hash_dimension) across all runs.\n",
    "for eps in [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]:\n",
    "    for b in [n_ab*10*v, n_ab*1*v, int(n_ab*0.1*v)]:\n",
    "        arr = []\n",
    "        for i in res2.keys():\n",
    "            for run in range(0,10):\n",
    "                if i[0] == eps and i[1] == b:\n",
    "                    arr.append( res2[i][0] )\n",
    "        if len(arr) > 0:\n",
    "            res_avg2[eps, b] = np.mean(arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b917f60",
   "metadata": {},
   "outputs": [],
   "source": [
    "eps_vals = [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]\n",
    "\n",
    "# res_eps2 gets results for a fixed epsilon while changing hash dimension\n",
    "res_eps2 = {}\n",
    "for eps in [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]:\n",
    "    res_eps2[eps] = []\n",
    "    for b in [n_ab*10*v, n_ab*1*v, int(n_ab*0.1*v)]:\n",
    "        res_eps2[eps].append(res_avg2[eps,b])\n",
    "\n",
    "# res_d2 gets results for a fixed hash dimension while changing epsilon\n",
    "res_d2 = {}\n",
    "for d in [n_ab*10*v, n_ab*1*v, int(n_ab*0.1*v)]:\n",
    "    res_d2[d] = []\n",
    "    for eps in [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]:    \n",
    "        res_d2[d].append(res_avg2[eps,d])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3d5597c",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_avg_run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e6fcfaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "figure(figsize=(8*1.2, 4*1.2), dpi = 200)\n",
    "for gamma in (1e-7,1e-6,1e-5):\n",
    "    for beta in (1e-3,1e-2,1e-1, 3e-1):\n",
    "        y = []\n",
    "        x = []\n",
    "        for rho in (50, 10, 1, 0.5, 0.1, 0.022, 0.005)[::-1]:\n",
    "            y.append(res_avg_run[gamma,beta,rho])\n",
    "            eps = (2*rho)**0.5 \n",
    "            x.append(eps)\n",
    "        plt.xscale('log',base=10) \n",
    "        plt.plot(x,y,label= (gamma,beta))\n",
    "        plt.legend(loc=\"lower right\")\n",
    "\n",
    "# plt.plot(eps_vals[:-1], res_d2[2400000][:-1], label = r\"Ours, $ d=2.4 \\times 10^{6}$\",color = \"black\",linestyle='dashed')\n",
    "plt.figtext(0.5, 0.9, r\"EMNIST test accuracy. Ours vs. Zhao et al.\", wrap=True, horizontalalignment='center', fontsize=20)\n",
    "plt.legend(loc=\"lower right\")\n",
    "plt.tick_params(axis='both', which='major', labelsize=14)\n",
    "plt.ylabel(\"Test accuracy\", size = 20)\n",
    "plt.xlabel(\"epsilon\", size = 20)\n",
    "plt.savefig('zhao_baseline.png',bbox_inches = \"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbf9f2f3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f12b361",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c02f451c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c405d8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
