import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


p_metric = 'uf'


experiment_results_linear = {
    'original': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein': [], 'uf': [], 'tv': []},
    'fdami_linear_dp_1': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.18564145770094503, 0.18656490104765863, 0.1857706032985035, 0.18722656706331176, 0.18610838390583181, 0.1904816478391715, 0.1878886287306574, 0.1856176023610295, 0.18644268131870045, 0.18711335566350676], 'wasserstein': [0.0001414213505890226, 0.0001414213504515393, 0.0001414213504150283, 0.00014142135056216385, 0.00014142135065754545, 0.00014142135051572696, 0.00014142135043340406, 0.00014142135052692845, 0.00014142135060931366, 0.00014142135044050108], 'uf': [0.29603936901807465, 0.2978447523487553, 0.294616595182425, 0.2898396747283301, 0.30266809776227627, 0.29897603780311094, 0.2933846054070365, 0.2976392308889447, 0.2986196807583404, 0.29499711029162623], 'tv': [0.19965396024452653, 0.20136046563779575, 0.20071532336792908, 0.1997903872613499, 0.2002487814825935, 0.20068187550770122, 0.19968448394190086, 0.20005598305574657, 0.1997037556632104, 0.1998828764041164]},
    'fdami_linear_dp_2': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.18525384852234617, 0.1886684393479689, 0.18583111621841034, 0.18673276098009345, 0.18599588395121605, 0.18673290150693986, 0.1868203982530717, 0.18736420972851553, 0.18569915510778331, 0.18593729191104544], 'wasserstein': [0.1359785774201627, 0.1360739062939111, 0.13599465385087203, 0.13602022613074577, 0.13600004555489215, 0.13602071851977823, 0.13602296096723399, 0.1360368917684079, 0.13599296242996314, 0.13599760042117834], 'uf': [0.29054316742072456, 0.29543105561569716, 0.30169045516298476, 0.2989584040820673, 0.29268481638660027, 0.3000360486775876, 0.2926379232347463, 0.298212984233445, 0.29060427921371235, 0.2906431258880866], 'tv': [0.1994721686987787, 0.20083842176104405, 0.20055681126817215, 0.199470726526347, 0.20009229546566376, 0.20036471724572846, 0.19952596346633944, 0.1998974039247602, 0.1996827142100448, 0.19964616598206886]},
    'fdami_linear_dp_3': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.18530758597298339, 0.1885762698119444, 0.18598591941439255, 0.1864935917047522, 0.1861569642259203, 0.18743320314808734, 0.18813842472011072, 0.18784686672389397, 0.18652210055295573, 0.18656482728903304], 'wasserstein': [0.24730443778526084, 0.24789818952919584, 0.24751672726894913, 0.24761398942878757, 0.24762384811761812, 0.2477297679644272, 0.24786567532003947, 0.24783055591764697, 0.2476526447770117, 0.24756706423046862], 'uf': [0.28653138708603737, 0.28990324453019345, 0.28204564533568666, 0.28866299769821785, 0.28549642700634115, 0.28279504250429416, 0.2841523035818454, 0.2810325169426372, 0.2870928418711313, 0.2872567258967307], 'tv': [0.18121832756765777, 0.18249342703871752, 0.18242630421189854, 0.18385706778981947, 0.18070228171055525, 0.18053579855805313, 0.1803904202250437, 0.18372896450872467, 0.1802373679890361, 0.18023005369241218]},
    'fdami_linear_dp_4': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.18709957942153968, 0.18660658770897517, 0.18825969239732426, 0.18820747548047612, 0.18909720706720062, 0.185841989495523, 0.1870228947492438, 0.18678562467669435, 0.18688537926693652, 0.1879001321842974], 'wasserstein': [0.33879086737505876, 0.33878300733285843, 0.33915486484578633, 0.3388704897275324, 0.33953274483644635, 0.33868978460181653, 0.3384848191845441, 0.33882101378232415, 0.33887841771398813, 0.33913990330719945], 'uf': [0.2532258638095271, 0.25464951403490627, 0.2614423964499951, 0.24593734122045122, 0.2526281622936416, 0.25264533424195407, 0.2596615457547707, 0.24628725379739475, 0.2469083895638433, 0.2431970630890129], 'tv': [0.13854605861817992, 0.1402328733636493, 0.1445761564767547, 0.13550429387630536, 0.13998555326908546, 0.14110414215000133, 0.13770740246711455, 0.13482005004467523, 0.13392648760390025, 0.13769333353760127]},
    'fdami_linear_dp_5': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.18743833027909443, 0.1870654505841185, 0.1873672023800412, 0.18629485609626514, 0.1874865919596118, 0.1876514150399921, 0.18579243559692923, 0.18743702903373508, 0.18679644328151865, 0.18812149633654643], 'wasserstein': [0.4176585598668006, 0.41702712641962253, 0.4170225263289267, 0.4158273728974382, 0.41648924869313775, 0.416800475693527, 0.4165787564440671, 0.41785408508814637, 0.4168518187352875, 0.41696011130659494], 'uf': [0.2219093241107128, 0.21823205287981934, 0.21047051370587613, 0.2192001654618741, 0.23238640157436075, 0.22100601428975652, 0.21276389858141972, 0.21722089183823706, 0.21763940581648406, 0.22133652376910776], 'tv': [0.10011923407972434, 0.10606212860756958, 0.09243059219527161, 0.1007153448580903, 0.10868570037450043, 0.1040832861276777, 0.10080103354754599, 0.10268945634618265, 0.10650010209398963, 0.10214280006800913]},
    'fdami_linear_dp_6': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.1879081750224465, 0.1862513428781448, 0.1860239487359584, 0.18714678320013434, 0.18591198163725253, 0.1859981608308864, 0.18621520831181154, 0.18629668268968486, 0.18782111417419542, 0.1862950355533933], 'wasserstein': [0.48156583768454847, 0.4817853280425294, 0.4815407701309813, 0.481056331880479, 0.48076817130314076, 0.4813117680071874, 0.4807640379406463, 0.4818727005523247, 0.48124160749891914, 0.4810421048346985], 'uf': [0.1843362636680731, 0.19770099578452896, 0.19729100880444445, 0.1871552486866777, 0.18958135643880014, 0.20147287459553403, 0.18030082410057163, 0.20434245300705822, 0.19514208110406586, 0.17872227780604022], 'tv': [0.0712280976129237, 0.07456618845725893, 0.07872397740920012, 0.0785361117468445, 0.07786214729984187, 0.08174395642722265, 0.07424049648478293, 0.08664085069562732, 0.07685315083952171, 0.07582889200898058]},
    'fdami_linear_dp_7': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.18588630754323407, 0.1873438033030693, 0.1868803463902248, 0.18794855241831074, 0.1884003106084627, 0.1882597707649782, 0.18584744726221572, 0.18736623540724795, 0.188597530698272, 0.18749248822651837], 'wasserstein': [0.533188126831753, 0.5329475080887787, 0.5338709476593582, 0.5341344793528514, 0.5335590463590458, 0.5334509451093572, 0.5324776612000893, 0.5343990183706683, 0.5342910909091839, 0.5330026057399359], 'uf': [0.16902775287885274, 0.170152943852413, 0.1575349068793909, 0.16678999487003338, 0.18738098319715948, 0.18341930275458151, 0.16043345966703115, 0.1646792748575988, 0.16030684279137264, 0.16489447141405741], 'tv': [0.06181869098985049, 0.06393518228352602, 0.057440518050919764, 0.06124671068778886, 0.06321759669215754, 0.06739055039221653, 0.05405889406429354, 0.05479509257778792, 0.05930056384653226, 0.05896635347594836]},
    'fdami_linear_dp_8': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.18718669954876654, 0.1858383962315001, 0.18565481709221598, 0.18498312192131264, 0.18842635306679292, 0.18734202985103574, 0.18673675668912332, 0.18604591835386894, 0.18734273154042405, 0.18582460413807864], 'wasserstein': [0.5748690264708225, 0.5739674585068701, 0.5748929838625367, 0.5727072454027112, 0.5756949278882345, 0.5749237963936052, 0.574928518343438, 0.5738891805682833, 0.5741978277386587, 0.5750144378626662], 'uf': [0.14982716999453805, 0.14986450438921536, 0.14940837153940717, 0.1497873579785689, 0.13904577408345206, 0.1426414426152739, 0.15924874492280464, 0.152150109802488, 0.15157480429777506, 0.15405346448443905], 'tv': [0.050036444926354395, 0.051611960696115844, 0.049218120969561574, 0.045542671296468606, 0.040548395431745377, 0.04539417432067383, 0.04953839630675472, 0.046647123570839444, 0.0478614368778949, 0.04410609072737981]},
    'fdami_linear_dp_9': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.18578021566247588, 0.18677480665871718, 0.1877471995013798, 0.1865589923976373, 0.18808594539987067, 0.18444769280371057, 0.18541953260370925, 0.1883713058059061, 0.18773738022249614, 0.18662523253503935], 'wasserstein': [0.6067959355874791, 0.6060027745401555, 0.6073624817189776, 0.6054035856076448, 0.6072570705591389, 0.6069206895337292, 0.6072441111024153, 0.6066112620905113, 0.6078231098128117, 0.6077811886837728], 'uf': [0.09499902682639699, 0.12648366420147508, 0.12515078737552612, 0.12641159200134217, 0.1423081114144372, 0.1396676285337905, 0.1444916172119291, 0.11717386434767788, 0.12420983688479896, 0.15274902784496422], 'tv': [0.03098049629942068, 0.03828535451295245, 0.03355608909759222, 0.040093486307063, 0.039964551050118646, 0.03550857144764252, 0.04073739373988494, 0.035476667287096864, 0.03580502238016192, 0.04206896924626946]},
    'fdami_linear_dp_10': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.18613957431513078, 0.18787221706001203, 0.18747491633575294, 0.18736928277989456, 0.18721658732626525, 0.18680156954299776, 0.18674406068480534, 0.1876331138039668, 0.18626981583889446, 0.18576894488044518], 'wasserstein': [0.6306341250002173, 0.6311018154956404, 0.6334034752005206, 0.6344528739649468, 0.6307401644504425, 0.6324941606772254, 0.6328756182352019, 0.6314421637702894, 0.6306018388397443, 0.6313410286734651], 'uf': [0.13714606168697763, 0.12959953239070426, 0.11813106155196501, 0.11362938796937794, 0.11795622759045973, 0.12727774947591522, 0.09147867144336365, 0.11889893634392573, 0.10665444993252263, 0.1298565161502268], 'tv': [0.03551712200902146, 0.03530660635755567, 0.028506007379786147, 0.02109524236284177, 0.035595696438672, 0.031666765367815874, 0.018695816130089082, 0.030615366064265626, 0.0303031493687671, 0.03611876614992904]},
    'fdami_linear_dp_11': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.18710551817495222, 0.18486469890180207, 0.18514808058285898, 0.18559431126813972, 0.1868822742158617, 0.1866009896757391, 0.18725323623176218, 0.1857678743732304, 0.18909793585616957, 0.18667068669199294], 'wasserstein': [0.6533380171193307, 0.6497240486036681, 0.651405487457065, 0.6500009074721586, 0.6498396824618156, 0.6524277939485029, 0.6527823592307717, 0.6515514644496551, 0.6520912072481574, 0.6516434520433223], 'uf': [0.10005156403738721, 0.10534884613696334, 0.10500058193605133, 0.12902690897289948, 0.1228014704970475, 0.10490652213493804, 0.11016146236155186, 0.11688189742335714, 0.1268215581333069, 0.11770844508032338], 'tv': [0.02587147835697512, 0.030781827409887552, 0.024626139622388465, 0.0300877376963945, 0.02996886451725511, 0.025690675041117106, 0.02506292044574754, 0.02262080906643915, 0.02959579690531755, 0.031957709340028506]},
    'fdami_linear_dp_12': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [0.1856722785947398, 0.1859049370829532, 0.1858029704147258, 0.18835446373403958, 0.18706079065484177, 0.18708222250650672, 0.18789129490881654, 0.1880389129676267, 0.1874378626021805, 0.1882791540898838], 'wasserstein': [0.7580483956531048, 0.7567192048461305, 0.7586432718952854, 0.7621474231291502, 0.7592121478361661, 0.7574401048296288, 0.7598604477665156, 0.7615915113775785, 0.758262964757489, 0.760104343929541], 'uf': [0.06339944053578822, 0.028390717011857282, 0.058467798250718774, 0.05674843978190085, 0.04670290063605606, 0.07010385685247836, 0.04392725142319785, 0.0627029555481474, 0.02574485146949468, 0.04922725156176637], 'tv': [0.007560941840623059, 0.0013577646086426753, 0.0010494757899690565, 0.002774006455502187, 0.003201016939100243, 0.002689022801585894, 0.00023599094224269201, 0.004786944031973439, 0.0027505688170388343, 0.0019808351672383484]},
}




