{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5ac754fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run NNs.ipynb\n",
    "%run Helpers.ipynb\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ca124b90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "38b59e27",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 2 #number of bidders\n",
    "m = 2 #number of items\n",
    "\n",
    "nLayersAllocation   = 7 #minimum 2\n",
    "nLayersPayment      = 7 #minimum 2\n",
    "nLayersMisreport    = 7 #minimum 2\n",
    "widthAllocation     = 100\n",
    "widthPayment        = 100\n",
    "widthMisreport      = 100\n",
    "\n",
    "batch_size = 20000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "021a570a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AllocationNet(\n",
      "  (mlp1): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=2, bias=True)\n",
      "    )\n",
      "  )\n",
      "  (mlp2): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "PaymentNet(\n",
      "  (mlp): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=2, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "MisreportNetBNIC(\n",
      "  (mlp): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "ANet = AllocationNet(n,m,nLayersAllocation,widthAllocation)\n",
    "ANet.to(device)\n",
    "ANet.train()\n",
    "\n",
    "\n",
    "PNet = PaymentNet(n,m,nLayersPayment,widthPayment)\n",
    "PNet.to(device)\n",
    "PNet.train()\n",
    "\n",
    "MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "MNet.to(device)\n",
    "MNet.train()\n",
    "\n",
    "print(ANet)\n",
    "print(PNet)\n",
    "print(MNet)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e10ea38c",
   "metadata": {},
   "source": [
    "# set up dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "da032379",
   "metadata": {},
   "outputs": [],
   "source": [
    "prob_single_1 = torch.tensor([0.2,0.8]).to(device)\n",
    "prob_single_2 = torch.tensor([0.5,0.7]).to(device)\n",
    "prob_single_3 = torch.tensor([0.7,0.9]).to(device)\n",
    "no_prob_single = torch.tensor([1,1]).to(device)\n",
    "single_bidder_1 = torch.tensor([1,0]).to(device)\n",
    "single_bidder_2 = torch.tensor([0,1]).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4cb66a55",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set_prob_1 = get_train_or_test_set(batch_size,n,m,prob_single_1)\n",
    "test_set_prob_1 = get_train_or_test_set(batch_size,n,m,prob_single_1)\n",
    "\n",
    "train_set_prob_2 = get_train_or_test_set(batch_size,n,m,prob_single_2)\n",
    "test_set_prob_2 = get_train_or_test_set(batch_size,n,m,prob_single_2)\n",
    "\n",
    "train_set_prob_3 = get_train_or_test_set(batch_size,n,m,prob_single_3)\n",
    "test_set_prob_3 = get_train_or_test_set(batch_size,n,m,prob_single_3)\n",
    "\n",
    "train_set_no_prob = get_train_or_test_set(batch_size,n,m,no_prob_single)\n",
    "test_set_no_prob = get_train_or_test_set(batch_size,n,m,no_prob_single)\n",
    "\n",
    "train_set_single_1 = get_train_or_test_set(batch_size,n,m,single_bidder_1)\n",
    "test_set_single_1 = get_train_or_test_set(batch_size,n,m,single_bidder_1)\n",
    "\n",
    "train_set_single_2 = get_train_or_test_set(batch_size,n,m,single_bidder_2)\n",
    "test_set_single_2 = get_train_or_test_set(batch_size,n,m,single_bidder_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "531955b6",
   "metadata": {},
   "source": [
    "# train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6435cdcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rates = [0.0005] * 100 + [0.00005] * 100 + [0.000005] * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1a70c85f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 revenue:  0.1482238918542862 rgt:  0.04326273128390312 loss:  -0.13373878598213196\n",
      "1 revenue:  0.15353310108184814 rgt:  0.05745549127459526 loss:  -0.09467867016792297\n",
      "2 revenue:  0.15747517347335815 rgt:  0.06642243266105652 loss:  -0.0726834237575531\n",
      "3 revenue:  0.1571110486984253 rgt:  0.06802401691675186 loss:  -0.06753414124250412\n",
      "4 revenue:  0.1559925675392151 rgt:  0.0665435791015625 loss:  -0.07045486569404602\n",
      "5 revenue:  0.15514886379241943 rgt:  0.0636075884103775 loss:  -0.0780763179063797\n",
      "6 revenue:  0.1546611338853836 rgt:  0.05954122170805931 loss:  -0.0897177904844284\n",
      "7 revenue:  0.15459808707237244 rgt:  0.054330576211214066 loss:  -0.10576976835727692\n",
      "8 revenue:  0.15501174330711365 rgt:  0.04781804233789444 loss:  -0.12722378969192505\n",
      "9 revenue:  0.15566933155059814 rgt:  0.039657145738601685 loss:  -0.15575124323368073\n",
      "10 revenue:  0.1563897430896759 rgt:  0.02173192985355854 loss:  -0.22631175816059113\n",
      "11 revenue:  0.15863369405269623 rgt:  0.01775064691901207 loss:  -0.2473059892654419\n",
      "12 revenue:  0.15964074432849884 rgt:  0.010039378888905048 loss:  -0.2893142104148865\n",
      "13 revenue:  0.1614830493927002 rgt:  0.004376691300421953 loss:  -0.33131566643714905\n",
      "14 revenue:  0.16364552080631256 rgt:  0.001618646550923586 loss:  -0.36267906427383423\n",
      "15 revenue:  0.1753869354724884 rgt:  0.0008610653458163142 loss:  -0.3885856568813324\n",
      "16 revenue:  0.2000514715909958 rgt:  0.0012774261413142085 loss:  -0.41025134921073914\n",
      "17 revenue:  0.21973934769630432 rgt:  0.0011795529862865806 loss:  -0.43323811888694763\n",
      "18 revenue:  0.2462255358695984 rgt:  0.0016334975371137261 loss:  -0.45416000485420227\n",
      "19 revenue:  0.2685072720050812 rgt:  0.0016064088558778167 loss:  -0.4764893054962158\n",
      "20 revenue:  0.2989595830440521 rgt:  0.004633724689483643 loss:  -0.4740661382675171\n",
      "21 revenue:  0.28326207399368286 rgt:  0.0024806843139231205 loss:  -0.4799356162548065\n",
      "22 revenue:  0.2894366681575775 rgt:  0.002110410248860717 loss:  -0.4899425506591797\n",
      "23 revenue:  0.2965393662452698 rgt:  0.0011892098700627685 loss:  -0.508878767490387\n",
      "24 revenue:  0.3274659514427185 rgt:  0.003321547294035554 loss:  -0.5112912058830261\n",
      "25 revenue:  0.33356142044067383 rgt:  0.0021439974661916494 loss:  -0.5290994644165039\n",
      "26 revenue:  0.3656436800956726 rgt:  0.006414793431758881 loss:  -0.5181770324707031\n",
      "27 revenue:  0.3539375960826874 rgt:  0.002990842331200838 loss:  -0.5372462272644043\n",
      "28 revenue:  0.3724902868270874 rgt:  0.003772000316530466 loss:  -0.5451304912567139\n",
      "29 revenue:  0.3667217493057251 rgt:  0.0027244361117482185 loss:  -0.5506541132926941\n",
      "30 revenue:  0.40931567549705505 rgt:  0.003805106272920966 loss:  -0.5742864608764648\n",
      "31 revenue:  0.3897726833820343 rgt:  0.003649124875664711 loss:  -0.5602598786354065\n",
      "32 revenue:  0.3942817151546478 rgt:  0.004202705807983875 loss:  -0.5588868260383606\n",
      "33 revenue:  0.4140995442867279 rgt:  0.003805136773735285 loss:  -0.5780140161514282\n",
      "34 revenue:  0.4608907699584961 rgt:  0.009374313056468964 loss:  -0.5726935863494873\n",
      "35 revenue:  0.430914044380188 rgt:  0.003394666127860546 loss:  -0.594781219959259\n",
      "36 revenue:  0.48825979232788086 rgt:  0.010146223939955235 loss:  -0.5878807902336121\n",
      "37 revenue:  0.44987234473228455 rgt:  0.004501972813159227 loss:  -0.5991258025169373\n",
      "38 revenue:  0.5181741118431091 rgt:  0.013000099919736385 loss:  -0.5928246974945068\n",
      "39 revenue:  0.4730105400085449 rgt:  0.0062871635891497135 loss:  -0.6021782755851746\n",
      "40 revenue:  0.5412019491195679 rgt:  0.0098567521199584 loss:  -0.6265259385108948\n",
      "41 revenue:  0.4793839156627655 rgt:  0.0069612194783985615 loss:  -0.6019798517227173\n",
      "42 revenue:  0.5304938554763794 rgt:  0.012779856100678444 loss:  -0.6025218963623047\n",
      "43 revenue:  0.473482608795166 rgt:  0.0053282310254871845 loss:  -0.6097771525382996\n",
      "44 revenue:  0.5546084642410278 rgt:  0.014563348144292831 loss:  -0.6094779968261719\n",
      "45 revenue:  0.4957810342311859 rgt:  0.0069285109639167786 loss:  -0.6139504313468933\n",
      "46 revenue:  0.5774824023246765 rgt:  0.017576701939105988 loss:  -0.6097684502601624\n",
      "47 revenue:  0.5124896168708801 rgt:  0.008874082937836647 loss:  -0.6128069758415222\n",
      "48 revenue:  0.59258633852005 rgt:  0.01957375928759575 loss:  -0.6103159785270691\n",
      "49 revenue:  0.522855818271637 rgt:  0.009906336665153503 loss:  -0.613650381565094\n",
      "50 revenue:  0.6023215651512146 rgt:  0.01468287967145443 loss:  -0.6402376294136047\n",
      "51 revenue:  0.5214836001396179 rgt:  0.009872314520180225 loss:  -0.6129059195518494\n",
      "52 revenue:  0.5956788659095764 rgt:  0.020009389147162437 loss:  -0.6103380918502808\n",
      "53 revenue:  0.520952045917511 rgt:  0.009318484924733639 loss:  -0.6159188747406006\n",
      "54 revenue:  0.6037437319755554 rgt:  0.02032664231956005 loss:  -0.614111065864563\n",
      "55 revenue:  0.5308626294136047 rgt:  0.010130359791219234 loss:  -0.6178227663040161\n",
      "56 revenue:  0.613764762878418 rgt:  0.02182978391647339 loss:  -0.613852322101593\n",
      "57 revenue:  0.5385463237762451 rgt:  0.011011965572834015 loss:  -0.6179068684577942\n",
      "58 revenue:  0.6203494071960449 rgt:  0.022829586640000343 loss:  -0.6136981248855591\n",
      "59 revenue:  0.5434802770614624 rgt:  0.01150711253285408 loss:  -0.6184324622154236\n",
      "60 revenue:  0.6251137852668762 rgt:  0.019456032663583755 loss:  -0.6317001581192017\n",
      "61 revenue:  0.5384005308151245 rgt:  0.01168602705001831 loss:  -0.6139695048332214\n",
      "62 revenue:  0.6161075234413147 rgt:  0.02211526781320572 loss:  -0.614097535610199\n",
      "63 revenue:  0.5384934544563293 rgt:  0.010421316139400005 loss:  -0.621314525604248\n",
      "64 revenue:  0.623832106590271 rgt:  0.023082949221134186 loss:  -0.6148164868354797\n",
      "65 revenue:  0.5453804731369019 rgt:  0.011255193501710892 loss:  -0.6211526989936829\n",
      "66 revenue:  0.6290982365608215 rgt:  0.023929588496685028 loss:  -0.614535391330719\n",
      "67 revenue:  0.5490188598632812 rgt:  0.011678071692585945 loss:  -0.6212144494056702\n",
      "68 revenue:  0.6321868300437927 rgt:  0.024418272078037262 loss:  -0.6144197583198547\n",
      "69 revenue:  0.5511963963508606 rgt:  0.011868491768836975 loss:  -0.6216145753860474\n",
      "70 revenue:  0.6343985199928284 rgt:  0.018456242978572845 loss:  -0.6421810984611511\n",
      "71 revenue:  0.5437455177307129 rgt:  0.011622113175690174 loss:  -0.6179625988006592\n",
      "72 revenue:  0.6234127879142761 rgt:  0.02270538918673992 loss:  -0.6161762475967407\n",
      "73 revenue:  0.543062686920166 rgt:  0.010162033140659332 loss:  -0.6259584426879883\n",
      "74 revenue:  0.631439745426178 rgt:  0.02387011982500553 loss:  -0.616261899471283\n",
      "75 revenue:  0.5498198866844177 rgt:  0.011161625385284424 loss:  -0.6246877908706665\n",
      "76 revenue:  0.6352149844169617 rgt:  0.02451368421316147 loss:  -0.6159213781356812\n",
      "77 revenue:  0.5521173477172852 rgt:  0.011459521017968655 loss:  -0.62453693151474\n",
      "78 revenue:  0.6368573904037476 rgt:  0.024792062118649483 loss:  -0.6157861948013306\n",
      "79 revenue:  0.5530209541320801 rgt:  0.011526806280016899 loss:  -0.6247636079788208\n",
      "80 revenue:  0.6374943256378174 rgt:  0.017755040898919106 loss:  -0.6474289894104004\n",
      "81 revenue:  0.5456060767173767 rgt:  0.011148347519338131 loss:  -0.6219170689582825\n",
      "82 revenue:  0.6249366998672485 rgt:  0.022962452843785286 loss:  -0.6160329580307007\n",
      "83 revenue:  0.5414689183235168 rgt:  0.009680688381195068 loss:  -0.6277741193771362\n",
      "84 revenue:  0.6299166083335876 rgt:  0.02307852730154991 loss:  -0.6186779141426086\n",
      "85 revenue:  0.5476201772689819 rgt:  0.010330712422728539 loss:  -0.6280423998832703\n",
      "86 revenue:  0.6341403722763062 rgt:  0.02378169260919094 loss:  -0.618334174156189\n",
      "87 revenue:  0.5505864024162292 rgt:  0.010740325786173344 loss:  -0.6276388168334961\n",
      "88 revenue:  0.6353210806846619 rgt:  0.023938661441206932 loss:  -0.6184101700782776\n",
      "89 revenue:  0.5507516860961914 rgt:  0.010658804327249527 loss:  -0.6282257437705994\n",
      "90 revenue:  0.6354378461837769 rgt:  0.01574069820344448 loss:  -0.655940592288971\n",
      "91 revenue:  0.5434205532073975 rgt:  0.010217562317848206 loss:  -0.6258707046508789\n",
      "92 revenue:  0.6202059984207153 rgt:  0.02150602638721466 loss:  -0.619376003742218\n",
      "93 revenue:  0.5366523265838623 rgt:  0.008209393359720707 loss:  -0.6337499618530273\n",
      "94 revenue:  0.6276748180389404 rgt:  0.0220542773604393 loss:  -0.6216979622840881\n",
      "95 revenue:  0.544241726398468 rgt:  0.009198683314025402 loss:  -0.6326184868812561\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "96 revenue:  0.6312900185585022 rgt:  0.022675076499581337 loss:  -0.621279776096344\n",
      "97 revenue:  0.5460897088050842 rgt:  0.009394985623657703 loss:  -0.632655680179596\n",
      "98 revenue:  0.6313297152519226 rgt:  0.02259841188788414 loss:  -0.6216362714767456\n",
      "99 revenue:  0.5461145043373108 rgt:  0.009213249199092388 loss:  -0.633796215057373\n",
      "100 revenue:  0.6314941048622131 rgt:  0.016141550615429878 loss:  -0.6514747738838196\n",
      "101 revenue:  0.6228730082511902 rgt:  0.0207374207675457 loss:  -0.6244804263114929\n",
      "102 revenue:  0.6138484477996826 rgt:  0.018158113583922386 loss:  -0.6305742859840393\n",
      "103 revenue:  0.6054623126983643 rgt:  0.015793193131685257 loss:  -0.6366500854492188\n",
      "104 revenue:  0.5973973274230957 rgt:  0.013958706520497799 loss:  -0.6408088207244873\n",
      "105 revenue:  0.5897406935691833 rgt:  0.012022008188068867 loss:  -0.646278440952301\n",
      "106 revenue:  0.5921255946159363 rgt:  0.012452756986021996 loss:  -0.6454519033432007\n",
      "107 revenue:  0.5866422057151794 rgt:  0.01137326005846262 loss:  -0.6479066014289856\n",
      "108 revenue:  0.5866364240646362 rgt:  0.011737053282558918 loss:  -0.6458468437194824\n",
      "109 revenue:  0.5789022445678711 rgt:  0.010147872380912304 loss:  -0.6499713063240051\n",
      "110 revenue:  0.581479549407959 rgt:  0.007463004905730486 loss:  -0.6686959266662598\n",
      "111 revenue:  0.5900845527648926 rgt:  0.013109284453094006 loss:  -0.6405642032623291\n",
      "112 revenue:  0.5811921954154968 rgt:  0.011879355646669865 loss:  -0.6414873600006104\n",
      "113 revenue:  0.5743650197982788 rgt:  0.009488934651017189 loss:  -0.6509681940078735\n",
      "114 revenue:  0.5813379287719727 rgt:  0.010310791432857513 loss:  -0.6506019830703735\n",
      "115 revenue:  0.5872778296470642 rgt:  0.011142389848828316 loss:  -0.6496402621269226\n",
      "116 revenue:  0.5825973749160767 rgt:  0.01013773214071989 loss:  -0.6524562239646912\n",
      "117 revenue:  0.5886475443840027 rgt:  0.011661756783723831 loss:  -0.647581934928894\n",
      "118 revenue:  0.5801751613616943 rgt:  0.010053644888103008 loss:  -0.6513704061508179\n",
      "119 revenue:  0.5831009149551392 rgt:  0.010838592424988747 loss:  -0.6486629247665405\n",
      "120 revenue:  0.574536919593811 rgt:  0.006971423048526049 loss:  -0.6675150990486145\n",
      "121 revenue:  0.5833421945571899 rgt:  0.01169481873512268 loss:  -0.6439306735992432\n",
      "122 revenue:  0.5744173526763916 rgt:  0.011034018360078335 loss:  -0.6418259143829346\n",
      "123 revenue:  0.566537082195282 rgt:  0.008396397344768047 loss:  -0.6526578664779663\n",
      "124 revenue:  0.5756258368492126 rgt:  0.009545656852424145 loss:  -0.6514521241188049\n",
      "125 revenue:  0.5760294795036316 rgt:  0.008959426544606686 loss:  -0.6553519368171692\n",
      "126 revenue:  0.5848715305328369 rgt:  0.010373460128903389 loss:  -0.6525448560714722\n",
      "127 revenue:  0.5798357129096985 rgt:  0.009226608090102673 loss:  -0.6561871767044067\n",
      "128 revenue:  0.5870862603187561 rgt:  0.010883642360568047 loss:  -0.651006817817688\n",
      "129 revenue:  0.5790355205535889 rgt:  0.009376395493745804 loss:  -0.6547352075576782\n",
      "130 revenue:  0.585614800453186 rgt:  0.007585249841213226 loss:  -0.6705756783485413\n",
      "131 revenue:  0.594019889831543 rgt:  0.013285424560308456 loss:  -0.6421786546707153\n",
      "132 revenue:  0.584627628326416 rgt:  0.012275206856429577 loss:  -0.6415402889251709\n",
      "133 revenue:  0.5761778354644775 rgt:  0.00965733453631401 loss:  -0.6511342525482178\n",
      "134 revenue:  0.5751402378082275 rgt:  0.00867052748799324 loss:  -0.656593382358551\n",
      "135 revenue:  0.5848261713981628 rgt:  0.010481750592589378 loss:  -0.6518766283988953\n",
      "136 revenue:  0.5810773372650146 rgt:  0.009305744431912899 loss:  -0.6565118432044983\n",
      "137 revenue:  0.5900671482086182 rgt:  0.010992315597832203 loss:  -0.6523213386535645\n",
      "138 revenue:  0.581936776638031 rgt:  0.009098699316382408 loss:  -0.658361554145813\n",
      "139 revenue:  0.5861502885818481 rgt:  0.01016006339341402 loss:  -0.6546468734741211\n",
      "140 revenue:  0.579383134841919 rgt:  0.006861958187073469 loss:  -0.6714727878570557\n",
      "141 revenue:  0.5880764722824097 rgt:  0.012073085643351078 loss:  -0.6449103355407715\n",
      "142 revenue:  0.5793277025222778 rgt:  0.01086440123617649 loss:  -0.6460385918617249\n",
      "143 revenue:  0.5712842345237732 rgt:  0.007950943894684315 loss:  -0.6587139964103699\n",
      "144 revenue:  0.5805832743644714 rgt:  0.009527212008833885 loss:  -0.6548250317573547\n",
      "145 revenue:  0.5783228874206543 rgt:  0.008511172607541084 loss:  -0.65970778465271\n",
      "146 revenue:  0.5867855548858643 rgt:  0.00982047338038683 loss:  -0.6571000814437866\n",
      "147 revenue:  0.5833330750465393 rgt:  0.008979316800832748 loss:  -0.6600234508514404\n",
      "148 revenue:  0.5883819460868835 rgt:  0.010225561447441578 loss:  -0.6557130813598633\n",
      "149 revenue:  0.579832136631012 rgt:  0.008550586178898811 loss:  -0.6604466438293457\n",
      "150 revenue:  0.585972011089325 rgt:  0.0073498631827533245 loss:  -0.6724063754081726\n",
      "151 revenue:  0.5943595767021179 rgt:  0.013192997314035892 loss:  -0.6428930759429932\n",
      "152 revenue:  0.5845798850059509 rgt:  0.011294436641037464 loss:  -0.6470081806182861\n",
      "153 revenue:  0.5764790177345276 rgt:  0.008501152507960796 loss:  -0.6585587859153748\n",
      "154 revenue:  0.578783392906189 rgt:  0.00840340368449688 loss:  -0.6607041954994202\n",
      "155 revenue:  0.5883821249008179 rgt:  0.010051418095827103 loss:  -0.6567520499229431\n",
      "156 revenue:  0.5839396119117737 rgt:  0.008791252039372921 loss:  -0.6616060137748718\n",
      "157 revenue:  0.5926675796508789 rgt:  0.010439920239150524 loss:  -0.6572327613830566\n",
      "158 revenue:  0.585480272769928 rgt:  0.008841143921017647 loss:  -0.6622978448867798\n",
      "159 revenue:  0.591871976852417 rgt:  0.010146229527890682 loss:  -0.6584570407867432\n",
      "160 revenue:  0.5843330025672913 rgt:  0.007427942473441362 loss:  -0.6708028316497803\n",
      "161 revenue:  0.592707097530365 rgt:  0.012747741304337978 loss:  -0.6442207098007202\n",
      "162 revenue:  0.5833486914634705 rgt:  0.010229264385998249 loss:  -0.6524031162261963\n",
      "163 revenue:  0.5759662985801697 rgt:  0.007700689602643251 loss:  -0.663469672203064\n",
      "164 revenue:  0.5816075205802917 rgt:  0.008236432448029518 loss:  -0.6636403203010559\n",
      "165 revenue:  0.5896127820014954 rgt:  0.009572375565767288 loss:  -0.6604511737823486\n",
      "166 revenue:  0.5848104357719421 rgt:  0.008342544548213482 loss:  -0.6650484204292297\n",
      "167 revenue:  0.5934291481971741 rgt:  0.009933941066265106 loss:  -0.6607400178909302\n",
      "168 revenue:  0.5866613984107971 rgt:  0.008553564548492432 loss:  -0.6648987531661987\n",
      "169 revenue:  0.5929065346717834 rgt:  0.009867170825600624 loss:  -0.6608030200004578\n",
      "170 revenue:  0.585541307926178 rgt:  0.006999058183282614 loss:  -0.6745467782020569\n",
      "171 revenue:  0.5936599969863892 rgt:  0.012151254341006279 loss:  -0.6481090188026428\n",
      "172 revenue:  0.5835166573524475 rgt:  0.010139432735741138 loss:  -0.653048038482666\n",
      "173 revenue:  0.5738762617111206 rgt:  0.007054893299937248 loss:  -0.6664974093437195\n",
      "174 revenue:  0.5827348828315735 rgt:  0.008131438866257668 loss:  -0.6650643348693848\n",
      "175 revenue:  0.5820198059082031 rgt:  0.007638619747012854 loss:  -0.6678639054298401\n",
      "176 revenue:  0.5913481712341309 rgt:  0.009200829081237316 loss:  -0.6638694405555725\n",
      "177 revenue:  0.5849396586418152 rgt:  0.00771297886967659 loss:  -0.6692764759063721\n",
      "178 revenue:  0.592130184173584 rgt:  0.00900618638843298 loss:  -0.66559237241745\n",
      "179 revenue:  0.5868592262268066 rgt:  0.007952501066029072 loss:  -0.6689375638961792\n",
      "180 revenue:  0.588477611541748 rgt:  0.006830884609371424 loss:  -0.6776424050331116\n",
      "181 revenue:  0.5964477062225342 rgt:  0.012298881076276302 loss:  -0.6491007208824158\n",
      "182 revenue:  0.5858296751976013 rgt:  0.009791206568479538 loss:  -0.6566529273986816\n",
      "183 revenue:  0.5763739943504333 rgt:  0.0067469412460923195 loss:  -0.6703057289123535\n",
      "184 revenue:  0.5829755663871765 rgt:  0.007552754133939743 loss:  -0.6690685749053955\n",
      "185 revenue:  0.5842617154121399 rgt:  0.0074265990406274796 loss:  -0.6707652807235718\n",
      "186 revenue:  0.5930452942848206 rgt:  0.008886120282113552 loss:  -0.6669415235519409\n",
      "187 revenue:  0.5859788656234741 rgt:  0.007387019693851471 loss:  -0.6721572279930115\n",
      "188 revenue:  0.5945016741752625 rgt:  0.008939370512962341 loss:  -0.6675512790679932\n",
      "189 revenue:  0.5862705707550049 rgt:  0.007214573677629232 loss:  -0.6735293865203857\n",
      "190 revenue:  0.591801643371582 rgt:  0.00691558513790369 loss:  -0.679210364818573\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "191 revenue:  0.5994903445243835 rgt:  0.012160344049334526 loss:  -0.6518328785896301\n",
      "192 revenue:  0.5886650085449219 rgt:  0.009721247479319572 loss:  -0.6589270830154419\n",
      "193 revenue:  0.578370213508606 rgt:  0.006597364787012339 loss:  -0.6726844906806946\n",
      "194 revenue:  0.581397533416748 rgt:  0.006492091808468103 loss:  -0.6754281520843506\n",
      "195 revenue:  0.5905809998512268 rgt:  0.007931004278361797 loss:  -0.6715050339698792\n",
      "196 revenue:  0.5842381119728088 rgt:  0.0064956373535096645 loss:  -0.6772629618644714\n",
      "197 revenue:  0.5926228761672974 rgt:  0.007948733866214752 loss:  -0.6727151274681091\n",
      "198 revenue:  0.5849306583404541 rgt:  0.006190309766680002 loss:  -0.6799382567405701\n",
      "199 revenue:  0.5912980437278748 rgt:  0.0076485793106257915 loss:  -0.673853874206543\n",
      "200 revenue:  0.5820252895355225 rgt:  0.00611566286534071 loss:  -0.6785869002342224\n",
      "201 revenue:  0.5828055739402771 rgt:  0.009486490860581398 loss:  -0.6565315127372742\n",
      "202 revenue:  0.5818871855735779 rgt:  0.008425396867096424 loss:  -0.6625994443893433\n",
      "203 revenue:  0.5811654925346375 rgt:  0.006490909494459629 loss:  -0.6752845048904419\n",
      "204 revenue:  0.5818130970001221 rgt:  0.006112837232649326 loss:  -0.6784687638282776\n",
      "205 revenue:  0.5822383165359497 rgt:  0.005821026396006346 loss:  -0.680928111076355\n",
      "206 revenue:  0.5829275250434875 rgt:  0.005794358905404806 loss:  -0.6815813183784485\n",
      "207 revenue:  0.5837231278419495 rgt:  0.005910059437155724 loss:  -0.6812301874160767\n",
      "208 revenue:  0.5845479369163513 rgt:  0.0062089357525110245 loss:  -0.6795510053634644\n",
      "209 revenue:  0.5854100584983826 rgt:  0.006940440274775028 loss:  -0.6748707294464111\n",
      "210 revenue:  0.585909366607666 rgt:  0.006460618227720261 loss:  -0.6786080598831177\n",
      "211 revenue:  0.5866845846176147 rgt:  0.009916078299283981 loss:  -0.6564574241638184\n",
      "212 revenue:  0.5857127904891968 rgt:  0.00862084235996008 loss:  -0.6638489365577698\n",
      "213 revenue:  0.5847899913787842 rgt:  0.006943484302610159 loss:  -0.6744440793991089\n",
      "214 revenue:  0.5845968127250671 rgt:  0.006598637439310551 loss:  -0.6767581105232239\n",
      "215 revenue:  0.5845907330513 rgt:  0.0062926397658884525 loss:  -0.6789659857749939\n",
      "216 revenue:  0.5848321914672852 rgt:  0.006260286550968885 loss:  -0.6793603897094727\n",
      "217 revenue:  0.5853338241577148 rgt:  0.006515528541058302 loss:  -0.6778362989425659\n",
      "218 revenue:  0.5860082507133484 rgt:  0.007186588831245899 loss:  -0.6735509037971497\n",
      "219 revenue:  0.5863250494003296 rgt:  0.008047088049352169 loss:  -0.6679654717445374\n",
      "220 revenue:  0.5861746668815613 rgt:  0.0061013950034976006 loss:  -0.68140709400177\n",
      "221 revenue:  0.5869408845901489 rgt:  0.009234761819243431 loss:  -0.6607877612113953\n",
      "222 revenue:  0.5859541296958923 rgt:  0.008947307243943214 loss:  -0.6619384288787842\n",
      "223 revenue:  0.5849126577377319 rgt:  0.006966934073716402 loss:  -0.6743601560592651\n",
      "224 revenue:  0.5847744941711426 rgt:  0.006473164074122906 loss:  -0.6777758002281189\n",
      "225 revenue:  0.5849389433860779 rgt:  0.006227289326488972 loss:  -0.6796719431877136\n",
      "226 revenue:  0.58534836769104 rgt:  0.006201806478202343 loss:  -0.6801267266273499\n",
      "227 revenue:  0.5859298706054688 rgt:  0.006439481861889362 loss:  -0.6787741184234619\n",
      "228 revenue:  0.5865501761436462 rgt:  0.007100389804691076 loss:  -0.6745009422302246\n",
      "229 revenue:  0.5865646600723267 rgt:  0.007846971042454243 loss:  -0.6694445013999939\n",
      "230 revenue:  0.5863537192344666 rgt:  0.006230090744793415 loss:  -0.6805757880210876\n",
      "231 revenue:  0.5871184468269348 rgt:  0.00961234886199236 loss:  -0.6585811972618103\n",
      "232 revenue:  0.5861541032791138 rgt:  0.009020975790917873 loss:  -0.6616067886352539\n",
      "233 revenue:  0.5852464437484741 rgt:  0.0072899991646409035 loss:  -0.6723420023918152\n",
      "234 revenue:  0.5847911238670349 rgt:  0.006582543719559908 loss:  -0.6770004034042358\n",
      "235 revenue:  0.5847803354263306 rgt:  0.006272162310779095 loss:  -0.6792396903038025\n",
      "236 revenue:  0.5850557088851929 rgt:  0.0061484770849347115 loss:  -0.6803281307220459\n",
      "237 revenue:  0.5856256484985352 rgt:  0.006186941172927618 loss:  -0.6804171800613403\n",
      "238 revenue:  0.5863385796546936 rgt:  0.006445278879255056 loss:  -0.6789991855621338\n",
      "239 revenue:  0.5870689153671265 rgt:  0.007042626850306988 loss:  -0.675240695476532\n",
      "240 revenue:  0.5873028039932251 rgt:  0.00618598610162735 loss:  -0.6815192699432373\n",
      "241 revenue:  0.5880821347236633 rgt:  0.009781161323189735 loss:  -0.658183753490448\n",
      "242 revenue:  0.5870996713638306 rgt:  0.00891434121876955 loss:  -0.6628937125205994\n",
      "243 revenue:  0.5861802697181702 rgt:  0.007170550525188446 loss:  -0.6737738847732544\n",
      "244 revenue:  0.5857278108596802 rgt:  0.006619294639676809 loss:  -0.6773496866226196\n",
      "245 revenue:  0.5856232047080994 rgt:  0.006273151375353336 loss:  -0.679783284664154\n",
      "246 revenue:  0.5858899354934692 rgt:  0.006215320434421301 loss:  -0.6803812980651855\n",
      "247 revenue:  0.5865673422813416 rgt:  0.006308381445705891 loss:  -0.6801426410675049\n",
      "248 revenue:  0.5874207019805908 rgt:  0.006403458304703236 loss:  -0.6800081729888916\n",
      "249 revenue:  0.5881001353263855 rgt:  0.006851409561932087 loss:  -0.6772516369819641\n",
      "250 revenue:  0.5882139801979065 rgt:  0.005945905111730099 loss:  -0.6838948726654053\n",
      "251 revenue:  0.5889884829521179 rgt:  0.009319047443568707 loss:  -0.6616011261940002\n",
      "252 revenue:  0.5879704356193542 rgt:  0.009158434346318245 loss:  -0.6619337201118469\n",
      "253 revenue:  0.5868937373161316 rgt:  0.007321476005017757 loss:  -0.6732022762298584\n",
      "254 revenue:  0.5862622261047363 rgt:  0.006700861733406782 loss:  -0.6771174073219299\n",
      "255 revenue:  0.586049497127533 rgt:  0.0063477191142737865 loss:  -0.6795178055763245\n",
      "256 revenue:  0.5861682891845703 rgt:  0.006260668858885765 loss:  -0.6802306175231934\n",
      "257 revenue:  0.5865954160690308 rgt:  0.006478144787251949 loss:  -0.6789295673370361\n",
      "258 revenue:  0.5871860980987549 rgt:  0.007123211398720741 loss:  -0.6747578382492065\n",
      "259 revenue:  0.5871293544769287 rgt:  0.008003072813153267 loss:  -0.6687802672386169\n",
      "260 revenue:  0.5868348479270935 rgt:  0.00658177025616169 loss:  -0.6783410310745239\n",
      "261 revenue:  0.5875892639160156 rgt:  0.009707173332571983 loss:  -0.6583111882209778\n",
      "262 revenue:  0.5866032838821411 rgt:  0.00872834213078022 loss:  -0.6637459397315979\n",
      "263 revenue:  0.5856739282608032 rgt:  0.007002157624810934 loss:  -0.6746117472648621\n",
      "264 revenue:  0.5853154063224792 rgt:  0.006428590044379234 loss:  -0.6784514784812927\n",
      "265 revenue:  0.5853698253631592 rgt:  0.0061471289955079556 loss:  -0.6805433630943298\n",
      "266 revenue:  0.585704505443573 rgt:  0.006096518598496914 loss:  -0.6811360120773315\n",
      "267 revenue:  0.5861635804176331 rgt:  0.006234783213585615 loss:  -0.6804172396659851\n",
      "268 revenue:  0.5867013335227966 rgt:  0.006752799730747938 loss:  -0.6770355701446533\n",
      "269 revenue:  0.5868095755577087 rgt:  0.007546032313257456 loss:  -0.6716205477714539\n",
      "270 revenue:  0.586685061454773 rgt:  0.006513901520520449 loss:  -0.6787305474281311\n",
      "271 revenue:  0.587436318397522 rgt:  0.009656177833676338 loss:  -0.6585214734077454\n",
      "272 revenue:  0.586460292339325 rgt:  0.008592158555984497 loss:  -0.6645205020904541\n",
      "273 revenue:  0.5855302810668945 rgt:  0.006755969952791929 loss:  -0.67624831199646\n",
      "274 revenue:  0.5852064490318298 rgt:  0.006285724695771933 loss:  -0.6794190406799316\n",
      "275 revenue:  0.5852465033531189 rgt:  0.006010020151734352 loss:  -0.6814791560173035\n",
      "276 revenue:  0.585551917552948 rgt:  0.005933472886681557 loss:  -0.6822505593299866\n",
      "277 revenue:  0.5861254930496216 rgt:  0.00598074309527874 loss:  -0.6822717785835266\n",
      "278 revenue:  0.5867382884025574 rgt:  0.006033797282725573 loss:  -0.6822766065597534\n",
      "279 revenue:  0.5869054794311523 rgt:  0.00633226428180933 loss:  -0.6801892518997192\n",
      "280 revenue:  0.5869508385658264 rgt:  0.006117470562458038 loss:  -0.6817948818206787\n",
      "281 revenue:  0.5877118706703186 rgt:  0.009144708514213562 loss:  -0.6618505120277405\n",
      "282 revenue:  0.5867127180099487 rgt:  0.00860616471618414 loss:  -0.6645957827568054\n",
      "283 revenue:  0.5856981873512268 rgt:  0.006740017328411341 loss:  -0.6764710545539856\n",
      "284 revenue:  0.5854620933532715 rgt:  0.006162445526570082 loss:  -0.6804906725883484\n",
      "285 revenue:  0.5855656862258911 rgt:  0.005912735592573881 loss:  -0.6824150681495667\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "286 revenue:  0.5858736038208008 rgt:  0.005892641376703978 loss:  -0.6827670335769653\n",
      "287 revenue:  0.5863825082778931 rgt:  0.006111503578722477 loss:  -0.6814680099487305\n",
      "288 revenue:  0.586947500705719 rgt:  0.006768614985048771 loss:  -0.6770842671394348\n",
      "289 revenue:  0.5867429971694946 rgt:  0.007811091374605894 loss:  -0.6697995662689209\n",
      "290 revenue:  0.586187481880188 rgt:  0.005979974754154682 loss:  -0.6823179721832275\n",
      "291 revenue:  0.5869581699371338 rgt:  0.00884171575307846 loss:  -0.6632594466209412\n",
      "292 revenue:  0.5859780311584473 rgt:  0.0086179468780756 loss:  -0.6640406847000122\n",
      "293 revenue:  0.5849003791809082 rgt:  0.006581616122275591 loss:  -0.6770784854888916\n",
      "294 revenue:  0.5846669673919678 rgt:  0.006015219260007143 loss:  -0.6810615658760071\n",
      "295 revenue:  0.5849106311798096 rgt:  0.005801362916827202 loss:  -0.6828258633613586\n",
      "296 revenue:  0.5854107141494751 rgt:  0.005854693241417408 loss:  -0.6827501654624939\n",
      "297 revenue:  0.5860819220542908 rgt:  0.006207487545907497 loss:  -0.680564284324646\n",
      "298 revenue:  0.5867598056793213 rgt:  0.006659068167209625 loss:  -0.6777397394180298\n",
      "299 revenue:  0.586848795413971 rgt:  0.00722609693184495 loss:  -0.6738275289535522\n"
     ]
    }
   ],
   "source": [
    "training_set = train_set_prob_2\n",
    "\n",
    "for i in range(len(learning_rates)):\n",
    "    opt_aucter = torch.optim.AdamW(list(ANet.parameters()) + list(PNet.parameters()),learning_rates[i])\n",
    "    \n",
    "    if (i % 10 == 0):\n",
    "        MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "        MNet.to(device)\n",
    "    \n",
    "    misreports = find_max_regret_misreport_BNIC(ANet,PNet,MNet,training_set,100, 300,0.00001)\n",
    "    \n",
    "    opt_aucter.zero_grad()\n",
    "\n",
    "    u_orig = truthful_utility_calculation(ANet,PNet,training_set)\n",
    "    u_new = misreport_utility_calculation_BNIC(ANet,PNet, training_set, misreports)\n",
    "        \n",
    "    rgt = torch.mean(u_new - u_orig)\n",
    "    rev = torch.mean(truthful_revenue_calculation(ANet,PNet,training_set))\n",
    "\n",
    "    #update NN\n",
    "    loss =  -1 * (torch.sqrt(rev + 1e-7) -  torch.sqrt(rgt + 1e-7)) + rgt\n",
    "    \n",
    "    loss.backward()\n",
    "    opt_aucter.step()\n",
    "\n",
    "    print(i, 'revenue: ',rev.item(), 'rgt: ', rgt.item(), 'loss: ', loss.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c50478a8",
   "metadata": {},
   "source": [
    "# Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "91828680",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MisreportNetBNIC(\n",
       "  (mlp): MLP(\n",
       "    (model): Sequential(\n",
       "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
       "      (tanh1): Tanh()\n",
       "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh2): Tanh()\n",
       "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh3): Tanh()\n",
       "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh4): Tanh()\n",
       "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh5): Tanh()\n",
       "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh6): Tanh()\n",
       "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "MNet.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3ac6dc74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8857050538063049 0.01857674866914749 0.647739250998677\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_no_prob)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "107ef0bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5803734064102173 0.009080263786017895 0.44426491547198615\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_prob_2) #optimal 0.5795\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f9388bbf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5027751922607422 0.00522175757214427 0.4055202175025896\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_single_1)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "338b6f77",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5633723735809326 0.004841404501348734 0.4637625732947778\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_single_2)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b38e4e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
