{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ff2b0e1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_parties = 3 # Set the number of answering parties.\n",
    "index = 6 # Set the index of the mnist test set to use as the query (index of a sample)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "685afe1a",
   "metadata": {},
   "source": [
    "#### Imports and helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "87590b3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/dockuser/code/he-transformer/build/ext_ngraph_tf/src/ext_ngraph_tf/build_cmake/venv-tf-py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "This script assumes that a subdir with name {n_parties} exists in /models with the model \n",
    "files stored here.\n",
    "The number of model files should equal the value of {n_parties} + 1.\n",
    "It kicks off a server for each answering party and a single client who will be \n",
    "requesting queries.\n",
    "client.py holds the clients training protocol, and server.py the response algorithms.\n",
    "train_inits.py should be run first to train each model on a separate partition and save \n",
    "them as per the required scheme.\n",
    "USAGE: call this file with: \n",
    "OMP_NUM_THREADS=24 NGRAPH_HE_VERBOSE_OPS=all NGRAPH_HE_LOG_LEVEL=3 python run_experiment.py\n",
    "SETUP: create a tmux session with 3 panes, each in /home/dockuser/code/capc\n",
    "\"\"\"\n",
    "\n",
    "import warnings\n",
    "from utils import client_data\n",
    "from utils.client_data import get_data\n",
    "from utils.time_utils import get_timestamp, log_timing\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "import tensorflow as tf\n",
    "\n",
    "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n",
    "import argparse\n",
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import atexit\n",
    "import libtmux\n",
    "from utils.remove_files import remove_files_by_name\n",
    "import consts\n",
    "from consts import out_client_name, out_server_name, out_final_name\n",
    "import getpass\n",
    "import get_r_star\n",
    "import subprocess\n",
    "import os\n",
    "import client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "33759015",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_FLAGS():\n",
    "    \"\"\"Initial setup of parameters to be used.\"\"\"\n",
    "    parser = argparse.ArgumentParser('')\n",
    "    parser.add_argument('--session', type=str, help='session name',\n",
    "                        default='capc')\n",
    "    parser.add_argument('--log_timing_file', type=str,\n",
    "                        help='name of the global log timing file',\n",
    "                        default=f'logs/log-timing-{get_timestamp()}.log')\n",
    "    parser.add_argument('--n_parties', type=int, default=n_parties,\n",
    "                        help='number of servers')\n",
    "    parser.add_argument('--seed', type=int, default=2,\n",
    "                        help='seed for top level script')\n",
    "    parser.add_argument('--batch_size', type=int, default=1,\n",
    "                        help='batch size')\n",
    "    parser.add_argument('--num_classes', type=int, default=10,\n",
    "                        help='Number of classes in the dataset.')\n",
    "    parser.add_argument(\n",
    "        \"--rstar_exp\",\n",
    "        type=int,\n",
    "        default=10,\n",
    "        help='The exponent for 2 to generate the random r* from.',\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--max_logit\",\n",
    "        type=float,\n",
    "        default=36.0,\n",
    "        help='The maximum value of a logit.',\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--user\",\n",
    "        type=str,\n",
    "        default=getpass.getuser(),\n",
    "        help=\"The name of the OS USER.\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--log_level\",\n",
    "        type=int,\n",
    "        default=0,\n",
    "        help='log level for he-transformer',\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        '--round_exp',\n",
    "        type=int,\n",
    "        default=3,\n",
    "        help='Multiply r* and logits by 2^round_exp.'\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        '--num_threads',\n",
    "        type=int,\n",
    "        default=20,\n",
    "        help='Number of threads.',\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        '--qp_id', type=int, default=0, help='which model is the QP?')\n",
    "    parser.add_argument(\n",
    "        \"--start_batch\",\n",
    "        type=int,\n",
    "        default=0,\n",
    "        help=\"Test data start index\")\n",
    "    parser.add_argument(\n",
    "        \"--model_type\",\n",
    "        type=str,\n",
    "        default='cryptonets-relu',\n",
    "        help=\"The type of models used.\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--input_node\",\n",
    "        type=str,\n",
    "        default=\"import/input:0\",\n",
    "        help=\"Tensor name of data input\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--output_node\",\n",
    "        type=str,\n",
    "        default=\"import/output/BiasAdd:0\",\n",
    "        help=\"Tensor name of model output\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        '--dataset_path', type=str,\n",
    "        default='/home/dockuser/queries',\n",
    "        help='where the queries are.')\n",
    "    parser.add_argument(\n",
    "        '--dataset_name', type=str,\n",
    "        default='mnist',\n",
    "        help='name of dataset where queries came from')\n",
    "    parser.add_argument('--debug', default=False, action='store_true')\n",
    "    parser.add_argument('--start_port', type=int, default=37000,\n",
    "                    help='the number of the starting port')\n",
    "    parser.add_argument('--n_queries',\n",
    "                        type=int,\n",
    "                        default=1,\n",
    "                        help='total len(queries)')\n",
    "    parser.add_argument('--checkpoint_dir', type=str,\n",
    "                        default='/home/dockuser/checkpoints',\n",
    "                        help='dir with all checkpoints')\n",
    "    parser.add_argument('--cpu', default=False, action='store_true',\n",
    "                        help='set to use cpu and no encryption.')\n",
    "    parser.add_argument('--ignore_parties', default=True, action='store_true', #False\n",
    "                        help='set when using crypto models.')\n",
    "    # parser.add_argument('--',\n",
    "    #                     default='$HE_TRANSFORMER/configs/he_seal_ckks_config_N13_L5_gc.json')\n",
    "    parser.add_argument('--encryption_params',\n",
    "                        default='config/10.json')\n",
    "    FLAGS, unparsed = parser.parse_known_args()\n",
    "    if unparsed:\n",
    "        print(\"Unparsed flags:\", unparsed)\n",
    "        exit(1)\n",
    "    return FLAGS\n",
    "\n",
    "def clean_old_files():\n",
    "    \"\"\"Delete old data files. This function is called before running the protocol.\"\"\"\n",
    "    for name in [out_client_name,\n",
    "                 out_server_name,\n",
    "                 out_final_name,\n",
    "                 consts.input_data,\n",
    "                 consts.input_labels,\n",
    "                 consts.predict_labels, consts.label_final_name]:\n",
    "        remove_files_by_name(starts_with=name)\n",
    "\n",
    "\n",
    "# Provide data.\n",
    "def set_data_labels(FLAGS):\n",
    "    \"\"\"Gets MNIST data and labels, saving it in the local folder\"\"\"\n",
    "    data, labels = get_data(start_batch=FLAGS.start_batch,\n",
    "                            batch_size=FLAGS.batch_size)\n",
    "    np.save(consts.input_data, data)\n",
    "    np.save(consts.input_labels, labels)\n",
    "\n",
    "\n",
    "def get_models(model_dir, n_parties, ignore_parties):\n",
    "    \"\"\"Gets model files from model_dir.\"\"\"\n",
    "    model_files = [f for f in os.listdir(model_dir) if\n",
    "                   os.path.isfile(os.path.join(model_dir, f))]\n",
    "    if len(model_files) != n_parties and not ignore_parties:\n",
    "        raise ValueError(\n",
    "            f'{len(model_files)} models found when {n_parties + 1} parties requested. Not equal.')\n",
    "    return model_dir, model_files\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69b9e575",
   "metadata": {},
   "source": [
    "#### Initial setup for CaPC protocol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b9552a76",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unparsed flags: ['-f', '/home/dockuser/.local/share/jupyter/runtime/kernel-33b6b4db-5ebc-41ed-895c-fa235191cc1b.json']\n",
      "delete file: files/logits37000privacy.txt\n",
      "delete file: files/output37000.txt\n",
      "delete file: files/noise37000.txt\n",
      "delete file: files/output.txt\n",
      "delete file: files/noise.txt\n",
      "delete file: files/inference_times\n",
      "delete file: files/argmax_times\n",
      "delete file: files/client_csp_times\n",
      "delete file: files/inference_no_network_times\n"
     ]
    }
   ],
   "source": [
    "FLAGS = get_FLAGS()\n",
    "np.random.seed(FLAGS.seed)\n",
    "clean_old_files()\n",
    "set_data_labels(FLAGS=FLAGS)\n",
    "\n",
    "\n",
    "log_timing_file = FLAGS.log_timing_file\n",
    "log_timing('main: start capc', log_file=log_timing_file)\n",
    "n_parties = FLAGS.n_parties\n",
    "batch_size = FLAGS.batch_size\n",
    "num_classes = FLAGS.num_classes\n",
    "rstar_exp = FLAGS.rstar_exp\n",
    "log_level = FLAGS.log_level\n",
    "round_exp = FLAGS.round_exp\n",
    "num_threads = FLAGS.num_threads\n",
    "input_node = FLAGS.input_node\n",
    "output_node = FLAGS.output_node\n",
    "start_port = FLAGS.start_port\n",
    "index = FLAGS.start_batch\n",
    "backend = 'HE_SEAL' if not FLAGS.cpu else 'CPU'\n",
    "models_loc, model_files = get_models(\n",
    "    FLAGS.checkpoint_dir, n_parties=n_parties,\n",
    "    ignore_parties=FLAGS.ignore_parties)\n",
    "for port in range(37000, 37000 + n_parties): \n",
    "    files_to_delete = [consts.out_client_name + str(port) + 'privacy.txt']\n",
    "    files_to_delete += [consts.out_final_name + str(port) + '.txt' ]#+ 'privacy.txt']\n",
    "    files_to_delete += [consts.out_server_name + str(port) + '.txt']#+ 'privacy.txt']\n",
    "    files_to_delete += [f\"{out_final_name}.txt\",\n",
    "                        f\"{out_server_name}.txt\"]  # aggregates across all parties\n",
    "    files_to_delete += [consts.inference_times_name,\n",
    "                        consts.argmax_times_name,\n",
    "                        consts.client_csp_times_name,\n",
    "                        consts.inference_no_network_times_name]\n",
    "    for f in files_to_delete:\n",
    "        if os.path.exists(f):\n",
    "            print(f'delete file: {f}')\n",
    "            os.remove(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f261f7b",
   "metadata": {},
   "source": [
    "#### Step 1 of Protocol. The files server.py and client.py together complete Step 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ff81bd25",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "port: 37000\n",
      "run_exp rstar: [ -95.08244041 -934.90307722  137.7087547   -96.45973992 -127.08674132\n",
      " -311.47428658 -568.87959749  280.26693909 -374.30722831 -441.53774059]\n",
      "port: 37000\n",
      "Start the servers (answering parties: APs)\n",
      "port: 37000\n",
      "port: 37001\n",
      "run_exp rstar: [ 284.08208951   95.68300908 -712.38027193   63.80799235 -610.26715516\n",
      "  620.36638273  760.94139933   24.19704296  745.75792201 -824.88606309]\n",
      "port: 37001\n",
      "Start the servers (answering parties: APs)\n",
      "port: 37001\n",
      "port: 37002\n",
      "run_exp rstar: [  46.74399257 -854.29323902 -111.20547308 -790.30468473 -727.57637796\n",
      "  234.13439279 -525.12742276 -768.97523853 -536.81288792 -271.55576831]\n",
      "port: 37002\n",
      "Start the servers (answering parties: APs)\n",
      "port: 37002\n"
     ]
    }
   ],
   "source": [
    "for query_num in range(FLAGS.n_queries):  #Querying process\n",
    "    for port, model_file in zip(\n",
    "            [37000 + int(i + query_num * n_parties) for i in\n",
    "             range(n_parties)],\n",
    "            model_files):\n",
    "        print(f\"port: {port}\")\n",
    "        full_model_file = fr'{models_loc}/{model_file}'\n",
    "        full_model_file_new = \"\"\n",
    "        for s in full_model_file:\n",
    "            if s == '(' or s == ')':\n",
    "                full_model_file_new += \"\\\\\"\n",
    "            full_model_file_new += s\n",
    "        full_model_file = full_model_file_new\n",
    "        new_model_file = os.path.join(\"/home/dockuser/models\",\n",
    "                                      str(port) + \".pb\")\n",
    "        r_star = get_r_star.get_rstar_server(  # Generate random vector needed in Step 1a\n",
    "            max_logit=FLAGS.max_logit,\n",
    "            batch_size=batch_size,\n",
    "            num_classes=num_classes,\n",
    "            exp=FLAGS.rstar_exp,\n",
    "        ).flatten()\n",
    "        print(f\"run_exp rstar: {r_star}\")\n",
    "        print(f\"port: {port}\")\n",
    "        print('Start the servers (answering parties: APs)')\n",
    "        log_timing('start server (AP)', log_file=log_timing_file)\n",
    "        cmd_string = \" \".join([  # Command to start server with the relevant parameters. \n",
    "                                  f'OMP_NUM_THREADS={num_threads}',\n",
    "                                  f'NGRAPH_HE_LOG_LEVEL={log_level}',\n",
    "                                  'python -W ignore', 'server.py',\n",
    "                                  '--backend', backend,\n",
    "                                  '--n_parties', f'{n_parties}',   \n",
    "                                  '--model_file', new_model_file,\n",
    "                                  '--dataset_name', FLAGS.dataset_name,\n",
    "                                  '--indext', str(index),\n",
    "                                  '--encryption_parameters',\n",
    "                                  FLAGS.encryption_params,\n",
    "                                  '--enable_client', 'true',\n",
    "                                  '--enable_gc', 'true',\n",
    "                                  '--mask_gc_inputs', 'true',\n",
    "                                  '--mask_gc_outputs', 'true',\n",
    "                                  '--from_pytorch', '1', '--dataset_name',\n",
    "                                  FLAGS.dataset_name,\n",
    "                                  '--dataset_path', FLAGS.dataset_path,\n",
    "                                  '--num_gc_threads', f'{num_threads}',\n",
    "                                  '--input_node', f'{input_node}',\n",
    "                                  '--output_node', f'{output_node}',\n",
    "                                  '--minibatch_id', f'{query_num}',\n",
    "                                  '--rstar_exp', f'{rstar_exp}',\n",
    "                                  '--num_classes', f'{num_classes}',\n",
    "                                  '--round_exp', f'{round_exp}',\n",
    "                                  '--log_timing_file', log_timing_file,\n",
    "                                  \"--r_star\"] + [str(x) for x in r_star] + [\n",
    "                                  '--port', f'{port}',\n",
    "                              ])\n",
    "        subprocess.Popen(cmd_string, shell = True)  # Run server.py with the given parameters. \n",
    "        if not FLAGS.cpu:\n",
    "            time.sleep(1)\n",
    "            print(f\"port: {port}\")\n",
    "            log_timing('start cleint (the querying party: QP)', log_file=log_timing_file)\n",
    "            cmd_string = \" \".join([\n",
    "                        # Command to start client server with the relevant parameters.\n",
    "                        f'OMP_NUM_THREADS={num_threads}',\n",
    "                        f'NGRAPH_HE_LOG_LEVEL={log_level}',\n",
    "                        'python -W ignore client.py',\n",
    "                        '--batch_size', f'{batch_size}',\n",
    "                        '--encrypt_data_str', 'encrypt',\n",
    "                        '--indext', str(index),\n",
    "                        '--n_parties', f'{n_parties}',\n",
    "                        '--round_exp', f'{round_exp}',\n",
    "                        '--from_pytorch', '1',\n",
    "                        '--minibatch_id', f'{query_num}',\n",
    "                        '--dataset_path', f'{FLAGS.dataset_path}',\n",
    "                        '--port', f'{port}',\n",
    "                        '--dataset_name', FLAGS.dataset_name,\n",
    "                        '--data_partition', 'test',\n",
    "                        '--log_timing_file', log_timing_file,\n",
    "                    ])\n",
    "            subprocess.Popen(cmd_string, shell=True)  # Run client_server.py with the given parameters. \n",
    "            time.sleep(16) \n",
    "        else:\n",
    "            time.sleep(1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ca0d3e0",
   "metadata": {},
   "source": [
    "#### Steps 2 and 3 of the Protocol. The file pg.py runs the privacy guardian. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7b71146a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start privacy guardian: python -W ignore pg.py 37000 37003\n"
     ]
    }
   ],
   "source": [
    "log_timing('start privacy guardian', log_file=log_timing_file)\n",
    "# Command to run Privacy Guardian (Steps 2 & 3).\n",
    "cmd_string = \" \".join(\n",
    "    ['python -W ignore', 'pg.py',\n",
    "     f'{start_port + int(query_num * n_parties)}',\n",
    "     f'{start_port + int(query_num * n_parties) + n_parties}'\n",
    "     ])\n",
    "print(f\"start privacy guardian: {cmd_string}\")\n",
    "subprocess.Popen(cmd_string, shell=True)  # Run pg.py with the given parameters.\n",
    "time.sleep(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed27f686",
   "metadata": {},
   "source": [
    "#### Compare predicted label with actual label. The client (querying party) print the outputted label. The actual label is manually found using the index of the query used"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ff0bc27f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted label:  7\n",
      "Actual label 7\n"
     ]
    }
   ],
   "source": [
    "client.print_label()\n",
    "(x_train, y_train, x_test, y_test) = client_data.load_mnist_data(index, 1)\n",
    "print(\"Actual label: \", np.argmax(y_test)) \n",
    "log_timing('finish capc', log_file=log_timing_file)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