experiment_results_decaf = {
    'decaf': {'wasserstein': [0.28934748477096245, 0.23984055010068947, 0.25885325285830996, 0.19614975754497638, 0.2535665625550373, 0.29058919828809415, 0.21068924874417053, 0.3199095012967847, 0.26806195447982656, 0.30399800388459547], 'uf': [0.22964709889817345, 0.15963971168781937, 0.18474569965835735, 0.07096377050462786, 0.176049132356369, 0.24073901603166015, 0.09805918333498106, 0.2670156227778122, 0.20945612649812625, 0.25190537034021465], 'tv': [0.008555949, 0.00083988905, 0.0010824203, 0.0011399984, 0.0033919811, 0.0013713241, 0.002756536, 0.0020361543, 0.009393573, 0.0057342052]}
}

experiment_results_fairgan = {
    'fairgan': {'wasserstein': [0.4757227957410177, 0.41087281067541526, 0.3712433330713351, 0.3854114995867691, 0.47124884264909683, 0.47956526911499536, 0.43102255853855864, 0.33732882705799466, 0.41977287005674574, 0.4292879740518283], 'uf': [0.4527077268748765, 0.5448665712418499, 0.5042649630094785, 0.4807268004805594, 0.2527485637032103, 0.4533406381136227, 0.40481291505602307, 0.48569409864512736, 0.32452457999243367, 0.53534860693735], 'tv': [0.08872664108877826, 0.5557719640431831, 0.49007134874924785, 0.452853501569027, 0.2232363436727971, 0.0843308640164685, 0.2977997780442808, 0.4676237792645547, 0.24923762606216526, 0.5232701296372159]}
}

