{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8869f6c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import os\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "import pickle\n",
    "import pandas as pd\n",
    "import hashlib\n",
    "import time\n",
    "from matplotlib.pyplot import figure \n",
    "import matplotlib.ticker as ticker"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f76909a7",
   "metadata": {},
   "source": [
    "# Deep learning experiments for dp-joins"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61e1966c",
   "metadata": {},
   "source": [
    "## First we will load the usual model without privacy and without joins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7aea1ea3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\")\n",
    "\n",
    "# Hyper-parameters \n",
    "input_size = 784 # 28x28\n",
    "hidden_size = 500 \n",
    "num_classes = 10\n",
    "\n",
    "# MNIST dataset. FLIPPING TEST AND TRAIN\n",
    "train_dataset= torchvision.datasets.EMNIST(root='./data', \n",
    "                                           train=True,\n",
    "                                           split = \"digits\",\n",
    "                                           transform=transforms.ToTensor(),  \n",
    "                                           download=True)\n",
    "\n",
    "test_dataset = torchvision.datasets.EMNIST(root='./data', \n",
    "                                          train=False, \n",
    "                                           split = \"digits\",\n",
    "                                          transform=transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b2c01b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_dataset), len(test_dataset) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67965afc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data loader\n",
    "batch_size = 128\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, \n",
    "                                           batch_size=batch_size, \n",
    "                                           shuffle=True)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, \n",
    "                                          batch_size=batch_size, \n",
    "                                          shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4338ae0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "examples = iter(test_loader)\n",
    "example_data, example_targets = examples.next()\n",
    "\n",
    "for i in range(6):\n",
    "    plt.subplot(2,3,i+1)\n",
    "    plt.imshow(example_data[i+5][0], cmap='gray')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae59788b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class netter(nn.Module):\n",
    "    \n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        super().__init__()\n",
    "        self.bnorm0 = nn.BatchNorm1d(input_size)\n",
    "        \n",
    "        self.lin1 = nn.Linear(input_size, hidden_size)\n",
    "        self.bnorm1 = nn.BatchNorm1d(hidden_size)\n",
    "        \n",
    "        self.lin2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.bnorm2 = nn.BatchNorm1d(hidden_size)\n",
    "        \n",
    "        self.lin2a = nn.Linear(hidden_size, hidden_size)\n",
    "        self.bnorm2a = nn.BatchNorm1d(hidden_size)\n",
    "        \n",
    "        self.lin3 = nn.Linear(hidden_size, hidden_size//2)\n",
    "        self.bnorm3 = nn.BatchNorm1d(hidden_size//2)\n",
    "        \n",
    "        self.lin4 = nn.Linear(hidden_size//2, hidden_size//4)\n",
    "        self.bnorm4 = nn.BatchNorm1d(hidden_size//4)\n",
    "        \n",
    "        self.lin5 = nn.Linear(hidden_size//4, hidden_size//8)\n",
    "        self.bnorm5 = nn.BatchNorm1d(hidden_size//8)\n",
    "        \n",
    "        self.lin6 = nn.Linear(hidden_size//8, 10)\n",
    "        self.bnorm6 = nn.BatchNorm1d(10)         \n",
    "        \n",
    "        self.relu = nn.ReLU()\n",
    "        self.drop = nn.Dropout(0.2)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.bnorm0(x)\n",
    "        \n",
    "        x = self.lin1(x)\n",
    "        x = self.bnorm1(x)\n",
    "        x = self.relu(x)\n",
    "        \n",
    "        x = self.drop(x)\n",
    "        y = self.lin2(x)\n",
    "        y = self.bnorm2(x)\n",
    "        y = self.relu(x)\n",
    "        \n",
    "        y = self.drop(y)\n",
    "        x = x + y \n",
    "        y = self.lin2a(x)\n",
    "        y = self.bnorm2a(y)\n",
    "        y = self.relu(y)\n",
    "        \n",
    "        y = self.drop(y)\n",
    "        x = x + y\n",
    "        x = self.lin3(x)\n",
    "        x = self.bnorm3(x)        \n",
    "        x = self.relu(x)\n",
    "                \n",
    "        x = self.drop(x)\n",
    "        x = self.lin4(x)                                 \n",
    "        x = self.bnorm4(x)\n",
    "        x = self.relu(x)\n",
    "        \n",
    "        \n",
    "        x = self.drop(x)\n",
    "        x = self.lin5(x)\n",
    "        x = self.bnorm5(x)        \n",
    "        x = self.relu(x)\n",
    "        \n",
    "        x = self.drop(x)\n",
    "        x = self.lin6(x)\n",
    "        x = self.bnorm6(x)        \n",
    "        x = self.relu(x)\n",
    "                    \n",
    "        # No need for softmax. Since we use cross entropy loss\n",
    "        return x\n",
    "    \n",
    "model = netter(input_size, hidden_size)       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcc7c670",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7df9c24",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "num_epochs = 10\n",
    "learning_rate = 0.0001\n",
    "loss = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images,labels) in enumerate(train_loader):\n",
    "        images = images.reshape(-1,28*28).cuda()\n",
    "\n",
    "        outputs = model(images)\n",
    "        \n",
    "        l = loss(outputs, labels.cuda())\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        l.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if i%100 ==0:\n",
    "            print(epoch, l.item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23c67c67",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92f4fe90",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    nc = 0\n",
    "    ns = 0\n",
    "    for i, (images,labels) in enumerate(test_loader):\n",
    "        images = images.reshape(-1,28*28)\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs,1)\n",
    "        ns += outputs.shape[0]\n",
    "        nc += (predicted==labels).sum().item()\n",
    "acc= nc/ns\n",
    "print(acc)\n",
    "# After 5 epochs 98.035%"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3a6be26",
   "metadata": {},
   "source": [
    "# Lets start creating sketches now"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30a1a5b4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "total_epochs = 201\n",
    "epoch_log = 50\n",
    "results_dict = {}\n",
    "for run in range(0, 1):\n",
    "    for eps in [0.1, 10, 1, 0.25, 0.5, 2, 4, 8, 20]:\n",
    "        v = 10\n",
    "        n_ab = 240_000\n",
    "        for d in [n_ab*100*v, n_ab*10*v, int(n_ab*1*v),int(n_ab*0.1*v)]:\n",
    "            t0 = time.time()\n",
    "            labels = []\n",
    "            hashes = []\n",
    "            vectors = []\n",
    "            \n",
    "            # Create the sketch\n",
    "            salt = np.random.randint(low=65,high=90,size=20,dtype=\"int32\").view(f\"U{20}\")[0]\n",
    "            for i in range(n_ab):\n",
    "                labels.append(train_dataset.targets[i])\n",
    "                s = str(i) + str(train_dataset.targets[i].item()) + salt \n",
    "                hot = int(hashlib.sha256(s.encode(\"utf-8\")).hexdigest(),16)\n",
    "                hashes.append(hot)\n",
    "                vectors.append((hot // 2) % d)\n",
    "\n",
    "            vectors = np.array(vectors)   \n",
    "            one_hot_matrix_hash = {}      \n",
    "\n",
    "            for i in range(vectors.size):\n",
    "                hot = hashes[i]\n",
    "                sign = 2 * (hot % 2) - 1\n",
    "                one_hot_matrix_hash[i] = [vectors[i], sign]\n",
    "            \n",
    "            buckets = np.zeros(d)        \n",
    "\n",
    "            for i in range(vectors.size):\n",
    "                buckets[one_hot_matrix_hash[i][0]] += one_hot_matrix_hash[i][1]\n",
    "\n",
    "            # Add two sided geometric noise to the sketch\n",
    "            alpha = np.exp(-eps)\n",
    "            buckets += np.random.geometric(1-alpha,size = (buckets.shape))-np.random.geometric(1-alpha,size = (buckets.shape))\n",
    "\n",
    "            ##============================================================================##\n",
    "            ##================== This marks the end of Sender's work =====================##\n",
    "            ##============================================================================##\n",
    "            \n",
    "            # Now clip them\n",
    "            buckets[buckets >= 1] = 1\n",
    "            buckets[buckets <= -1] = -1\n",
    "            buckets[buckets == 0 ] = 0           \n",
    "\n",
    "            # The sum over i should be over the \"ids\" that party R is privy to. \n",
    "            out = []\n",
    "            rep_dict = {}\n",
    "            for i in range(vectors.size):\n",
    "                temp = []\n",
    "                for j in range(v):                                                \n",
    "                    s = str(i) + str(j) + salt \n",
    "                    hot = int(hashlib.sha256(s.encode(\"utf-8\")).hexdigest(),16)\n",
    "                    vec = (hot // 2) % d                        \n",
    "                    sign = 2 * (hot % 2) - 1                        \n",
    "                    if vec in rep_dict:\n",
    "                        rep_dict[vec] += 1\n",
    "                    else:\n",
    "                        rep_dict[vec] = 1                        \n",
    "                    temp.append(sign*buckets[vec]) \n",
    "                out.append([temp])\n",
    "\n",
    "            # The below lines count N_R from defintion 6.2 of our paper \n",
    "            w = []    \n",
    "            for i in range(vectors.size):\n",
    "                temp = []\n",
    "                for j in range(v):                                                \n",
    "                    s = str(i) + str(j) + salt \n",
    "                    hot = int(hashlib.sha256(s.encode(\"utf-8\")).hexdigest(),16)\n",
    "                    vec = (hot // 2) % d                                            \n",
    "                    temp.append(1/rep_dict[vec])\n",
    "                w.append([temp])\n",
    "    \n",
    "            out = np.array(out).squeeze()\n",
    "            w = np.array(w).squeeze()\n",
    "          \n",
    "            # The lines below are just to get an estimate the number of labels that get flipped. \n",
    "            result = np.argmax(out,axis = 1)\n",
    "            tru = 0\n",
    "            fal = 0\n",
    "            for i in range(vectors.size):\n",
    "                if int(result[i]) == int(train_dataset.targets[i]):\n",
    "                    tru  += 1\n",
    "                else:\n",
    "                    fal += 1\n",
    "            print(\"For epsilon\", eps,\" and d \", d, \"We have true: \", tru, \"and false: \", fal)\n",
    "\n",
    "            # Now party R starts training\n",
    "            labels_raw = torch.tensor(out).cuda() \n",
    "            weights_raw = torch.tensor(w).cuda() \n",
    "            image_data = train_dataset.data.cuda()\n",
    "            batch_size = 2048\n",
    "            model_raw = netter(input_size, hidden_size)\n",
    "            model_raw = model_raw.cuda()\n",
    "\n",
    "            num_epochs = total_epochs\n",
    "            optimizer_raw = torch.optim.AdamW(model_raw.parameters(), lr = 1e-5)\n",
    "\n",
    "            for epoch in range(num_epochs):\n",
    "                for i in range(labels_raw.shape[0] // batch_size):\n",
    "                    images = image_data[i*batch_size:(i+1)*batch_size,:,:].float()\n",
    "                    labels_for_model = labels_raw[i*batch_size:(i+1)*batch_size,:]        \n",
    "                    weights_for_model = weights_raw[i*batch_size:(i+1)*batch_size,:]        \n",
    "                    images = images.reshape(-1,28*28)\n",
    "                    outputs = model_raw(images)                            \n",
    "                    labels_final = torch.mul(labels_for_model, weights_for_model) \n",
    "                    \n",
    "                    out_norm = torch.nn.functional.softmax(outputs,dim=1)\n",
    "                    out_clip = torch.where(out_norm > 0.0001, out_norm, 0.0001)                                        \n",
    "                    out_log = torch.log(out_norm)\n",
    "                    l = - torch.mean( torch.sum(labels_final * out_log,dim=1) )\n",
    "                    \n",
    "                    optimizer_raw.zero_grad()\n",
    "                    l.backward()\n",
    "                    optimizer_raw.step()\n",
    "\n",
    "                    if epoch % epoch_log == 0 and i == 0:                        \n",
    "                        with torch.no_grad():\n",
    "                            model.eval()\n",
    "                            nc = 0\n",
    "                            ns = 0\n",
    "                            for i, (images,labels_test) in enumerate(test_loader):\n",
    "                                images = images.reshape(-1,28*28).cuda()        \n",
    "                                outputs = model_raw(images)\n",
    "                                _, predicted = torch.max(outputs,1)\n",
    "                                ns += outputs.shape[0]\n",
    "                                nc += (predicted==labels_test.cuda()).sum().item()\n",
    "                        acc= nc/ns\n",
    "                        print(f\"For {epoch}, {eps}, {d} the test accuracy is: {acc} \") \n",
    "                        print(f\"The time it took for {epoch_log} epochs: {time.time() - t0} \\n\")\n",
    "                        t0 = time.time()\n",
    "                        model.train()\n",
    "\n",
    "                        results_dict[d,eps,epoch,run] = [acc, tru, fal]\n",
    "\n",
    "                f = open(\"perfect_join_ours\",\"wb\")\n",
    "                pickle.dump(results_dict,f)\n",
    "                f.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adfd1f5a",
   "metadata": {},
   "source": [
    "# Analyse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bafd706",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = {}\n",
    "n_ab = 240_000\n",
    "v = 10\n",
    "stop_epoch = 100\n",
    "# Get the results for (epsilon, hash_dimension) for each run. \n",
    "for run in range(10):\n",
    "    for eps in [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]:\n",
    "        for d in [n_ab*100*v, n_ab*10*v, int(n_ab*1*v),int(n_ab*0.1*v)]:\n",
    "            arr = []\n",
    "            for i in results_dict.keys():\n",
    "                if i[1] == eps and i[0] == d and i[3] == run and i[2] == stop_epoch:\n",
    "                    arr.append(results_dict[i][0])\n",
    "            if len(arr) > 0:\n",
    "                res[eps, d, run] = arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4d3a33d",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_avg = {}\n",
    "# Average the results for (epsilon, hash_dimension) across all runs.\n",
    "for eps in [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]:\n",
    "    for d in [n_ab*100*v, n_ab*10*v, int(n_ab*1*v),int(n_ab*0.1*v)]:\n",
    "        arr = []\n",
    "        for i in res.keys():\n",
    "            for run in range(0,10):\n",
    "                if i[0] == eps and i[1] == d:\n",
    "                    arr.append( res[i][0] )\n",
    "        if len(arr) > 0:\n",
    "            res_avg[eps, d] = np.mean(arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72af19b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_eps = {}\n",
    "# res_eps gets results for a fixed epsilon while changing hash dimension\n",
    "for eps in [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]:\n",
    "    res_eps[eps] = []\n",
    "    for d in [n_ab*100*v, n_ab*10*v, int(n_ab*1*v),int(n_ab*0.1*v)]:\n",
    "        res_eps[eps].append(res_avg[eps,d])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e50aaa1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_d = {}\n",
    "# res_d gets results for a fixed hash dimension while changing epsilon\n",
    "for d in [n_ab*100*v, n_ab*10*v, int(n_ab*1*v),int(n_ab*0.1*v)]:\n",
    "    res_d[d] = []\n",
    "    for eps in [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]:    \n",
    "        res_d[d].append(res_avg[eps,d])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f491f23f",
   "metadata": {},
   "outputs": [],
   "source": [
    "eps_vals = [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1dcff68",
   "metadata": {},
   "outputs": [],
   "source": [
    "figure(figsize=(8*1.2, 4*1.2), dpi = 200)\n",
    "for key,values in res_d.items():\n",
    "    if key != 240_000 and key != 24_000:\n",
    "        plt.plot(eps_vals, values[:10], label = str(int(key/ 240_000_0)) +\"$ \\cdot k  |D_R ⋈ D_S| $\")\n",
    "        plt.legend(loc=\"lower right\")\n",
    "    elif key == 240_000:\n",
    "        plt.plot(eps_vals, values[:10], label = \"0.1$ \\cdot k |D_R ⋈ D_S| $\" )\n",
    "        plt.legend(loc=\"lower right\")\n",
    "    elif key == 24_000:\n",
    "        plt.plot(eps_vals, values[:10], label = \"0.01$ \\cdot k |D_R ⋈ D_S| $\" )\n",
    "        plt.legend(loc=\"lower right\")\n",
    "\n",
    "x_ticks = [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]    \n",
    "x_labels = [0.1, 0.25, 0.5, 1, 2, 4, 8, 10, 20]\n",
    "plt.xscale(\"log\")\n",
    " \n",
    "plt.gca().xaxis.set_major_locator(ticker.LogLocator(base=10.0))\n",
    "custom_formatter = ticker.FuncFormatter(lambda x, pos: f'{x:.2g}')\n",
    "plt.gca().xaxis.set_major_formatter(custom_formatter)\n",
    "\n",
    "plt.tick_params(axis='both', which='major', labelsize=14)\n",
    "plt.axhline(y=0.9804, color='k', linestyle='--', label = \"w/o privacy\")\n",
    "plt.legend(fontsize = 12)           \n",
    "plt.figtext(0.5, 0.9, r\"EMNIST test accuracy vs. $\\epsilon$. $k=10, |D_R ⋈ D_S|$ = 240K.\", wrap = True, horizontalalignment='center', fontsize=16)\n",
    "plt.ylabel(\"Test accuracy\", size = 20)\n",
    "plt.xlabel(\"epsilon\", size = 20)\n",
    "plt.savefig('emnist_with_eps.png', bbox_inches = \"tight\")\n",
    "plt.rc('xtick', labelsize=12) \n",
    "plt.rc('ytick', labelsize=12) \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c02f451c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c405d8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
