{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "845cfdc3-a605-4ab6-bde8-8175d5550fe6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "dtype = torch.float64\n",
    "torch.set_default_dtype(dtype)\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.nn import BCELoss\n",
    "import numpy as np\n",
    "import sympy as sp\n",
    "from madgrad import MADGRAD\n",
    "from dataset.binary_logic import *\n",
    "from library.LFL_modules import *\n",
    "from library.utils import *\n",
    "from pprint import pprint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b3765cad-fa09-466e-9c97-6d67030a7247",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0 mean_loss: {'supervision': 1.6574606340280718, 'regularization': 1.4443489121354494}\n",
      "epoch: 1 mean_loss: {'supervision': 0.6611690459553947, 'regularization': 1.0666495825146414}\n",
      "epoch: 2 mean_loss: {'supervision': 0.8356701826588864, 'regularization': 0.9209922972193612}\n",
      "epoch: 3 mean_loss: {'supervision': 0.5271902186789398, 'regularization': 0.963164856581713}\n",
      "epoch: 4 mean_loss: {'supervision': 0.5566949738408089, 'regularization': 1.0420131988048278}\n",
      "epoch: 5 mean_loss: {'supervision': 0.5112371744767026, 'regularization': 1.0154056947532781}\n",
      "epoch: 6 mean_loss: {'supervision': 0.5314298070175509, 'regularization': 0.9613010161462622}\n",
      "epoch: 7 mean_loss: {'supervision': 0.507450351556828, 'regularization': 0.9706714824921662}\n",
      "epoch: 8 mean_loss: {'supervision': 0.5231353073138345, 'regularization': 1.0149087218966781}\n",
      "epoch: 9 mean_loss: {'supervision': 0.5030043707193039, 'regularization': 0.9840713092740041}\n",
      "epoch: 10 mean_loss: {'supervision': 0.5229369440970817, 'regularization': 0.949190969583835}\n",
      "epoch: 11 mean_loss: {'supervision': 0.5066177204812699, 'regularization': 0.9752680588254713}\n",
      "epoch: 12 mean_loss: {'supervision': 0.5022375039301478, 'regularization': 0.9767479419502827}\n",
      "epoch: 13 mean_loss: {'supervision': 0.5030881468241017, 'regularization': 0.926143014220137}\n",
      "epoch: 14 mean_loss: {'supervision': 0.4946141634472264, 'regularization': 0.9132995195488148}\n",
      "epoch: 15 mean_loss: {'supervision': 0.5066119838222112, 'regularization': 0.9232012095779012}\n",
      "epoch: 16 mean_loss: {'supervision': 0.48997845343734325, 'regularization': 0.8837125484303823}\n",
      "epoch: 17 mean_loss: {'supervision': 0.5048693765847563, 'regularization': 0.8385811669653842}\n",
      "epoch: 18 mean_loss: {'supervision': 0.4881667909723941, 'regularization': 0.8236062532993494}\n",
      "epoch: 19 mean_loss: {'supervision': 0.49219248508221836, 'regularization': 0.7989589472707996}\n",
      "epoch: 20 mean_loss: {'supervision': 0.5049225916994999, 'regularization': 0.7465486280662954}\n",
      "epoch: 21 mean_loss: {'supervision': 0.5103633229987835, 'regularization': 0.7158827873185041}\n",
      "epoch: 22 mean_loss: {'supervision': 0.48889966930096945, 'regularization': 0.6870185218035276}\n",
      "epoch: 23 mean_loss: {'supervision': 0.5026875525989307, 'regularization': 0.638405217232429}\n",
      "epoch: 24 mean_loss: {'supervision': 0.49648938768319834, 'regularization': 0.6019620269474059}\n",
      "epoch: 25 mean_loss: {'supervision': 0.4955369090866659, 'regularization': 0.578669580823771}\n",
      "epoch: 26 mean_loss: {'supervision': 0.5039102677113262, 'regularization': 0.5306100651282719}\n",
      "epoch: 27 mean_loss: {'supervision': 0.49858501646101827, 'regularization': 0.5039156372593154}\n",
      "epoch: 28 mean_loss: {'supervision': 0.49818219201215425, 'regularization': 0.4837114951048119}\n",
      "epoch: 29 mean_loss: {'supervision': 0.5100713769059435, 'regularization': 0.4386559107096835}\n",
      "epoch: 30 mean_loss: {'supervision': 0.48996482562911586, 'regularization': 0.4168949564958846}\n",
      "epoch: 31 mean_loss: {'supervision': 0.49149385429647446, 'regularization': 0.4016973086790565}\n",
      "epoch: 32 mean_loss: {'supervision': 0.4887205896741836, 'regularization': 0.3710958045155636}\n",
      "epoch: 33 mean_loss: {'supervision': 0.48863426263208765, 'regularization': 0.354495473107618}\n",
      "epoch: 34 mean_loss: {'supervision': 0.4949138076253112, 'regularization': 0.35187831193769203}\n",
      "epoch: 35 mean_loss: {'supervision': 0.48667712383705125, 'regularization': 0.3220994653341033}\n",
      "epoch: 36 mean_loss: {'supervision': 0.4773276783136816, 'regularization': 0.31399069272681435}\n",
      "epoch: 37 mean_loss: {'supervision': 0.4569124145176647, 'regularization': 0.32486931252667106}\n",
      "epoch: 38 mean_loss: {'supervision': 0.4251149664837718, 'regularization': 0.309756375945896}\n",
      "epoch: 39 mean_loss: {'supervision': 0.34384467490148585, 'regularization': 0.32114486947158555}\n",
      "epoch: 40 mean_loss: {'supervision': 0.28877147634765943, 'regularization': 0.3133339613144995}\n",
      "epoch: 41 mean_loss: {'supervision': 0.21301929025819585, 'regularization': 0.3025435928597583}\n",
      "epoch: 42 mean_loss: {'supervision': 0.17918271769524988, 'regularization': 0.2901748242703395}\n",
      "epoch: 43 mean_loss: {'supervision': 0.1468509331354823, 'regularization': 0.2801852822934078}\n",
      "epoch: 44 mean_loss: {'supervision': 0.11216599262531114, 'regularization': 0.27768868626634446}\n",
      "epoch: 45 mean_loss: {'supervision': 0.07208798435727651, 'regularization': 0.2772024564766992}\n",
      "epoch: 46 mean_loss: {'supervision': 0.04460507689696911, 'regularization': 0.2755348316265118}\n",
      "epoch: 47 mean_loss: {'supervision': 0.01982907527572521, 'regularization': 0.27102676554930927}\n",
      "epoch: 48 mean_loss: {'supervision': 0.007191823498691739, 'regularization': 0.26505909140913575}\n",
      "epoch: 49 mean_loss: {'supervision': 0.0033736940759517735, 'regularization': 0.25886928048663393}\n",
      "epoch: 50 mean_loss: {'supervision': 0.0020998434660986884, 'regularization': 0.253614849631816}\n",
      "epoch: 51 mean_loss: {'supervision': 0.0015367984859512383, 'regularization': 0.249341726405826}\n",
      "epoch: 52 mean_loss: {'supervision': 0.0012076528758017332, 'regularization': 0.24582974790121978}\n",
      "epoch: 53 mean_loss: {'supervision': 0.0010011446407943007, 'regularization': 0.24285742973019508}\n",
      "epoch: 54 mean_loss: {'supervision': 0.0008485466553501467, 'regularization': 0.24026787108739403}\n",
      "epoch: 55 mean_loss: {'supervision': 0.000746710520135291, 'regularization': 0.23796194606393195}\n",
      "epoch: 56 mean_loss: {'supervision': 0.0006672307233645624, 'regularization': 0.23587359839912636}\n",
      "epoch: 57 mean_loss: {'supervision': 0.0006103142283234654, 'regularization': 0.2339505231973718}\n",
      "epoch: 58 mean_loss: {'supervision': 0.0005650646847758621, 'regularization': 0.23215290825874796}\n",
      "epoch: 59 mean_loss: {'supervision': 0.0005292906710448718, 'regularization': 0.2304570813687526}\n",
      "epoch: 60 mean_loss: {'supervision': 0.000499148227392868, 'regularization': 0.22884949513671785}\n",
      "epoch: 61 mean_loss: {'supervision': 0.00047519405140630447, 'regularization': 0.22731296697673192}\n",
      "epoch: 62 mean_loss: {'supervision': 0.0004553069462673602, 'regularization': 0.22581715287700874}\n",
      "epoch: 63 mean_loss: {'supervision': 0.00043711871072280495, 'regularization': 0.22431880294716172}\n",
      "epoch: 64 mean_loss: {'supervision': 0.0004228079688501921, 'regularization': 0.22277263575500506}\n",
      "epoch: 65 mean_loss: {'supervision': 0.0004095268508791641, 'regularization': 0.22114961644033654}\n",
      "epoch: 66 mean_loss: {'supervision': 0.0003987355497690882, 'regularization': 0.2194429267876009}\n",
      "epoch: 67 mean_loss: {'supervision': 0.0003886450234802823, 'regularization': 0.21765764204076682}\n",
      "epoch: 68 mean_loss: {'supervision': 0.00037996094504777234, 'regularization': 0.21580831185347887}\n",
      "epoch: 69 mean_loss: {'supervision': 0.00037220405001838755, 'regularization': 0.2139256902193625}\n",
      "epoch: 70 mean_loss: {'supervision': 0.0003651276404746134, 'regularization': 0.21205214683766482}\n",
      "epoch: 71 mean_loss: {'supervision': 0.0003586879437344817, 'regularization': 0.21023144948493055}\n",
      "epoch: 72 mean_loss: {'supervision': 0.00035302796324258363, 'regularization': 0.20851059155364463}\n",
      "epoch: 73 mean_loss: {'supervision': 0.0003471915732576106, 'regularization': 0.20694715967341698}\n",
      "epoch: 74 mean_loss: {'supervision': 0.0003419271354818263, 'regularization': 0.205602014014908}\n",
      "epoch: 75 mean_loss: {'supervision': 0.00033641002574502453, 'regularization': 0.2045108879584167}\n",
      "epoch: 76 mean_loss: {'supervision': 0.00033099973975091987, 'regularization': 0.2036622817866669}\n",
      "epoch: 77 mean_loss: {'supervision': 0.0003256295088750984, 'regularization': 0.20300505434790558}\n",
      "epoch: 78 mean_loss: {'supervision': 0.00032021835183438307, 'regularization': 0.20247358627670956}\n",
      "epoch: 79 mean_loss: {'supervision': 0.0003149761394432326, 'regularization': 0.20200847848498904}\n",
      "epoch: 80 mean_loss: {'supervision': 0.0003102747523796041, 'regularization': 0.20156493374022785}\n",
      "epoch: 81 mean_loss: {'supervision': 0.0003054461799658203, 'regularization': 0.20111470384974456}\n",
      "epoch: 82 mean_loss: {'supervision': 0.0003008429265152511, 'regularization': 0.20064710329645208}\n",
      "epoch: 83 mean_loss: {'supervision': 0.00029657231238170107, 'regularization': 0.20016691772446343}\n",
      "epoch: 84 mean_loss: {'supervision': 0.00029233154047839647, 'regularization': 0.1996869053348533}\n",
      "epoch: 85 mean_loss: {'supervision': 0.00028823137159098716, 'regularization': 0.19922224297524926}\n",
      "epoch: 86 mean_loss: {'supervision': 0.0002842765130228175, 'regularization': 0.19879332435735242}\n",
      "epoch: 87 mean_loss: {'supervision': 0.00028030470637258314, 'regularization': 0.19842361468474612}\n",
      "epoch: 88 mean_loss: {'supervision': 0.0002763930705909763, 'regularization': 0.198124946503053}\n",
      "epoch: 89 mean_loss: {'supervision': 0.0002726961834046972, 'regularization': 0.19789009164562865}\n",
      "epoch: 90 mean_loss: {'supervision': 0.00026915951019910154, 'regularization': 0.19770145190595909}\n",
      "epoch: 91 mean_loss: {'supervision': 0.00026548887736123773, 'regularization': 0.19754164028929547}\n",
      "epoch: 92 mean_loss: {'supervision': 0.00026208740227928924, 'regularization': 0.1973974213289891}\n",
      "epoch: 93 mean_loss: {'supervision': 0.0002588170544819759, 'regularization': 0.1972592775264254}\n",
      "epoch: 94 mean_loss: {'supervision': 0.0002555216647691837, 'regularization': 0.19711982427110836}\n",
      "epoch: 95 mean_loss: {'supervision': 0.0002525087542301904, 'regularization': 0.19697235397030677}\n",
      "epoch: 96 mean_loss: {'supervision': 0.00024941404574123225, 'regularization': 0.196809706354299}\n",
      "epoch: 97 mean_loss: {'supervision': 0.00024649678183387034, 'regularization': 0.19662357681981463}\n",
      "epoch: 98 mean_loss: {'supervision': 0.00024378852276007496, 'regularization': 0.19640473256259508}\n",
      "epoch: 99 mean_loss: {'supervision': 0.0002410325987687649, 'regularization': 0.19614529291559302}\n",
      "epoch: 100 mean_loss: {'supervision': 0.00023826531435232288, 'regularization': 0.1958444385285401}\n",
      "epoch: 101 mean_loss: {'supervision': 0.00023572212725225125, 'regularization': 0.1955151769772051}\n",
      "epoch: 102 mean_loss: {'supervision': 0.00023321502520619604, 'regularization': 0.1951821876646933}\n",
      "epoch: 103 mean_loss: {'supervision': 0.00023064414213061496, 'regularization': 0.1948650818261899}\n",
      "epoch: 104 mean_loss: {'supervision': 0.00022820207181165915, 'regularization': 0.1945629184009912}\n",
      "epoch: 105 mean_loss: {'supervision': 0.00022567202410921837, 'regularization': 0.1942573621480672}\n",
      "epoch: 106 mean_loss: {'supervision': 0.0002231837838211303, 'regularization': 0.19392970421583686}\n",
      "epoch: 107 mean_loss: {'supervision': 0.00022084255726636653, 'regularization': 0.193580978365669}\n",
      "epoch: 108 mean_loss: {'supervision': 0.00021836376449271703, 'regularization': 0.19324126606891806}\n",
      "epoch: 109 mean_loss: {'supervision': 0.00021605913687678757, 'regularization': 0.1929483881749785}\n",
      "epoch: 110 mean_loss: {'supervision': 0.00021382163114578602, 'regularization': 0.19271487739258159}\n",
      "epoch: 111 mean_loss: {'supervision': 0.0002116222899391722, 'regularization': 0.19252573722402724}\n",
      "epoch: 112 mean_loss: {'supervision': 0.0002093259934655455, 'regularization': 0.19235721968864317}\n",
      "epoch: 113 mean_loss: {'supervision': 0.00020714488771857823, 'regularization': 0.1921882137077791}\n",
      "epoch: 114 mean_loss: {'supervision': 0.0002052248774540905, 'regularization': 0.19200206961597677}\n",
      "epoch: 115 mean_loss: {'supervision': 0.00020315184386647978, 'regularization': 0.19178706645293145}\n",
      "epoch: 116 mean_loss: {'supervision': 0.0002011544729400815, 'regularization': 0.1915402662140491}\n",
      "epoch: 117 mean_loss: {'supervision': 0.00019915714388574776, 'regularization': 0.19127406400922436}\n",
      "epoch: 118 mean_loss: {'supervision': 0.00019741731441730455, 'regularization': 0.19101635284226987}\n",
      "epoch: 119 mean_loss: {'supervision': 0.0001954589023758112, 'regularization': 0.19079570767644194}\n",
      "epoch: 120 mean_loss: {'supervision': 0.00019353873159274105, 'regularization': 0.190624100096663}\n",
      "epoch: 121 mean_loss: {'supervision': 0.0001916456344910464, 'regularization': 0.19049606173367473}\n",
      "epoch: 122 mean_loss: {'supervision': 0.00018971616109471213, 'regularization': 0.19039925866322077}\n",
      "epoch: 123 mean_loss: {'supervision': 0.0001877743887936985, 'regularization': 0.19032243954641845}\n",
      "epoch: 124 mean_loss: {'supervision': 0.00018596596954945625, 'regularization': 0.19025767177159592}\n",
      "epoch: 125 mean_loss: {'supervision': 0.0001841285259505617, 'regularization': 0.19019985754091834}\n",
      "epoch: 126 mean_loss: {'supervision': 0.0001822772459898338, 'regularization': 0.19014573653074607}\n",
      "epoch: 127 mean_loss: {'supervision': 0.00018053490044204714, 'regularization': 0.1900930833034974}\n",
      "epoch: 128 mean_loss: {'supervision': 0.00017879088647981237, 'regularization': 0.19004017725281616}\n",
      "epoch: 129 mean_loss: {'supervision': 0.0001770807426105166, 'regularization': 0.18998544520429225}\n",
      "epoch: 130 mean_loss: {'supervision': 0.00017536827858761263, 'regularization': 0.18992715846673855}\n",
      "epoch: 131 mean_loss: {'supervision': 0.0001737696870567773, 'regularization': 0.18986311526167554}\n",
      "epoch: 132 mean_loss: {'supervision': 0.00017213717726740453, 'regularization': 0.18979023486073904}\n",
      "epoch: 133 mean_loss: {'supervision': 0.0001705702405034284, 'regularization': 0.18970393143417466}\n",
      "epoch: 134 mean_loss: {'supervision': 0.00016900395214616777, 'regularization': 0.1895971457475426}\n",
      "epoch: 135 mean_loss: {'supervision': 0.00016747150025254526, 'regularization': 0.18945902060013037}\n",
      "epoch: 136 mean_loss: {'supervision': 0.00016596874979222308, 'regularization': 0.18927396036764682}\n",
      "epoch: 137 mean_loss: {'supervision': 0.00016448768463003924, 'regularization': 0.18902453022105786}\n",
      "epoch: 138 mean_loss: {'supervision': 0.00016303153195103157, 'regularization': 0.18870608412269396}\n",
      "epoch: 139 mean_loss: {'supervision': 0.0001617272392389648, 'regularization': 0.18835143497563156}\n",
      "epoch: 140 mean_loss: {'supervision': 0.00016028866420597242, 'regularization': 0.1880228541870033}\n",
      "epoch: 141 mean_loss: {'supervision': 0.00015889541732150216, 'regularization': 0.1877525540166179}\n",
      "epoch: 142 mean_loss: {'supervision': 0.00015757930039645313, 'regularization': 0.1875184969531821}\n",
      "epoch: 143 mean_loss: {'supervision': 0.00015629347472945247, 'regularization': 0.18727694317185334}\n",
      "epoch: 144 mean_loss: {'supervision': 0.00015504668617028825, 'regularization': 0.1869933995284131}\n",
      "epoch: 145 mean_loss: {'supervision': 0.00015376795784985382, 'regularization': 0.1866678664952654}\n",
      "epoch: 146 mean_loss: {'supervision': 0.00015264087492866663, 'regularization': 0.18635137965559329}\n",
      "epoch: 147 mean_loss: {'supervision': 0.00015141996041812155, 'regularization': 0.18610654215681177}\n",
      "epoch: 148 mean_loss: {'supervision': 0.00015020528148162673, 'regularization': 0.1859478688028699}\n",
      "epoch: 149 mean_loss: {'supervision': 0.00014892881448662565, 'regularization': 0.18585095287773845}\n",
      "epoch: 150 mean_loss: {'supervision': 0.0001477453229524094, 'regularization': 0.18578899972739443}\n",
      "epoch: 151 mean_loss: {'supervision': 0.0001465100034468593, 'regularization': 0.18574534083363076}\n",
      "epoch: 152 mean_loss: {'supervision': 0.00014535634403453615, 'regularization': 0.1857112880864824}\n",
      "epoch: 153 mean_loss: {'supervision': 0.0001441953735486067, 'regularization': 0.18568248852285663}\n",
      "epoch: 154 mean_loss: {'supervision': 0.0001430141855991072, 'regularization': 0.18565672467166444}\n",
      "epoch: 155 mean_loss: {'supervision': 0.0001418532751336845, 'regularization': 0.18563281329034287}\n",
      "epoch: 156 mean_loss: {'supervision': 0.00014070381056279462, 'regularization': 0.18561008544344534}\n",
      "epoch: 157 mean_loss: {'supervision': 0.00013963504206901655, 'regularization': 0.1855881352833068}\n",
      "epoch: 158 mean_loss: {'supervision': 0.00013850253648197298, 'regularization': 0.18556669672363998}\n",
      "epoch: 159 mean_loss: {'supervision': 0.0001374309854170715, 'regularization': 0.18554557189117826}\n",
      "epoch: 160 mean_loss: {'supervision': 0.00013636351925653222, 'regularization': 0.1855245930401428}\n",
      "epoch: 161 mean_loss: {'supervision': 0.0001353024953923284, 'regularization': 0.18550360046722605}\n",
      "epoch: 162 mean_loss: {'supervision': 0.00013424829298192665, 'regularization': 0.18548241395638349}\n",
      "epoch: 163 mean_loss: {'supervision': 0.0001331851172199331, 'regularization': 0.185460808662598}\n",
      "epoch: 164 mean_loss: {'supervision': 0.0001321853318674948, 'regularization': 0.18543847888673684}\n",
      "epoch: 165 mean_loss: {'supervision': 0.00013121262761594957, 'regularization': 0.18541498022488168}\n",
      "epoch: 166 mean_loss: {'supervision': 0.00013020266525097567, 'regularization': 0.18538962913986584}\n",
      "epoch: 167 mean_loss: {'supervision': 0.0001292751914999916, 'regularization': 0.1853613064944932}\n",
      "epoch: 168 mean_loss: {'supervision': 0.0001283038703487595, 'regularization': 0.18532807044705335}\n",
      "epoch: 169 mean_loss: {'supervision': 0.00012737085541225456, 'regularization': 0.185286300195892}\n",
      "epoch: 170 mean_loss: {'supervision': 0.00012644875143119248, 'regularization': 0.1852286543476245}\n",
      "epoch: 171 mean_loss: {'supervision': 0.00012559327120390765, 'regularization': 0.18513878644929554}\n",
      "epoch: 172 mean_loss: {'supervision': 0.00012476337631653093, 'regularization': 0.18497761432226137}\n",
      "epoch: 173 mean_loss: {'supervision': 0.00012409485954332524, 'regularization': 0.18465914483960605}\n",
      "epoch: 174 mean_loss: {'supervision': 0.00012364199358271775, 'regularization': 0.18412180916616983}\n",
      "epoch: 175 mean_loss: {'supervision': 0.00012321298926343065, 'regularization': 0.18363173850304182}\n",
      "epoch: 176 mean_loss: {'supervision': 0.00012244700992698274, 'regularization': 0.1834188242099233}\n",
      "epoch: 177 mean_loss: {'supervision': 0.00012165238618963593, 'regularization': 0.18334871319711182}\n",
      "epoch: 178 mean_loss: {'supervision': 0.00012075621833260116, 'regularization': 0.18331876589844984}\n",
      "epoch: 179 mean_loss: {'supervision': 0.00011991000907250453, 'regularization': 0.18329968328040222}\n",
      "epoch: 180 mean_loss: {'supervision': 0.00011903430761915947, 'regularization': 0.1832840137564737}\n",
      "epoch: 181 mean_loss: {'supervision': 0.00011820995175175185, 'regularization': 0.18326965382028138}\n",
      "epoch: 182 mean_loss: {'supervision': 0.00011736526118978556, 'regularization': 0.18325592029864024}\n",
      "epoch: 183 mean_loss: {'supervision': 0.00011655293609736968, 'regularization': 0.1832425588165738}\n",
      "epoch: 184 mean_loss: {'supervision': 0.0001157175194157278, 'regularization': 0.18322946285302838}\n",
      "epoch: 185 mean_loss: {'supervision': 0.00011490915605952422, 'regularization': 0.1832165833815769}\n",
      "epoch: 186 mean_loss: {'supervision': 0.0001141068495432747, 'regularization': 0.18320389321445846}\n",
      "epoch: 187 mean_loss: {'supervision': 0.00011331835474319026, 'regularization': 0.18319137705005817}\n",
      "epoch: 188 mean_loss: {'supervision': 0.00011255212916308754, 'regularization': 0.1831790241643933}\n",
      "epoch: 189 mean_loss: {'supervision': 0.00011176225788856649, 'regularization': 0.18316682769128118}\n",
      "epoch: 190 mean_loss: {'supervision': 0.0001110041164134857, 'regularization': 0.18315478259865287}\n",
      "epoch: 191 mean_loss: {'supervision': 0.00011023577937611962, 'regularization': 0.18314288337150833}\n",
      "epoch: 192 mean_loss: {'supervision': 0.00010950003839259105, 'regularization': 0.183131125604219}\n",
      "epoch: 193 mean_loss: {'supervision': 0.00010876280224751682, 'regularization': 0.1831195052025108}\n",
      "epoch: 194 mean_loss: {'supervision': 0.00010804435699420545, 'regularization': 0.18310801775780527}\n",
      "epoch: 195 mean_loss: {'supervision': 0.000107312214331421, 'regularization': 0.18309665978828188}\n",
      "epoch: 196 mean_loss: {'supervision': 0.00010659171130694665, 'regularization': 0.18308542923703183}\n",
      "epoch: 197 mean_loss: {'supervision': 0.00010589792365304895, 'regularization': 0.18307432133824256}\n",
      "epoch: 198 mean_loss: {'supervision': 0.00010523707457532208, 'regularization': 0.1830633310682061}\n",
      "epoch: 199 mean_loss: {'supervision': 0.00010452453237616596, 'regularization': 0.18305245680843335}\n",
      "epoch: 200 mean_loss: {'supervision': 0.00010385296370577305, 'regularization': 0.18304169528098954}\n",
      "epoch: 201 mean_loss: {'supervision': 0.0001031740762051228, 'regularization': 0.18303104265396686}\n",
      "epoch: 202 mean_loss: {'supervision': 0.0001025218197690195, 'regularization': 0.18302049514414653}\n",
      "epoch: 203 mean_loss: {'supervision': 0.00010185963765535846, 'regularization': 0.18301004877087806}\n",
      "epoch: 204 mean_loss: {'supervision': 0.00010118800122383998, 'regularization': 0.1829997010124932}\n",
      "epoch: 205 mean_loss: {'supervision': 0.00010056856421824731, 'regularization': 0.18298944690150698}\n",
      "epoch: 206 mean_loss: {'supervision': 9.993505044758851e-05, 'regularization': 0.18297928185491164}\n",
      "epoch: 207 mean_loss: {'supervision': 9.927213907297248e-05, 'regularization': 0.1829692035368392}\n",
      "epoch: 208 mean_loss: {'supervision': 9.865801900031473e-05, 'regularization': 0.18295920714045655}\n",
      "epoch: 209 mean_loss: {'supervision': 9.804922628473369e-05, 'regularization': 0.18294928707467187}\n",
      "epoch: 210 mean_loss: {'supervision': 9.744992916915372e-05, 'regularization': 0.18293943799977186}\n",
      "epoch: 211 mean_loss: {'supervision': 9.684667697602682e-05, 'regularization': 0.1829296541023837}\n",
      "epoch: 212 mean_loss: {'supervision': 9.623951412146401e-05, 'regularization': 0.18291993085079009}\n",
      "epoch: 213 mean_loss: {'supervision': 9.56741547454488e-05, 'regularization': 0.1829102611562014}\n",
      "epoch: 214 mean_loss: {'supervision': 9.506026485237735e-05, 'regularization': 0.1829006381497492}\n",
      "epoch: 215 mean_loss: {'supervision': 9.447853742929701e-05, 'regularization': 0.18289105488886198}\n",
      "epoch: 216 mean_loss: {'supervision': 9.391566228822816e-05, 'regularization': 0.1828815017409629}\n",
      "epoch: 217 mean_loss: {'supervision': 9.33483859920964e-05, 'regularization': 0.18287196983550646}\n",
      "epoch: 218 mean_loss: {'supervision': 9.280737299288093e-05, 'regularization': 0.1828624474961609}\n",
      "epoch: 219 mean_loss: {'supervision': 9.224130331335283e-05, 'regularization': 0.18285292262676917}\n",
      "epoch: 220 mean_loss: {'supervision': 9.166981294323024e-05, 'regularization': 0.1828433821861914}\n",
      "epoch: 221 mean_loss: {'supervision': 9.113800834861872e-05, 'regularization': 0.18283380940770438}\n",
      "epoch: 222 mean_loss: {'supervision': 9.061761900999448e-05, 'regularization': 0.18282418461524813}\n",
      "epoch: 223 mean_loss: {'supervision': 9.008183288437632e-05, 'regularization': 0.18281448504491143}\n",
      "epoch: 224 mean_loss: {'supervision': 8.956833217901232e-05, 'regularization': 0.1828046835241669}\n",
      "epoch: 225 mean_loss: {'supervision': 8.903056703098713e-05, 'regularization': 0.18279474924256767}\n",
      "epoch: 226 mean_loss: {'supervision': 8.851989878224328e-05, 'regularization': 0.1827846441441476}\n",
      "epoch: 227 mean_loss: {'supervision': 8.800876271501898e-05, 'regularization': 0.18277432067209115}\n",
      "epoch: 228 mean_loss: {'supervision': 8.749465600076947e-05, 'regularization': 0.18276372030255533}\n",
      "epoch: 229 mean_loss: {'supervision': 8.699289588883372e-05, 'regularization': 0.18275277003477472}\n",
      "epoch: 230 mean_loss: {'supervision': 8.649302627092969e-05, 'regularization': 0.1827413775999216}\n",
      "epoch: 231 mean_loss: {'supervision': 8.601072843096065e-05, 'regularization': 0.18272942383844443}\n",
      "epoch: 232 mean_loss: {'supervision': 8.552793904726097e-05, 'regularization': 0.18271675399688572}\n",
      "epoch: 233 mean_loss: {'supervision': 8.503714740354779e-05, 'regularization': 0.18270316436001052}\n",
      "epoch: 234 mean_loss: {'supervision': 8.456122619277694e-05, 'regularization': 0.1826883801197543}\n",
      "epoch: 235 mean_loss: {'supervision': 8.407778927154451e-05, 'regularization': 0.1826720236323749}\n",
      "epoch: 236 mean_loss: {'supervision': 8.362355155714358e-05, 'regularization': 0.18265356436435928}\n",
      "epoch: 237 mean_loss: {'supervision': 8.31657454725428e-05, 'regularization': 0.18263224085930974}\n",
      "epoch: 238 mean_loss: {'supervision': 8.270204969202007e-05, 'regularization': 0.18260693181195597}\n",
      "epoch: 239 mean_loss: {'supervision': 8.22274803493009e-05, 'regularization': 0.18257593868964372}\n",
      "epoch: 240 mean_loss: {'supervision': 8.180005228291322e-05, 'regularization': 0.18253660848956962}\n",
      "epoch: 241 mean_loss: {'supervision': 8.133196096122362e-05, 'regularization': 0.1824846869973929}\n",
      "epoch: 242 mean_loss: {'supervision': 8.091693835029357e-05, 'regularization': 0.18241320758993912}\n",
      "epoch: 243 mean_loss: {'supervision': 8.049266185864001e-05, 'regularization': 0.1823107802237146}\n",
      "epoch: 244 mean_loss: {'supervision': 8.006640906910311e-05, 'regularization': 0.1821599408053197}\n",
      "epoch: 245 mean_loss: {'supervision': 7.963275085412849e-05, 'regularization': 0.18193993382742346}\n",
      "epoch: 246 mean_loss: {'supervision': 7.922572127808027e-05, 'regularization': 0.18164609530904585}\n",
      "epoch: 247 mean_loss: {'supervision': 7.884149144613489e-05, 'regularization': 0.1813256532474659}\n",
      "epoch: 248 mean_loss: {'supervision': 7.843613115819977e-05, 'regularization': 0.1810634378577348}\n",
      "epoch: 249 mean_loss: {'supervision': 7.802700655519106e-05, 'regularization': 0.18089833647378198}\n",
      "epoch: 250 mean_loss: {'supervision': 7.763691156137334e-05, 'regularization': 0.18080813381907218}\n",
      "epoch: 251 mean_loss: {'supervision': 7.721972922985695e-05, 'regularization': 0.18075957274823018}\n",
      "epoch: 252 mean_loss: {'supervision': 7.682749714234459e-05, 'regularization': 0.18073161129134924}\n",
      "epoch: 253 mean_loss: {'supervision': 7.641825506870539e-05, 'regularization': 0.18071374478113028}\n",
      "epoch: 254 mean_loss: {'supervision': 7.600967614058528e-05, 'regularization': 0.18070101429745836}\n",
      "epoch: 255 mean_loss: {'supervision': 7.560758405752651e-05, 'regularization': 0.18069104335323938}\n"
     ]
    }
   ],
   "source": [
    "seed(42)\n",
    "\n",
    "train_dataset = MultiLayerLogicDataset()\n",
    "batch_size = 128\n",
    "loss_fn = BCELoss()\n",
    "dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "epochs = 2 ** 8\n",
    "\n",
    "lr = 60.74940936219056\n",
    "reg_coef = 0.13347185900881423\n",
    "\n",
    "def train():\n",
    "    model.train()\n",
    "    losses = {\n",
    "        'supervision': [],\n",
    "        'regularization': []\n",
    "    }\n",
    "    mean_preds = []\n",
    "    for x, y in dataloader:\n",
    "        pred = model(x.to(device))\n",
    "        mean_preds.append(torch.mean(pred).item())\n",
    "        loss = loss_fn(pred, y.to(device))\n",
    "        losses['supervision'].append(loss.item())\n",
    "        reg_loss = model.reg_loss()\n",
    "        losses['regularization'].append(reg_loss.item())\n",
    "        loss += reg_coef * reg_loss\n",
    "\n",
    "        # Backpropagation\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    mean_loss = {k: np.mean(v) for k, v in losses.items()}\n",
    "    return mean_loss\n",
    "\n",
    "model = MultiLayerDNLWithNegation(n_input=8, n_hiddens=[32, 16, 2], layer_kwargs=[{} for i in range(3)]).to(device)\n",
    "optimizer = MADGRAD([{'params': model.parameters()}], lr=lr)\n",
    "loss_records = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    mean_loss = train()\n",
    "    print(f'epoch: {epoch} mean_loss: {mean_loss}')\n",
    "    loss_records.append(mean_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "df22fad6-64ec-43b1-b671-b7c50e389145",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_expressions = model.expression([sp.Symbol(f'x{i}') for i in range(8)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "593e1177-cb72-420c-8b83-9fbbc69e6a00",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(x0 & x1 & x4 & x5) | (x0 & x1 & x6 & x7) | (x2 & x3 & x4 & x5) | (x2 & x3 & x6 & x7)\n",
      "(~x0 & ~x1 & ~x4 & ~x5) | (~x0 & ~x1 & ~x6 & ~x7) | (~x2 & ~x3 & ~x4 & ~x5) | (~x2 & ~x3 & ~x6 & ~x7)\n"
     ]
    }
   ],
   "source": [
    "print(model_expressions[0])\n",
    "print(model_expressions[1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