experiment_results_oppdp = {
    'oppdp': {'wasserstein': [0.1737869134870876, 0.16755595347791083, 0.1680434609191852, 0.16745828168093282, 0.1703640788850925, 0.1603716112530111, 0.1680434609191852, 0.1676535683731575, 0.17350425635207034, 0.1659863080742293], 'uf': [0.0369269920055919, 0.08343587592657364, 0.05632746156824146, 0.06201242397630145, 0.0754824623962946, 0.10212140338846495, 0.03631375372766313, 0.041847635533914816, 0.03554709879815077, 0.08467740349595229], 'tv': [0.0006717522072600679, 0.003216828553356643, 0.0021567428082083784, 0.002145678402278639, 0.0023607972313479664, 0.0055617770976891245, 0.004637750493479642, 0.0013470996738535468, 0.004165778677090615, 0.00015844098444628418]}
}

experiment_results_tabfairgan = {
    'tabfairgan': {'wasserstein': [0.024144366111445763, 0.0428319885945683, 0.11784708887168648, 0.06071494127782209, 0.08384328345168873, 0.05318857088156267, 0.06414783009067859, 0.1224142752444441, 0.04323043427775256, 0.11929311951789055], 'uf': [0.35329783011461385, 0.33976252856276457, 0.35679156294783454, 0.3428849726427128, 0.33862676324468816, 0.34822132521062604, 0.3452364178255971, 0.3547312475999297, 0.34354879708223457, 0.347053500962448], 'tv': [0.20865184632612055, 0.1871612321510524, 0.1282162297196967, 0.18647626728682776, 0.14023060672157017, 0.17579787330654184, 0.20355877697457547, 0.16737147294511134, 0.1775816539226197, 0.21896076527486996]}
}


