{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5ac754fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run NNs.ipynb\n",
    "%run Helpers.ipynb\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ca124b90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "38b59e27",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 2 #number of bidders\n",
    "m = 2 #number of items\n",
    "\n",
    "nLayersAllocation   = 7 #minimum 2\n",
    "nLayersPayment      = 7 #minimum 2\n",
    "nLayersMisreport    = 7 #minimum 2\n",
    "widthAllocation     = 100\n",
    "widthPayment        = 100\n",
    "widthMisreport      = 100\n",
    "\n",
    "batch_size = 20000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "021a570a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AllocationNet(\n",
      "  (mlp1): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=2, bias=True)\n",
      "    )\n",
      "  )\n",
      "  (mlp2): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "PaymentNet(\n",
      "  (mlp): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=2, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "MisreportNetBNIC(\n",
      "  (mlp): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "ANet = AllocationNet(n,m,nLayersAllocation,widthAllocation)\n",
    "ANet.to(device)\n",
    "ANet.train()\n",
    "\n",
    "\n",
    "PNet = PaymentNet(n,m,nLayersPayment,widthPayment)\n",
    "PNet.to(device)\n",
    "PNet.train()\n",
    "\n",
    "MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "MNet.to(device)\n",
    "MNet.train()\n",
    "\n",
    "print(ANet)\n",
    "print(PNet)\n",
    "print(MNet)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e10ea38c",
   "metadata": {},
   "source": [
    "# set up dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "da032379",
   "metadata": {},
   "outputs": [],
   "source": [
    "prob_single_1 = torch.tensor([0.2,0.8]).to(device)\n",
    "prob_single_2 = torch.tensor([0.5,0.7]).to(device)\n",
    "prob_single_3 = torch.tensor([0.7,0.9]).to(device)\n",
    "no_prob_single = torch.tensor([1,1]).to(device)\n",
    "single_bidder_1 = torch.tensor([1,0]).to(device)\n",
    "single_bidder_2 = torch.tensor([0,1]).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4cb66a55",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set_prob_1 = get_train_or_test_set(batch_size,n,m,prob_single_1)\n",
    "test_set_prob_1 = get_train_or_test_set(batch_size,n,m,prob_single_1)\n",
    "\n",
    "train_set_prob_2 = get_train_or_test_set(batch_size,n,m,prob_single_2)\n",
    "test_set_prob_2 = get_train_or_test_set(batch_size,n,m,prob_single_2)\n",
    "\n",
    "train_set_prob_3 = get_train_or_test_set(batch_size,n,m,prob_single_3)\n",
    "test_set_prob_3 = get_train_or_test_set(batch_size,n,m,prob_single_3)\n",
    "\n",
    "train_set_no_prob = get_train_or_test_set(batch_size,n,m,no_prob_single)\n",
    "test_set_no_prob = get_train_or_test_set(batch_size,n,m,no_prob_single)\n",
    "\n",
    "train_set_single_1 = get_train_or_test_set(batch_size,n,m,single_bidder_1)\n",
    "test_set_single_1 = get_train_or_test_set(batch_size,n,m,single_bidder_1)\n",
    "\n",
    "train_set_single_2 = get_train_or_test_set(batch_size,n,m,single_bidder_2)\n",
    "test_set_single_2 = get_train_or_test_set(batch_size,n,m,single_bidder_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "531955b6",
   "metadata": {},
   "source": [
    "# train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6435cdcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# learning_rates = [0.001] * 100 + [0.0001] * 100 + [0.00001] * 100\n",
    "# learning_rates = [0.0001] * 200\n",
    "learning_rates = [0.0005] * 100 + [0.00005] * 100 + [0.000005] * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1a70c85f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 revenue:  0.1250172257423401 rgt:  0.036815743893384933 loss:  -0.124887615442276\n",
      "1 revenue:  0.13081888854503632 rgt:  0.03857725113630295 loss:  -0.12670066952705383\n",
      "2 revenue:  0.1368260681629181 rgt:  0.03887751325964928 loss:  -0.133848637342453\n",
      "3 revenue:  0.14395621418952942 rgt:  0.04069919511675835 loss:  -0.13697588443756104\n",
      "4 revenue:  0.15188676118850708 rgt:  0.04172016680240631 loss:  -0.1437510997056961\n",
      "5 revenue:  0.16129155457019806 rgt:  0.042852964252233505 loss:  -0.15174853801727295\n",
      "6 revenue:  0.17195504903793335 rgt:  0.04269567131996155 loss:  -0.16534952819347382\n",
      "7 revenue:  0.18498802185058594 rgt:  0.04479056969285011 loss:  -0.1736738383769989\n",
      "8 revenue:  0.19951313734054565 rgt:  0.047063183039426804 loss:  -0.1826651245355606\n",
      "9 revenue:  0.21312855184078217 rgt:  0.04819411411881447 loss:  -0.1939326822757721\n",
      "10 revenue:  0.22278998792171478 rgt:  0.04560153931379318 loss:  -0.21285948157310486\n",
      "11 revenue:  0.22934392094612122 rgt:  0.042320914566516876 loss:  -0.23085713386535645\n",
      "12 revenue:  0.23375731706619263 rgt:  0.036292996257543564 loss:  -0.25668418407440186\n",
      "13 revenue:  0.23553253710269928 rgt:  0.02903653122484684 loss:  -0.28587913513183594\n",
      "14 revenue:  0.23417504131793976 rgt:  0.020845595747232437 loss:  -0.3186904788017273\n",
      "15 revenue:  0.2307359278202057 rgt:  0.012152178213000298 loss:  -0.35796037316322327\n",
      "16 revenue:  0.22294728457927704 rgt:  0.006328152492642403 loss:  -0.38629454374313354\n",
      "17 revenue:  0.26184335350990295 rgt:  0.009346376173198223 loss:  -0.40568286180496216\n",
      "18 revenue:  0.2454826533794403 rgt:  0.004348003305494785 loss:  -0.4251740276813507\n",
      "19 revenue:  0.2889178991317749 rgt:  0.007721542380750179 loss:  -0.4419165253639221\n",
      "20 revenue:  0.26573044061660767 rgt:  0.003232955001294613 loss:  -0.45539766550064087\n",
      "21 revenue:  0.31094807386398315 rgt:  0.006755621172487736 loss:  -0.46867844462394714\n",
      "22 revenue:  0.28226426243782043 rgt:  0.002201416762545705 loss:  -0.48216384649276733\n",
      "23 revenue:  0.32989248633384705 rgt:  0.006265359465032816 loss:  -0.4889427125453949\n",
      "24 revenue:  0.29629623889923096 rgt:  0.001748394686728716 loss:  -0.500767707824707\n",
      "25 revenue:  0.3460848927497864 rgt:  0.004449670668691397 loss:  -0.5171335935592651\n",
      "26 revenue:  0.3081176280975342 rgt:  0.0010190417524427176 loss:  -0.5221404433250427\n",
      "27 revenue:  0.36106646060943604 rgt:  0.006297518964856863 loss:  -0.515233039855957\n",
      "28 revenue:  0.31831374764442444 rgt:  0.0005908363964408636 loss:  -0.5392930507659912\n",
      "29 revenue:  0.3740042746067047 rgt:  0.004458615556359291 loss:  -0.5403267741203308\n",
      "30 revenue:  0.32653510570526123 rgt:  0.00045796341146342456 loss:  -0.549572229385376\n",
      "31 revenue:  0.38630595803260803 rgt:  0.004747817292809486 loss:  -0.5478823184967041\n",
      "32 revenue:  0.3342645466327667 rgt:  0.00032726270728744566 loss:  -0.5597357749938965\n",
      "33 revenue:  0.39781299233436584 rgt:  0.005014326423406601 loss:  -0.5548973083496094\n",
      "34 revenue:  0.34195899963378906 rgt:  0.0003007463237736374 loss:  -0.567126989364624\n",
      "35 revenue:  0.4090770483016968 rgt:  0.0054968311451375484 loss:  -0.5599532127380371\n",
      "36 revenue:  0.349066823720932 rgt:  0.00023262960894498974 loss:  -0.5753307342529297\n",
      "37 revenue:  0.4197690188884735 rgt:  0.005386941600590944 loss:  -0.5691125392913818\n",
      "38 revenue:  0.355407178401947 rgt:  0.00020804144151043147 loss:  -0.5815252661705017\n",
      "39 revenue:  0.42913535237312317 rgt:  0.005234687123447657 loss:  -0.5774978399276733\n",
      "40 revenue:  0.36147600412368774 rgt:  0.00023834125022403896 loss:  -0.5855489373207092\n",
      "41 revenue:  0.4374412000179291 rgt:  0.003827754408121109 loss:  -0.5956960320472717\n",
      "42 revenue:  0.3660660982131958 rgt:  0.00024194683646783233 loss:  -0.5892342925071716\n",
      "43 revenue:  0.444490909576416 rgt:  0.006943313404917717 loss:  -0.5764310956001282\n",
      "44 revenue:  0.37095946073532104 rgt:  0.0003339449758641422 loss:  -0.5904536247253418\n",
      "45 revenue:  0.4506725072860718 rgt:  0.0053391046822071075 loss:  -0.5929125547409058\n",
      "46 revenue:  0.3740347623825073 rgt:  0.0003178406914230436 loss:  -0.5934351682662964\n",
      "47 revenue:  0.4556085765361786 rgt:  0.006652372423559427 loss:  -0.5867728590965271\n",
      "48 revenue:  0.37724021077156067 rgt:  0.0002387662825640291 loss:  -0.5985048413276672\n",
      "49 revenue:  0.46077483892440796 rgt:  0.005727185867726803 loss:  -0.597398042678833\n",
      "50 revenue:  0.3801646828651428 rgt:  0.0002820069785229862 loss:  -0.5994970202445984\n",
      "51 revenue:  0.4647887647151947 rgt:  0.005929008591920137 loss:  -0.5988245010375977\n",
      "52 revenue:  0.38209211826324463 rgt:  0.0003071716637350619 loss:  -0.6002997159957886\n",
      "53 revenue:  0.4678000807762146 rgt:  0.006623199209570885 loss:  -0.5959523320198059\n",
      "54 revenue:  0.3837945759296417 rgt:  0.0002861717657651752 loss:  -0.6023059487342834\n",
      "55 revenue:  0.47061917185783386 rgt:  0.006124116946011782 loss:  -0.6016354560852051\n",
      "56 revenue:  0.3854387104511261 rgt:  0.0003138216561637819 loss:  -0.602805495262146\n",
      "57 revenue:  0.4731296896934509 rgt:  0.006263724062591791 loss:  -0.6024362444877625\n",
      "58 revenue:  0.3869016766548157 rgt:  0.0002995991089846939 loss:  -0.6044028997421265\n",
      "59 revenue:  0.4754820168018341 rgt:  0.005789516028016806 loss:  -0.6076731085777283\n",
      "60 revenue:  0.3877546489238739 rgt:  0.00027341852546669543 loss:  -0.605887770652771\n",
      "61 revenue:  0.4774039089679718 rgt:  0.006841741036623716 loss:  -0.601387083530426\n",
      "62 revenue:  0.3887304365634918 rgt:  0.0002168353967135772 loss:  -0.6085370182991028\n",
      "63 revenue:  0.4791911840438843 rgt:  0.005557004362344742 loss:  -0.6121334433555603\n",
      "64 revenue:  0.3896526098251343 rgt:  0.0003111485275439918 loss:  -0.6062682867050171\n",
      "65 revenue:  0.4805257320404053 rgt:  0.0029742689803242683 loss:  -0.6356876492500305\n",
      "66 revenue:  0.3905518054962158 rgt:  0.00023043160035740584 loss:  -0.6095278263092041\n",
      "67 revenue:  0.4824735224246979 rgt:  0.004431820474565029 loss:  -0.6235986948013306\n",
      "68 revenue:  0.3912689983844757 rgt:  0.0002398995857220143 loss:  -0.6097832322120667\n",
      "69 revenue:  0.4836116433143616 rgt:  0.0057288953103125095 loss:  -0.6140029430389404\n",
      "70 revenue:  0.39170604944229126 rgt:  0.00021824891155119985 loss:  -0.6108694076538086\n",
      "71 revenue:  0.48432058095932007 rgt:  0.005444567184895277 loss:  -0.6166989207267761\n",
      "72 revenue:  0.3920624554157257 rgt:  0.000153097091242671 loss:  -0.6136186122894287\n",
      "73 revenue:  0.48551103472709656 rgt:  0.006104079075157642 loss:  -0.6125530004501343\n",
      "74 revenue:  0.3930724859237671 rgt:  0.00015494064427912235 loss:  -0.6143485903739929\n",
      "75 revenue:  0.48628219962120056 rgt:  0.00802227109670639 loss:  -0.5997495055198669\n",
      "76 revenue:  0.39368218183517456 rgt:  0.00014491478214040399 loss:  -0.6152539253234863\n",
      "77 revenue:  0.48685935139656067 rgt:  0.006136622279882431 loss:  -0.6132792830467224\n",
      "78 revenue:  0.3940216600894928 rgt:  0.0001970447920029983 loss:  -0.613473653793335\n",
      "79 revenue:  0.48734667897224426 rgt:  0.004422146361321211 loss:  -0.6271801590919495\n",
      "80 revenue:  0.3941650092601776 rgt:  0.00016258281539194286 loss:  -0.6149083375930786\n",
      "81 revenue:  0.4879821240901947 rgt:  0.008288156241178513 loss:  -0.5992292761802673\n",
      "82 revenue:  0.3945855498313904 rgt:  0.0002028317830991 loss:  -0.6137123107910156\n",
      "83 revenue:  0.48792392015457153 rgt:  0.007665663026273251 loss:  -0.603295624256134\n",
      "84 revenue:  0.39455750584602356 rgt:  0.0001596211368450895 loss:  -0.6153404712677002\n",
      "85 revenue:  0.48799949884414673 rgt:  0.004690160974860191 loss:  -0.6253939867019653\n",
      "86 revenue:  0.3944743275642395 rgt:  0.00013299632701091468 loss:  -0.616402268409729\n",
      "87 revenue:  0.48838597536087036 rgt:  0.004276912193745375 loss:  -0.6291704177856445\n",
      "88 revenue:  0.3944532573223114 rgt:  0.0001422926870873198 loss:  -0.6159800887107849\n",
      "89 revenue:  0.48863866925239563 rgt:  0.004684360232204199 loss:  -0.6258994936943054\n",
      "90 revenue:  0.3943139314651489 rgt:  0.00013198908709455281 loss:  -0.6163192987442017\n",
      "91 revenue:  0.48883816599845886 rgt:  0.007784753106534481 loss:  -0.6031531691551208\n",
      "92 revenue:  0.39474719762802124 rgt:  9.080329618882388e-05 loss:  -0.6186640858650208\n",
      "93 revenue:  0.4891149401664734 rgt:  0.005709443707019091 loss:  -0.6180965900421143\n",
      "94 revenue:  0.3951071798801422 rgt:  0.000128626634250395 loss:  -0.6171011328697205\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "95 revenue:  0.48916223645210266 rgt:  0.007072612643241882 loss:  -0.6082293391227722\n",
      "96 revenue:  0.3951667547225952 rgt:  0.000108072352304589 loss:  -0.6181142926216125\n",
      "97 revenue:  0.48918840289115906 rgt:  0.006137530785053968 loss:  -0.6149395108222961\n",
      "98 revenue:  0.39534348249435425 rgt:  0.00012708203576039523 loss:  -0.617358922958374\n",
      "99 revenue:  0.48916396498680115 rgt:  0.007020598277449608 loss:  -0.6085924506187439\n",
      "100 revenue:  0.3952955901622772 rgt:  0.00013502390356734395 loss:  -0.616966187953949\n",
      "101 revenue:  0.4050484001636505 rgt:  0.00011105569137725979 loss:  -0.6257801055908203\n",
      "102 revenue:  0.4147597551345825 rgt:  6.650287832599133e-05 loss:  -0.6357909440994263\n",
      "103 revenue:  0.4244755208492279 rgt:  0.00010293532250216231 loss:  -0.6412643790245056\n",
      "104 revenue:  0.4340401887893677 rgt:  0.00012732289906125516 loss:  -0.6474018692970276\n",
      "105 revenue:  0.44353923201560974 rgt:  0.00015731548774056137 loss:  -0.6532836556434631\n",
      "106 revenue:  0.45244529843330383 rgt:  0.00026345005608163774 loss:  -0.6561429500579834\n",
      "107 revenue:  0.4506569802761078 rgt:  0.000592848751693964 loss:  -0.6463665962219238\n",
      "108 revenue:  0.4416293799877167 rgt:  8.753113797865808e-05 loss:  -0.6551033854484558\n",
      "109 revenue:  0.451218843460083 rgt:  0.00010022481728810817 loss:  -0.6616119146347046\n",
      "110 revenue:  0.46072685718536377 rgt:  0.0002530330093577504 loss:  -0.6626055240631104\n",
      "111 revenue:  0.45698925852775574 rgt:  0.0001635299704503268 loss:  -0.6630545258522034\n",
      "112 revenue:  0.4663199186325073 rgt:  0.0009722549002617598 loss:  -0.6507214307785034\n",
      "113 revenue:  0.45725974440574646 rgt:  0.0001475686440244317 loss:  -0.663910448551178\n",
      "114 revenue:  0.46482816338539124 rgt:  0.0003520966856740415 loss:  -0.6626641750335693\n",
      "115 revenue:  0.4560641348361969 rgt:  0.0001389262906741351 loss:  -0.6633954048156738\n",
      "116 revenue:  0.4655381143093109 rgt:  0.0006593744037672877 loss:  -0.6559640169143677\n",
      "117 revenue:  0.4562176764011383 rgt:  0.0002497368259355426 loss:  -0.6593829393386841\n",
      "118 revenue:  0.449344664812088 rgt:  8.857334614731371e-05 loss:  -0.6608266234397888\n",
      "119 revenue:  0.4590246081352234 rgt:  0.00011459090455900878 loss:  -0.6666895747184753\n",
      "120 revenue:  0.46866005659103394 rgt:  0.0008773656445555389 loss:  -0.6540881991386414\n",
      "121 revenue:  0.4590858221054077 rgt:  0.00020461497479118407 loss:  -0.6630463004112244\n",
      "122 revenue:  0.4652634561061859 rgt:  0.00015748364967294037 loss:  -0.6693916320800781\n",
      "123 revenue:  0.47477519512176514 rgt:  0.00038855988532304764 loss:  -0.6689363718032837\n",
      "124 revenue:  0.46558886766433716 rgt:  0.00015517629799433053 loss:  -0.6697246432304382\n",
      "125 revenue:  0.4750915765762329 rgt:  0.002279957989230752 loss:  -0.6392390727996826\n",
      "126 revenue:  0.4654915928840637 rgt:  0.00029525504214689136 loss:  -0.6647883653640747\n",
      "127 revenue:  0.4566360116004944 rgt:  0.00012921029701828957 loss:  -0.6642478108406067\n",
      "128 revenue:  0.46630144119262695 rgt:  0.0008422023383900523 loss:  -0.6529980897903442\n",
      "129 revenue:  0.4569934010505676 rgt:  0.00014298361202236265 loss:  -0.6639081835746765\n",
      "130 revenue:  0.46665117144584656 rgt:  0.0008696637814864516 loss:  -0.6527572870254517\n",
      "131 revenue:  0.4570541977882385 rgt:  0.00011830119910882786 loss:  -0.6650583148002625\n",
      "132 revenue:  0.4667662978172302 rgt:  0.0005211917450651526 loss:  -0.6598500609397888\n",
      "133 revenue:  0.4569289982318878 rgt:  0.0001122772300732322 loss:  -0.6652522087097168\n",
      "134 revenue:  0.4667285084724426 rgt:  0.00042828262667171657 loss:  -0.6620497107505798\n",
      "135 revenue:  0.4571874439716339 rgt:  0.00016910946578718722 loss:  -0.6629793047904968\n",
      "136 revenue:  0.46685731410980225 rgt:  0.000453560525784269 loss:  -0.6615167856216431\n",
      "137 revenue:  0.4569600522518158 rgt:  0.0001070692014764063 loss:  -0.6655289530754089\n",
      "138 revenue:  0.4667956829071045 rgt:  0.00044629545300267637 loss:  -0.6616501808166504\n",
      "139 revenue:  0.456881046295166 rgt:  8.257997978944331e-05 loss:  -0.666754424571991\n",
      "140 revenue:  0.466766893863678 rgt:  0.00021965026098769158 loss:  -0.6681598424911499\n",
      "141 revenue:  0.46222546696662903 rgt:  0.00023498953669331968 loss:  -0.6643041372299194\n",
      "142 revenue:  0.4547441899776459 rgt:  0.0001385003561154008 loss:  -0.6624358892440796\n",
      "143 revenue:  0.4646075665950775 rgt:  0.00013613438932225108 loss:  -0.6698132753372192\n",
      "144 revenue:  0.4743645489215851 rgt:  0.0011309855617582798 loss:  -0.653978705406189\n",
      "145 revenue:  0.4645411968231201 rgt:  0.00020344305085018277 loss:  -0.6671023964881897\n",
      "146 revenue:  0.47415849566459656 rgt:  0.0005408684955909848 loss:  -0.6647921204566956\n",
      "147 revenue:  0.46449169516563416 rgt:  0.0001681979192653671 loss:  -0.6683951616287231\n",
      "148 revenue:  0.46234583854675293 rgt:  0.00010308184573659673 loss:  -0.6696993708610535\n",
      "149 revenue:  0.4723080098628998 rgt:  0.002070185262709856 loss:  -0.6396761536598206\n",
      "150 revenue:  0.46237948536872864 rgt:  0.00013724644668400288 loss:  -0.668128252029419\n",
      "151 revenue:  0.4722125232219696 rgt:  0.00014378664491232485 loss:  -0.6750382781028748\n",
      "152 revenue:  0.48182156682014465 rgt:  0.001439804444089532 loss:  -0.6547479033470154\n",
      "153 revenue:  0.47202029824256897 rgt:  0.0008235357818193734 loss:  -0.6575148105621338\n",
      "154 revenue:  0.46217307448387146 rgt:  0.00026991491904482245 loss:  -0.6631311774253845\n",
      "155 revenue:  0.4533877372741699 rgt:  0.0001446368987672031 loss:  -0.6611654758453369\n",
      "156 revenue:  0.46336254477500916 rgt:  0.00014589533384423703 loss:  -0.6684786677360535\n",
      "157 revenue:  0.47306278347969055 rgt:  0.0002663778141140938 loss:  -0.6712051033973694\n",
      "158 revenue:  0.4633508622646332 rgt:  0.0002940967387985438 loss:  -0.663252592086792\n",
      "159 revenue:  0.4559248983860016 rgt:  0.0001008448816719465 loss:  -0.6650741696357727\n",
      "160 revenue:  0.4659266471862793 rgt:  0.0001878587354440242 loss:  -0.6686906218528748\n",
      "161 revenue:  0.4756785035133362 rgt:  0.000994684174656868 loss:  -0.657159686088562\n",
      "162 revenue:  0.4657132625579834 rgt:  0.00018706686387304217 loss:  -0.6685640215873718\n",
      "163 revenue:  0.4578092098236084 rgt:  0.00013357936404645443 loss:  -0.6649205088615417\n",
      "164 revenue:  0.4675996005535126 rgt:  0.0005544534651562572 loss:  -0.6597092151641846\n",
      "165 revenue:  0.4578976035118103 rgt:  7.453020953107625e-05 loss:  -0.6679680347442627\n",
      "166 revenue:  0.46816951036453247 rgt:  0.0008488023304380476 loss:  -0.6542444229125977\n",
      "167 revenue:  0.4580465257167816 rgt:  7.810048555256799e-05 loss:  -0.6678702235221863\n",
      "168 revenue:  0.46810194849967957 rgt:  0.0008050652104429901 loss:  -0.654999315738678\n",
      "169 revenue:  0.45797809958457947 rgt:  0.00010712563380366191 loss:  -0.6662787795066833\n",
      "170 revenue:  0.4680527150630951 rgt:  0.00012227025581523776 loss:  -0.6729595065116882\n",
      "171 revenue:  0.4779979884624481 rgt:  0.0014317685272544622 loss:  -0.6521022915840149\n",
      "172 revenue:  0.46800497174263 rgt:  0.00025848124641925097 loss:  -0.6677700281143188\n",
      "173 revenue:  0.45834729075431824 rgt:  7.076348992995918e-05 loss:  -0.6685248017311096\n",
      "174 revenue:  0.46854275465011597 rgt:  0.0003912305983249098 loss:  -0.6643285751342773\n",
      "175 revenue:  0.45874080061912537 rgt:  6.746121653122827e-05 loss:  -0.6690171360969543\n",
      "176 revenue:  0.46893250942230225 rgt:  0.00019035946752410382 loss:  -0.6707955002784729\n",
      "177 revenue:  0.4743242561817169 rgt:  0.0006065130583010614 loss:  -0.6634760499000549\n",
      "178 revenue:  0.4642155170440674 rgt:  0.00011152274237247184 loss:  -0.6706570386886597\n",
      "179 revenue:  0.469226598739624 rgt:  0.0010213978821411729 loss:  -0.6520189642906189\n",
      "180 revenue:  0.4590069353580475 rgt:  0.00013525351823773235 loss:  -0.6657311320304871\n",
      "181 revenue:  0.46910569071769714 rgt:  0.00022780473227612674 loss:  -0.6695886850357056\n",
      "182 revenue:  0.45902881026268005 rgt:  0.00014220888260751963 loss:  -0.6654451489448547\n",
      "183 revenue:  0.46908095479011536 rgt:  0.00012634503946173936 loss:  -0.6735237836837769\n",
      "184 revenue:  0.47912895679473877 rgt:  0.0026398785412311554 loss:  -0.6381708979606628\n",
      "185 revenue:  0.46901413798332214 rgt:  0.0001782554609235376 loss:  -0.6713129281997681\n",
      "186 revenue:  0.4738103449344635 rgt:  0.00010482294601388276 loss:  -0.6779908537864685\n",
      "187 revenue:  0.4734478294849396 rgt:  0.000316820340231061 loss:  -0.6699565052986145\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "188 revenue:  0.463042676448822 rgt:  0.00011712752166204154 loss:  -0.6695281267166138\n",
      "189 revenue:  0.4732464849948883 rgt:  8.597670966992155e-05 loss:  -0.678565502166748\n",
      "190 revenue:  0.48335349559783936 rgt:  0.0009921554010361433 loss:  -0.6627440452575684\n",
      "191 revenue:  0.47317683696746826 rgt:  0.0009899453725665808 loss:  -0.6554235816001892\n",
      "192 revenue:  0.4628523290157318 rgt:  7.016606105025858e-05 loss:  -0.6718799471855164\n",
      "193 revenue:  0.473177045583725 rgt:  0.0011101134587079287 loss:  -0.6534487009048462\n",
      "194 revenue:  0.46283388137817383 rgt:  0.00011365916725480929 loss:  -0.6695395708084106\n",
      "195 revenue:  0.4730711877346039 rgt:  0.0004305716429371387 loss:  -0.6666185855865479\n",
      "196 revenue:  0.46300750970840454 rgt:  0.00010377957369200885 loss:  -0.6701506972312927\n",
      "197 revenue:  0.47334203124046326 rgt:  0.0006981847691349685 loss:  -0.6608753800392151\n",
      "198 revenue:  0.462974488735199 rgt:  9.621338540455326e-05 loss:  -0.6705121994018555\n",
      "199 revenue:  0.4732546806335449 rgt:  0.00021554724662564695 loss:  -0.6730347275733948\n",
      "200 revenue:  0.4627624750137329 rgt:  0.0002677984011825174 loss:  -0.6636311411857605\n",
      "201 revenue:  0.4618239998817444 rgt:  5.093328218208626e-05 loss:  -0.6723816990852356\n",
      "202 revenue:  0.4628726541996002 rgt:  0.00011301515041850507 loss:  -0.6695989370346069\n",
      "203 revenue:  0.46391239762306213 rgt:  9.895775292534381e-05 loss:  -0.6710594892501831\n",
      "204 revenue:  0.4649510085582733 rgt:  0.00015413828077726066 loss:  -0.6692997813224792\n",
      "205 revenue:  0.4659827947616577 rgt:  0.00011774537415476516 loss:  -0.6716560125350952\n",
      "206 revenue:  0.4670126736164093 rgt:  0.00011888150766026229 loss:  -0.6723565459251404\n",
      "207 revenue:  0.4666053056716919 rgt:  0.00028232630575075746 loss:  -0.665997326374054\n",
      "208 revenue:  0.46556931734085083 rgt:  0.00019155674090143293 loss:  -0.668290913105011\n",
      "209 revenue:  0.4646342694759369 rgt:  0.00010780093725770712 loss:  -0.6711455583572388\n",
      "210 revenue:  0.46561747789382935 rgt:  7.533903408329934e-05 loss:  -0.6736008524894714\n",
      "211 revenue:  0.4666614234447479 rgt:  0.00011063052079407498 loss:  -0.6724928021430969\n",
      "212 revenue:  0.4676944315433502 rgt:  7.889387052273378e-05 loss:  -0.6749151945114136\n",
      "213 revenue:  0.46872982382774353 rgt:  0.0003368200268596411 loss:  -0.6659463047981262\n",
      "214 revenue:  0.46773239970207214 rgt:  0.00016949593555182219 loss:  -0.6707172989845276\n",
      "215 revenue:  0.4669254720211029 rgt:  9.502594184596092e-05 loss:  -0.6734712719917297\n",
      "216 revenue:  0.46796083450317383 rgt:  0.00023186823818832636 loss:  -0.6686143279075623\n",
      "217 revenue:  0.4670957922935486 rgt:  5.7495199143886566e-05 loss:  -0.6757974624633789\n",
      "218 revenue:  0.4681415855884552 rgt:  6.923070031916723e-05 loss:  -0.6758130788803101\n",
      "219 revenue:  0.4691850543022156 rgt:  0.00032556045334786177 loss:  -0.6665992736816406\n",
      "220 revenue:  0.4681420922279358 rgt:  0.0004939617356285453 loss:  -0.6614876985549927\n",
      "221 revenue:  0.46710261702537537 rgt:  0.0001104906405089423 loss:  -0.6728224158287048\n",
      "222 revenue:  0.46813544631004333 rgt:  0.00037511048140004277 loss:  -0.6644588112831116\n",
      "223 revenue:  0.4671332836151123 rgt:  0.00016367726493626833 loss:  -0.6705103516578674\n",
      "224 revenue:  0.4681604504585266 rgt:  0.00010908866534009576 loss:  -0.6736641526222229\n",
      "225 revenue:  0.46919089555740356 rgt:  0.00025171320885419846 loss:  -0.668854832649231\n",
      "226 revenue:  0.46817123889923096 rgt:  0.0005779552739113569 loss:  -0.6596097350120544\n",
      "227 revenue:  0.4671286940574646 rgt:  7.957226625876501e-05 loss:  -0.6744627356529236\n",
      "228 revenue:  0.46817126870155334 rgt:  0.00047168845776468515 loss:  -0.6620380878448486\n",
      "229 revenue:  0.46713075041770935 rgt:  0.00016484477964695543 loss:  -0.6704617738723755\n",
      "230 revenue:  0.4661785960197449 rgt:  6.156865129014477e-05 loss:  -0.6748583316802979\n",
      "231 revenue:  0.4672211706638336 rgt:  0.00010805872443597764 loss:  -0.6730278730392456\n",
      "232 revenue:  0.46825867891311646 rgt:  0.00012014866661047563 loss:  -0.6732083559036255\n",
      "233 revenue:  0.46928972005844116 rgt:  9.157589374808595e-05 loss:  -0.6753810048103333\n",
      "234 revenue:  0.4703260064125061 rgt:  0.00012463452003430575 loss:  -0.67451012134552\n",
      "235 revenue:  0.469882607460022 rgt:  0.000721795076970011 loss:  -0.6578899621963501\n",
      "236 revenue:  0.46884340047836304 rgt:  9.929459338309243e-05 loss:  -0.6746524572372437\n",
      "237 revenue:  0.46987390518188477 rgt:  0.0005789712886326015 loss:  -0.6608306765556335\n",
      "238 revenue:  0.46883392333984375 rgt:  0.0002804719842970371 loss:  -0.6676837801933289\n",
      "239 revenue:  0.46792566776275635 rgt:  0.00016275193775072694 loss:  -0.6711268424987793\n",
      "240 revenue:  0.4688195288181305 rgt:  0.00013389762898441404 loss:  -0.6729944348335266\n",
      "241 revenue:  0.46801599860191345 rgt:  7.961964729474857e-05 loss:  -0.675108790397644\n",
      "242 revenue:  0.46905526518821716 rgt:  9.429214696865529e-05 loss:  -0.6750662922859192\n",
      "243 revenue:  0.4700935482978821 rgt:  0.00019047388923354447 loss:  -0.6716384291648865\n",
      "244 revenue:  0.4693814218044281 rgt:  0.0001991285680560395 loss:  -0.6708002090454102\n",
      "245 revenue:  0.46836259961128235 rgt:  0.00012738991063088179 loss:  -0.6729517579078674\n",
      "246 revenue:  0.4693850576877594 rgt:  0.0002120315912179649 loss:  -0.6703401207923889\n",
      "247 revenue:  0.4684104323387146 rgt:  9.504302579443902e-05 loss:  -0.6745560169219971\n",
      "248 revenue:  0.4694460332393646 rgt:  8.253625128418207e-05 loss:  -0.6759884357452393\n",
      "249 revenue:  0.4704851508140564 rgt:  8.053166675381362e-05 loss:  -0.6768592596054077\n",
      "250 revenue:  0.47152283787727356 rgt:  0.00030701563809998333 loss:  -0.6688435077667236\n",
      "251 revenue:  0.47049060463905334 rgt:  0.0006608189432881773 loss:  -0.6595540642738342\n",
      "252 revenue:  0.46945369243621826 rgt:  9.96689050225541e-05 loss:  -0.6750788688659668\n",
      "253 revenue:  0.47048789262771606 rgt:  8.959248953033239e-05 loss:  -0.6763610243797302\n",
      "254 revenue:  0.47151732444763184 rgt:  0.00011906840518349782 loss:  -0.6756357550621033\n",
      "255 revenue:  0.47254353761672974 rgt:  0.000514170853421092 loss:  -0.6642264127731323\n",
      "256 revenue:  0.4715050756931305 rgt:  0.00018085111514665186 loss:  -0.6730296611785889\n",
      "257 revenue:  0.4705527126789093 rgt:  0.00029629297205246985 loss:  -0.6684561371803284\n",
      "258 revenue:  0.4695289433002472 rgt:  0.0004037001053802669 loss:  -0.6647234559059143\n",
      "259 revenue:  0.4684816300868988 rgt:  6.486842903541401e-05 loss:  -0.6763321161270142\n",
      "260 revenue:  0.46952617168426514 rgt:  0.0007785030175000429 loss:  -0.6565379500389099\n",
      "261 revenue:  0.4684840142726898 rgt:  0.00021766126155853271 loss:  -0.669484555721283\n",
      "262 revenue:  0.46748486161231995 rgt:  7.065572572173551e-05 loss:  -0.6752464175224304\n",
      "263 revenue:  0.4685330092906952 rgt:  0.00028437181026674807 loss:  -0.6673441529273987\n",
      "264 revenue:  0.4674922823905945 rgt:  7.539878424722701e-05 loss:  -0.6749697327613831\n",
      "265 revenue:  0.46853622794151306 rgt:  0.00029988260939717293 loss:  -0.666877269744873\n",
      "266 revenue:  0.46748220920562744 rgt:  8.88480426510796e-05 loss:  -0.6742066740989685\n",
      "267 revenue:  0.46851688623428345 rgt:  0.000124686659546569 loss:  -0.6731874942779541\n",
      "268 revenue:  0.4695412516593933 rgt:  0.000903442851267755 loss:  -0.6542684435844421\n",
      "269 revenue:  0.4684999883174896 rgt:  0.00017290888354182243 loss:  -0.6711444854736328\n",
      "270 revenue:  0.46757981181144714 rgt:  4.492803054745309e-05 loss:  -0.6770429015159607\n",
      "271 revenue:  0.4686277210712433 rgt:  8.080627594608814e-05 loss:  -0.675488293170929\n",
      "272 revenue:  0.46966323256492615 rgt:  0.00015870260540395975 loss:  -0.6725594401359558\n",
      "273 revenue:  0.46861156821250916 rgt:  0.0003323309065308422 loss:  -0.6659870743751526\n",
      "274 revenue:  0.46758532524108887 rgt:  0.0001986911374842748 loss:  -0.6695041656494141\n",
      "275 revenue:  0.4665268361568451 rgt:  6.932315591257066e-05 loss:  -0.674626350402832\n",
      "276 revenue:  0.4675738215446472 rgt:  7.26734142517671e-05 loss:  -0.6751903891563416\n",
      "277 revenue:  0.4686160087585449 rgt:  0.00010911181016126648 loss:  -0.673995852470398\n",
      "278 revenue:  0.46964970231056213 rgt:  0.00017659086734056473 loss:  -0.6718409061431885\n",
      "279 revenue:  0.4686610698699951 rgt:  0.0002565075410529971 loss:  -0.6683128476142883\n",
      "280 revenue:  0.4676346182823181 rgt:  0.00012560041795950383 loss:  -0.6725009679794312\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "281 revenue:  0.46826404333114624 rgt:  0.0005751760909333825 loss:  -0.6597381830215454\n",
      "282 revenue:  0.4672192335128784 rgt:  0.000118991345516406 loss:  -0.6725025773048401\n",
      "283 revenue:  0.46825307607650757 rgt:  0.0003287341387476772 loss:  -0.6658278107643127\n",
      "284 revenue:  0.46720173954963684 rgt:  4.908587652607821e-05 loss:  -0.6764592528343201\n",
      "285 revenue:  0.4682497978210449 rgt:  0.00014895095955580473 loss:  -0.6719302535057068\n",
      "286 revenue:  0.4674568772315979 rgt:  0.0004393834387883544 loss:  -0.6623049974441528\n",
      "287 revenue:  0.46640947461128235 rgt:  6.740561366314068e-05 loss:  -0.6746582388877869\n",
      "288 revenue:  0.4674612581729889 rgt:  0.00011875600466737524 loss:  -0.6726906299591064\n",
      "289 revenue:  0.4684999883174896 rgt:  6.113511335570365e-05 loss:  -0.6765841841697693\n",
      "290 revenue:  0.46954309940338135 rgt:  0.00010029322584159672 loss:  -0.675112247467041\n",
      "291 revenue:  0.4705766439437866 rgt:  0.0007754184771329165 loss:  -0.6573624610900879\n",
      "292 revenue:  0.46953392028808594 rgt:  9.36004362301901e-05 loss:  -0.6754520535469055\n",
      "293 revenue:  0.47056815028190613 rgt:  0.0008671138784848154 loss:  -0.6556641459465027\n",
      "294 revenue:  0.46952345967292786 rgt:  6.362068234011531e-05 loss:  -0.6771717071533203\n",
      "295 revenue:  0.47056540846824646 rgt:  0.00018214923329651356 loss:  -0.6722956299781799\n",
      "296 revenue:  0.4697756767272949 rgt:  6.778929673600942e-05 loss:  -0.6770946383476257\n",
      "297 revenue:  0.4708188474178314 rgt:  0.00020615269022528082 loss:  -0.6715947985649109\n",
      "298 revenue:  0.46977469325065613 rgt:  0.0005106694297865033 loss:  -0.6622902750968933\n",
      "299 revenue:  0.46873778104782104 rgt:  5.498277459992096e-05 loss:  -0.6771675944328308\n"
     ]
    }
   ],
   "source": [
    "training_set = train_set_prob_1\n",
    "\n",
    "for i in range(len(learning_rates)):\n",
    "    opt_aucter = torch.optim.AdamW(list(ANet.parameters()) + list(PNet.parameters()),learning_rates[i])\n",
    "    \n",
    "    if (i % 1 == 0):\n",
    "        MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "        MNet.to(device)\n",
    "    \n",
    "    misreports = find_max_regret_misreport_BNIC(ANet,PNet,MNet,training_set,100, 300,0.00001)\n",
    "    \n",
    "    opt_aucter.zero_grad()\n",
    "\n",
    "    u_orig = truthful_utility_calculation(ANet,PNet,training_set)\n",
    "    u_new = misreport_utility_calculation_BNIC(ANet,PNet, training_set, misreports)\n",
    "        \n",
    "    rgt = torch.mean(u_new - u_orig)\n",
    "    rev = torch.mean(truthful_revenue_calculation(ANet,PNet,training_set))\n",
    "\n",
    "    #update NN\n",
    "    loss =  -1 * (torch.sqrt(rev + 1e-7) -  torch.sqrt(rgt + 1e-7)) + rgt\n",
    "    \n",
    "    loss.backward()\n",
    "    opt_aucter.step()\n",
    "\n",
    "    print(i, 'revenue: ',rev.item(), 'rgt: ', rgt.item(), 'loss: ', loss.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73dbf533",
   "metadata": {},
   "source": [
    "# Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0c1fd7af",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MisreportNetBNIC(\n",
       "  (mlp): MLP(\n",
       "    (model): Sequential(\n",
       "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
       "      (tanh1): Tanh()\n",
       "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh2): Tanh()\n",
       "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh3): Tanh()\n",
       "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh4): Tanh()\n",
       "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh5): Tanh()\n",
       "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh6): Tanh()\n",
       "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "MNet.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3ac6dc74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6875700950622559 0.0008380782674066722 0.6403982762670604\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_no_prob)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "107ef0bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.46991145610809326 0.0053620594553649426 0.3748804248368806\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_prob_1) #optimal 0.5132\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f9388bbf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.03085928224027157 3.4931552363559604e-05 0.02881771420585173\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_single_1)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "338b6f77",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5627250075340271 0.00769775127992034 0.43879119501577263\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_single_2)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b38e4e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
