{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8869f6c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "import pickle\n",
    "import pandas as pd\n",
    "import hashlib\n",
    "import time\n",
    "from matplotlib.pyplot import figure \n",
    "import matplotlib.ticker as ticker"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f76909a7",
   "metadata": {},
   "source": [
    "# Deep learning experiments for dp-joins"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61e1966c",
   "metadata": {},
   "source": [
    "## First we will load the usual model without privacy and without joins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7aea1ea3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "\n",
    "# Hyper-parameters \n",
    "input_size = 784 # 28x28\n",
    "hidden_size = 500 \n",
    "num_classes = 10\n",
    "\n",
    "# MNIST dataset. FLIPPING TEST AND TRAIN\n",
    "train_dataset= torchvision.datasets.EMNIST(root='./data', \n",
    "                                           train=True,\n",
    "                                           split = \"digits\",\n",
    "                                           transform=transforms.ToTensor(),  \n",
    "                                           download=True)\n",
    "\n",
    "test_dataset = torchvision.datasets.EMNIST(root='./data', \n",
    "                                          train=False, \n",
    "                                           split = \"digits\",\n",
    "                                          transform=transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1f4ae26",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "506e138c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data loader\n",
    "batch_size = 128\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, \n",
    "                                           batch_size=batch_size, \n",
    "                                           shuffle=False)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, \n",
    "                                          batch_size=batch_size, \n",
    "                                          shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4338ae0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "examples = iter(train_loader)\n",
    "example_data, example_targets = examples.next()\n",
    "for i in range(9):\n",
    "    plt.subplot(3,3,i+1)\n",
    "    plt.imshow(example_data[i][0], cmap='gray')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae59788b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class netter(nn.Module):\n",
    "    \n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        super().__init__()\n",
    "        self.bnorm0 = nn.BatchNorm1d(input_size)\n",
    "        \n",
    "        self.lin1 = nn.Linear(input_size, hidden_size)\n",
    "        self.bnorm1 = nn.BatchNorm1d(hidden_size)\n",
    "        \n",
    "        self.lin2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.bnorm2 = nn.BatchNorm1d(hidden_size)\n",
    "        \n",
    "        self.lin2a = nn.Linear(hidden_size, hidden_size)\n",
    "        self.bnorm2a = nn.BatchNorm1d(hidden_size)\n",
    "        \n",
    "        self.lin3 = nn.Linear(hidden_size, hidden_size//2)\n",
    "        self.bnorm3 = nn.BatchNorm1d(hidden_size//2)\n",
    "        \n",
    "        self.lin4 = nn.Linear(hidden_size//2, hidden_size//4)\n",
    "        self.bnorm4 = nn.BatchNorm1d(hidden_size//4)\n",
    "        \n",
    "        self.lin5 = nn.Linear(hidden_size//4, hidden_size//8)\n",
    "        self.bnorm5 = nn.BatchNorm1d(hidden_size//8)\n",
    "        \n",
    "        self.lin6 = nn.Linear(hidden_size//8, 10)\n",
    "        self.bnorm6 = nn.BatchNorm1d(10)         \n",
    "        \n",
    "        self.relu = nn.ReLU()\n",
    "        self.drop = nn.Dropout(0.2)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.bnorm0(x)\n",
    "        \n",
    "        x = self.lin1(x)\n",
    "        x = self.bnorm1(x)\n",
    "        x = self.relu(x)\n",
    "        \n",
    "        x = self.drop(x)\n",
    "        y = self.lin2(x)\n",
    "        y = self.bnorm2(x)\n",
    "        y = self.relu(x)\n",
    "        \n",
    "        y = self.drop(y)\n",
    "        x = x + y \n",
    "        y = self.lin2a(x)\n",
    "        y = self.bnorm2a(y)\n",
    "        y = self.relu(y)\n",
    "        \n",
    "        y = self.drop(y)\n",
    "        x = x + y\n",
    "        x = self.lin3(x)\n",
    "        x = self.bnorm3(x)        \n",
    "        x = self.relu(x)\n",
    "                \n",
    "        x = self.drop(x)\n",
    "        x = self.lin4(x)                                 \n",
    "        x = self.bnorm4(x)\n",
    "        x = self.relu(x)\n",
    "                \n",
    "        x = self.drop(x)\n",
    "        x = self.lin5(x)\n",
    "        x = self.bnorm5(x)        \n",
    "        x = self.relu(x)\n",
    "        \n",
    "        x = self.drop(x)\n",
    "        x = self.lin6(x)\n",
    "        x = self.bnorm6(x)        \n",
    "        x = self.relu(x)\n",
    "                        \n",
    "        # No need for softmax. Since we use cross entropy loss\n",
    "        return x\n",
    "    \n",
    "model = netter(input_size, hidden_size)       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "187cae70",
   "metadata": {},
   "outputs": [],
   "source": [
    "count = 0\n",
    "for param in model.parameters():\n",
    "    if param.requires_grad:\n",
    "        count += param.numel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "026c736d",
   "metadata": {},
   "outputs": [],
   "source": [
    "count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7df9c24",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "num_epochs = 10\n",
    "learning_rate = 0.0001\n",
    "loss = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images,labels) in enumerate(train_loader):\n",
    "        images = images.reshape(-1,28*28)\n",
    "        outputs = model(images)\n",
    "        l = loss(outputs, labels)\n",
    "        optimizer.zero_grad()\n",
    "        l.backward()\n",
    "        optimizer.step()\n",
    "                \n",
    "        if i%100 == 0:\n",
    "            print(epoch, l.item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92f4fe90",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    nc = 0\n",
    "    ns = 0\n",
    "    for i, (images,labels) in enumerate(test_loader):\n",
    "        images = images.reshape(-1,28*28)\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs,1)\n",
    "        ns += outputs.shape[0]\n",
    "        nc += (predicted==labels).sum().item()\n",
    "acc= nc/ns\n",
    "print(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30a1a5b4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "total_epochs = 401\n",
    "epoch_log = 50\n",
    "results_dict = {}\n",
    "for eps in [1]:\n",
    "    for run in range(0,1):\n",
    "        v = 10\n",
    "        n_ab = 20_000\n",
    "        for d in [int(n_ab*100*v), int(n_ab*10*v),int(n_ab*1*v),int(n_ab*0.1*v)]:        \n",
    "            for n_a in range(0, 220_001 , 20_000):\n",
    "\n",
    "                t0 = time.time()\n",
    "                labels = []\n",
    "                hashes = []\n",
    "                vectors = []\n",
    "                \n",
    "                # Create the sketch\n",
    "                salt = np.random.randint(low=65,high=90,size=20,dtype=\"int32\").view(f\"U{20}\")[0]\n",
    "                for i in range(n_ab):\n",
    "                    labels.append(train_dataset.targets[i])\n",
    "                    s = str(i) + str(train_dataset.targets[i].item()) + salt\n",
    "                    hot = int(hashlib.sha256(s.encode(\"utf-8\")).hexdigest(),16)\n",
    "                    hashes.append(hot)\n",
    "                    vectors.append((hot // 2) % d)\n",
    "\n",
    "                vectors = np.array(vectors)    \n",
    "                one_hot_matrix_hash = {}       \n",
    "                \n",
    "                for i in range(vectors.size):\n",
    "                    hot = hashes[i]\n",
    "                    sign = 2 * (hot % 2) - 1\n",
    "                    one_hot_matrix_hash[i] = [vectors[i], sign]\n",
    "\n",
    "                buckets = np.zeros(d)        \n",
    "                for i in range(vectors.size):\n",
    "                    buckets[one_hot_matrix_hash[i][0]] += one_hot_matrix_hash[i][1]\n",
    "            \n",
    "                # Add two sided geometric noise to the sketch\n",
    "                alpha = np.exp(-eps)\n",
    "                buckets += np.random.geometric(1-alpha,size = (buckets.shape))-np.random.geometric(1-alpha,size = (buckets.shape))\n",
    "                \n",
    "                ##============================================================================##\n",
    "                ##================== This marks the end of Sender's work ====================##\n",
    "                ##============================================================================##\n",
    "                \n",
    "                # Now clip them\n",
    "                buckets[buckets >= 1] = 1\n",
    "                buckets[buckets <= -1] = -1\n",
    "                buckets[buckets == 0 ] = 0\n",
    "                \n",
    "                # The sum over i should be over the \"ids\" that party R is privy to. \n",
    "                na_set = set()\n",
    "                for i in range(n_a):    \n",
    "                    na_set.add(np.random.randint(1e12,1e18))\n",
    "                na_list = list(na_set)\n",
    "                                                                               \n",
    "                out = []\n",
    "                rep_dict = {}\n",
    "                for i in range(vectors.size):\n",
    "                    temp = []\n",
    "                    for j in range(v):                                                \n",
    "                        s = str(i) + str(j) + salt \n",
    "                        hot = int(hashlib.sha256(s.encode(\"utf-8\")).hexdigest(),16)\n",
    "                        vec = (hot // 2) % d                        \n",
    "                        sign = 2 * (hot % 2) - 1                          \n",
    "                        if vec in rep_dict:\n",
    "                            rep_dict[vec] += 1\n",
    "                        else:\n",
    "                            rep_dict[vec] = 1                        \n",
    "                        temp.append(sign*buckets[vec]) \n",
    "                    out.append([temp])\n",
    "                                                        \n",
    "                na_vectors = []\n",
    "                na_sign_vectors = []\n",
    "                for i in na_list:\n",
    "                    temp = []\n",
    "                    for j in range(v):    \n",
    "                        s = str(i) + str(j) +salt           \n",
    "                        hot = int(hashlib.sha256(s.encode(\"utf-8\")).hexdigest(),16)\n",
    "                        vec = (hot // 2) % d                        \n",
    "                        sign = 2 * (hot % 2) - 1                          \n",
    "                        if vec in rep_dict:\n",
    "                            rep_dict[vec] += 1\n",
    "                        else:\n",
    "                            rep_dict[vec] = 1                        \n",
    "                        temp.append(sign*buckets[vec])\n",
    "                    out.append([temp])            \n",
    "                    \n",
    "                # The below lines count N_R from defintion 6.2 of our paper \n",
    "                w = []    \n",
    "                for i in range(vectors.size):\n",
    "                    temp = []\n",
    "                    for j in range(v):                                                \n",
    "                        s = str(i) + str(j) + salt \n",
    "                        hot = int(hashlib.sha256(s.encode(\"utf-8\")).hexdigest(),16)\n",
    "                        vec = (hot // 2) % d                                            \n",
    "                        temp.append(1/rep_dict[vec])\n",
    "                    w.append([temp])\n",
    "\n",
    "                for i in na_list:\n",
    "                    temp = []\n",
    "                    for j in range(v):                                                \n",
    "                        s = str(i) + str(j) + salt \n",
    "                        hot = int(hashlib.sha256(s.encode(\"utf-8\")).hexdigest(),16)\n",
    "                        vec = (hot // 2) % d                                            \n",
    "                        temp.append(1/rep_dict[vec])\n",
    "                    w.append([temp])\n",
    "                \n",
    "                out = np.array(out).squeeze()\n",
    "                w = np.array(w).squeeze()\n",
    "           \n",
    "                ##============================================================================##\n",
    "                ##============ This marks the end of Receiver's non-trainig work =============##\n",
    "                ##============================================================================##\n",
    "\n",
    "                # The lines below are just to get an estimate the number of labels that get flipped. \n",
    "                result = np.argmax(out,axis = 1)\n",
    "                tru = 0\n",
    "                fal = 0\n",
    "                for i in range(vectors.size):\n",
    "                    if int(result[i]) == int(train_dataset.targets[i]):\n",
    "                        tru  += 1\n",
    "                    else:\n",
    "                        fal += 1\n",
    "                print(\"For epsilon\", eps,\" and d \", d, \"We have true: \", tru, \"and false: \", fal)\n",
    "\n",
    "                # Now start training\n",
    "\n",
    "                labels_raw = torch.tensor(out).cuda() \n",
    "                weights_raw = torch.tensor(w).cuda() \n",
    "                image_data = train_dataset.data.cuda()\n",
    "                batch_size = 2048\n",
    "\n",
    "                model_raw = netter(input_size, hidden_size)\n",
    "                model_raw = model_raw.cuda()\n",
    "\n",
    "                loss_raw = nn.CrossEntropyLoss()\n",
    "\n",
    "                num_epochs = total_epochs\n",
    "                optimizer_raw = torch.optim.AdamW(model_raw.parameters(), lr = 1e-5)\n",
    "\n",
    "                for epoch in range(num_epochs):\n",
    "                    for i in range(labels_raw.shape[0] // batch_size):\n",
    "                        images = image_data[i*batch_size:(i+1)*batch_size,:,:].float()\n",
    "                        labels_for_model = labels_raw[i*batch_size:(i+1)*batch_size,:]        \n",
    "                        weights_for_model = weights_raw[i*batch_size:(i+1)*batch_size,:]        \n",
    "\n",
    "                        images = images.reshape(-1,28*28)\n",
    "                        outputs = model_raw(images)        \n",
    "                        labels_final = torch.mul(labels_for_model, weights_for_model) \n",
    "\n",
    "                        out_norm = torch.nn.functional.softmax(outputs,dim=1)\n",
    "                        out_clip = torch.where(out_norm > 0.0001, out_norm, 0.0001)                                    \n",
    "                        out_log = torch.log(out_norm)\n",
    "                        l = - torch.mean( torch.sum(labels_final * out_log,dim=1) )\n",
    "\n",
    "                        optimizer_raw.zero_grad()\n",
    "                        l.backward()\n",
    "                        optimizer_raw.step()\n",
    "\n",
    "                        if epoch % epoch_log == 0 and i == 0:                                                        \n",
    "                            with torch.no_grad():\n",
    "                                model.eval()\n",
    "                                nc = 0\n",
    "                                ns = 0\n",
    "                                for i, (images,labels_test) in enumerate(test_loader):\n",
    "                                    images = images.reshape(-1,28*28).cuda()        \n",
    "                                    outputs = model_raw(images)\n",
    "                                    _, predicted = torch.max(outputs,1)\n",
    "                                    ns += outputs.shape[0]\n",
    "                                    nc += (predicted==labels_test.cuda()).sum().item()\n",
    "                            acc= nc/ns\n",
    "                            print(f\"For {epoch}, {eps}, {d}, {n_a} the test accuracy is: {acc} \")\n",
    "                            print(f\"The time it took for {epoch_log} epochs: {time.time() - t0} \\n\")\n",
    "                            t0 = time.time()\n",
    "                            results_dict[d,eps,epoch,n_a, run] = [acc, tru, fal]\n",
    "                            model.train()\n",
    "                            \n",
    "                f = open(\"different_join_ours\",\"wb\")\n",
    "                pickle.dump(results_dict,f)\n",
    "                f.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adfd1f5a",
   "metadata": {},
   "source": [
    "# Analyse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bafd706",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = {}\n",
    "stop_epoch = 200\n",
    "# Get the results for (n_a, hash_dimension) for each run. \n",
    "for run in range(0,6):\n",
    "    for eps in [1]:        \n",
    "        for n_a in range(0, 220_000, 20_000):\n",
    "            for b in [int(n_ab*100*v), int(n_ab*10*v),int(n_ab*1*v),int(n_ab*0.1*v)]:\n",
    "                arr = []\n",
    "                for i in results_dict.keys():\n",
    "                    if i[3] == n_a and i[4] == run and i[2] == stop_epoch and i[0] == b:\n",
    "                        arr = results_dict[i][0], i[2] \n",
    "                if len(arr) > 0:\n",
    "                    res[eps, b, n_a, run] = arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "165342db",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_avg = {}\n",
    "res_err = {}\n",
    "# Average the results for (epsilon, hash_dimension) across all runs.\n",
    "# First we want to plot for a fixed hash dimension d. Lets fix that to be n_ab*1*v\n",
    "# This plot has not been included in the paper. But its helpful nevertheless.\n",
    "for eps in [1]:\n",
    "    for n_a in range(0, 220_000, 20_000):\n",
    "        for b in [int(n_ab*1*v)]:\n",
    "            arr = []\n",
    "            for i in res.keys():\n",
    "                if i[2] == n_a and i[0] == eps and i[1] == b:\n",
    "                    arr.append( res[i][0] )\n",
    "            if len(arr) > 0:\n",
    "                res_avg[eps, b, n_a] = [max(arr), min(arr)] \n",
    "                res_err[eps,b,n_a] = [np.percentile(arr, 25), np.percentile(arr, 50) , np.percentile(arr, 75)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bfbd923",
   "metadata": {},
   "outputs": [],
   "source": [
    "q100 = np.array([])\n",
    "q0 = np.array([])\n",
    "q50 = np.array([])\n",
    "q75 = np.array([])\n",
    "q25 = np.array([])\n",
    "for i in res_avg.keys():\n",
    "    q100 = np.append(q100, res_avg[i][0])\n",
    "    q0 = np.append(q0, res_avg[i][1])\n",
    "    q50 = np.append(q50, res_err[i][1])\n",
    "    q25 = np.append(q25, res_err[i][0])\n",
    "    q75 = np.append(q75, res_err[i][2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56619498",
   "metadata": {},
   "outputs": [],
   "source": [
    "join = n_ab \n",
    "total  = 20_000 + np.arange(0, 200_001, 20_000)\n",
    "ratio =  total  / join"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83da07cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "figure(figsize=(8*1.2, 4*1.2), dpi = 200)\n",
    "plt.plot(ratio, q100, \":k\", label = \"max\")\n",
    "plt.errorbar(\n",
    "    ratio, q50,\n",
    "    yerr = (\n",
    "        q50 - q25,\n",
    "        q75 - q50,\n",
    "    ),\n",
    "    fmt = \"k\",\n",
    "    label = \"0.25/0.5/0.75 quantile\",\n",
    ")\n",
    "plt.plot(ratio, q0, \":k\", label = \"min\")\n",
    "plt.xscale(\"log\")\n",
    "plt.legend()\n",
    "plt.figtext(0.5, 0.9, r\"EMNIST test accuracy vs. Join size, d = $k \\cdot|D_R ⋈ D_S|$\", wrap=True, horizontalalignment='center', fontsize=14)\n",
    "# plt.tight_layout()\n",
    "plt.tick_params(axis='both', which='major', labelsize=14)\n",
    "plt.ylabel(\"Test accuracy\", size = 12)\n",
    "plt.xlabel(\"$|D_R|/ |D_R ⋈ D_S| $\", size = 12)\n",
    "plt.rc('xtick', labelsize=12) \n",
    "plt.rc('ytick', labelsize=12) \n",
    "plt.savefig('change_with_na.png',bbox_inches = \"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d07a601",
   "metadata": {},
   "source": [
    "# Multi-d join sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3ec6106",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_avg = {}\n",
    "res_err = {}\n",
    "# Average the results for (epsilon, hash_dimension) across all runs.\n",
    "# Now we plot for all d. And show that using different colored lines.\n",
    "for eps in [1]:\n",
    "    for n_a in range(0, 220_000, 20_000):\n",
    "        for b in [int(n_ab*100*v), int(n_ab*10*v),int(n_ab*1*v),int(n_ab*0.1*v)]:\n",
    "            arr = []\n",
    "            for i in res.keys():\n",
    "                if i[2] == n_a and i[0] == eps and i[1] == b:\n",
    "                    arr.append( res[i][0] )\n",
    "            if len(arr) > 0:\n",
    "#                 res_avg[eps, b, n_a] = [max(arr), min(arr)] \n",
    "                res_avg[eps, b, n_a] = np.mean(arr)\n",
    "                res_err[eps,b,n_a] = [np.percentile(arr, 25), np.percentile(arr, 50) , np.percentile(arr, 75)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac0edfad",
   "metadata": {},
   "outputs": [],
   "source": [
    "figure(figsize=(8*1.2, 4*1.2), dpi = 200)\n",
    "for b in [int(n_ab*100*v), int(n_ab*10*v),int(n_ab*1*v),int(n_ab*0.1*v)]:\n",
    "    y = []\n",
    "    x = []\n",
    "    for n_a in range(0, 200_001 , 20_000):\n",
    "        y.append(res_avg[eps,b,n_a])\n",
    "        x.append((n_a+n_ab)/n_ab)\n",
    "    plt.xscale('log',base=10) \n",
    "    dimer = b // n_ab // 10\n",
    "    if dimer == 0:\n",
    "        dimer = 0.1\n",
    "    plt.plot(x,y,label =  str(dimer)  + \" $ \\cdot k |D_R$ ⋈ $D_S| $\")\n",
    "    plt.legend(loc=\"upper right\")\n",
    "\n",
    "plt.gca().xaxis.set_major_locator(ticker.LogLocator(base=2.0))\n",
    "custom_formatter = ticker.FuncFormatter(lambda x, pos: f'{x:.2g}')\n",
    "plt.gca().xaxis.set_major_formatter(custom_formatter)    \n",
    "plt.legend(fontsize = 12)    \n",
    "plt.ylabel(\"Test accuracy\", size = 14)\n",
    "plt.figtext(0.5, 0.9, r\"EMNIST test accuracy vs. Join size. $\\varepsilon$ = 1. $k=10$, |$D_R ⋈ D_S$| = 20K\", wrap=True, horizontalalignment='center', fontsize=13)\n",
    "plt.axhline(y=0.9804, color='k', linestyle='--', label = \"w/o privacy\")\n",
    "plt.xlabel(\"$|D_R|$ / $|D_R$ ⋈ $D_S|$ \", size = 20)\n",
    "plt.rc('xtick', labelsize=14) \n",
    "plt.rc('ytick', labelsize=14) \n",
    "\n",
    "plt.savefig('emnist_join.png',bbox_inches = \"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ccf806c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23e376ba",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91170a7b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
