{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "845cfdc3-a605-4ab6-bde8-8175d5550fe6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "dtype = torch.float64\n",
    "torch.set_default_dtype(dtype)\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.nn import BCELoss\n",
    "import numpy as np\n",
    "import sympy as sp\n",
    "from madgrad import MADGRAD\n",
    "from dataset.binary_logic import *\n",
    "from library.LFL_modules import *\n",
    "from library.utils import *\n",
    "from pprint import pprint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "577ef06b-54d7-4ef4-bee6-8d0c90a39d23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0 mean_loss: {'supervision': 0.5782821518786319, 'regularization': 1.5114227262852589}\n",
      "epoch: 1 mean_loss: {'supervision': 0.5696601676889982, 'regularization': 1.473567901272365}\n",
      "epoch: 2 mean_loss: {'supervision': 0.5397120816834651, 'regularization': 1.4347071369293096}\n",
      "epoch: 3 mean_loss: {'supervision': 0.5297573517603917, 'regularization': 1.3831538935639867}\n",
      "epoch: 4 mean_loss: {'supervision': 0.5091272800885367, 'regularization': 1.3303802661017174}\n",
      "epoch: 5 mean_loss: {'supervision': 0.48148183197797734, 'regularization': 1.2842442148856819}\n",
      "epoch: 6 mean_loss: {'supervision': 0.48716612513316854, 'regularization': 1.2365737761104714}\n",
      "epoch: 7 mean_loss: {'supervision': 0.489689622480702, 'regularization': 1.184989610144365}\n",
      "epoch: 8 mean_loss: {'supervision': 0.482094762008679, 'regularization': 1.1374668921734572}\n",
      "epoch: 9 mean_loss: {'supervision': 0.46269447588501594, 'regularization': 1.08743327588842}\n",
      "epoch: 10 mean_loss: {'supervision': 0.44804186226322107, 'regularization': 1.0290190236823484}\n",
      "epoch: 11 mean_loss: {'supervision': 0.4374157109560872, 'regularization': 0.9686938430579188}\n",
      "epoch: 12 mean_loss: {'supervision': 0.42253143415437744, 'regularization': 0.9212459064333808}\n",
      "epoch: 13 mean_loss: {'supervision': 0.3956090423023189, 'regularization': 0.8854636551039476}\n",
      "epoch: 14 mean_loss: {'supervision': 0.3639322083301117, 'regularization': 0.8524339182873311}\n",
      "epoch: 15 mean_loss: {'supervision': 0.3412991692583046, 'regularization': 0.82629386951948}\n",
      "epoch: 16 mean_loss: {'supervision': 0.2892025042998747, 'regularization': 0.8037956443046813}\n",
      "epoch: 17 mean_loss: {'supervision': 0.2702923600060907, 'regularization': 0.7901830083376906}\n",
      "epoch: 18 mean_loss: {'supervision': 0.23976991095469877, 'regularization': 0.7797881630604347}\n",
      "epoch: 19 mean_loss: {'supervision': 0.21164314869383663, 'regularization': 0.7688739929389329}\n",
      "epoch: 20 mean_loss: {'supervision': 0.20849070014774232, 'regularization': 0.7593804915666046}\n",
      "epoch: 21 mean_loss: {'supervision': 0.18141386398199522, 'regularization': 0.751144216883405}\n",
      "epoch: 22 mean_loss: {'supervision': 0.16817571038042828, 'regularization': 0.7441522505688211}\n",
      "epoch: 23 mean_loss: {'supervision': 0.16176122215686892, 'regularization': 0.7361612237397673}\n",
      "epoch: 24 mean_loss: {'supervision': 0.15524929724048833, 'regularization': 0.7271476359512239}\n",
      "epoch: 25 mean_loss: {'supervision': 0.14434263799988856, 'regularization': 0.7182872036672021}\n",
      "epoch: 26 mean_loss: {'supervision': 0.14226538950322742, 'regularization': 0.7091987994434604}\n",
      "epoch: 27 mean_loss: {'supervision': 0.11337392664324886, 'regularization': 0.6997307545586187}\n",
      "epoch: 28 mean_loss: {'supervision': 0.10662392220926833, 'regularization': 0.69091371901883}\n",
      "epoch: 29 mean_loss: {'supervision': 0.07283690663013681, 'regularization': 0.6828965281169719}\n",
      "epoch: 30 mean_loss: {'supervision': 0.06699500572639011, 'regularization': 0.676217883988588}\n",
      "epoch: 31 mean_loss: {'supervision': 0.05465088312338506, 'regularization': 0.6703255379566817}\n",
      "epoch: 32 mean_loss: {'supervision': 0.049219831790121196, 'regularization': 0.6648208279818916}\n",
      "epoch: 33 mean_loss: {'supervision': 0.05571479585106959, 'regularization': 0.6591490381649301}\n",
      "epoch: 34 mean_loss: {'supervision': 0.039376690007782655, 'regularization': 0.653318030074073}\n",
      "epoch: 35 mean_loss: {'supervision': 0.03371720910959346, 'regularization': 0.6478770723188907}\n",
      "epoch: 36 mean_loss: {'supervision': 0.024130207799249008, 'regularization': 0.6430380572766874}\n",
      "epoch: 37 mean_loss: {'supervision': 0.027521456932265168, 'regularization': 0.6385881884674568}\n",
      "epoch: 38 mean_loss: {'supervision': 0.0296859445418726, 'regularization': 0.6343179421782049}\n",
      "epoch: 39 mean_loss: {'supervision': 0.02137680479323475, 'regularization': 0.6305866965108582}\n",
      "epoch: 40 mean_loss: {'supervision': 0.017607194015872164, 'regularization': 0.6272950474980353}\n",
      "epoch: 41 mean_loss: {'supervision': 0.006530837993212945, 'regularization': 0.6234713566929673}\n",
      "epoch: 42 mean_loss: {'supervision': 0.009022295615387366, 'regularization': 0.6195585735384768}\n",
      "epoch: 43 mean_loss: {'supervision': 0.0040299177804499035, 'regularization': 0.6157203260829949}\n",
      "epoch: 44 mean_loss: {'supervision': 0.007447050171672193, 'regularization': 0.6120410679820275}\n",
      "epoch: 45 mean_loss: {'supervision': 0.005078368540873049, 'regularization': 0.6084692325721054}\n",
      "epoch: 46 mean_loss: {'supervision': 0.004841329100697923, 'regularization': 0.604967486909463}\n",
      "epoch: 47 mean_loss: {'supervision': 0.004854712303690852, 'regularization': 0.6015941619010687}\n",
      "epoch: 48 mean_loss: {'supervision': 0.0029163245263923512, 'regularization': 0.5983522138181498}\n",
      "epoch: 49 mean_loss: {'supervision': 0.0028194973090796827, 'regularization': 0.5952336021961151}\n",
      "epoch: 50 mean_loss: {'supervision': 0.002097075128502978, 'regularization': 0.5922152496791631}\n",
      "epoch: 51 mean_loss: {'supervision': 0.00347982839050054, 'regularization': 0.5892901457400547}\n",
      "epoch: 52 mean_loss: {'supervision': 0.0018736338209561084, 'regularization': 0.5864466746572219}\n",
      "epoch: 53 mean_loss: {'supervision': 0.0017671766032826494, 'regularization': 0.583680284114483}\n",
      "epoch: 54 mean_loss: {'supervision': 0.001800713706389641, 'regularization': 0.580984066276291}\n",
      "epoch: 55 mean_loss: {'supervision': 0.0017646227041249007, 'regularization': 0.5783501667903378}\n",
      "epoch: 56 mean_loss: {'supervision': 0.002519099746073972, 'regularization': 0.5757702275143559}\n",
      "epoch: 57 mean_loss: {'supervision': 0.0015384049469618407, 'regularization': 0.5732392084195471}\n",
      "epoch: 58 mean_loss: {'supervision': 0.0024329984046924306, 'regularization': 0.5707536782206852}\n",
      "epoch: 59 mean_loss: {'supervision': 0.0016079282461786853, 'regularization': 0.5683097159919751}\n",
      "epoch: 60 mean_loss: {'supervision': 0.0024165259587541196, 'regularization': 0.5659035054494237}\n",
      "epoch: 61 mean_loss: {'supervision': 0.0023602218603630872, 'regularization': 0.5635357724521479}\n",
      "epoch: 62 mean_loss: {'supervision': 0.0013735947349866953, 'regularization': 0.5612041361219067}\n",
      "epoch: 63 mean_loss: {'supervision': 0.0031581078320805575, 'regularization': 0.558913617472844}\n",
      "epoch: 64 mean_loss: {'supervision': 0.001443116603564811, 'regularization': 0.5566620517260772}\n",
      "epoch: 65 mean_loss: {'supervision': 0.0011667013576670074, 'regularization': 0.554448520827432}\n",
      "epoch: 66 mean_loss: {'supervision': 0.0011589469711088141, 'regularization': 0.552272223937496}\n",
      "epoch: 67 mean_loss: {'supervision': 0.0015340728075416386, 'regularization': 0.5501315568429971}\n",
      "epoch: 68 mean_loss: {'supervision': 0.0010973613085566419, 'regularization': 0.5480286317857866}\n",
      "epoch: 69 mean_loss: {'supervision': 0.001163947900731351, 'regularization': 0.5459675828764244}\n",
      "epoch: 70 mean_loss: {'supervision': 0.0015735603482111909, 'regularization': 0.543951535015029}\n",
      "epoch: 71 mean_loss: {'supervision': 0.0009916558426473101, 'regularization': 0.5419876017179872}\n",
      "epoch: 72 mean_loss: {'supervision': 0.0010028856190661371, 'regularization': 0.5400860971789829}\n",
      "epoch: 73 mean_loss: {'supervision': 0.0015831421094520733, 'regularization': 0.5382567113913668}\n",
      "epoch: 74 mean_loss: {'supervision': 0.0009961638899507585, 'regularization': 0.5365079113090427}\n",
      "epoch: 75 mean_loss: {'supervision': 0.0008201622563135218, 'regularization': 0.5348487649653226}\n",
      "epoch: 76 mean_loss: {'supervision': 0.001113413042474472, 'regularization': 0.5332836048238414}\n",
      "epoch: 77 mean_loss: {'supervision': 0.0010152221522737033, 'regularization': 0.531812230783274}\n",
      "epoch: 78 mean_loss: {'supervision': 0.0010586315422246291, 'regularization': 0.530431549041473}\n",
      "epoch: 79 mean_loss: {'supervision': 0.0009808535963024498, 'regularization': 0.5291362056511085}\n",
      "epoch: 80 mean_loss: {'supervision': 0.0009185821028681791, 'regularization': 0.5279191481192229}\n",
      "epoch: 81 mean_loss: {'supervision': 0.0011407325554059508, 'regularization': 0.5267717587363787}\n",
      "epoch: 82 mean_loss: {'supervision': 0.0012057934691859739, 'regularization': 0.525683924575064}\n",
      "epoch: 83 mean_loss: {'supervision': 0.00108940827934309, 'regularization': 0.5246446946062984}\n",
      "epoch: 84 mean_loss: {'supervision': 0.00111544181554794, 'regularization': 0.5236437615435598}\n",
      "epoch: 85 mean_loss: {'supervision': 0.000781320283819494, 'regularization': 0.5226718040747116}\n",
      "epoch: 86 mean_loss: {'supervision': 0.0008523359052074364, 'regularization': 0.521719880182312}\n",
      "epoch: 87 mean_loss: {'supervision': 0.0007446103714248864, 'regularization': 0.5207799944900703}\n",
      "epoch: 88 mean_loss: {'supervision': 0.001711449627890833, 'regularization': 0.5198470048408323}\n",
      "epoch: 89 mean_loss: {'supervision': 0.0008241495716803428, 'regularization': 0.5189167248155725}\n",
      "epoch: 90 mean_loss: {'supervision': 0.0007289563334303812, 'regularization': 0.5179879837071967}\n",
      "epoch: 91 mean_loss: {'supervision': 0.0007863499680748302, 'regularization': 0.5170615672044784}\n",
      "epoch: 92 mean_loss: {'supervision': 0.0009271702642572242, 'regularization': 0.5161403922994758}\n",
      "epoch: 93 mean_loss: {'supervision': 0.0009947869955815231, 'regularization': 0.5152296214627418}\n",
      "epoch: 94 mean_loss: {'supervision': 0.0009926557121173452, 'regularization': 0.5143348852700994}\n",
      "epoch: 95 mean_loss: {'supervision': 0.0012854981553815585, 'regularization': 0.5134604359696475}\n",
      "epoch: 96 mean_loss: {'supervision': 0.0006646378928086978, 'regularization': 0.512608227800454}\n",
      "epoch: 97 mean_loss: {'supervision': 0.0006997245144295839, 'regularization': 0.511778353067897}\n",
      "epoch: 98 mean_loss: {'supervision': 0.000684647516030767, 'regularization': 0.5109669532202438}\n",
      "epoch: 99 mean_loss: {'supervision': 0.0009222240018185227, 'regularization': 0.5101684283354068}\n",
      "epoch: 100 mean_loss: {'supervision': 0.0007635904944415292, 'regularization': 0.5093760568444059}\n",
      "epoch: 101 mean_loss: {'supervision': 0.000752395403003373, 'regularization': 0.5085830922696146}\n",
      "epoch: 102 mean_loss: {'supervision': 0.000690967154951598, 'regularization': 0.5077856507523253}\n",
      "epoch: 103 mean_loss: {'supervision': 0.0006535697796213377, 'regularization': 0.5069834733943261}\n",
      "epoch: 104 mean_loss: {'supervision': 0.0006273099596225459, 'regularization': 0.5061795420424308}\n",
      "epoch: 105 mean_loss: {'supervision': 0.0005666565538199374, 'regularization': 0.5053791392868808}\n",
      "epoch: 106 mean_loss: {'supervision': 0.0009983482186413971, 'regularization': 0.5045887746317115}\n",
      "epoch: 107 mean_loss: {'supervision': 0.0007764205373436336, 'regularization': 0.5038148617599363}\n",
      "epoch: 108 mean_loss: {'supervision': 0.0005242592259145947, 'regularization': 0.5030617292163552}\n",
      "epoch: 109 mean_loss: {'supervision': 0.0005755172965601128, 'regularization': 0.502335357584861}\n",
      "epoch: 110 mean_loss: {'supervision': 0.0006275602864290207, 'regularization': 0.5016405016797483}\n",
      "epoch: 111 mean_loss: {'supervision': 0.0005845599520033608, 'regularization': 0.5009798258497633}\n",
      "epoch: 112 mean_loss: {'supervision': 0.0005872800308585366, 'regularization': 0.5003536650380447}\n",
      "epoch: 113 mean_loss: {'supervision': 0.0006400232748910888, 'regularization': 0.4997606616386332}\n",
      "epoch: 114 mean_loss: {'supervision': 0.0005892647362766006, 'regularization': 0.499198273386618}\n",
      "epoch: 115 mean_loss: {'supervision': 0.0010884589932439858, 'regularization': 0.4986635774652299}\n",
      "epoch: 116 mean_loss: {'supervision': 0.0004703103344355755, 'regularization': 0.4981541446795117}\n",
      "epoch: 117 mean_loss: {'supervision': 0.0007800506803221109, 'regularization': 0.4976675871055583}\n",
      "epoch: 118 mean_loss: {'supervision': 0.0005445585712598825, 'regularization': 0.4972024487911504}\n",
      "epoch: 119 mean_loss: {'supervision': 0.0005038603748981525, 'regularization': 0.4967576926741485}\n",
      "epoch: 120 mean_loss: {'supervision': 0.0005700607744199749, 'regularization': 0.4963324919958384}\n",
      "epoch: 121 mean_loss: {'supervision': 0.0005792224425594833, 'regularization': 0.4959259003155925}\n",
      "epoch: 122 mean_loss: {'supervision': 0.000717420987681973, 'regularization': 0.49553658636646314}\n",
      "epoch: 123 mean_loss: {'supervision': 0.0007424021594703176, 'regularization': 0.49516324059185235}\n",
      "epoch: 124 mean_loss: {'supervision': 0.0006201568622194824, 'regularization': 0.4948036457190151}\n",
      "epoch: 125 mean_loss: {'supervision': 0.00046983452717176753, 'regularization': 0.4944551896770426}\n",
      "epoch: 126 mean_loss: {'supervision': 0.0007449387262839582, 'regularization': 0.4941150950998728}\n",
      "epoch: 127 mean_loss: {'supervision': 0.0004874742379395504, 'regularization': 0.49378006904608995}\n",
      "epoch: 128 mean_loss: {'supervision': 0.0005626259234886784, 'regularization': 0.4934472737960893}\n",
      "epoch: 129 mean_loss: {'supervision': 0.00046714162431335456, 'regularization': 0.4931131199010048}\n",
      "epoch: 130 mean_loss: {'supervision': 0.00042208280142575334, 'regularization': 0.49277408985860927}\n",
      "epoch: 131 mean_loss: {'supervision': 0.0005146963360816658, 'regularization': 0.4924234074678111}\n",
      "epoch: 132 mean_loss: {'supervision': 0.0006106330854735872, 'regularization': 0.492055360718557}\n",
      "epoch: 133 mean_loss: {'supervision': 0.0005464983410763764, 'regularization': 0.4916641411872664}\n",
      "epoch: 134 mean_loss: {'supervision': 0.0005408276503788937, 'regularization': 0.49124374267039905}\n",
      "epoch: 135 mean_loss: {'supervision': 0.0004646424918748444, 'regularization': 0.49078793198093396}\n",
      "epoch: 136 mean_loss: {'supervision': 0.0005776517417922062, 'regularization': 0.4902904090654384}\n",
      "epoch: 137 mean_loss: {'supervision': 0.0006710657432282271, 'regularization': 0.48974259405552967}\n",
      "epoch: 138 mean_loss: {'supervision': 0.0004061402705502584, 'regularization': 0.48913025632395046}\n",
      "epoch: 139 mean_loss: {'supervision': 0.0004966142774245917, 'regularization': 0.4884267422967127}\n",
      "epoch: 140 mean_loss: {'supervision': 0.0004153044956130678, 'regularization': 0.4875792296762809}\n",
      "epoch: 141 mean_loss: {'supervision': 0.00041532428928552236, 'regularization': 0.4864814774710974}\n",
      "epoch: 142 mean_loss: {'supervision': 0.0003976150705853694, 'regularization': 0.4849103863455676}\n",
      "epoch: 143 mean_loss: {'supervision': 0.0004849978315279613, 'regularization': 0.48237913373104063}\n",
      "epoch: 144 mean_loss: {'supervision': 0.0004169955018660703, 'regularization': 0.4779120502628799}\n",
      "epoch: 145 mean_loss: {'supervision': 0.0003698347356118664, 'regularization': 0.4706149174857768}\n",
      "epoch: 146 mean_loss: {'supervision': 0.0005229443260730727, 'regularization': 0.46279976138947343}\n",
      "epoch: 147 mean_loss: {'supervision': 0.0004169436945702219, 'regularization': 0.45832060775220007}\n",
      "epoch: 148 mean_loss: {'supervision': 0.0005090467738109109, 'regularization': 0.45656830751510474}\n",
      "epoch: 149 mean_loss: {'supervision': 0.00038731827939746, 'regularization': 0.4558420222441774}\n",
      "epoch: 150 mean_loss: {'supervision': 0.0003578812432764263, 'regularization': 0.4554385688425925}\n",
      "epoch: 151 mean_loss: {'supervision': 0.00041215558055050577, 'regularization': 0.4551421995981449}\n",
      "epoch: 152 mean_loss: {'supervision': 0.0004104056410073694, 'regularization': 0.45488280681972193}\n",
      "epoch: 153 mean_loss: {'supervision': 0.0003688185445141504, 'regularization': 0.45463414931194}\n",
      "epoch: 154 mean_loss: {'supervision': 0.0010606247424435377, 'regularization': 0.4543826895932943}\n",
      "epoch: 155 mean_loss: {'supervision': 0.000437756264007728, 'regularization': 0.4541198174728424}\n",
      "epoch: 156 mean_loss: {'supervision': 0.00031457810453764735, 'regularization': 0.45383932495738294}\n",
      "epoch: 157 mean_loss: {'supervision': 0.0005410024177621239, 'regularization': 0.4535365572312052}\n",
      "epoch: 158 mean_loss: {'supervision': 0.00037442109536948466, 'regularization': 0.4532088335432664}\n",
      "epoch: 159 mean_loss: {'supervision': 0.00030910424464927007, 'regularization': 0.45285671519605064}\n",
      "epoch: 160 mean_loss: {'supervision': 0.00038294435045360805, 'regularization': 0.45248423323995385}\n",
      "epoch: 161 mean_loss: {'supervision': 0.00033616809789712115, 'regularization': 0.4520989281656821}\n",
      "epoch: 162 mean_loss: {'supervision': 0.0003860650947488519, 'regularization': 0.4517109979268528}\n",
      "epoch: 163 mean_loss: {'supervision': 0.00038056042272965776, 'regularization': 0.4513318432093535}\n",
      "epoch: 164 mean_loss: {'supervision': 0.0008095921188443608, 'regularization': 0.4509721880227612}\n",
      "epoch: 165 mean_loss: {'supervision': 0.0004704095205170237, 'regularization': 0.45063944552339663}\n",
      "epoch: 166 mean_loss: {'supervision': 0.0003596853199878272, 'regularization': 0.4503355233633167}\n",
      "epoch: 167 mean_loss: {'supervision': 0.0003699695948674528, 'regularization': 0.45005681397786834}\n",
      "epoch: 168 mean_loss: {'supervision': 0.000475314456717363, 'regularization': 0.44979616044597226}\n",
      "epoch: 169 mean_loss: {'supervision': 0.0003071406597276409, 'regularization': 0.449545352071547}\n",
      "epoch: 170 mean_loss: {'supervision': 0.00030752314797414756, 'regularization': 0.4492967850755575}\n",
      "epoch: 171 mean_loss: {'supervision': 0.00028885032375596805, 'regularization': 0.4490437092870053}\n",
      "epoch: 172 mean_loss: {'supervision': 0.0003310369181327783, 'regularization': 0.4487809199685636}\n",
      "epoch: 173 mean_loss: {'supervision': 0.00036904474238797413, 'regularization': 0.4485049773573452}\n",
      "epoch: 174 mean_loss: {'supervision': 0.0003050558123018315, 'regularization': 0.44821391254784054}\n",
      "epoch: 175 mean_loss: {'supervision': 0.000374740944752852, 'regularization': 0.4479076527611111}\n",
      "epoch: 176 mean_loss: {'supervision': 0.00035360036926158363, 'regularization': 0.44758801141888827}\n",
      "epoch: 177 mean_loss: {'supervision': 0.0004565430949623728, 'regularization': 0.44725860406898366}\n",
      "epoch: 178 mean_loss: {'supervision': 0.00028578278103399556, 'regularization': 0.4469243423830758}\n",
      "epoch: 179 mean_loss: {'supervision': 0.00024284289891178022, 'regularization': 0.446590794238089}\n",
      "epoch: 180 mean_loss: {'supervision': 0.0006910843438841109, 'regularization': 0.4462630111806605}\n",
      "epoch: 181 mean_loss: {'supervision': 0.00030926253077271683, 'regularization': 0.4459445074161825}\n",
      "epoch: 182 mean_loss: {'supervision': 0.0002922683433608543, 'regularization': 0.44563703834522367}\n",
      "epoch: 183 mean_loss: {'supervision': 0.0002763302655585095, 'regularization': 0.4453402616938965}\n",
      "epoch: 184 mean_loss: {'supervision': 0.0003089909387031048, 'regularization': 0.44505215863773506}\n",
      "epoch: 185 mean_loss: {'supervision': 0.000496880818264509, 'regularization': 0.4447696052291572}\n",
      "epoch: 186 mean_loss: {'supervision': 0.0002823320916913916, 'regularization': 0.4444894678100735}\n",
      "epoch: 187 mean_loss: {'supervision': 0.00026144480566207595, 'regularization': 0.4442087069749655}\n",
      "epoch: 188 mean_loss: {'supervision': 0.00032036162053216766, 'regularization': 0.44392468096358484}\n",
      "epoch: 189 mean_loss: {'supervision': 0.00025646760743015763, 'regularization': 0.4436354723227831}\n",
      "epoch: 190 mean_loss: {'supervision': 0.0003471160287664567, 'regularization': 0.4433403271437423}\n",
      "epoch: 191 mean_loss: {'supervision': 0.00036877304506619673, 'regularization': 0.44303984399805685}\n",
      "epoch: 192 mean_loss: {'supervision': 0.0002469194600176714, 'regularization': 0.4427361509589172}\n",
      "epoch: 193 mean_loss: {'supervision': 0.0002682861998469264, 'regularization': 0.44243242962499796}\n",
      "epoch: 194 mean_loss: {'supervision': 0.0002751123057584844, 'regularization': 0.44213266888767455}\n",
      "epoch: 195 mean_loss: {'supervision': 0.0002743129477235965, 'regularization': 0.44184092149097803}\n",
      "epoch: 196 mean_loss: {'supervision': 0.0002659421828927564, 'regularization': 0.4415612553057573}\n",
      "epoch: 197 mean_loss: {'supervision': 0.0002778020887801826, 'regularization': 0.4412977871536672}\n",
      "epoch: 198 mean_loss: {'supervision': 0.0003688994148874979, 'regularization': 0.4410544735246644}\n",
      "epoch: 199 mean_loss: {'supervision': 0.0003294410355452944, 'regularization': 0.44083465016125467}\n",
      "epoch: 200 mean_loss: {'supervision': 0.00023135512532447073, 'regularization': 0.44064014195190104}\n",
      "epoch: 201 mean_loss: {'supervision': 0.00030815268637491666, 'regularization': 0.4404708501157619}\n",
      "epoch: 202 mean_loss: {'supervision': 0.0003100150227703985, 'regularization': 0.4403249835688955}\n",
      "epoch: 203 mean_loss: {'supervision': 0.00030758690171084856, 'regularization': 0.44019954548306073}\n",
      "epoch: 204 mean_loss: {'supervision': 0.00023259348275214378, 'regularization': 0.4400911412908742}\n",
      "epoch: 205 mean_loss: {'supervision': 0.00021772166960536368, 'regularization': 0.4399965183161714}\n",
      "epoch: 206 mean_loss: {'supervision': 0.0007005477591260949, 'regularization': 0.439912811808852}\n",
      "epoch: 207 mean_loss: {'supervision': 0.0002464374736139067, 'regularization': 0.4398376197812351}\n",
      "epoch: 208 mean_loss: {'supervision': 0.0002292721637774813, 'regularization': 0.4397690426405139}\n",
      "epoch: 209 mean_loss: {'supervision': 0.00021809905927354053, 'regularization': 0.439705588103434}\n",
      "epoch: 210 mean_loss: {'supervision': 0.0003000285896124017, 'regularization': 0.43964604095384624}\n",
      "epoch: 211 mean_loss: {'supervision': 0.00029242557357218944, 'regularization': 0.43958915438697294}\n",
      "epoch: 212 mean_loss: {'supervision': 0.00021615413940820192, 'regularization': 0.43953439270101935}\n",
      "epoch: 213 mean_loss: {'supervision': 0.0001972111698267941, 'regularization': 0.4394811639036349}\n",
      "epoch: 214 mean_loss: {'supervision': 0.00024345739420642847, 'regularization': 0.4394289038186181}\n",
      "epoch: 215 mean_loss: {'supervision': 0.0003634150545668675, 'regularization': 0.4393770647155643}\n",
      "epoch: 216 mean_loss: {'supervision': 0.0002524697439000439, 'regularization': 0.4393250560018457}\n",
      "epoch: 217 mean_loss: {'supervision': 0.00019282705499876532, 'regularization': 0.43926799880576206}\n",
      "epoch: 218 mean_loss: {'supervision': 0.00022225501071164645, 'regularization': 0.4392098866044418}\n",
      "epoch: 219 mean_loss: {'supervision': 0.00021859473059493865, 'regularization': 0.4391511292823578}\n",
      "epoch: 220 mean_loss: {'supervision': 0.00022839177570145217, 'regularization': 0.4390903978133568}\n",
      "epoch: 221 mean_loss: {'supervision': 0.00021351555328157675, 'regularization': 0.43902634611708125}\n",
      "epoch: 222 mean_loss: {'supervision': 0.0002531456410942481, 'regularization': 0.43895746432835203}\n",
      "epoch: 223 mean_loss: {'supervision': 0.00021793964161484814, 'regularization': 0.4388818515674516}\n",
      "epoch: 224 mean_loss: {'supervision': 0.0002752780334304704, 'regularization': 0.43879765680586436}\n",
      "epoch: 225 mean_loss: {'supervision': 0.00021291823687760153, 'regularization': 0.43870266011217796}\n",
      "epoch: 226 mean_loss: {'supervision': 0.00020905585849814823, 'regularization': 0.4385941657196047}\n",
      "epoch: 227 mean_loss: {'supervision': 0.00020794733845301398, 'regularization': 0.43846960636479204}\n",
      "epoch: 228 mean_loss: {'supervision': 0.0001994429373556505, 'regularization': 0.43832723311189825}\n",
      "epoch: 229 mean_loss: {'supervision': 0.00021582553217237962, 'regularization': 0.4381672113283601}\n",
      "epoch: 230 mean_loss: {'supervision': 0.0002432683833712747, 'regularization': 0.4379935804759165}\n",
      "epoch: 231 mean_loss: {'supervision': 0.0001835621868456622, 'regularization': 0.43781442830913236}\n",
      "epoch: 232 mean_loss: {'supervision': 0.0001900088319005719, 'regularization': 0.437640580815306}\n",
      "epoch: 233 mean_loss: {'supervision': 0.00019789992418012576, 'regularization': 0.4374823443981757}\n",
      "epoch: 234 mean_loss: {'supervision': 0.00023785820102380063, 'regularization': 0.43734611800421486}\n",
      "epoch: 235 mean_loss: {'supervision': 0.00022896020582820816, 'regularization': 0.43723329071572853}\n",
      "epoch: 236 mean_loss: {'supervision': 0.00020202080965140356, 'regularization': 0.4371415743888418}\n",
      "epoch: 237 mean_loss: {'supervision': 0.00023782708801158492, 'regularization': 0.4370670124721906}\n",
      "epoch: 238 mean_loss: {'supervision': 0.0002130866857550766, 'regularization': 0.43700530647296554}\n",
      "epoch: 239 mean_loss: {'supervision': 0.0001813914649759626, 'regularization': 0.4369527119581875}\n",
      "epoch: 240 mean_loss: {'supervision': 0.0002431328708833713, 'regularization': 0.4369072039580672}\n",
      "epoch: 241 mean_loss: {'supervision': 0.00022625197120115583, 'regularization': 0.43686683517336977}\n",
      "epoch: 242 mean_loss: {'supervision': 0.00022393037558888823, 'regularization': 0.4368301144619317}\n",
      "epoch: 243 mean_loss: {'supervision': 0.0002805566539124437, 'regularization': 0.43679615877832517}\n",
      "epoch: 244 mean_loss: {'supervision': 0.00019297370026330425, 'regularization': 0.4367642205172666}\n",
      "epoch: 245 mean_loss: {'supervision': 0.0002192661665343137, 'regularization': 0.4367337803682194}\n",
      "epoch: 246 mean_loss: {'supervision': 0.0002227254821444277, 'regularization': 0.4367044520233375}\n",
      "epoch: 247 mean_loss: {'supervision': 0.00024405933989838185, 'regularization': 0.43667583723329295}\n",
      "epoch: 248 mean_loss: {'supervision': 0.0002698692977665819, 'regularization': 0.4366478174875808}\n",
      "epoch: 249 mean_loss: {'supervision': 0.00019869245800800605, 'regularization': 0.43662022648777005}\n",
      "epoch: 250 mean_loss: {'supervision': 0.00019830137250553272, 'regularization': 0.4365929145429885}\n",
      "epoch: 251 mean_loss: {'supervision': 0.0001975730585520142, 'regularization': 0.4365657471560759}\n",
      "epoch: 252 mean_loss: {'supervision': 0.00016460607290162067, 'regularization': 0.43653858050594596}\n",
      "epoch: 253 mean_loss: {'supervision': 0.00016195758757471936, 'regularization': 0.436511296643705}\n",
      "epoch: 254 mean_loss: {'supervision': 0.0001791150495743583, 'regularization': 0.4364837680086394}\n",
      "epoch: 255 mean_loss: {'supervision': 0.00018342387028055413, 'regularization': 0.43645584398933224}\n"
     ]
    }
   ],
   "source": [
    "seed(42)\n",
    "\n",
    "train_dataset = MultiLayerLogicDataset()\n",
    "batch_size = 128\n",
    "loss_fn = BCELoss()\n",
    "dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "epochs = 2 ** 8\n",
    "\n",
    "noise_scales = [0.8044418562388825, 0.09504842110818068, 0.42150252547988554]\n",
    "lr = 28.353628505319445\n",
    "reg_coef = 0.05891286543711857\n",
    "\n",
    "def train():\n",
    "    model.train()\n",
    "    losses = {\n",
    "        'supervision': [],\n",
    "        'regularization': []\n",
    "    }\n",
    "    mean_preds = []\n",
    "    for x, y in dataloader:\n",
    "        pred = model(x.to(device))\n",
    "        mean_preds.append(torch.mean(pred).item())\n",
    "        loss = loss_fn(pred, y.to(device))\n",
    "        losses['supervision'].append(loss.item())\n",
    "        reg_loss = model.reg_loss()\n",
    "        losses['regularization'].append(reg_loss.item())\n",
    "        loss += reg_coef * reg_loss\n",
    "\n",
    "        # Backpropagation\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    mean_loss = {k: np.mean(v) for k, v in losses.items()}\n",
    "    return mean_loss\n",
    "\n",
    "model = MultiLayerLFLWithNegation(n_input=8, n_hiddens=[32, 16, 2], layer_kwargs=[{'noise_scale': noise_scales[i]} for i in range(3)]).to(device)\n",
    "optimizer = MADGRAD([{'params': model.parameters()}], lr=lr)\n",
    "loss_records = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    mean_loss = train()\n",
    "    print(f'epoch: {epoch} mean_loss: {mean_loss}')\n",
    "    loss_records.append(mean_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "df22fad6-64ec-43b1-b671-b7c50e389145",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_expressions = model.expression([sp.Symbol(f'x{i}') for i in range(8)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "593e1177-cb72-420c-8b83-9fbbc69e6a00",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(x0 & x1 & x4 & x5) | (x0 & x1 & x6 & x7) | (x2 & x3 & x4 & x5) | (x2 & x3 & x6 & x7 & ~x1) | (x1 & x2 & x3 & x6 & x7 & ~x4) | (x1 & x2 & x3 & x4 & x6 & x7 & ~x0)\n",
      "((~x6 & ~x7) | (~x0 & ~x1 & ~x4 & ~x5) | (~x2 & ~x3 & ~x4 & ~x5)) & ((x6 & ~x3 & ~x4) | (~x2 & ~x3 & ~x6) | (~x0 & ~x1 & ~x4 & ~x5) | (~x0 & ~x1 & ~x6 & ~x7)) & ((x1 & ~x2) | (x0 & x1 & x6) | (~x6 & ~x7) | (x0 & x1 & x6 & x7) | (x2 & x3 & x4 & x5) | (x6 & ~x3 & ~x4) | (x7 & ~x3 & ~x6) | (~x2 & ~x3 & ~x6) | (x2 & x3 & x6 & x7 & ~x1) | (x1 & x2 & x3 & x6 & x7 & ~x4) | (~x0 & ~x1 & ~x4 & ~x5) | (~x0 & ~x1 & ~x6 & ~x7) | (x1 & x2 & x3 & x4 & x6 & x7 & ~x0))\n"
     ]
    }
   ],
   "source": [
    "print(model_expressions[0])\n",
    "print(model_expressions[1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