def display_results(results):
    return np.array([np.round(np.mean(results), 3),  np.round(np.std(results), 3)])

df_linear = pd.DataFrame(np.zeros((12, 2)),
                  columns=['mean', 'std'])
df_linear['model'] = ['fdami_linear_dp_1', 'fdami_linear_dp_2', 'fdami_linear_dp_3', 'fdami_linear_dp_4','fdami_linear_dp_5','fdami_linear_dp_6','fdami_linear_dp_7','fdami_linear_dp_8','fdami_linear_dp_9','fdami_linear_dp_10', 'fdami_linear_dp_11', 'fdami_linear_dp_12']



for model, model_results in experiment_results_linear.items():
    df_linear.loc[df_linear['model'] == model, 'mean'] = display_results(model_results[p_metric])[0]
    df_linear.loc[df_linear['model'] == model, 'std'] = display_results(model_results[p_metric])[1]

df_linear = df_linear.transpose()
# df_linear.iloc[0,10] = 0
# df_linear.iloc[1,10] = 0

eta_list = pd.DataFrame(np.zeros((12, 2)),
                  columns=['mean', 'std'])
eta_list['model'] = ['fdami_linear_dp_1', 'fdami_linear_dp_2', 'fdami_linear_dp_3', 'fdami_linear_dp_4','fdami_linear_dp_5','fdami_linear_dp_6','fdami_linear_dp_7','fdami_linear_dp_8','fdami_linear_dp_9','fdami_linear_dp_10', 'fdami_linear_dp_11', 'fdami_linear_dp_12']


