import torch as t
from torch import nn
from torch.nn import functional as nnF
from functools import lru_cache


class NBitsGelu(t.autograd.Function):
    xs_const = [
        t.tensor([0.00000000000000000]),
        t.tensor([-0.46132802110453103,
            -0.00000000000000055,
            0.46132802110452981]),
        t.tensor([-2.40884797702306042,
            -0.71192497573751246,
            -0.32633073511790534,
            -0.00000000000000026,
            0.32633073511790522,
            0.71192497573751290,
            2.40884797702305997]),
        t.tensor([-2.64952682104609716,
            -2.02781335198672830,
            -0.92408609910716744,
            -0.66777523353780222,
            -0.46992418705007644,
            -0.30085904432656912,
            -0.14501492885208012,
            0.00000000000000026,
            0.14501492885207939,
            0.30085904432656857,
            0.46992418705007816,
            0.66777523353780444,
            0.92408609910717177,
            2.02781335198672696,
            2.64952682104609893]),
        t.tensor([-2.92710030191747705,
            -2.37891574922338567,
            -1.98382220976643753,
            -1.52832260384709029,
            -1.36664201183946266,
            -1.19128153542139237,
            -0.97707449316668482,
            -0.81655879956253341,
            -0.68992865286588045,
            -0.56706086359017105,
            -0.46063400362679208,
            -0.36083071867741179,
            -0.27213466889115817,
            -0.18725333960087970,
            -0.09545320590096010,
            -0.00000000000016468,
            0.09545320590116181,
            0.18725333960099158,
            0.27213466889079202,
            0.36083071867779953,
            0.46063400362653045,
            0.56706086359029162,
            0.68992865286583815,
            0.81655879956254707,
            0.97707449316669170,
            1.19128153542138793,
            1.36664201183946377,
            1.52832260384708984,
            1.98382220976642665,
            2.37891574922339410,
            2.92710030191748238]),
        t.tensor([-2.82908059050189387,
            -2.26284818482951255,
            -1.82685489973241766,
            -1.75049771653586728,
            -1.68804605157983634,
            -1.62557643759734938,
            -1.56303423098593863,
            -1.50039423041699127,
            -1.43763056102797160,
            -1.37471847611981346,
            -1.31163666119048195,
            -1.24836995326919165,
            -1.18491233070203483,
            -1.12126991197296477,
            -1.05746356803751462,
            -0.99353063436294531,
            -0.92952516171149446,
            -0.86551623302573599,
            -0.80158413899831715,
            -0.73781462415744559,
            -0.67429187959286629,
            -0.61109130208778561,
            -0.54827312030519804,
            -0.48587777060813292,
            -0.42392347923054063,
            -0.36240603621418210,
            -0.30130037583300529,
            -0.24056338119248749,
            -0.18013730508720732,
            -0.11995329902345060,
            -0.05993471089287655,
            0.00000000000000023,
            0.05993471089287550,
            0.11995329902344819,
            0.18013730508720646,
            0.24056338119248902,
            0.30130037583300756,
            0.36240603621418088,
            0.42392347923053325,
            0.48587777060813098,
            0.54827312030520670,
            0.61109130208778950,
            0.67429187959286507,
            0.73781462415744126,
            0.80158413899831116,
            0.86551623302573999,
            0.92952516171149524,
            0.99353063436294486,
            1.05746356803751618,
            1.12126991197296499,
            1.18491233070203217,
            1.24836995326919209,
            1.31163666119048283,
            1.37471847611981302,
            1.43763056102797049,
            1.50039423041699171,
            1.56303423098594063,
            1.62557643759735071,
            1.68804605157983922,
            1.75049771653586461,
            1.82685489973241588,
            2.26284818482951300,
            2.82908059050189475]),
        t.tensor([-2.88528048414695926,
            -2.32849632999631462,
            -1.91753552506321467,
            -1.87504313068165085,
            -1.84379099310485439,
            -1.81254936311919557,
            -1.78130730712385787,
            -1.75006454499142006,
            -1.71882077022800850,
            -1.68757564630236745,
            -1.65632880888022860,
            -1.62507986892908796,
            -1.59382841673355569,
            -1.56257402684728408,
            -1.53131626398879894,
            -1.50005468986750046,
            -1.46878887090272481,
            -1.43751838677328969,
            -1.40624283970796338,
            -1.37496186439904089,
            -1.34367513839225872,
            -1.31238239277768609,
            -1.28108342297808231,
            -1.24977809940541729,
            -1.21846637773280175,
            -1.18714830851012998,
            -1.15582404583772758,
            -1.12449385480492792,
            -1.09315811740081514,
            -1.06181733661348310,
            -1.03047213845306129,
            -0.99912327166296688,
            -0.96777160492377334,
            -0.93641812140432301,
            -0.90506391057489488,
            -0.87371015726572165,
            -0.84235812802918786,
            -0.81100915494337356,
            -0.77966461707514034,
            -0.74832591989959718,
            -0.71699447304614849,
            -0.68567166680602054,
            -0.65435884788909637,
            -0.62305729495662410,
            -0.59176819447843099,
            -0.56049261746823575,
            -0.52923149763717003,
            -0.49798561147569753,
            -0.46675556072766844,
            -0.43554175766067044,
            -0.40434441346673067,
            -0.37316353004905911,
            -0.34199889536879641,
            -0.31085008244234680,
            -0.27971645199901091,
            -0.24859715873170521,
            -0.21749116100424748,
            -0.18639723381580167,
            -0.15531398477054340,
            -0.12423987275554012,
            -0.09317322899443597,
            -0.06211228011611894,
            -0.03105517285702154,
            0.00000000000000004,
            0.03105517285702139,
            0.06211228011611803,
            0.09317322899443668,
            0.12423987275553956,
            0.15531398477054290,
            0.18639723381580456,
            0.21749116100424717,
            0.24859715873170263,
            0.27971645199901546,
            0.31085008244235057,
            0.34199889536879102,
            0.37316353004906105,
            0.40434441346673899,
            0.43554175766066527,
            0.46675556072766483,
            0.49798561147570342,
            0.52923149763716848,
            0.56049261746823231,
            0.59176819447842566,
            0.62305729495662154,
            0.65435884788910181,
            0.68567166680601865,
            0.71699447304614883,
            0.74832591989959074,
            0.77966461707513235,
            0.81100915494337822,
            0.84235812802918753,
            0.87371015726572698,
            0.90506391057490554,
            0.93641812140432235,
            0.96777160492377357,
            0.99912327166296488,
            1.03047213845305397,
            1.06181733661348288,
            1.09315811740082225,
            1.12449385480492769,
            1.15582404583772447,
            1.18714830851012998,
            1.21846637773280375,
            1.24977809940542017,
            1.28108342297807853,
            1.31238239277768542,
            1.34367513839226094,
            1.37496186439903934,
            1.40624283970796471,
            1.43751838677328969,
            1.46878887090272525,
            1.50005468986750046,
            1.53131626398879805,
            1.56257402684728186,
            1.59382841673355613,
            1.62507986892909018,
            1.65632880888023148,
            1.68757564630236989,
            1.71882077022800828,
            1.75006454499141495,
            1.78130730712385676,
            1.81254936311919868,
            1.84379099310485528,
            1.87504313068164885,
            1.91753552506321157,
            2.32849632999631195,
            2.88528048414696370]),
        t.tensor([-2.90212209091241302,
            -2.35928671715313376,
            -1.96430075965753348,
            -1.93750744254844132,
            -1.92187759368646516,
            -1.90625314738377316,
            -1.89062870899535063,
            -1.87500427452834773,
            -1.85937984089812525,
            -1.84375540475908295,
            -1.82813096250145790,
            -1.81250651024946530,
            -1.79688204386070516,
            -1.78125755892688842,
            -1.76563305077596233,
            -1.75000851447576466,
            -1.73438394483921487,
            -1.71875933643113976,
            -1.70313468357683817,
            -1.68750998037237299,
            -1.67188522069669077,
            -1.65626039822562410,
            -1.64063550644774980,
            -1.62501053868220846,
            -1.60938548809844661,
            -1.59376034773792741,
            -1.57813511053775346,
            -1.56250976935625396,
            -1.54688431700046580,
            -1.53125874625546832,
            -1.51563304991553660,
            -1.50000722081701632,
            -1.48438125187289738,
            -1.46875513610890573,
            -1.45312886670109065,
            -1.43750243701472002,
            -1.42187584064437011,
            -1.40624907145508038,
            -1.39062212362433990,
            -1.37499499168481631,
            -1.35936767056754992,
            -1.34374015564546734,
            -1.32811244277697549,
            -1.31248452834937046,
            -1.29685640932190771,
            -1.28122808326822080,
            -1.26559954841781797,
            -1.24997080369646585,
            -1.23434184876513586,
            -1.21871268405723976,
            -1.20308331081392472,
            -1.18745373111709429,
            -1.17182394791994771,
            -1.15619396507467820,
            -1.14056378735716502,
            -1.12493342048831191,
            -1.10930287115182380,
            -1.09367214700818405,
            -1.07804125670456274,
            -1.06241020988049395,
            -1.04677901716910315,
            -1.03114769019364827,
            -1.01551624155930442,
            -0.99988468484002524,
            -0.98425303456034119,
            -0.96862130617198550,
            -0.95298951602533866,
            -0.93735768133569708,
            -0.92172582014420479,
            -0.90609395127362524,
            -0.89046209427892753,
            -0.87483026939284547,
            -0.85919849746646315,
            -0.84356679990495920,
            -0.82793519859883780,
            -0.81230371585069472,
            -0.79667237429787219,
            -0.78104119683128870,
            -0.76541020651068370,
            -0.74977942647666040,
            -0.73414887985993627,
            -0.71851858968814930,
            -0.70288857879062294,
            -0.68725886970161210,
            -0.67162948456237115,
            -0.65600044502265797,
            -0.64037177214205676,
            -0.62474348629167886,
            -0.60911560705676504,
            -0.59348815314065406,
            -0.57786114227070229,
            -0.56223459110662533,
            -0.54660851515181685,
            -0.53098292866810115,
            -0.51535784459445932,
            -0.49973327447025506,
            -0.48410922836328629,
            -0.46848571480323647,
            -0.45286274072091165,
            -0.43724031139364783,
            -0.42161843039722058,
            -0.40599709956470043,
            -0.39037631895238345,
            -0.37475608681324257,
            -0.35913639957801236,
            -0.34351725184409176,
            -0.32789863637249078,
            -0.31228054409280864,
            -0.29666296411643323,
            -0.28104588375784090,
            -0.26542928856407633,
            -0.24981316235227818,
            -0.23419748725514519,
            -0.21858224377419913,
            -0.20296741084060996,
            -0.18735296588336781,
            -0.17173888490450759,
            -0.15612514256100932,
            -0.14051171225312367,
            -0.12489856621860219,
            -0.10928567563248753,
            -0.09367301071195598,
            -0.07806054082574637,
            -0.06244823460763993,
            -0.04683606007349175,
            -0.03122398474118920,
            -0.01561197575305361,
            -0.00000000000000001,
            0.01561197575305308,
            0.03122398474118990,
            0.04683606007349152,
            0.06244823460763812,
            0.07806054082574569,
            0.09367301071195608,
            0.10928567563248735,
            0.12489856621860289,
            0.14051171225312384,
            0.15612514256100921,
            0.17173888490450637,
            0.18735296588336911,
            0.20296741084060804,
            0.21858224377419594,
            0.23419748725514794,
            0.24981316235227999,
            0.26542928856407960,
            0.28104588375783307,
            0.29666296411643089,
            0.31228054409281081,
            0.32789863637248962,
            0.34351725184409171,
            0.35913639957801463,
            0.37475608681324590,
            0.39037631895237779,
            0.40599709956470220,
            0.42161843039722463,
            0.43724031139364150,
            0.45286274072091520,
            0.46848571480323825,
            0.48410922836328263,
            0.49973327447025739,
            0.51535784459446143,
            0.53098292866809427,
            0.54660851515181996,
            0.56223459110663243,
            0.57786114227069840,
            0.59348815314065229,
            0.60911560705676715,
            0.62474348629167831,
            0.64037177214205820,
            0.65600044502265864,
            0.67162948456237759,
            0.68725886970161687,
            0.70288857879061017,
            0.71851858968813298,
            0.73414887985994182,
            0.74977942647666851,
            0.76541020651068470,
            0.78104119683129458,
            0.79667237429787030,
            0.81230371585068961,
            0.82793519859884246,
            0.84356679990496408,
            0.85919849746646237,
            0.87483026939284836,
            0.89046209427892520,
            0.90609395127361869,
            0.92172582014420412,
            0.93735768133570441,
            0.95298951602534021,
            0.96862130617197750,
            0.98425303456034507,
            0.99988468484002824,
            1.01551624155929576,
            1.03114769019364538,
            1.04677901716910893,
            1.06241020988049328,
            1.07804125670455875,
            1.09367214700818294,
            1.10930287115182891,
            1.12493342048831324,
            1.14056378735715991,
            1.15619396507467664,
            1.17182394791994615,
            1.18745373111709496,
            1.20308331081392517,
            1.21871268405723976,
            1.23434184876513742,
            1.24997080369646785,
            1.26559954841781486,
            1.28122808326821902,
            1.29685640932190838,
            1.31248452834936979,
            1.32811244277697527,
            1.34374015564546911,
            1.35936767056754926,
            1.37499499168481543,
            1.39062212362433968,
            1.40624907145508038,
            1.42187584064437145,
            1.43750243701471936,
            1.45312886670109176,
            1.46875513610890573,
            1.48438125187289649,
            1.50000722081701632,
            1.51563304991553371,
            1.53125874625546743,
            1.54688431700046691,
            1.56250976935625530,
            1.57813511053775146,
            1.59376034773792763,
            1.60938548809845039,
            1.62501053868220890,
            1.64063550644775069,
            1.65626039822562610,
            1.67188522069669032,
            1.68750998037236810,
            1.70313468357683373,
            1.71875933643114442,
            1.73438394483921554,
            1.75000851447576222,
            1.76563305077596300,
            1.78125755892688842,
            1.79688204386070294,
            1.81250651024946907,
            1.82813096250145990,
            1.84375540475907918,
            1.85937984089812192,
            1.87500427452834906,
            1.89062870899535418,
            1.90625314738377516,
            1.92187759368646116,
            1.93750744254844198,
            1.96430075965754036,
            2.35928671715313998,
            2.90212209091241968]),
    ]
    ys_const = [
        t.tensor([0.00000000000000000,
            1.00000000000000000]),
        t.tensor([-0.00760939905895985,
            0.32228164205929283,
            0.67771835794070634,
            1.00760939905895985]),
        t.tensor([-0.00109568445401875,
            -0.08859905421738420,
            0.12499325245301279,
            0.37208706831557098,
            0.62791293168442885,
            0.87500674754698693,
            1.08859905421738445,
            1.00109568445401886]),
        t.tensor([-0.00061544137315238,
            -0.05226400945224528,
            -0.10968093075636222,
            -0.01616241151827657,
            0.09284718920680775,
            0.20788265818235127,
            0.32538257439567753,
            0.44234954243304520,
            0.55765045756695475,
            0.67461742560432192,
            0.79211734181764903,
            0.90715281079319321,
            1.01616241151827857,
            1.10968093075636243,
            1.05226400945224530,
            1.00061544137315250]),
        t.tensor([-0.00029329340169840,
            -0.02854139952539195,
            -0.06641448071143671,
            -0.10915824625047750,
            -0.12825952808128421,
            -0.12406392804652215,
            -0.09986820686193564,
            -0.05358452610203352,
            -0.00010720712329275,
            0.05951007627620760,
            0.12435347674215998,
            0.19026991730183726,
            0.25588057946045256,
            0.32002152891506930,
            0.38804219558887709,
            0.46197742838139511,
            0.53802257161855382,
            0.61195780441124548,
            0.67997847108483567,
            0.74411942053955260,
            0.80973008269820845,
            0.87564652325779502,
            0.94048992372381446,
            1.00010720712328482,
            1.05358452610203734,
            1.09986820686193632,
            1.12406392804652100,
            1.12825952808128571,
            1.10915824625047810,
            1.06641448071143685,
            1.02854139952539136,
            1.00029329340169837]),
        t.tensor([-0.00038456582627704,
            -0.03558473441315251,
            -0.08052000068658792,
            -0.10725340496319002,
            -0.11364620936774648,
            -0.11872324596227261,
            -0.12298118622336887,
            -0.12622479141152740,
            -0.12824477968402076,
            -0.12882103148996180,
            -0.12772510670012566,
            -0.12472370596442466,
            -0.11958333584539720,
            -0.11207640028901288,
            -0.10198881774590096,
            -0.08912902765950530,
            -0.07333790442004298,
            -0.05449869461222381,
            -0.03254575041946415,
            -0.00747070445727027,
            0.02067505125705777,
            0.05178213983257415,
            0.08568898513139335,
            0.12219060292394657,
            0.16104891655868597,
            0.20200305165131152,
            0.24477836538698960,
            0.28909347104441679,
            0.33466504258847318,
            0.38121059236453880,
            0.42844964772304611,
            0.47610381714039773,
            0.52389618285960204,
            0.57155035227695261,
            0.61878940763545975,
            0.66533495741152693,
            0.71090652895558493,
            0.75522163461301073,
            0.79799694834868540,
            0.83895108344131109,
            0.87780939707605554,
            0.91431101486860944,
            0.94821786016742693,
            0.97932494874294029,
            1.00747070445726994,
            1.03254575041946217,
            1.05449869461222523,
            1.07333790442004129,
            1.08912902765950781,
            1.10198881774589919,
            1.11207640028901622,
            1.11958333584539438,
            1.12472370596442706,
            1.12772510670012571,
            1.12882103148995672,
            1.12824477968402248,
            1.12622479141152954,
            1.12298118622336940,
            1.11872324596226957,
            1.11364620936774883,
            1.10725340496319080,
            1.08052000068658849,
            1.03558473441315235,
            1.00038456582627711]),
        t.tensor([-0.00032963801607456,
            -0.03144899474976939,
            -0.07237475700758451,
            -0.09634133849055662,
            -0.10019163865503100,
            -0.10337801328801051,
            -0.10647568979161709,
            -0.10946593993161140,
            -0.11232851660834062,
            -0.11504224011654615,
            -0.11758503839729113,
            -0.11993399428201490,
            -0.12206540015412194,
            -0.12395482017400423,
            -0.12557716019059412,
            -0.12690674543644861,
            -0.12791740607459506,
            -0.12858257063040973,
            -0.12887536730095286,
            -0.12876873308402831,
            -0.12823553060903956,
            -0.12724867247958555,
            -0.12578125285157568,
            -0.12380668587069379,
            -0.12129885047714439,
            -0.11823224095572617,
            -0.11458212246486700,
            -0.11032469062304774,
            -0.10543723406677229,
            -0.09989829872705513,
            -0.09368785240527382,
            -0.08678744807239387,
            -0.07918038417476764,
            -0.07085186011302450,
            -0.06178912497703734,
            -0.05198161757636834,
            -0.04142109580967498,
            -0.03010175347350466,
            -0.01802032272339011,
            -0.00517616056997490,
            0.00842868198333836,
            0.02278940928127171,
            0.03789844964187317,
            0.05374544372714435,
            0.07031725275834511,
            0.08759798631036457,
            0.10556904906649998,
            0.12420920558818814,
            0.14349466186675403,
            0.16339916218689465,
            0.18389409965260192,
            0.20494863861034671,
            0.22652984715340754,
            0.24860283790303903,
            0.27113091533248224,
            0.29407572802116183,
            0.31739742438954849,
            0.34105481065999815,
            0.36500551000425752,
            0.38920612206322297,
            0.41361238224866209,
            0.43817932045027935,
            0.46286141896622940,
            0.48761276964403705,
            0.51238723035596290,
            0.53713858103377021,
            0.56182067954972070,
            0.58638761775133763,
            0.61079387793677709,
            0.63499448999574337,
            0.65894518934000257,
            0.68260257561045057,
            0.70592427197883911,
            0.72886908466752032,
            0.75139716209696139,
            0.77347015284659137,
            0.79505136138965604,
            0.81610590034740038,
            0.83660083781310102,
            0.85650533813324692,
            0.87579079441181340,
            0.89443095093349811,
            0.91240201368963370,
            0.92968274724165312,
            0.94625455627285515,
            0.96210155035812772,
            0.97721059071873084,
            0.99157131801665899,
            1.00517616056997339,
            1.01802032272338550,
            1.03010175347350508,
            1.04142109580967723,
            1.05198161757637187,
            1.06178912497703837,
            1.07085186011302347,
            1.07918038417476625,
            1.08678744807239269,
            1.09368785240527644,
            1.09989829872705225,
            1.10543723406677130,
            1.11032469062305084,
            1.11458212246487065,
            1.11823224095572660,
            1.12129885047714350,
            1.12380668587068833,
            1.12578125285157649,
            1.12724867247959071,
            1.12823553060903903,
            1.12876873308403147,
            1.12887536730094684,
            1.12858257063040957,
            1.12791740607459379,
            1.12690674543644653,
            1.12557716019059728,
            1.12395482017400528,
            1.12206540015411727,
            1.11993399428201990,
            1.11758503839728607,
            1.11504224011655251,
            1.11232851660833942,
            1.10946593993160847,
            1.10647568979161925,
            1.10337801328800289,
            1.10019163865503922,
            1.09634133849055670,
            1.07237475700758478,
            1.03144899474976981,
            1.00032963801607466]),
        t.tensor([-0.00031455783889701,
            -0.02988220810728750,
            -0.06838815676661437,
            -0.09052381626826272,
            -0.09279728070701555,
            -0.09446262302455798,
            -0.09611762826211076,
            -0.09776054918065887,
            -0.09938928584286907,
            -0.10100167418675354,
            -0.10259548653090733,
            -0.10416843212275302,
            -0.10571815779420764,
            -0.10724224872670400,
            -0.10873822932783883,
            -0.11020356422140801,
            -0.11163565935265615,
            -0.11303186321018566,
            -0.11438946816616835,
            -0.11570571193582646,
            -0.11697777915745033,
            -0.11820280309367875,
            -0.11937786745480922,
            -0.12050000834451111,
            -0.12156621632827740,
            -0.12257343862456249,
            -0.12351858141854224,
            -0.12439851229798785,
            -0.12521006281063407,
            -0.12595003114217504,
            -0.12661518491378290,
            -0.12720226409773372,
            -0.12770798404951841,
            -0.12812903865462680,
            -0.12846210358773133,
            -0.12870383968198731,
            -0.12885089640574107,
            -0.12889991544356544,
            -0.12884753437866370,
            -0.12869039047290171,
            -0.12842512454069202,
            -0.12804838491299517,
            -0.12755683148657232,
            -0.12694713985434555,
            -0.12621600551170145,
            -0.12536014813363192,
            -0.12437631591725358,
            -0.12326128998394514,
            -0.12201188883519327,
            -0.12062497285565056,
            -0.11909744885720500,
            -0.11742627465696613,
            -0.11560846368223408,
            -0.11364108959530056,
            -0.11152129093037615,
            -0.10924627573498059,
            -0.10681332620801866,
            -0.10421980332610868,
            -0.10146315145018489,
            -0.09854090290363757,
            -0.09545068251360159,
            -0.09219021210643127,
            -0.08875731494874418,
            -0.08514992012489672,
            -0.08136606684205296,
            -0.07740390865375155,
            -0.07326171759284809,
            -0.06893788820497800,
            -0.06443094147339049,
            -0.05973952862624372,
            -0.05486243481754034,
            -0.04979858267311532,
            -0.04454703569275254,
            -0.03910700150051654,
            -0.03347783493467905,
            -0.02765904096961883,
            -0.02165027746184888,
            -0.01545135771282986,
            -0.00906225284155850,
            -0.00248309396012123,
            0.00428582585400533,
            0.01124404979495591,
            0.01839095578019531,
            0.02572575512213179,
            0.03324749142416124,
            0.04095503970533251,
            0.04884710575724134,
            0.05692222573630709,
            0.06517876599410004,
            0.07361492314770822,
            0.08222872439180305,
            0.09101802805342750,
            0.09998052438993345,
            0.10911373663009812,
            0.11841502225785103,
            0.12788157453742816,
            0.13751042427843285,
            0.14729844183868110,
            0.15724233936216325,
            0.16733867324903981,
            0.17758384685420067,
            0.18797411341025075,
            0.19850557917062545,
            0.20917420676791010,
            0.21997581878220465,
            0.23090610151387087,
            0.24196060895484517,
            0.25313476695217318,
            0.26442387755726598,
            0.27582312355406297,
            0.28732757315902224,
            0.29893218488563772,
            0.31063181256599859,
            0.32242121052164247,
            0.33429503887592693,
            0.34624786899980142,
            0.35827418908292513,
            0.37036840982181118,
            0.38252487021667442,
            0.39473784346850216,
            0.40700154296784891,
            0.41931012836676890,
            0.43165771172523826,
            0.44403836372340311,
            0.45644611993092293,
            0.46887498712468101,
            0.48131894964608551,
            0.49377197578918691,
            0.50622802421081292,
            0.51868105035391421,
            0.53112501287531932,
            0.54355388006907612,
            0.55596163627659578,
            0.56834228827476230,
            0.58068987163323071,
            0.59299845703215071,
            0.60526215653149917,
            0.61747512978332419,
            0.62963159017818915,
            0.64172581091707503,
            0.65375213100019824,
            0.66570496112407151,
            0.67757878947835659,
            0.68936818743400397,
            0.70106781511436289,
            0.71267242684097609,
            0.72417687644593454,
            0.73557612244273374,
            0.74686523304782793,
            0.75803939104515428,
            0.76909389848612986,
            0.78002418121779771,
            0.79082579323208679,
            0.80149442082937317,
            0.81202588658975350,
            0.82241615314579497,
            0.83266132675096127,
            0.84275766063783841,
            0.85270155816131732,
            0.86248957572156837,
            0.87211842546257534,
            0.88158497774214473,
            0.89088626336989707,
            0.90001947561007434,
            0.90898197194657415,
            0.91777127560819560,
            0.92638507685229232,
            0.93482123400589601,
            0.94307777426369310,
            0.95115289424276328,
            0.95904496029466846,
            0.96675250857584227,
            0.97427424487786418,
            0.98160904421979878,
            0.98875595020504026,
            0.99571417414600072,
            1.00248309396012769,
            1.00906225284155204,
            1.01545135771283390,
            1.02165027746184989,
            1.02765904096961225,
            1.03347783493467849,
            1.03910700150051838,
            1.04454703569276153,
            1.04979858267310533,
            1.05486243481754638,
            1.05973952862623721,
            1.06443094147339479,
            1.06893788820498226,
            1.07326171759283784,
            1.07740390865375124,
            1.08136606684205927,
            1.08514992012489975,
            1.08875731494873618,
            1.09219021210643352,
            1.09545068251360012,
            1.09854090290364081,
            1.10146315145018203,
            1.10421980332610281,
            1.10681332620802841,
            1.10924627573497592,
            1.11152129093037022,
            1.11364108959530239,
            1.11560846368224431,
            1.11742627465694833,
            1.11909744885722429,
            1.12062497285564033,
            1.12201188883519509,
            1.12326128998395158,
            1.12437631591724840,
            1.12536014813363017,
            1.12621600551170431,
            1.12694713985434380,
            1.12755683148657737,
            1.12804838491300630,
            1.12842512454068311,
            1.12869039047288999,
            1.12884753437867258,
            1.12889991544355173,
            1.12885089640574687,
            1.12870383968198951,
            1.12846210358772892,
            1.12812903865463676,
            1.12770798404951300,
            1.12720226409773505,
            1.12661518491379242,
            1.12595003114215975,
            1.12521006281063585,
            1.12439851229799315,
            1.12351858141853889,
            1.12257343862456560,
            1.12156621632826603,
            1.12050000834453600,
            1.11937786745479806,
            1.11820280309367148,
            1.11697777915746044,
            1.11570571193580714,
            1.11438946816618079,
            1.11303186321019676,
            1.11163565935263753,
            1.11020356422141631,
            1.10873822932783717,
            1.10724224872670596,
            1.10571815779420612,
            1.10416843212275695,
            1.10259548653090800,
            1.10100167418676365,
            1.09938928584285445,
            1.09776054918065791,
            1.09611762826210790,
            1.09446262302456487,
            1.09279728070701454,
            1.09052381626826556,
            1.06838815676661425,
            1.02988220810728670,
            1.00031455783889700]),
    ]

    @staticmethod
    @lru_cache
    def xs(device, dtype, bits: int):
        return NBitsGelu.xs_const[bits - 1].to(device=device, dtype=dtype)

    @staticmethod
    @lru_cache
    def ys(device, dtype, bits: int):
        return NBitsGelu.ys_const[bits - 1].to(device=device, dtype=dtype)
    
    @staticmethod
    def forward(ctx, x: t.Tensor, bits: int):
        xs = NBitsGelu.xs(x.device, x.dtype, bits)
        discr = t.searchsorted(xs, x.float()).type(t.uint8)
        ctx.save_for_backward(discr)
        ctx.bits = bits
        
        return nnF.gelu(x)
        
    @staticmethod
    def backward(ctx, grad_output):
        discr, = ctx.saved_tensors
        ys = NBitsGelu.ys(grad_output.device, grad_output.dtype, ctx.bits)
        return ys[discr.type(t.int64)] * grad_output, None

    
class MyGelu(nn.Module):
    def __init__(self, n_bits):
        super().__init__()
        self.n_bits = n_bits
    
    def forward(self, x):
        return NBitsGelu.apply(x, self.n_bits)
