{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5ac754fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run NNs.ipynb\n",
    "%run Helpers.ipynb\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ca124b90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "38b59e27",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 2 #number of bidders\n",
    "m = 2 #number of items\n",
    "\n",
    "nLayersAllocation   = 7 #minimum 2\n",
    "nLayersPayment      = 7 #minimum 2\n",
    "nLayersMisreport    = 7 #minimum 2\n",
    "widthAllocation     = 100\n",
    "widthPayment        = 100\n",
    "widthMisreport      = 100\n",
    "\n",
    "batch_size = 20000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "021a570a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AllocationNet(\n",
      "  (mlp1): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=2, bias=True)\n",
      "    )\n",
      "  )\n",
      "  (mlp2): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "PaymentNet(\n",
      "  (mlp): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=2, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "MisreportNetBNIC(\n",
      "  (mlp): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "ANet = AllocationNet(n,m,nLayersAllocation,widthAllocation)\n",
    "ANet.to(device)\n",
    "ANet.train()\n",
    "\n",
    "\n",
    "PNet = PaymentNet(n,m,nLayersPayment,widthPayment)\n",
    "PNet.to(device)\n",
    "PNet.train()\n",
    "\n",
    "MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "MNet.to(device)\n",
    "MNet.train()\n",
    "\n",
    "print(ANet)\n",
    "print(PNet)\n",
    "print(MNet)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e10ea38c",
   "metadata": {},
   "source": [
    "# set up dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "da032379",
   "metadata": {},
   "outputs": [],
   "source": [
    "prob_single_1 = torch.tensor([0.2,0.8]).to(device)\n",
    "prob_single_2 = torch.tensor([0.5,0.7]).to(device)\n",
    "prob_single_3 = torch.tensor([0.7,0.9]).to(device)\n",
    "no_prob_single = torch.tensor([1,1]).to(device)\n",
    "single_bidder_1 = torch.tensor([1,0]).to(device)\n",
    "single_bidder_2 = torch.tensor([0,1]).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4cb66a55",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set_prob_1 = get_train_or_test_set(batch_size,n,m,prob_single_1)\n",
    "test_set_prob_1 = get_train_or_test_set(batch_size,n,m,prob_single_1)\n",
    "\n",
    "train_set_prob_2 = get_train_or_test_set(batch_size,n,m,prob_single_2)\n",
    "test_set_prob_2 = get_train_or_test_set(batch_size,n,m,prob_single_2)\n",
    "\n",
    "train_set_prob_3 = get_train_or_test_set(batch_size,n,m,prob_single_3)\n",
    "test_set_prob_3 = get_train_or_test_set(batch_size,n,m,prob_single_3)\n",
    "\n",
    "train_set_no_prob = get_train_or_test_set(batch_size,n,m,no_prob_single)\n",
    "test_set_no_prob = get_train_or_test_set(batch_size,n,m,no_prob_single)\n",
    "\n",
    "train_set_single_1 = get_train_or_test_set(batch_size,n,m,single_bidder_1)\n",
    "test_set_single_1 = get_train_or_test_set(batch_size,n,m,single_bidder_1)\n",
    "\n",
    "train_set_single_2 = get_train_or_test_set(batch_size,n,m,single_bidder_2)\n",
    "test_set_single_2 = get_train_or_test_set(batch_size,n,m,single_bidder_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "531955b6",
   "metadata": {},
   "source": [
    "# train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6435cdcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rates = [0.0005] * 100 + [0.00005] * 100 + [0.000005] * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1a70c85f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 revenue:  0.19551920890808105 rgt:  0.057282786816358566 loss:  -0.14555442333221436\n",
      "1 revenue:  0.20206886529922485 rgt:  0.07579467445611954 loss:  -0.09841766208410263\n",
      "2 revenue:  0.20486298203468323 rgt:  0.08636758476495743 loss:  -0.07236664742231369\n",
      "3 revenue:  0.20274685323238373 rgt:  0.08771194517612457 loss:  -0.0664001852273941\n",
      "4 revenue:  0.20059727132320404 rgt:  0.08549781143665314 loss:  -0.06998293101787567\n",
      "5 revenue:  0.1987055540084839 rgt:  0.08134649693965912 loss:  -0.07920439541339874\n",
      "6 revenue:  0.19703851640224457 rgt:  0.07561076432466507 loss:  -0.09330528229475021\n",
      "7 revenue:  0.19554834067821503 rgt:  0.06822425127029419 loss:  -0.11278641223907471\n",
      "8 revenue:  0.1941530406475067 rgt:  0.05893489718437195 loss:  -0.1389279067516327\n",
      "9 revenue:  0.19293388724327087 rgt:  0.04744058847427368 loss:  -0.1739930808544159\n",
      "10 revenue:  0.19151292741298676 rgt:  0.0247390978038311 loss:  -0.2555959224700928\n",
      "11 revenue:  0.19130448997020721 rgt:  0.017807522788643837 loss:  -0.2861310839653015\n",
      "12 revenue:  0.18988490104675293 rgt:  0.008123830892145634 loss:  -0.33750125765800476\n",
      "13 revenue:  0.18972648680210114 rgt:  0.002932438161224127 loss:  -0.3784908354282379\n",
      "14 revenue:  0.21231821179389954 rgt:  0.002080047968775034 loss:  -0.41309142112731934\n",
      "15 revenue:  0.23921281099319458 rgt:  0.0017451249295845628 loss:  -0.44557294249534607\n",
      "16 revenue:  0.26818886399269104 rgt:  0.0016872561536729336 loss:  -0.475104957818985\n",
      "17 revenue:  0.2979276180267334 rgt:  0.0016658883541822433 loss:  -0.5033451318740845\n",
      "18 revenue:  0.3309513032436371 rgt:  0.002013185527175665 loss:  -0.5284008979797363\n",
      "19 revenue:  0.35832124948501587 rgt:  0.0019559895154088736 loss:  -0.5524157881736755\n",
      "20 revenue:  0.3923993706703186 rgt:  0.00754101388156414 loss:  -0.5320373177528381\n",
      "21 revenue:  0.38969358801841736 rgt:  0.00590264517813921 loss:  -0.5415225625038147\n",
      "22 revenue:  0.40530893206596375 rgt:  0.004897874314337969 loss:  -0.5617554187774658\n",
      "23 revenue:  0.41434431076049805 rgt:  0.0031397428829222918 loss:  -0.5845218300819397\n",
      "24 revenue:  0.44405731558799744 rgt:  0.00632847985252738 loss:  -0.5804955363273621\n",
      "25 revenue:  0.44467028975486755 rgt:  0.004359423648566008 loss:  -0.5964499711990356\n",
      "26 revenue:  0.46149906516075134 rgt:  0.004667519126087427 loss:  -0.606349766254425\n",
      "27 revenue:  0.45668280124664307 rgt:  0.0030810879543423653 loss:  -0.6171936392784119\n",
      "28 revenue:  0.489385724067688 rgt:  0.004758482798933983 loss:  -0.6258202195167542\n",
      "29 revenue:  0.4865659177303314 rgt:  0.0036269216798245907 loss:  -0.633691132068634\n",
      "30 revenue:  0.5351361632347107 rgt:  0.005636346992105246 loss:  -0.6508174538612366\n",
      "31 revenue:  0.5160935521125793 rgt:  0.006816842593252659 loss:  -0.6290149688720703\n",
      "32 revenue:  0.5263649225234985 rgt:  0.007553257513791323 loss:  -0.6310468316078186\n",
      "33 revenue:  0.5412409901618958 rgt:  0.006323666777461767 loss:  -0.6498450636863708\n",
      "34 revenue:  0.6005781888961792 rgt:  0.013941246084868908 loss:  -0.6429551839828491\n",
      "35 revenue:  0.5727353096008301 rgt:  0.007986590266227722 loss:  -0.6594379544258118\n",
      "36 revenue:  0.6351777911186218 rgt:  0.016466932371258736 loss:  -0.6521896719932556\n",
      "37 revenue:  0.6001986265182495 rgt:  0.009798713028430939 loss:  -0.6659373044967651\n",
      "38 revenue:  0.6639900207519531 rgt:  0.018588002771139145 loss:  -0.6599296927452087\n",
      "39 revenue:  0.6208379864692688 rgt:  0.01139967329800129 loss:  -0.6697633862495422\n",
      "40 revenue:  0.6882275938987732 rgt:  0.015624630264937878 loss:  -0.6889713406562805\n",
      "41 revenue:  0.6275243163108826 rgt:  0.012630372308194637 loss:  -0.6671486496925354\n",
      "42 revenue:  0.6644124984741211 rgt:  0.01647326536476612 loss:  -0.6702932119369507\n",
      "43 revenue:  0.6582447290420532 rgt:  0.013818230479955673 loss:  -0.6799533367156982\n",
      "44 revenue:  0.7036314010620117 rgt:  0.02109677530825138 loss:  -0.6724830865859985\n",
      "45 revenue:  0.6444256901741028 rgt:  0.011469924822449684 loss:  -0.6841932535171509\n",
      "46 revenue:  0.721876323223114 rgt:  0.023061787709593773 loss:  -0.6747098565101624\n",
      "47 revenue:  0.6587515473365784 rgt:  0.01272608246654272 loss:  -0.6860987544059753\n",
      "48 revenue:  0.7363768815994263 rgt:  0.024714229628443718 loss:  -0.6762019991874695\n",
      "49 revenue:  0.670015811920166 rgt:  0.013889401219785213 loss:  -0.6868019104003906\n",
      "50 revenue:  0.7470406889915466 rgt:  0.020710207521915436 loss:  -0.6996942162513733\n",
      "51 revenue:  0.6704705357551575 rgt:  0.014949287287890911 loss:  -0.6816058158874512\n",
      "52 revenue:  0.7218506932258606 rgt:  0.021904675289988518 loss:  -0.6797107458114624\n",
      "53 revenue:  0.653251051902771 rgt:  0.01211489550769329 loss:  -0.6860565543174744\n",
      "54 revenue:  0.7344096302986145 rgt:  0.022342629730701447 loss:  -0.6851596236228943\n",
      "55 revenue:  0.6682264804840088 rgt:  0.013241306878626347 loss:  -0.6891387104988098\n",
      "56 revenue:  0.7488158941268921 rgt:  0.024287914857268333 loss:  -0.6852074861526489\n",
      "57 revenue:  0.6774324774742126 rgt:  0.014326765201985836 loss:  -0.6890413165092468\n",
      "58 revenue:  0.7569822072982788 rgt:  0.025394702330231667 loss:  -0.6852951645851135\n",
      "59 revenue:  0.6834922432899475 rgt:  0.015015346929430962 loss:  -0.6891830563545227\n",
      "60 revenue:  0.7620394229888916 rgt:  0.023039797320961952 loss:  -0.6981199383735657\n",
      "61 revenue:  0.6774032711982727 rgt:  0.015292015857994556 loss:  -0.6840919256210327\n",
      "62 revenue:  0.746529757976532 rgt:  0.02372153475880623 loss:  -0.686279833316803\n",
      "63 revenue:  0.6728190183639526 rgt:  0.012978198006749153 loss:  -0.6933549642562866\n",
      "64 revenue:  0.7545164823532104 rgt:  0.024335509166121483 loss:  -0.6882948875427246\n",
      "65 revenue:  0.6804625391960144 rgt:  0.013673732988536358 loss:  -0.6942927241325378\n",
      "66 revenue:  0.7615030407905579 rgt:  0.025225462391972542 loss:  -0.6885904669761658\n",
      "67 revenue:  0.686182975769043 rgt:  0.014343843795359135 loss:  -0.6942516565322876\n",
      "68 revenue:  0.7661951184272766 rgt:  0.025801481679081917 loss:  -0.6888956427574158\n",
      "69 revenue:  0.6901648044586182 rgt:  0.014780331403017044 loss:  -0.6944065093994141\n",
      "70 revenue:  0.7690334916114807 rgt:  0.02104250341653824 loss:  -0.7108424305915833\n",
      "71 revenue:  0.6816348433494568 rgt:  0.01451600156724453 loss:  -0.6906130909919739\n",
      "72 revenue:  0.7506261467933655 rgt:  0.0231450404971838 loss:  -0.6911066174507141\n",
      "73 revenue:  0.6751533150672913 rgt:  0.011708525009453297 loss:  -0.701762318611145\n",
      "74 revenue:  0.7584601640701294 rgt:  0.023678628727793694 loss:  -0.6933387517929077\n",
      "75 revenue:  0.6831854581832886 rgt:  0.012527805753052235 loss:  -0.7020944952964783\n",
      "76 revenue:  0.7648806571960449 rgt:  0.024550193920731544 loss:  -0.6933390498161316\n",
      "77 revenue:  0.6883283257484436 rgt:  0.01321379654109478 loss:  -0.7014901041984558\n",
      "78 revenue:  0.7686070203781128 rgt:  0.025024769827723503 loss:  -0.6934851408004761\n",
      "79 revenue:  0.691366970539093 rgt:  0.013562966138124466 loss:  -0.7014613151550293\n",
      "80 revenue:  0.7703711986541748 rgt:  0.019613036885857582 loss:  -0.718048095703125\n",
      "81 revenue:  0.6835945248603821 rgt:  0.013813531026244164 loss:  -0.695452868938446\n",
      "82 revenue:  0.7432178258895874 rgt:  0.021206121891736984 loss:  -0.6952711939811707\n",
      "83 revenue:  0.6645905375480652 rgt:  0.010097533464431763 loss:  -0.7046397924423218\n",
      "84 revenue:  0.7472652196884155 rgt:  0.020572451874613762 loss:  -0.700441300868988\n",
      "85 revenue:  0.6734420657157898 rgt:  0.010443948209285736 loss:  -0.7079951763153076\n",
      "86 revenue:  0.7559193968772888 rgt:  0.021513881161808968 loss:  -0.7012459635734558\n",
      "87 revenue:  0.6803675293922424 rgt:  0.01118732150644064 loss:  -0.7078860998153687\n",
      "88 revenue:  0.7611041069030762 rgt:  0.02208293043076992 loss:  -0.7017263174057007\n",
      "89 revenue:  0.6846650838851929 rgt:  0.011695727705955505 loss:  -0.7076019644737244\n",
      "90 revenue:  0.7638339400291443 rgt:  0.01607905514538288 loss:  -0.7310932874679565\n",
      "91 revenue:  0.6759892702102661 rgt:  0.01137237623333931 loss:  -0.7041715383529663\n",
      "92 revenue:  0.7333258986473083 rgt:  0.018337151035666466 loss:  -0.7025923728942871\n",
      "93 revenue:  0.655256450176239 rgt:  0.007441590540111065 loss:  -0.7157723307609558\n",
      "94 revenue:  0.7396690249443054 rgt:  0.01783759333193302 loss:  -0.7086448073387146\n",
      "95 revenue:  0.6698914766311646 rgt:  0.007981732487678528 loss:  -0.7211462259292603\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "96 revenue:  0.75432950258255 rgt:  0.01956460252404213 loss:  -0.7090830206871033\n",
      "97 revenue:  0.680133044719696 rgt:  0.009015043266117573 loss:  -0.7207387089729309\n",
      "98 revenue:  0.7616126537322998 rgt:  0.02043144591152668 loss:  -0.7093338370323181\n",
      "99 revenue:  0.6855230927467346 rgt:  0.009688426740467548 loss:  -0.7198445200920105\n",
      "100 revenue:  0.7644681930541992 rgt:  0.016133293509483337 loss:  -0.731188178062439\n",
      "101 revenue:  0.7559671998023987 rgt:  0.0198065172880888 loss:  -0.7089213132858276\n",
      "102 revenue:  0.7468795776367188 rgt:  0.016669675707817078 loss:  -0.71844083070755\n",
      "103 revenue:  0.7385786771774292 rgt:  0.013362050987780094 loss:  -0.7304492592811584\n",
      "104 revenue:  0.7318181991577148 rgt:  0.011872408911585808 loss:  -0.7346303462982178\n",
      "105 revenue:  0.7266086935997009 rgt:  0.01089733187109232 loss:  -0.7371254563331604\n",
      "106 revenue:  0.7279793620109558 rgt:  0.0110876290127635 loss:  -0.7368311882019043\n",
      "107 revenue:  0.7289048433303833 rgt:  0.011393190361559391 loss:  -0.735626757144928\n",
      "108 revenue:  0.7258901596069336 rgt:  0.010906840674579144 loss:  -0.7366488575935364\n",
      "109 revenue:  0.7308597564697266 rgt:  0.011693882755935192 loss:  -0.7350708246231079\n",
      "110 revenue:  0.7256749272346497 rgt:  0.008609920740127563 loss:  -0.7504655122756958\n",
      "111 revenue:  0.7244297862052917 rgt:  0.012765688821673393 loss:  -0.7253829836845398\n",
      "112 revenue:  0.7143355011940002 rgt:  0.010458555072546005 loss:  -0.73245769739151\n",
      "113 revenue:  0.7115275859832764 rgt:  0.007927426137030125 loss:  -0.7465569376945496\n",
      "114 revenue:  0.715846061706543 rgt:  0.008190497756004333 loss:  -0.7473845481872559\n",
      "115 revenue:  0.7245767712593079 rgt:  0.009637845680117607 loss:  -0.7434099912643433\n",
      "116 revenue:  0.7219713926315308 rgt:  0.009163962677121162 loss:  -0.7447959780693054\n",
      "117 revenue:  0.7297368049621582 rgt:  0.010294799692928791 loss:  -0.7424877882003784\n",
      "118 revenue:  0.7267445921897888 rgt:  0.009795229882001877 loss:  -0.7437266111373901\n",
      "119 revenue:  0.7331979274749756 rgt:  0.010720815509557724 loss:  -0.7420071959495544\n",
      "120 revenue:  0.7335780262947083 rgt:  0.008954038843512535 loss:  -0.752911388874054\n",
      "121 revenue:  0.7308347821235657 rgt:  0.012835775502026081 loss:  -0.7287575006484985\n",
      "122 revenue:  0.7203627824783325 rgt:  0.010926268994808197 loss:  -0.733286440372467\n",
      "123 revenue:  0.7104883790016174 rgt:  0.007677657995373011 loss:  -0.7476043105125427\n",
      "124 revenue:  0.7130421996116638 rgt:  0.006918624043464661 loss:  -0.7543208599090576\n",
      "125 revenue:  0.7221412062644958 rgt:  0.008357660844922066 loss:  -0.7500105500221252\n",
      "126 revenue:  0.7219825983047485 rgt:  0.008366766385734081 loss:  -0.7498583197593689\n",
      "127 revenue:  0.7298523783683777 rgt:  0.00948968157172203 loss:  -0.7474088668823242\n",
      "128 revenue:  0.7260424494743347 rgt:  0.009008681401610374 loss:  -0.7481579780578613\n",
      "129 revenue:  0.7330687046051025 rgt:  0.010040782392024994 loss:  -0.745949387550354\n",
      "130 revenue:  0.7301390171051025 rgt:  0.008689130656421185 loss:  -0.7525766491889954\n",
      "131 revenue:  0.7275558710098267 rgt:  0.011855675838887691 loss:  -0.7322290539741516\n",
      "132 revenue:  0.7170289158821106 rgt:  0.01006532832980156 loss:  -0.7363837361335754\n",
      "133 revenue:  0.7090443968772888 rgt:  0.007092839572578669 loss:  -0.7507354021072388\n",
      "134 revenue:  0.7153880000114441 rgt:  0.006516295485198498 loss:  -0.7585657238960266\n",
      "135 revenue:  0.7194271683692932 rgt:  0.006981095764786005 loss:  -0.7576559782028198\n",
      "136 revenue:  0.728533148765564 rgt:  0.008547907695174217 loss:  -0.75253826379776\n",
      "137 revenue:  0.7261396050453186 rgt:  0.00826213974505663 loss:  -0.7529793381690979\n",
      "138 revenue:  0.7338769435882568 rgt:  0.009500212967395782 loss:  -0.7496964931488037\n",
      "139 revenue:  0.7294512391090393 rgt:  0.008972386829555035 loss:  -0.7503836154937744\n",
      "140 revenue:  0.7361959218978882 rgt:  0.008766180835664272 loss:  -0.7556241154670715\n",
      "141 revenue:  0.7272096872329712 rgt:  0.01126941293478012 loss:  -0.7353386282920837\n",
      "142 revenue:  0.717609703540802 rgt:  0.009084556251764297 loss:  -0.7427204847335815\n",
      "143 revenue:  0.7094833850860596 rgt:  0.005851156543940306 loss:  -0.7599638104438782\n",
      "144 revenue:  0.7177825570106506 rgt:  0.005861745215952396 loss:  -0.7647960782051086\n",
      "145 revenue:  0.7247085571289062 rgt:  0.007123023271560669 loss:  -0.7597765922546387\n",
      "146 revenue:  0.7330694198608398 rgt:  0.00851218681782484 loss:  -0.7554205656051636\n",
      "147 revenue:  0.7292750477790833 rgt:  0.007945298217236996 loss:  -0.7568938136100769\n",
      "148 revenue:  0.7367556691169739 rgt:  0.009099377319216728 loss:  -0.7538542747497559\n",
      "149 revenue:  0.7321723699569702 rgt:  0.008522043004631996 loss:  -0.7548333406448364\n",
      "150 revenue:  0.7386237382888794 rgt:  0.009028004482388496 loss:  -0.7553879022598267\n",
      "151 revenue:  0.7340413331985474 rgt:  0.011644480749964714 loss:  -0.7372076511383057\n",
      "152 revenue:  0.7232037782669067 rgt:  0.008918261155486107 loss:  -0.747058629989624\n",
      "153 revenue:  0.7142468690872192 rgt:  0.005578894633799791 loss:  -0.7648597955703735\n",
      "154 revenue:  0.7244819402694702 rgt:  0.006284540984779596 loss:  -0.7656048536300659\n",
      "155 revenue:  0.7278338670730591 rgt:  0.006953362841159105 loss:  -0.762791097164154\n",
      "156 revenue:  0.7363627552986145 rgt:  0.008276022039353848 loss:  -0.7588666677474976\n",
      "157 revenue:  0.7336810231208801 rgt:  0.00805697962641716 loss:  -0.7587336301803589\n",
      "158 revenue:  0.7408416271209717 rgt:  0.009197140112519264 loss:  -0.7556222677230835\n",
      "159 revenue:  0.7373940944671631 rgt:  0.008883967064321041 loss:  -0.7555773258209229\n",
      "160 revenue:  0.7434579730033875 rgt:  0.009776311926543713 loss:  -0.753588080406189\n",
      "161 revenue:  0.7346420884132385 rgt:  0.011269381269812584 loss:  -0.7396854758262634\n",
      "162 revenue:  0.7246209383010864 rgt:  0.00802371185272932 loss:  -0.7536473274230957\n",
      "163 revenue:  0.7189682722091675 rgt:  0.005256881006062031 loss:  -0.7701581120491028\n",
      "164 revenue:  0.7285047769546509 rgt:  0.0059691863134503365 loss:  -0.7702946662902832\n",
      "165 revenue:  0.7347521185874939 rgt:  0.0072407471016049385 loss:  -0.7648429870605469\n",
      "166 revenue:  0.741065502166748 rgt:  0.008323290385305882 loss:  -0.7612957954406738\n",
      "167 revenue:  0.73834228515625 rgt:  0.008070385083556175 loss:  -0.7613622546195984\n",
      "168 revenue:  0.7446205615997314 rgt:  0.009239006787538528 loss:  -0.757554829120636\n",
      "169 revenue:  0.7387322783470154 rgt:  0.00856687594205141 loss:  -0.7583706378936768\n",
      "170 revenue:  0.7437764406204224 rgt:  0.009078793227672577 loss:  -0.7580627799034119\n",
      "171 revenue:  0.737240195274353 rgt:  0.01090305671095848 loss:  -0.7433058023452759\n",
      "172 revenue:  0.7262235283851624 rgt:  0.00828204769641161 loss:  -0.7528992295265198\n",
      "173 revenue:  0.7200544476509094 rgt:  0.005452115088701248 loss:  -0.7692690491676331\n",
      "174 revenue:  0.7310248613357544 rgt:  0.0057394057512283325 loss:  -0.7735010385513306\n",
      "175 revenue:  0.7365226149559021 rgt:  0.006860264576971531 loss:  -0.7685214281082153\n",
      "176 revenue:  0.7453122138977051 rgt:  0.00819468218833208 loss:  -0.7645950317382812\n",
      "177 revenue:  0.743074893951416 rgt:  0.00808736588805914 loss:  -0.7640002965927124\n",
      "178 revenue:  0.7500648498535156 rgt:  0.00920133013278246 loss:  -0.760937511920929\n",
      "179 revenue:  0.7462211847305298 rgt:  0.008960016071796417 loss:  -0.7602230906486511\n",
      "180 revenue:  0.7524720430374146 rgt:  0.009687143377959728 loss:  -0.7593406438827515\n",
      "181 revenue:  0.7481400966644287 rgt:  0.011958062648773193 loss:  -0.7436395883560181\n",
      "182 revenue:  0.7372095584869385 rgt:  0.008766786195337772 loss:  -0.7562107443809509\n",
      "183 revenue:  0.731178879737854 rgt:  0.0057624271139502525 loss:  -0.7734163403511047\n",
      "184 revenue:  0.740604817867279 rgt:  0.006890531629323959 loss:  -0.7706837058067322\n",
      "185 revenue:  0.749709963798523 rgt:  0.00840654969215393 loss:  -0.7657636404037476\n",
      "186 revenue:  0.7456203699111938 rgt:  0.00796442199498415 loss:  -0.7662845849990845\n",
      "187 revenue:  0.7523936629295349 rgt:  0.008976590819656849 loss:  -0.7636843919754028\n",
      "188 revenue:  0.747692346572876 rgt:  0.008544627577066422 loss:  -0.7637098431587219\n",
      "189 revenue:  0.7545176148414612 rgt:  0.0098984744399786 loss:  -0.7592397332191467\n",
      "190 revenue:  0.7484911680221558 rgt:  0.009406859055161476 loss:  -0.7587575912475586\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "191 revenue:  0.7452443838119507 rgt:  0.011393573135137558 loss:  -0.745140790939331\n",
      "192 revenue:  0.739575982093811 rgt:  0.008760442957282066 loss:  -0.7576279044151306\n",
      "193 revenue:  0.733589231967926 rgt:  0.005903156008571386 loss:  -0.7737624645233154\n",
      "194 revenue:  0.74489825963974 rgt:  0.00685362983494997 loss:  -0.773434042930603\n",
      "195 revenue:  0.7505785226821899 rgt:  0.008330400101840496 loss:  -0.7667574286460876\n",
      "196 revenue:  0.7564989328384399 rgt:  0.009120224975049496 loss:  -0.7651488780975342\n",
      "197 revenue:  0.7525803446769714 rgt:  0.00868511013686657 loss:  -0.7656343579292297\n",
      "198 revenue:  0.7594854235649109 rgt:  0.010162229649722576 loss:  -0.7605140805244446\n",
      "199 revenue:  0.7522140145301819 rgt:  0.009248585440218449 loss:  -0.7618840932846069\n",
      "200 revenue:  0.760044515132904 rgt:  0.010189726017415524 loss:  -0.7606709599494934\n",
      "201 revenue:  0.7590843439102173 rgt:  0.012754182331264019 loss:  -0.7455655336380005\n",
      "202 revenue:  0.7579980492591858 rgt:  0.01077695470303297 loss:  -0.756041407585144\n",
      "203 revenue:  0.7570351958274841 rgt:  0.008994328789412975 loss:  -0.7662444710731506\n",
      "204 revenue:  0.7565410733222961 rgt:  0.008527473546564579 loss:  -0.768921434879303\n",
      "205 revenue:  0.7561662197113037 rgt:  0.008481343276798725 loss:  -0.7690021395683289\n",
      "206 revenue:  0.7558823227882385 rgt:  0.008833813481032848 loss:  -0.766592264175415\n",
      "207 revenue:  0.7556656002998352 rgt:  0.009454101324081421 loss:  -0.7626035213470459\n",
      "208 revenue:  0.7555736303329468 rgt:  0.009858590550720692 loss:  -0.7600879669189453\n",
      "209 revenue:  0.7556190490722656 rgt:  0.010508784092962742 loss:  -0.7562419772148132\n",
      "210 revenue:  0.755174458026886 rgt:  0.010125059634447098 loss:  -0.7582589387893677\n",
      "211 revenue:  0.7546489834785461 rgt:  0.012221304699778557 loss:  -0.7459336519241333\n",
      "212 revenue:  0.7539904117584229 rgt:  0.009883304126560688 loss:  -0.7590277194976807\n",
      "213 revenue:  0.7531347870826721 rgt:  0.008055034093558788 loss:  -0.7700279951095581\n",
      "214 revenue:  0.7531236410140991 rgt:  0.008019655011594296 loss:  -0.770254373550415\n",
      "215 revenue:  0.7533367276191711 rgt:  0.008165089413523674 loss:  -0.7694233059883118\n",
      "216 revenue:  0.7538662552833557 rgt:  0.008759898133575916 loss:  -0.7659000158309937\n",
      "217 revenue:  0.7546869516372681 rgt:  0.009353709407150745 loss:  -0.7626584768295288\n",
      "218 revenue:  0.7555259466171265 rgt:  0.009565805085003376 loss:  -0.7618387341499329\n",
      "219 revenue:  0.7562882900238037 rgt:  0.010144379921257496 loss:  -0.758784294128418\n",
      "220 revenue:  0.7568743824958801 rgt:  0.00976274348795414 loss:  -0.7614154815673828\n",
      "221 revenue:  0.7565625905990601 rgt:  0.012243503704667091 loss:  -0.7469117641448975\n",
      "222 revenue:  0.7563387751579285 rgt:  0.010477087460458279 loss:  -0.7568422555923462\n",
      "223 revenue:  0.7553714513778687 rgt:  0.008313030935823917 loss:  -0.7696316838264465\n",
      "224 revenue:  0.7550527453422546 rgt:  0.008173744194209576 loss:  -0.770354688167572\n",
      "225 revenue:  0.7549974322319031 rgt:  0.008288990706205368 loss:  -0.7695725560188293\n",
      "226 revenue:  0.7555655241012573 rgt:  0.008879531174898148 loss:  -0.7661214470863342\n",
      "227 revenue:  0.7563047409057617 rgt:  0.009672019630670547 loss:  -0.761638879776001\n",
      "228 revenue:  0.7569195032119751 rgt:  0.009776392951607704 loss:  -0.7613586783409119\n",
      "229 revenue:  0.7575308680534363 rgt:  0.010377802886068821 loss:  -0.7581127882003784\n",
      "230 revenue:  0.7578354477882385 rgt:  0.010061347857117653 loss:  -0.7601694464683533\n",
      "231 revenue:  0.7572197318077087 rgt:  0.012502594850957394 loss:  -0.7458657622337341\n",
      "232 revenue:  0.7562623023986816 rgt:  0.010623669251799583 loss:  -0.7559381127357483\n",
      "233 revenue:  0.7552890181541443 rgt:  0.0086943618953228 loss:  -0.7671353220939636\n",
      "234 revenue:  0.7548698782920837 rgt:  0.008255169726908207 loss:  -0.7697188258171082\n",
      "235 revenue:  0.7547567486763 rgt:  0.008198986761271954 loss:  -0.7700196504592896\n",
      "236 revenue:  0.7552669048309326 rgt:  0.008518537506461143 loss:  -0.7682459950447083\n",
      "237 revenue:  0.7561469674110413 rgt:  0.009381270967423916 loss:  -0.7633283734321594\n",
      "238 revenue:  0.756895899772644 rgt:  0.009598484262824059 loss:  -0.7624268531799316\n",
      "239 revenue:  0.7576298713684082 rgt:  0.00997651182115078 loss:  -0.7605599164962769\n",
      "240 revenue:  0.7581853270530701 rgt:  0.00969141535460949 loss:  -0.7626015543937683\n",
      "241 revenue:  0.7577938437461853 rgt:  0.012519107200205326 loss:  -0.7461052536964417\n",
      "242 revenue:  0.7569224238395691 rgt:  0.01045896578580141 loss:  -0.75728440284729\n",
      "243 revenue:  0.7559599280357361 rgt:  0.008537174202501774 loss:  -0.7685250639915466\n",
      "244 revenue:  0.7555903196334839 rgt:  0.008289732038974762 loss:  -0.7699087858200073\n",
      "245 revenue:  0.7554556727409363 rgt:  0.008418981917202473 loss:  -0.7689950466156006\n",
      "246 revenue:  0.7558943033218384 rgt:  0.008938266895711422 loss:  -0.7659407258033752\n",
      "247 revenue:  0.7567298412322998 rgt:  0.009168495424091816 loss:  -0.7649809718132019\n",
      "248 revenue:  0.7576156854629517 rgt:  0.00946106668561697 loss:  -0.7636817097663879\n",
      "249 revenue:  0.7584094405174255 rgt:  0.010048766620457172 loss:  -0.7605743408203125\n",
      "250 revenue:  0.7590810060501099 rgt:  0.009480632841587067 loss:  -0.7644028663635254\n",
      "251 revenue:  0.7587338089942932 rgt:  0.012575194239616394 loss:  -0.7463384866714478\n",
      "252 revenue:  0.7581756711006165 rgt:  0.010830096900463104 loss:  -0.7558345794677734\n",
      "253 revenue:  0.7571981549263 rgt:  0.008734275586903095 loss:  -0.7679792642593384\n",
      "254 revenue:  0.7567271590232849 rgt:  0.008482794277369976 loss:  -0.769315242767334\n",
      "255 revenue:  0.7564377784729004 rgt:  0.008480624295771122 loss:  -0.7691629528999329\n",
      "256 revenue:  0.7567601799964905 rgt:  0.008885708637535572 loss:  -0.7667694687843323\n",
      "257 revenue:  0.7574295997619629 rgt:  0.009523960761725903 loss:  -0.7631891369819641\n",
      "258 revenue:  0.7580671310424805 rgt:  0.009861942380666733 loss:  -0.7615007758140564\n",
      "259 revenue:  0.758639931678772 rgt:  0.010709877125918865 loss:  -0.7568005323410034\n",
      "260 revenue:  0.7586092948913574 rgt:  0.010762492194771767 loss:  -0.7564764618873596\n",
      "261 revenue:  0.7581665515899658 rgt:  0.012516891583800316 loss:  -0.7463313937187195\n",
      "262 revenue:  0.757102906703949 rgt:  0.010323727503418922 loss:  -0.7581866979598999\n",
      "263 revenue:  0.7561815977096558 rgt:  0.008448668755590916 loss:  -0.769221305847168\n",
      "264 revenue:  0.7559288144111633 rgt:  0.008223576471209526 loss:  -0.7705336809158325\n",
      "265 revenue:  0.7560308575630188 rgt:  0.008265107870101929 loss:  -0.7703220844268799\n",
      "266 revenue:  0.7566677331924438 rgt:  0.008809521794319153 loss:  -0.7671974301338196\n",
      "267 revenue:  0.7574664950370789 rgt:  0.009574692696332932 loss:  -0.7628999948501587\n",
      "268 revenue:  0.7580476403236389 rgt:  0.009755001403391361 loss:  -0.7621364593505859\n",
      "269 revenue:  0.7586627006530762 rgt:  0.010287117213010788 loss:  -0.7592994570732117\n",
      "270 revenue:  0.7590798735618591 rgt:  0.010541871190071106 loss:  -0.7580359578132629\n",
      "271 revenue:  0.7583809494972229 rgt:  0.012585428543388844 loss:  -0.7460801601409912\n",
      "272 revenue:  0.7574040293693542 rgt:  0.01034220028668642 loss:  -0.7582504153251648\n",
      "273 revenue:  0.7564693093299866 rgt:  0.008371992036700249 loss:  -0.7698813676834106\n",
      "274 revenue:  0.7561970353126526 rgt:  0.008241262286901474 loss:  -0.7705727815628052\n",
      "275 revenue:  0.7561590075492859 rgt:  0.00824451632797718 loss:  -0.7705297470092773\n",
      "276 revenue:  0.7566906213760376 rgt:  0.00860597938299179 loss:  -0.768504798412323\n",
      "277 revenue:  0.7574079632759094 rgt:  0.009149202145636082 loss:  -0.7654907703399658\n",
      "278 revenue:  0.7582814693450928 rgt:  0.009557271376252174 loss:  -0.7634745836257935\n",
      "279 revenue:  0.7591380476951599 rgt:  0.010027797892689705 loss:  -0.7611181735992432\n",
      "280 revenue:  0.7600288391113281 rgt:  0.010217064991593361 loss:  -0.7604992985725403\n",
      "281 revenue:  0.7596421837806702 rgt:  0.012618432752788067 loss:  -0.7466239929199219\n",
      "282 revenue:  0.7589535117149353 rgt:  0.010486624203622341 loss:  -0.7582881450653076\n",
      "283 revenue:  0.757989764213562 rgt:  0.008578723296523094 loss:  -0.7694255113601685\n",
      "284 revenue:  0.7575756907463074 rgt:  0.008389683440327644 loss:  -0.7704028487205505\n",
      "285 revenue:  0.7574742436408997 rgt:  0.008473198860883713 loss:  -0.7698063254356384\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "286 revenue:  0.758033037185669 rgt:  0.00894187018275261 loss:  -0.7671471238136292\n",
      "287 revenue:  0.7588338851928711 rgt:  0.009810665622353554 loss:  -0.7622507810592651\n",
      "288 revenue:  0.7591957449913025 rgt:  0.010107862763106823 loss:  -0.7606722116470337\n",
      "289 revenue:  0.7595864534378052 rgt:  0.010646244511008263 loss:  -0.7577152848243713\n",
      "290 revenue:  0.7597152590751648 rgt:  0.009774086996912956 loss:  -0.7629779577255249\n",
      "291 revenue:  0.7593008875846863 rgt:  0.012640835717320442 loss:  -0.7463060617446899\n",
      "292 revenue:  0.7588286995887756 rgt:  0.010726599022746086 loss:  -0.7568114995956421\n",
      "293 revenue:  0.7578632831573486 rgt:  0.008508282713592052 loss:  -0.7698043584823608\n",
      "294 revenue:  0.7575498223304749 rgt:  0.008391637355089188 loss:  -0.7703753709793091\n",
      "295 revenue:  0.7575072050094604 rgt:  0.008467313833534718 loss:  -0.7698631286621094\n",
      "296 revenue:  0.7581851482391357 rgt:  0.008980832993984222 loss:  -0.7669897675514221\n",
      "297 revenue:  0.758976936340332 rgt:  0.009756380692124367 loss:  -0.7626615762710571\n",
      "298 revenue:  0.7594936490058899 rgt:  0.009791262447834015 loss:  -0.7627468109130859\n",
      "299 revenue:  0.7600890398025513 rgt:  0.010413452051579952 loss:  -0.7593706846237183\n"
     ]
    }
   ],
   "source": [
    "training_set = train_set_prob_3\n",
    "\n",
    "for i in range(len(learning_rates)):\n",
    "    opt_aucter = torch.optim.AdamW(list(ANet.parameters()) + list(PNet.parameters()),learning_rates[i])\n",
    "    \n",
    "    if (i % 10 == 0):\n",
    "        MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "        MNet.to(device)\n",
    "    \n",
    "    misreports = find_max_regret_misreport_BNIC(ANet,PNet,MNet,training_set,100, 300,0.00001)\n",
    "    \n",
    "    opt_aucter.zero_grad()\n",
    "\n",
    "    u_orig = truthful_utility_calculation(ANet,PNet,training_set)\n",
    "    u_new = misreport_utility_calculation_BNIC(ANet,PNet, training_set, misreports)\n",
    "        \n",
    "    rgt = torch.mean(u_new - u_orig)\n",
    "    rev = torch.mean(truthful_revenue_calculation(ANet,PNet,training_set))\n",
    "\n",
    "    #update NN\n",
    "    loss =  -1 * (torch.sqrt(rev + 1e-7) -  torch.sqrt(rgt + 1e-7)) + rgt\n",
    "    \n",
    "    loss.backward()\n",
    "    opt_aucter.step()\n",
    "\n",
    "    print(i, 'revenue: ',rev.item(), 'rgt: ', rgt.item(), 'loss: ', loss.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0858fec",
   "metadata": {},
   "source": [
    "# Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4847750b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MisreportNetBNIC(\n",
       "  (mlp): MLP(\n",
       "    (model): Sequential(\n",
       "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
       "      (tanh1): Tanh()\n",
       "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh2): Tanh()\n",
       "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh3): Tanh()\n",
       "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh4): Tanh()\n",
       "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh5): Tanh()\n",
       "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh6): Tanh()\n",
       "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "MNet.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3ac6dc74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9272199869155884 0.017088692635297775 0.6925550113614808\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_no_prob)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "107ef0bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7646074295043945 0.013217831030488014 0.5767635604739821\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_prob_3) #optimal 0.7351\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f9388bbf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.47261863946914673 0.00021396297961473465 0.4527206266808484\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_single_1)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "338b6f77",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5347976684570312 0.002219789894297719 0.4681077069387778\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_single_2)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b38e4e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