for model, model_results in experiment_results_linear.items():
    eta_list.loc[eta_list['model'] == model, 'mean'] = display_results(model_results["etasq"])[0]
    eta_list.loc[eta_list['model'] == model, 'std'] = display_results(model_results["etasq"])[1]

eta_list = eta_list.transpose()
# eta_list.iloc[0,0] = 1
print(eta_list)

df_fairgan = pd.DataFrame(np.zeros((12, 2)),
                  columns=['mean', 'std'])
for ind in range(12):
    df_fairgan.loc[ind, 'mean'] = np.round(np.mean(experiment_results_fairgan['fairgan'][p_metric]), 3)
    df_fairgan.loc[ind, 'std'] = np.round(np.std(experiment_results_fairgan['fairgan'][p_metric]), 3)

df_fairgan = df_fairgan.transpose()

df_decaf = pd.DataFrame(np.zeros((12, 2)),
                  columns=['mean', 'std'])
for ind in range(12):
    df_decaf.loc[ind, 'mean'] = np.round(np.mean(experiment_results_decaf['decaf'][p_metric]), 3)
    df_decaf.loc[ind, 'std'] = np.round(np.std(experiment_results_decaf['decaf'][p_metric]), 3)

df_decaf = df_decaf.transpose()

df_oppdp = pd.DataFrame(np.zeros((12, 2)),
                  columns=['mean', 'std'])
for ind in range(12):
    df_oppdp.loc[ind, 'mean'] = np.round(np.mean(experiment_results_oppdp['oppdp'][p_metric]), 3)
    df_oppdp.loc[ind, 'std'] = np.round(np.std(experiment_results_oppdp['oppdp'][p_metric]), 3)

df_oppdp = df_oppdp.transpose()

df_tabfairgan = pd.DataFrame(np.zeros((12, 2)),
                  columns=['mean', 'std'])
for ind in range(12):
    df_tabfairgan.loc[ind, 'mean'] = np.round(np.mean(experiment_results_tabfairgan['tabfairgan'][p_metric]), 3)
    df_tabfairgan.loc[ind, 'std'] = np.round(np.std(experiment_results_tabfairgan['tabfairgan'][p_metric]), 3)

