{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5ac754fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run NNs.ipynb\n",
    "%run Helpers.ipynb\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ca124b90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "38b59e27",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 2 #number of bidders\n",
    "m = 2 #number of items\n",
    "\n",
    "nLayersAllocation   = 7 #minimum 2\n",
    "nLayersPayment      = 7 #minimum 2\n",
    "nLayersMisreport    = 7 #minimum 2\n",
    "widthAllocation     = 100\n",
    "widthPayment        = 100\n",
    "widthMisreport      = 100\n",
    "\n",
    "batch_size = 20000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "021a570a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AllocationNet(\n",
      "  (mlp1): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=2, bias=True)\n",
      "    )\n",
      "  )\n",
      "  (mlp2): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "PaymentNet(\n",
      "  (mlp): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=2, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "MisreportNetBNIC(\n",
      "  (mlp): MLP(\n",
      "    (model): Sequential(\n",
      "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
      "      (tanh1): Tanh()\n",
      "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh2): Tanh()\n",
      "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh3): Tanh()\n",
      "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh4): Tanh()\n",
      "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh5): Tanh()\n",
      "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
      "      (tanh6): Tanh()\n",
      "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "ANet = AllocationNet(n,m,nLayersAllocation,widthAllocation)\n",
    "ANet.to(device)\n",
    "ANet.train()\n",
    "\n",
    "\n",
    "PNet = PaymentNet(n,m,nLayersPayment,widthPayment)\n",
    "PNet.to(device)\n",
    "PNet.train()\n",
    "\n",
    "MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "MNet.to(device)\n",
    "MNet.train()\n",
    "\n",
    "print(ANet)\n",
    "print(PNet)\n",
    "print(MNet)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e10ea38c",
   "metadata": {},
   "source": [
    "# set up dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "da032379",
   "metadata": {},
   "outputs": [],
   "source": [
    "no_prob_single = torch.tensor([1,1]).to(device)\n",
    "prob_single_1 = torch.tensor([0.2,0.8]).to(device)\n",
    "prob_single_2 = torch.tensor([0.5,0.7]).to(device)\n",
    "prob_single_3 = torch.tensor([0.7,0.9]).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4cb66a55",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set_no_prob = get_train_or_test_set(batch_size,n,m,no_prob_single)\n",
    "test_set_no_prob = get_train_or_test_set(batch_size,n,m,no_prob_single)\n",
    "\n",
    "train_set_prob_1 = get_train_or_test_set(batch_size,n,m,prob_single_1)\n",
    "test_set_prob_1 = get_train_or_test_set(batch_size,n,m,prob_single_1)\n",
    "\n",
    "train_set_prob_2 = get_train_or_test_set(batch_size,n,m,prob_single_2)\n",
    "test_set_prob_2 = get_train_or_test_set(batch_size,n,m,prob_single_2)\n",
    "\n",
    "train_set_prob_3 = get_train_or_test_set(batch_size,n,m,prob_single_3)\n",
    "test_set_prob_3 = get_train_or_test_set(batch_size,n,m,prob_single_3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "531955b6",
   "metadata": {},
   "source": [
    "# train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6435cdcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rates = [0.0005] * 100 + [0.00005] * 100 + [0.000005] * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1a70c85f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 revenue:  0.24349819123744965 rgt:  0.07162182778120041 loss:  -0.1542108952999115\n",
      "1 revenue:  0.2510206997394562 rgt:  0.07214757800102234 loss:  -0.16026899218559265\n",
      "2 revenue:  0.25901028513908386 rgt:  0.07432665675878525 loss:  -0.16197463870048523\n",
      "3 revenue:  0.26741668581962585 rgt:  0.07559631019830704 loss:  -0.16657927632331848\n",
      "4 revenue:  0.276561975479126 rgt:  0.07741955667734146 loss:  -0.1702282428741455\n",
      "5 revenue:  0.2860710918903351 rgt:  0.07825466990470886 loss:  -0.1768609881401062\n",
      "6 revenue:  0.2955203652381897 rgt:  0.07907644659280777 loss:  -0.18333595991134644\n",
      "7 revenue:  0.3032947778701782 rgt:  0.06863927841186523 loss:  -0.22009167075157166\n",
      "8 revenue:  0.3151087760925293 rgt:  0.072422556579113 loss:  -0.219808429479599\n",
      "9 revenue:  0.32129353284835815 rgt:  0.06620771437883377 loss:  -0.24331116676330566\n",
      "10 revenue:  0.32683253288269043 rgt:  0.05721049755811691 loss:  -0.2752948999404907\n",
      "11 revenue:  0.33195003867149353 rgt:  0.047650836408138275 loss:  -0.31020936369895935\n",
      "12 revenue:  0.3357914984226227 rgt:  0.03616747260093689 loss:  -0.3531300723552704\n",
      "13 revenue:  0.33725276589393616 rgt:  0.021812118589878082 loss:  -0.41123300790786743\n",
      "14 revenue:  0.3384305536746979 rgt:  0.010447677224874496 loss:  -0.4690859019756317\n",
      "15 revenue:  0.35534900426864624 rgt:  0.005671708378940821 loss:  -0.5151285529136658\n",
      "16 revenue:  0.3869420886039734 rgt:  0.004184484016150236 loss:  -0.5531738996505737\n",
      "17 revenue:  0.41386571526527405 rgt:  0.0030328503344208 loss:  -0.5852189064025879\n",
      "18 revenue:  0.44524267315864563 rgt:  0.0031597272027283907 loss:  -0.6078930497169495\n",
      "19 revenue:  0.46203920245170593 rgt:  0.002713222522288561 loss:  -0.6249319314956665\n",
      "20 revenue:  0.48706093430519104 rgt:  0.0033270022831857204 loss:  -0.6368895173072815\n",
      "21 revenue:  0.49512410163879395 rgt:  0.0024545311462134123 loss:  -0.6516518592834473\n",
      "22 revenue:  0.5156624913215637 rgt:  0.0025228706654161215 loss:  -0.6653444170951843\n",
      "23 revenue:  0.5139698386192322 rgt:  0.0017182078445330262 loss:  -0.6737462878227234\n",
      "24 revenue:  0.5333675146102905 rgt:  0.0012467509368434548 loss:  -0.6937627196311951\n",
      "25 revenue:  0.52876877784729 rgt:  0.001091402256861329 loss:  -0.6930356621742249\n",
      "26 revenue:  0.5518118739128113 rgt:  0.0007257079123519361 loss:  -0.7151739597320557\n",
      "27 revenue:  0.5490237474441528 rgt:  0.0005993569502606988 loss:  -0.7158782482147217\n",
      "28 revenue:  0.5925602316856384 rgt:  0.0012691570445895195 loss:  -0.732883632183075\n",
      "29 revenue:  0.5727176666259766 rgt:  0.0008271967526525259 loss:  -0.7271912693977356\n",
      "30 revenue:  0.6112232804298401 rgt:  0.0006058239378035069 loss:  -0.7565864324569702\n",
      "31 revenue:  0.6006902456283569 rgt:  0.0005438735825009644 loss:  -0.751175045967102\n",
      "32 revenue:  0.663500964641571 rgt:  0.002478185575455427 loss:  -0.7622951865196228\n",
      "33 revenue:  0.6260613799095154 rgt:  0.0008868279401212931 loss:  -0.7605722546577454\n",
      "34 revenue:  0.6856129169464111 rgt:  0.002850677352398634 loss:  -0.7717742323875427\n",
      "35 revenue:  0.6403995156288147 rgt:  0.0011105377925559878 loss:  -0.7658129930496216\n",
      "36 revenue:  0.7023227214813232 rgt:  0.002770062070339918 loss:  -0.7826446294784546\n",
      "37 revenue:  0.6516007781028748 rgt:  0.0010433035204187036 loss:  -0.7738729119300842\n",
      "38 revenue:  0.7195327281951904 rgt:  0.0027821639087051153 loss:  -0.7927234768867493\n",
      "39 revenue:  0.6630535125732422 rgt:  0.001297823153436184 loss:  -0.7769564986228943\n",
      "40 revenue:  0.7330247759819031 rgt:  0.003941833972930908 loss:  -0.7894421219825745\n",
      "41 revenue:  0.6714345216751099 rgt:  0.0011735205771401525 loss:  -0.7839795351028442\n",
      "42 revenue:  0.7445150017738342 rgt:  0.0033733032178133726 loss:  -0.801398515701294\n",
      "43 revenue:  0.6811421513557434 rgt:  0.0016442009946331382 loss:  -0.783119261264801\n",
      "44 revenue:  0.751347541809082 rgt:  0.004548467230051756 loss:  -0.7948115468025208\n",
      "45 revenue:  0.6823112368583679 rgt:  0.0014306033262982965 loss:  -0.7867661118507385\n",
      "46 revenue:  0.7561734318733215 rgt:  0.0031884026248008013 loss:  -0.8099271655082703\n",
      "47 revenue:  0.6869518160820007 rgt:  0.0010798816801980138 loss:  -0.7948827147483826\n",
      "48 revenue:  0.7653538584709167 rgt:  0.0030965260230004787 loss:  -0.8161012530326843\n",
      "49 revenue:  0.6946727633476257 rgt:  0.0011951586930081248 loss:  -0.7977027297019958\n",
      "50 revenue:  0.7741202116012573 rgt:  0.004407862666994333 loss:  -0.8090407252311707\n",
      "51 revenue:  0.6987062692642212 rgt:  0.0020324348006397486 loss:  -0.7887704968452454\n",
      "52 revenue:  0.7747227549552917 rgt:  0.0035796493757516146 loss:  -0.8167727589607239\n",
      "53 revenue:  0.7003125548362732 rgt:  0.0017716598231345415 loss:  -0.7929829359054565\n",
      "54 revenue:  0.7779020667076111 rgt:  0.002861110959202051 loss:  -0.8256362676620483\n",
      "55 revenue:  0.7016395330429077 rgt:  0.001877712202258408 loss:  -0.7924278974533081\n",
      "56 revenue:  0.7799755930900574 rgt:  0.0038529178127646446 loss:  -0.8172367215156555\n",
      "57 revenue:  0.7017846703529358 rgt:  0.0011186827905476093 loss:  -0.8031590580940247\n",
      "58 revenue:  0.7830033302307129 rgt:  0.00470936531201005 loss:  -0.811539888381958\n",
      "59 revenue:  0.7048267126083374 rgt:  0.0016364993061870337 loss:  -0.7974482178688049\n",
      "60 revenue:  0.783568799495697 rgt:  0.004339708015322685 loss:  -0.8149774074554443\n",
      "61 revenue:  0.704857587814331 rgt:  0.0010699530830606818 loss:  -0.8057764172554016\n",
      "62 revenue:  0.7857265472412109 rgt:  0.002815085928887129 loss:  -0.8305388689041138\n",
      "63 revenue:  0.7077311277389526 rgt:  0.0017060916870832443 loss:  -0.7982555031776428\n",
      "64 revenue:  0.7851700782775879 rgt:  0.0034028925001621246 loss:  -0.8243602514266968\n",
      "65 revenue:  0.7062182426452637 rgt:  0.0010137382196262479 loss:  -0.807513415813446\n",
      "66 revenue:  0.7870467901229858 rgt:  0.003013911657035351 loss:  -0.8292427659034729\n",
      "67 revenue:  0.7077016234397888 rgt:  0.0010230033658444881 loss:  -0.8082411289215088\n",
      "68 revenue:  0.7880880832672119 rgt:  0.0026636768598109484 loss:  -0.8334678411483765\n",
      "69 revenue:  0.709453284740448 rgt:  0.0010162914404645562 loss:  -0.8093933463096619\n",
      "70 revenue:  0.7890880703926086 rgt:  0.0018003495642915368 loss:  -0.8440743088722229\n",
      "71 revenue:  0.7115821242332458 rgt:  0.0009844485903158784 loss:  -0.8111914396286011\n",
      "72 revenue:  0.791145384311676 rgt:  0.002471635350957513 loss:  -0.8372754454612732\n",
      "73 revenue:  0.7123565673828125 rgt:  0.001562528545036912 loss:  -0.8029196262359619\n",
      "74 revenue:  0.7884451746940613 rgt:  0.0013355943374335766 loss:  -0.8500616550445557\n",
      "75 revenue:  0.7121371030807495 rgt:  0.0008012905600480735 loss:  -0.8147720694541931\n",
      "76 revenue:  0.7901164889335632 rgt:  0.0028447473887354136 loss:  -0.832703173160553\n",
      "77 revenue:  0.7132328152656555 rgt:  0.0007270100759342313 loss:  -0.8168392181396484\n",
      "78 revenue:  0.7935732007026672 rgt:  0.0021177036687731743 loss:  -0.8426899909973145\n",
      "79 revenue:  0.7147578597068787 rgt:  0.0014930546749383211 loss:  -0.8052992224693298\n",
      "80 revenue:  0.790008544921875 rgt:  0.0030794665217399597 loss:  -0.8302510380744934\n",
      "81 revenue:  0.7109062671661377 rgt:  0.0006064824992790818 loss:  -0.8179172873497009\n",
      "82 revenue:  0.7909643650054932 rgt:  0.0020525981672108173 loss:  -0.8420025706291199\n",
      "83 revenue:  0.7133099436759949 rgt:  0.0010642894776538014 loss:  -0.8108875751495361\n",
      "84 revenue:  0.7887680530548096 rgt:  0.0019224613206461072 loss:  -0.8423566818237305\n",
      "85 revenue:  0.7127586007118225 rgt:  0.0009139858884736896 loss:  -0.8131025433540344\n",
      "86 revenue:  0.7868064641952515 rgt:  0.001075620879419148 loss:  -0.8531473875045776\n",
      "87 revenue:  0.7125608921051025 rgt:  0.001026643207296729 loss:  -0.811063826084137\n",
      "88 revenue:  0.7861495614051819 rgt:  0.0011786067625507712 loss:  -0.8511399030685425\n",
      "89 revenue:  0.7117499113082886 rgt:  0.0005583780002780259 loss:  -0.8194622993469238\n",
      "90 revenue:  0.7896108627319336 rgt:  0.0013773883692920208 loss:  -0.8501086831092834\n",
      "91 revenue:  0.7135279774665833 rgt:  0.0003714261983986944 loss:  -0.8250595331192017\n",
      "92 revenue:  0.799095869064331 rgt:  0.0017771907150745392 loss:  -0.8499865531921387\n",
      "93 revenue:  0.7223989963531494 rgt:  0.0007347511127591133 loss:  -0.8220977783203125\n",
      "94 revenue:  0.7954466342926025 rgt:  0.0012292113387957215 loss:  -0.8555874228477478\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "95 revenue:  0.722159206867218 rgt:  0.0009625991806387901 loss:  -0.8178095817565918\n",
      "96 revenue:  0.7915706038475037 rgt:  0.0010106468107551336 loss:  -0.8568997383117676\n",
      "97 revenue:  0.7201322317123413 rgt:  0.0005581857403740287 loss:  -0.8244198560714722\n",
      "98 revenue:  0.7931422591209412 rgt:  0.0017173894448205829 loss:  -0.8474254012107849\n",
      "99 revenue:  0.7200196981430054 rgt:  0.0005908901803195477 loss:  -0.8236386775970459\n",
      "100 revenue:  0.7906131744384766 rgt:  0.0009337374358437955 loss:  -0.857671856880188\n",
      "101 revenue:  0.7839911580085754 rgt:  0.0003600028867367655 loss:  -0.8660964369773865\n",
      "102 revenue:  0.7871705889701843 rgt:  0.00034940970363095403 loss:  -0.8681818246841431\n",
      "103 revenue:  0.7969229221343994 rgt:  0.00028175811166875064 loss:  -0.8756350874900818\n",
      "104 revenue:  0.8068138360977173 rgt:  0.0005351772997528315 loss:  -0.8745569586753845\n",
      "105 revenue:  0.801582396030426 rgt:  0.00030558250728063285 loss:  -0.8775220513343811\n",
      "106 revenue:  0.8113608360290527 rgt:  0.000426092796260491 loss:  -0.8796852231025696\n",
      "107 revenue:  0.8185715079307556 rgt:  0.0010711821960285306 loss:  -0.8709478974342346\n",
      "108 revenue:  0.8122437000274658 rgt:  0.00036595750134438276 loss:  -0.8817471265792847\n",
      "109 revenue:  0.8147105574607849 rgt:  0.0006522081093862653 loss:  -0.8764206767082214\n",
      "110 revenue:  0.8133219480514526 rgt:  0.0002607051865197718 loss:  -0.885433554649353\n",
      "111 revenue:  0.8233093619346619 rgt:  0.0006105552311055362 loss:  -0.8820420503616333\n",
      "112 revenue:  0.8178703784942627 rgt:  0.00035880241193808615 loss:  -0.8850583434104919\n",
      "113 revenue:  0.8268596529960632 rgt:  0.0009410398197360337 loss:  -0.8776993155479431\n",
      "114 revenue:  0.818804919719696 rgt:  0.00031896750442683697 loss:  -0.8866970539093018\n",
      "115 revenue:  0.8283664584159851 rgt:  0.0008147731423377991 loss:  -0.8807857036590576\n",
      "116 revenue:  0.8202871084213257 rgt:  0.0005240450846031308 loss:  -0.8822788596153259\n",
      "117 revenue:  0.8267836570739746 rgt:  0.0011735734296962619 loss:  -0.8738440275192261\n",
      "118 revenue:  0.8178384900093079 rgt:  0.0006696227355860174 loss:  -0.8777956962585449\n",
      "119 revenue:  0.8109976649284363 rgt:  0.000384299288270995 loss:  -0.8805637955665588\n",
      "120 revenue:  0.8192060589790344 rgt:  0.0005726407980546355 loss:  -0.8805955052375793\n",
      "121 revenue:  0.8156692981719971 rgt:  0.00033840074320323765 loss:  -0.8844074010848999\n",
      "122 revenue:  0.8255358934402466 rgt:  0.0011153609957545996 loss:  -0.874076247215271\n",
      "123 revenue:  0.8169986009597778 rgt:  0.0005166155751794577 loss:  -0.8806318640708923\n",
      "124 revenue:  0.822308361530304 rgt:  0.0003853916423395276 loss:  -0.8867928981781006\n",
      "125 revenue:  0.8304706811904907 rgt:  0.0006519563030451536 loss:  -0.8851143717765808\n",
      "126 revenue:  0.8247050642967224 rgt:  0.0010560114169493318 loss:  -0.8745788931846619\n",
      "127 revenue:  0.8183009624481201 rgt:  0.0004666933382395655 loss:  -0.8825278878211975\n",
      "128 revenue:  0.825570285320282 rgt:  0.0006985364598222077 loss:  -0.8814787864685059\n",
      "129 revenue:  0.8165737390518188 rgt:  0.0003487465437501669 loss:  -0.8846185803413391\n",
      "130 revenue:  0.8256896138191223 rgt:  0.0003696942294482142 loss:  -0.8890750408172607\n",
      "131 revenue:  0.8253060579299927 rgt:  0.0007923696539364755 loss:  -0.8795204162597656\n",
      "132 revenue:  0.8257704973220825 rgt:  0.0004898796323686838 loss:  -0.8860938549041748\n",
      "133 revenue:  0.8194913864135742 rgt:  0.0005914618377573788 loss:  -0.8803442120552063\n",
      "134 revenue:  0.8116567134857178 rgt:  0.00048375592450611293 loss:  -0.8784394860267639\n",
      "135 revenue:  0.8187738656997681 rgt:  0.0004265767929609865 loss:  -0.8837785720825195\n",
      "136 revenue:  0.825709342956543 rgt:  0.0003674422041513026 loss:  -0.8891466856002808\n",
      "137 revenue:  0.8254708647727966 rgt:  0.00041525153210386634 loss:  -0.8877588510513306\n",
      "138 revenue:  0.8328661918640137 rgt:  0.0006180261261761189 loss:  -0.8871349096298218\n",
      "139 revenue:  0.826866090297699 rgt:  0.0005976930260658264 loss:  -0.8842743039131165\n",
      "140 revenue:  0.8203593492507935 rgt:  0.0003134215367026627 loss:  -0.887717068195343\n",
      "141 revenue:  0.830143928527832 rgt:  0.0006379293627105653 loss:  -0.8852251768112183\n",
      "142 revenue:  0.8231704831123352 rgt:  0.0002978985430672765 loss:  -0.8897269368171692\n",
      "143 revenue:  0.8328776955604553 rgt:  0.0004786761128343642 loss:  -0.8902617692947388\n",
      "144 revenue:  0.8267562389373779 rgt:  0.0006072671967558563 loss:  -0.8840094208717346\n",
      "145 revenue:  0.8287445306777954 rgt:  0.00031584728276357055 loss:  -0.8922634124755859\n",
      "146 revenue:  0.8351494073867798 rgt:  0.0007095911423675716 loss:  -0.8865154981613159\n",
      "147 revenue:  0.8260132670402527 rgt:  0.0003954682615585625 loss:  -0.8885684013366699\n",
      "148 revenue:  0.8334611058235168 rgt:  0.00044976643403060734 loss:  -0.8912811279296875\n",
      "149 revenue:  0.8292175531387329 rgt:  0.0006817764369770885 loss:  -0.8838194012641907\n",
      "150 revenue:  0.8296053409576416 rgt:  0.0004257378459442407 loss:  -0.889765202999115\n",
      "151 revenue:  0.827427327632904 rgt:  0.0003681326634250581 loss:  -0.8900728821754456\n",
      "152 revenue:  0.8344697952270508 rgt:  0.00035791113623417914 loss:  -0.8942141532897949\n",
      "153 revenue:  0.8425247073173523 rgt:  0.001077646273188293 loss:  -0.8839848041534424\n",
      "154 revenue:  0.8348329663276672 rgt:  0.00041453485027886927 loss:  -0.892914891242981\n",
      "155 revenue:  0.8333579897880554 rgt:  0.0003059617301914841 loss:  -0.8950839042663574\n",
      "156 revenue:  0.8422096371650696 rgt:  0.0008216550340875983 loss:  -0.8882319331169128\n",
      "157 revenue:  0.8357322216033936 rgt:  0.00030779707594774663 loss:  -0.8963291645050049\n",
      "158 revenue:  0.8444380760192871 rgt:  0.0006291630561463535 loss:  -0.8932189345359802\n",
      "159 revenue:  0.8383712768554688 rgt:  0.0003035930567421019 loss:  -0.897895872592926\n",
      "160 revenue:  0.8474855422973633 rgt:  0.0011998250847682357 loss:  -0.8847500681877136\n",
      "161 revenue:  0.8378371000289917 rgt:  0.0006314751226454973 loss:  -0.8895718455314636\n",
      "162 revenue:  0.8352091312408447 rgt:  0.0002361384395044297 loss:  -0.8982915878295898\n",
      "163 revenue:  0.8447859883308411 rgt:  0.0004585534334182739 loss:  -0.8972477316856384\n",
      "164 revenue:  0.843068540096283 rgt:  0.0014623320894315839 loss:  -0.8784835934638977\n",
      "165 revenue:  0.8356431722640991 rgt:  0.00023051253810990602 loss:  -0.8987188339233398\n",
      "166 revenue:  0.8461182117462158 rgt:  0.0011333674192428589 loss:  -0.8850464820861816\n",
      "167 revenue:  0.8388808369636536 rgt:  0.00026927379076369107 loss:  -0.8992225527763367\n",
      "168 revenue:  0.8484058380126953 rgt:  0.0006092015537433326 loss:  -0.8957962989807129\n",
      "169 revenue:  0.8388407826423645 rgt:  0.0002323709923075512 loss:  -0.9004032015800476\n",
      "170 revenue:  0.8493767380714417 rgt:  0.0004965374246239662 loss:  -0.8988345265388489\n",
      "171 revenue:  0.8429187536239624 rgt:  0.0004882988869212568 loss:  -0.8955181241035461\n",
      "172 revenue:  0.84832763671875 rgt:  0.000973900780081749 loss:  -0.8888642191886902\n",
      "173 revenue:  0.8405656218528748 rgt:  0.0004201559931971133 loss:  -0.895903468132019\n",
      "174 revenue:  0.8469463586807251 rgt:  0.0015570727409794927 loss:  -0.8792788982391357\n",
      "175 revenue:  0.8393046855926514 rgt:  0.0004190686740912497 loss:  -0.8952431082725525\n",
      "176 revenue:  0.8386717438697815 rgt:  0.0003627561964094639 loss:  -0.896378755569458\n",
      "177 revenue:  0.8450232148170471 rgt:  0.001839898293837905 loss:  -0.8745164275169373\n",
      "178 revenue:  0.8390163779258728 rgt:  0.0009621222270652652 loss:  -0.8839966058731079\n",
      "179 revenue:  0.8328778147697449 rgt:  0.0002879709645640105 loss:  -0.8953608274459839\n",
      "180 revenue:  0.8412699103355408 rgt:  0.0005254953284747899 loss:  -0.8937563896179199\n",
      "181 revenue:  0.8365805745124817 rgt:  0.0004152806068304926 loss:  -0.8938516974449158\n",
      "182 revenue:  0.8343668580055237 rgt:  0.0005687500815838575 loss:  -0.8890175819396973\n",
      "183 revenue:  0.8341243863105774 rgt:  0.0003901672607753426 loss:  -0.8931587934494019\n",
      "184 revenue:  0.8405677676200867 rgt:  0.0005508805043064058 loss:  -0.89280104637146\n",
      "185 revenue:  0.8374255895614624 rgt:  0.0002718046016525477 loss:  -0.8983483910560608\n",
      "186 revenue:  0.8458929657936096 rgt:  0.0007340625743381679 loss:  -0.8918949365615845\n",
      "187 revenue:  0.8382628560066223 rgt:  0.0002906662703026086 loss:  -0.8982244729995728\n",
      "188 revenue:  0.8457126617431641 rgt:  0.0004309868672862649 loss:  -0.8984327912330627\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "189 revenue:  0.8407487869262695 rgt:  0.00030349093140102923 loss:  -0.899196207523346\n",
      "190 revenue:  0.843417763710022 rgt:  0.000329932983731851 loss:  -0.8998811841011047\n",
      "191 revenue:  0.8501589894294739 rgt:  0.0009470182121731341 loss:  -0.8903183937072754\n",
      "192 revenue:  0.8421216607093811 rgt:  0.0003456995473243296 loss:  -0.8987305164337158\n",
      "193 revenue:  0.8441858887672424 rgt:  0.0003234459727536887 loss:  -0.900485098361969\n",
      "194 revenue:  0.8509671688079834 rgt:  0.000475106033263728 loss:  -0.9002045392990112\n",
      "195 revenue:  0.8504695296287537 rgt:  0.0007501700893044472 loss:  -0.8940678238868713\n",
      "196 revenue:  0.8438194990158081 rgt:  0.00039527571061626077 loss:  -0.8983172178268433\n",
      "197 revenue:  0.8483757376670837 rgt:  0.0010760938748717308 loss:  -0.8871917128562927\n",
      "198 revenue:  0.8430131673812866 rgt:  0.00021782297699246556 loss:  -0.9031775593757629\n",
      "199 revenue:  0.853565514087677 rgt:  0.0006158162723295391 loss:  -0.8984526991844177\n",
      "200 revenue:  0.8465335965156555 rgt:  0.0003206497640348971 loss:  -0.9018425345420837\n",
      "201 revenue:  0.8471552729606628 rgt:  0.00023827204131521285 loss:  -0.9047328233718872\n",
      "202 revenue:  0.8480835556983948 rgt:  0.0004293651436455548 loss:  -0.8997616767883301\n",
      "203 revenue:  0.8481819033622742 rgt:  0.0009217953193001449 loss:  -0.8896834850311279\n",
      "204 revenue:  0.847966194152832 rgt:  0.00029526615981012583 loss:  -0.9033693671226501\n",
      "205 revenue:  0.8486647009849548 rgt:  0.00021928166097495705 loss:  -0.9061992764472961\n",
      "206 revenue:  0.8497302532196045 rgt:  0.0005921802367083728 loss:  -0.8968791961669922\n",
      "207 revenue:  0.8495609164237976 rgt:  0.0005825445987284184 loss:  -0.8969957828521729\n",
      "208 revenue:  0.8486466407775879 rgt:  0.0002992248337250203 loss:  -0.9036200046539307\n",
      "209 revenue:  0.8493679761886597 rgt:  0.0003006648039445281 loss:  -0.9039684534072876\n",
      "210 revenue:  0.85003662109375 rgt:  0.0002845606650225818 loss:  -0.9048179388046265\n",
      "211 revenue:  0.850835919380188 rgt:  0.0002546326140873134 loss:  -0.9061927795410156\n",
      "212 revenue:  0.8513686060905457 rgt:  0.0006972562405280769 loss:  -0.8955917358398438\n",
      "213 revenue:  0.8505397439002991 rgt:  0.0005067904130555689 loss:  -0.8992261290550232\n",
      "214 revenue:  0.8495896458625793 rgt:  0.00027959130238741636 loss:  -0.904728353023529\n",
      "215 revenue:  0.8501999974250793 rgt:  0.0007390590035356581 loss:  -0.8941364884376526\n",
      "216 revenue:  0.8494222164154053 rgt:  0.00045710967970080674 loss:  -0.8998015522956848\n",
      "217 revenue:  0.8491084575653076 rgt:  0.0004109914298169315 loss:  -0.9007844924926758\n",
      "218 revenue:  0.8489628434181213 rgt:  0.0003911819658242166 loss:  -0.9012197852134705\n",
      "219 revenue:  0.8487800359725952 rgt:  0.00021952713723294437 loss:  -0.9062533378601074\n",
      "220 revenue:  0.8498278856277466 rgt:  0.0002829507284332067 loss:  -0.904754102230072\n",
      "221 revenue:  0.8506138324737549 rgt:  0.000281893735518679 loss:  -0.9052128195762634\n",
      "222 revenue:  0.8513667583465576 rgt:  0.00032480439404025674 loss:  -0.9043455719947815\n",
      "223 revenue:  0.8508902192115784 rgt:  0.00032312158145941794 loss:  -0.9041357040405273\n",
      "224 revenue:  0.8515306711196899 rgt:  0.00020814232993870974 loss:  -0.908145546913147\n",
      "225 revenue:  0.8525949120521545 rgt:  0.0007019800832495093 loss:  -0.8961619734764099\n",
      "226 revenue:  0.8515933156013489 rgt:  0.0005096629611216486 loss:  -0.8997305631637573\n",
      "227 revenue:  0.8510047793388367 rgt:  0.0003034495166502893 loss:  -0.9047731161117554\n",
      "228 revenue:  0.8517594337463379 rgt:  0.00041620529373176396 loss:  -0.9020884037017822\n",
      "229 revenue:  0.8513078093528748 rgt:  0.0003683238464873284 loss:  -0.903100848197937\n",
      "230 revenue:  0.8515628576278687 rgt:  0.00026485355920158327 loss:  -0.9062594175338745\n",
      "231 revenue:  0.8525609970092773 rgt:  0.0004254672967363149 loss:  -0.9022876024246216\n",
      "232 revenue:  0.8522215485572815 rgt:  0.0004077367193531245 loss:  -0.9025558233261108\n",
      "233 revenue:  0.8515346646308899 rgt:  0.00048503425205126405 loss:  -0.9002755880355835\n",
      "234 revenue:  0.8513894081115723 rgt:  0.00019860388420056552 loss:  -0.9084128737449646\n",
      "235 revenue:  0.8524577617645264 rgt:  0.0007544827531091869 loss:  -0.895062267780304\n",
      "236 revenue:  0.8517349362373352 rgt:  0.0004212566709611565 loss:  -0.9019467234611511\n",
      "237 revenue:  0.8508951663970947 rgt:  0.0004889813717454672 loss:  -0.8998357057571411\n",
      "238 revenue:  0.8504776954650879 rgt:  0.0002943806175608188 loss:  -0.9047587513923645\n",
      "239 revenue:  0.8507632613182068 rgt:  0.00036872323835268617 loss:  -0.9027948379516602\n",
      "240 revenue:  0.8513035178184509 rgt:  0.0005877003422938287 loss:  -0.8978288769721985\n",
      "241 revenue:  0.850326657295227 rgt:  0.0008199508301913738 loss:  -0.8926752209663391\n",
      "242 revenue:  0.8496413826942444 rgt:  0.0003126778465230018 loss:  -0.9037618041038513\n",
      "243 revenue:  0.8502774238586426 rgt:  0.00037499828613363206 loss:  -0.9023625254631042\n",
      "244 revenue:  0.8497627973556519 rgt:  0.00021110643865540624 loss:  -0.9070818424224854\n",
      "245 revenue:  0.8508356213569641 rgt:  0.0002213423140347004 loss:  -0.9073052406311035\n",
      "246 revenue:  0.8518728017807007 rgt:  0.00038739063893444836 loss:  -0.9028974771499634\n",
      "247 revenue:  0.8510950207710266 rgt:  0.00024579858290962875 loss:  -0.9066212177276611\n",
      "248 revenue:  0.8520708680152893 rgt:  0.0003853718808386475 loss:  -0.9030581712722778\n",
      "249 revenue:  0.8513208627700806 rgt:  0.00019949530542362481 loss:  -0.9083432555198669\n",
      "250 revenue:  0.8523811101913452 rgt:  0.00040487799560651183 loss:  -0.9027159810066223\n",
      "251 revenue:  0.8514723181724548 rgt:  0.0005599524592980742 loss:  -0.8985272645950317\n",
      "252 revenue:  0.850875973701477 rgt:  0.0003199570346623659 loss:  -0.9042193293571472\n",
      "253 revenue:  0.8514222502708435 rgt:  0.00035478759673424065 loss:  -0.9035322666168213\n",
      "254 revenue:  0.8508204817771912 rgt:  0.00021204528457019478 loss:  -0.9076220393180847\n",
      "255 revenue:  0.8518369197845459 rgt:  0.0002890366595238447 loss:  -0.905657172203064\n",
      "256 revenue:  0.852520763874054 rgt:  0.0005382188828662038 loss:  -0.8995806574821472\n",
      "257 revenue:  0.851620078086853 rgt:  0.0003220468061044812 loss:  -0.90456223487854\n",
      "258 revenue:  0.8513343334197998 rgt:  0.0005720898625440896 loss:  -0.8981853127479553\n",
      "259 revenue:  0.851085901260376 rgt:  0.000324348162394017 loss:  -0.904206395149231\n",
      "260 revenue:  0.8511859178543091 rgt:  0.00032227771589532495 loss:  -0.9043202996253967\n",
      "261 revenue:  0.8518214821815491 rgt:  0.0005020726821385324 loss:  -0.9000305533409119\n",
      "262 revenue:  0.8508366942405701 rgt:  0.0004053097509313375 loss:  -0.9018680453300476\n",
      "263 revenue:  0.8511950969696045 rgt:  0.0003482803876977414 loss:  -0.9035891890525818\n",
      "264 revenue:  0.8512345552444458 rgt:  0.0002945241576526314 loss:  -0.9051646590232849\n",
      "265 revenue:  0.8516706824302673 rgt:  0.00045897893141955137 loss:  -0.9009750485420227\n",
      "266 revenue:  0.8515769243240356 rgt:  0.0008758414187468588 loss:  -0.8923371434211731\n",
      "267 revenue:  0.8507591485977173 rgt:  0.00027119697188027203 loss:  -0.9056238532066345\n",
      "268 revenue:  0.8512670397758484 rgt:  0.0004023231740575284 loss:  -0.9021785855293274\n",
      "269 revenue:  0.8511468768119812 rgt:  0.0004050206916872412 loss:  -0.9020437002182007\n",
      "270 revenue:  0.850338339805603 rgt:  0.00032551269396208227 loss:  -0.9037677049636841\n",
      "271 revenue:  0.8504640460014343 rgt:  0.0002193987020291388 loss:  -0.9071712493896484\n",
      "272 revenue:  0.8515436053276062 rgt:  0.0003083117480855435 loss:  -0.9049212336540222\n",
      "273 revenue:  0.8521669507026672 rgt:  0.0003928508667740971 loss:  -0.9029130935668945\n",
      "274 revenue:  0.8519759774208069 rgt:  0.0003774047363549471 loss:  -0.9032185673713684\n",
      "275 revenue:  0.8522729277610779 rgt:  0.00033026578603312373 loss:  -0.9046801328659058\n",
      "276 revenue:  0.8522300720214844 rgt:  0.0003964554925914854 loss:  -0.9028530120849609\n",
      "277 revenue:  0.852321445941925 rgt:  0.0007041179342195392 loss:  -0.895971417427063\n",
      "278 revenue:  0.8517321944236755 rgt:  0.0002755550085566938 loss:  -0.9060150384902954\n",
      "279 revenue:  0.8521019220352173 rgt:  0.0003589075931813568 loss:  -0.90378737449646\n",
      "280 revenue:  0.8526154160499573 rgt:  0.00026732560945674777 loss:  -0.9067513942718506\n",
      "281 revenue:  0.8535158038139343 rgt:  0.0003921451570931822 loss:  -0.9036619067192078\n",
      "282 revenue:  0.8535072207450867 rgt:  0.00035160529660061 loss:  -0.9047491550445557\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "283 revenue:  0.8531474471092224 rgt:  0.0002792628074530512 loss:  -0.9066664576530457\n",
      "284 revenue:  0.85345458984375 rgt:  0.00027755918563343585 loss:  -0.9068853855133057\n",
      "285 revenue:  0.8537778854370117 rgt:  0.000330882437992841 loss:  -0.9054772853851318\n",
      "286 revenue:  0.8536757826805115 rgt:  0.0003787800669670105 loss:  -0.9041022062301636\n",
      "287 revenue:  0.8528673648834229 rgt:  0.0003614378802012652 loss:  -0.904132604598999\n",
      "288 revenue:  0.8528608083724976 rgt:  0.0004893525037914515 loss:  -0.9008917808532715\n",
      "289 revenue:  0.852715015411377 rgt:  0.00047809790703468025 loss:  -0.9010798931121826\n",
      "290 revenue:  0.8525850176811218 rgt:  0.00035893410677090287 loss:  -0.9040482044219971\n",
      "291 revenue:  0.8531213402748108 rgt:  0.0002203032490797341 loss:  -0.9085794687271118\n",
      "292 revenue:  0.8541746139526367 rgt:  0.0009413962834514678 loss:  -0.8925905227661133\n",
      "293 revenue:  0.8532872796058655 rgt:  0.000419512449298054 loss:  -0.9028316140174866\n",
      "294 revenue:  0.853208601474762 rgt:  0.00046527860104106367 loss:  -0.9016550779342651\n",
      "295 revenue:  0.8529717326164246 rgt:  0.0004021948843728751 loss:  -0.9031052589416504\n",
      "296 revenue:  0.8520784974098206 rgt:  0.00026747601805254817 loss:  -0.9064557552337646\n",
      "297 revenue:  0.8528198003768921 rgt:  0.0002048508176812902 loss:  -0.9089615345001221\n",
      "298 revenue:  0.8538892269134521 rgt:  0.0003358596295583993 loss:  -0.9053962826728821\n",
      "299 revenue:  0.8544452786445618 rgt:  0.00028097446192987263 loss:  -0.9073159098625183\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(learning_rates)):\n",
    "    opt_aucter = torch.optim.AdamW(list(ANet.parameters()) + list(PNet.parameters()),learning_rates[i])\n",
    "    \n",
    "    if (i % 1 == 0):\n",
    "        MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "        MNet.to(device)\n",
    "    \n",
    "    misreports = find_max_regret_misreport_BNIC(ANet,PNet,MNet,train_set_no_prob,100, 300,0.00001)\n",
    "    \n",
    "    opt_aucter.zero_grad()\n",
    "\n",
    "    u_orig = truthful_utility_calculation(ANet,PNet,train_set_no_prob)\n",
    "    u_new = misreport_utility_calculation_BNIC(ANet,PNet, train_set_no_prob, misreports)\n",
    "        \n",
    "    rgt = torch.mean(u_new - u_orig)\n",
    "    rev = torch.mean(truthful_revenue_calculation(ANet,PNet,train_set_no_prob))\n",
    "\n",
    "    #update NN\n",
    "    loss =  -1 * (torch.sqrt(rev + 1e-7) -  torch.sqrt(rgt + 1e-7)) + rgt\n",
    "    \n",
    "    loss.backward()\n",
    "    opt_aucter.step()\n",
    "\n",
    "    print(i, 'revenue: ',rev.item(), 'rgt: ', rgt.item(), 'loss: ', loss.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81f8ab1c",
   "metadata": {},
   "source": [
    "# Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "15af6ba6",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MisreportNetBNIC(\n",
       "  (mlp): MLP(\n",
       "    (model): Sequential(\n",
       "      (fc1): Linear(in_features=4, out_features=100, bias=True)\n",
       "      (tanh1): Tanh()\n",
       "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh2): Tanh()\n",
       "      (fc3): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh3): Tanh()\n",
       "      (fc4): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh4): Tanh()\n",
       "      (fc5): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh5): Tanh()\n",
       "      (fc6): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (tanh6): Tanh()\n",
       "      (fc7): Linear(in_features=100, out_features=4, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MNet = MisreportNetBNIC(n,m,nLayersMisreport,widthMisreport)\n",
    "MNet.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3ac6dc74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.853687584400177 0.00030504169990308583 0.8217181586726173\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_no_prob)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "107ef0bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2793348431587219 0.0009892991511151195 0.2470768432788602\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_prob_1)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f9388bbf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4221286475658417 0.0031390590593218803 0.3524642244188611\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_prob_2)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "338b6f77",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.614517092704773 0.004496638663113117 0.5138802672720019\n"
     ]
    }
   ],
   "source": [
    "rev,rgt = test(ANet,PNet,MNet,test_set_prob_3)\n",
    "print(rev,rgt,(np.sqrt(rev) - np.sqrt(rgt))**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b38e4e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