df_tabfairgan = df_tabfairgan.transpose()

fig, ax = plt.subplots(figsize=(7, 5))

# ax.errorbar(1 / (sig_list ** 2 + 1), df.iloc[0,], yerr=df.iloc[1,], label = "FDA-DECAF")
sig_list = np.arange(0.1, 1.1, 0.1)
sig_list = np.append(0.0001, sig_list)
sig_list = np.append(sig_list, 100000000)
x_list = eta_list.iloc[0,] / (sig_list ** 2 + eta_list.iloc[0,])
ax.plot(x_list, df_fairgan.iloc[0,], label = 'FairGAN', marker = ".")
ax.fill_between(x_list.astype(float), df_fairgan.iloc[0,].astype(float) - df_fairgan.iloc[1,].astype(float), df_fairgan.iloc[0,].astype(float) + df_fairgan.iloc[1,].astype(float), alpha=.1)
ax.plot(x_list, df_linear.iloc[0,], label = "FDA", marker = "^")
ax.fill_between(x_list.astype(float), df_linear.iloc[0,].astype(float) - df_linear.iloc[1,].astype(float), df_linear.iloc[0,].astype(float) + df_linear.iloc[1,].astype(float), alpha=.1)
ax.plot(x_list, df_tabfairgan.iloc[0,], label = 'TabFairGAN', marker = "s")
ax.fill_between(x_list.astype(float), df_tabfairgan.iloc[0,].astype(float) - df_tabfairgan.iloc[1,].astype(float), df_tabfairgan.iloc[0,].astype(float) + df_tabfairgan.iloc[1,].astype(float), alpha=.1)
ax.plot(x_list, df_oppdp.iloc[0,], label = 'OPPDP', marker = "p")
ax.fill_between(x_list.astype(float), df_oppdp.iloc[0,].astype(float) - df_oppdp.iloc[1,].astype(float), df_oppdp.iloc[0,].astype(float) + df_oppdp.iloc[1,].astype(float), alpha=.1)
ax.plot(x_list, df_decaf.iloc[0,], label = 'DECAF', marker = "x")
ax.fill_between(x_list.astype(float), df_decaf.iloc[0,].astype(float) - df_decaf.iloc[1,].astype(float), df_decaf.iloc[0,].astype(float) + df_decaf.iloc[1,].astype(float), alpha=.1)


# ax.errorbar(eta_list.iloc[0,] / (sig_list ** 2 + eta_list.iloc[0,]), df_linear.iloc[0,], yerr=df_linear.iloc[1,], label = "FDA")
# ax.errorbar(eta_list.iloc[0,] / (sig_list ** 2 + eta_list.iloc[0,]), df_fairgan.iloc[0,], yerr=df_fairgan.iloc[1,], label = "fairgan")

ax.set(ｙlim = (0, 1))
ax.set_xlabel(r"$\alpha$", fontsize = "18")
# ax.set_ylabel(p_metric, fontsize = "15")
if p_metric == 'wasserstein':
    ax.set_ylabel(r"$\hat{W}_2(\mu_\hat{Y}, \mu_{Y})$", fontsize = "18")
if p_metric == 'uf':
    ax.set_ylabel(r"$\hat{\mathcal{UF}}(\mathcal{P}_{\hat{\mathcal{D}}})$", fontsize = "18")
if p_metric == 'tv':
    ax.set_ylabel(r"$|P(\hat{Y}=1|S=1)-P(\hat{Y}=1|S=0)|$", fontsize = "18")
ax.set_title("")
# ax.set_title("(1) Adult dataset")
ax.legend(bbox_to_anchor=(0.3, 0.7),
              loc='lower left', fontsize= '15', ncol = 2)
# ax.legend(loc = 'upper right', fontsize = "8", ncols = 5)

plt.xticks(fontsize = '18')
plt.xticks(np.arange(0, 1.2,0.2))
plt.yticks(fontsize = '18')
# plt.show()
model = 'combine'
data ='adult'
plt.savefig(f'{model}_{p_metric}_{data}_nips')