import matplotlib.pyplot as plt
import numpy as np
font1 = {'family' : 'Times New Roman',
'weight' : 'normal',
'size'   : 14,
}

font_legend = {'family' : 'Times New Roman',
'weight' : 'normal',
'size'   : 14,
}


c1 = 'gold'
c2 = 'grey'
c3 = '#d7191c'
c4 = '#2b83ba'
c5 = 'red'

algs = ['Drnet', 'Vcnet', 'SCIGAN', 'TransTEE']
y0 = [[7.643178727034977, 8.637840374333054, 9.548064637559705, 10.376598098746186, 11.12618733992374, 11.799578943123624, 12.399519490377081, 12.928755563715363, 13.390033745169728, 13.786100616771416, 14.11970276055168, 14.393586758541769, 14.610499192772936, 14.77318664527643, 14.884395698083502, 14.9468729332254, 14.963364932733374, 14.936618278638674, 14.86937955297255, 14.764395337766253, 14.624412215051033, 14.452176766858138, 14.25043557521882, 14.021935222164327, 13.769422289725917, 13.495643359934828, 13.203345014822316, 12.895273836419632, 12.574176406758022, 12.24279930786874, 11.903889121783033, 11.560192430532155, 11.21445581614735, 10.869425860659874, 10.527849146100973, 10.1924722545019, 9.866041767893904, 9.551304268308233, 9.251006337776138, 8.96789455832887, 8.70471551199768, 8.464215780813815, 8.249141946808527, 8.062240592013064, 7.9062582984586784, 7.783941648176619, 7.698037223198137, 7.65129160555448, 7.6464513772769, 7.686263120396647, 7.773473416944969, 7.910828848953119, 8.101075998452345, 8.346961447473896, 8.651231778049024, 9.01663357220898, 9.445913411985012, 9.941817879408369, 10.507093556510302, 11.144487025322064, 11.8567448678749, 12.646613666200066, 13.516840002328806, 14.47017045829237, 15.509351616122014], 
[7.704382003485508, 8.193855913078744, 8.682156414786165, 9.1681129137137, 9.650560434207222, 10.128342412629852, 10.600313469973237, 11.065342157656142, 11.522313669927813, 11.970132516373779, 12.407725148117306, 12.834042531420739, 13.24806266251713, 13.648793017643326, 14.035272932401133, 14.40657590474254, 14.761811816058092, 15.10012906504383, 15.42071660923133, 15.722805909286684, 16.005672771417416, 16.268639083470546, 16.5110744405599, 16.732397656325514, 16.932078156202333, 17.109637249357988, 17.264649276250644, 17.396742629055662, 17.505600642515095, 17.590962353074143, 17.652623124484883, 17.690435138377424, 17.704307748622504, 17.694207698635985, 17.660159201104317, 17.602243879939877, 17.520600574605258, 17.41542500727569, 17.286969313637435, 17.135541438446968, 16.961504397300008, 16.765275406380102, 16.547324882273017, 16.308175314244657, 16.048400011686013, 15.76862172972783, 15.469511176319788, 15.15178540435321, 14.816206092681753, 14.463577720161007, 14.09474563708435, 13.710594038638382, 13.312043845236072, 12.90005049480919, 12.475601652352363, 12.039714842209808, 11.593435008780634, 11.13783201149057, 10.673998060035261, 10.203045096043624, 9.726102127438178, 9.2443125218826, 8.758831265804947, 8.270822195567334, 7.781455207419896], 
[7.575222310150349, 8.997330511267009, 10.37440116605067, 11.706434274501332, 12.993429836618997, 14.235387852403665, 15.432308321855333, 16.584191244974, 17.691036621759675, 18.75284445221235, 19.769614736332024, 20.7413474741187, 21.668042665572386, 22.549700310693062, 23.386320409480746, 24.177902961935427, 24.92444796805712, 25.62595542784581, 26.2824253413015, 26.89385770842419, 27.460252529213882, 27.98160980367058, 28.45792953179428, 28.88921171358497, 29.27545634904268, 29.61666343816738, 29.912832980959088, 30.1639649774178, 30.3700594275435, 30.531116331336214, 30.647135688795927, 30.718117499922634, 30.744061764716363, 30.72496848317708, 30.660837655304796, 30.551669281099514, 30.397463360561233, 30.19821989368996, 29.953938880485687, 29.664620320948416, 29.330264215078145, 28.950870562874876, 28.526439364338607, 28.05697061946934, 27.542464328267098, 26.982920490731832, 26.378339106863578, 25.728720176662318, 25.034063700128062, 24.294369677260804, 23.509638108060553, 22.679868992527304, 21.805062330661052, 20.885218122461815, 19.92033636792957, 18.910417067064323, 17.855460219866096, 16.755465826334856, 15.61043388647062, 14.420364400273382, 13.185257367743153, 11.905112788879919, 10.579930663683681, 9.209710992154452, 7.794453774292231]]
y1 = [[11.76574, 11.896426, 12.027107, 12.157791, 12.288475, 12.41916, 12.549842, 12.680525, 12.811211, 12.941893, 13.0725765, 13.20326, 13.333943, 13.985896, 14.050802, 14.115711, 14.180618, 14.245526, 14.310434, 14.37534, 14.44025, 14.505158, 14.570065, 14.6349745, 14.699881, 14.764789, 11.06401, 11.05994, 11.055872, 11.051802, 11.047732, 11.043661, 11.039593, 11.035522, 11.031454, 11.027384, 11.0233135, 11.019245, 11.015176, 7.65496, 7.7031236, 7.751287, 7.7994514, 7.847614, 7.895777, 7.94394, 7.992104, 8.040266, 8.088429, 8.136593, 8.184756, 8.232919, 9.5904875, 9.834543, 10.078599, 10.322658, 10.566712, 10.810769, 11.054825, 11.2988825, 11.542936, 11.786993, 12.03105, 12.275107, 12.519163],
[10.418328, 10.564691, 10.711049, 10.8574095, 11.003768, 11.150128, 11.296486, 11.442849, 11.589207, 11.735567, 11.881926, 12.028285, 12.174644, 14.414253, 14.605555, 14.796858, 14.988157, 15.179459, 15.370758, 15.562059, 15.753363, 15.944662, 16.135963, 16.327265, 16.518566, 16.709867, 16.97889, 17.045097, 17.111311, 17.177517, 17.243727, 17.309938, 17.376146, 17.442354, 17.508566, 17.574776, 17.640984, 17.707195, 17.773403, 15.804968, 15.778455, 15.751943, 15.725428, 15.698913, 15.672401, 15.645886, 15.619372, 15.592861, 15.566347, 15.539833, 15.51332, 15.486806, 11.271006, 11.339977, 11.408946, 11.477917, 11.546886, 11.615858, 11.684828, 11.7537985, 11.82277, 11.891739, 11.96071, 12.029678, 12.09865],
[14.102627, 14.574996, 15.047365, 15.5197315, 15.9921, 16.46447, 16.936836, 17.409206, 17.881575, 18.35394, 18.82631, 19.298677, 19.771044, 23.349823, 23.833439, 24.317057, 24.800676, 25.284292, 25.76791, 26.251524, 26.735144, 27.218758, 27.702377, 28.185993, 28.669611, 29.153225, 29.879837, 29.912884, 29.945929, 29.978971, 30.012022, 30.04507, 30.078114, 30.11116, 30.144209, 30.17726, 30.210306, 30.243347, 30.276396, 29.099895, 28.634285, 28.168674, 27.703064, 27.237455, 26.771847, 26.306238, 25.840626, 25.375015, 24.909409, 24.443798, 23.978188, 23.51258, 18.400671, 18.325207, 18.249743, 18.174274, 18.098806, 18.023344, 17.947876, 17.87241, 17.796944, 17.721476, 17.646013, 17.570545, 17.495079]]
y2 = [[12.330246, 12.547377, 12.75565, 12.953739, 13.141205, 13.317605, 13.482501, 13.635452, 13.776026, 13.903777, 14.018273, 14.119081, 14.205759, 14.277877, 14.334998, 14.376692, 14.402524, 14.412061, 14.3876505, 14.342562, 14.279574, 14.198388, 14.096966, 13.972352, 13.824062, 13.651666, 13.454739, 13.232876, 12.9856825, 12.712782, 12.413811, 12.088418, 11.736269, 11.357042, 10.950435, 10.53675, 10.11037, 9.726765, 9.375082, 9.0780735, 8.83952, 8.5913105, 8.387188, 8.201518, 8.014122, 7.8258047, 7.697151, 7.685704, 7.7083383, 7.7424707, 7.8079023, 7.8947487, 8.123806, 8.390825, 8.680347, 9.05595, 9.509546, 10.060078, 10.674885, 11.324486, 12.010243, 12.733552, 13.495831, 14.298528, 15.120244],
[9.188688, 9.559096, 9.928629, 10.296094, 10.661271, 11.023808, 11.38345, 11.739967, 12.093126, 12.442689, 12.788415, 13.130054, 13.467361, 13.800082, 14.127956, 14.450726, 14.768124, 15.079881, 15.385723, 15.685377, 15.978557, 16.264982, 16.525257, 16.754526, 16.96412, 17.154125, 17.323904, 17.472818, 17.600224, 17.70547, 17.7879, 17.846855, 17.88166, 17.89165, 17.876135, 17.834435, 17.765854, 17.669697, 17.545258, 17.39183, 17.208694, 16.995129, 16.750412, 16.473398, 16.166473, 15.846457, 15.557603, 15.2392435, 14.889774, 14.508534, 14.094866, 13.685346, 13.281353, 12.85381, 12.395729, 11.919935, 11.478256, 11.008449, 10.568577, 10.139461, 9.685379, 9.206232, 8.701813, 8.171387, 7.6270275],
[10.567877, 11.527222, 12.486235, 13.444023, 14.399701, 15.352392, 16.301235, 17.24538, 18.18399, 19.116234, 20.041302, 20.958387, 21.8667, 22.765461, 23.653902, 24.497982, 25.325535, 26.139973, 26.94056, 27.489395, 27.969696, 28.432, 28.869444, 29.27145, 29.635498, 29.960773, 30.247818, 30.492851, 30.697319, 30.861547, 30.984924, 31.066902, 31.107014, 31.10484, 31.057482, 30.962029, 30.821148, 30.633667, 30.400188, 30.120491, 29.794416, 29.421871, 29.002802, 28.536732, 28.003319, 27.415943, 26.77843, 26.090958, 25.379684, 24.600813, 23.757057, 22.914566, 22.042799, 21.127775, 20.170195, 19.170794, 18.130388, 17.049858, 15.930128, 14.772195, 13.577114, 12.346013, 11.080066, 9.780535, 8.448704]]
y3 = [[7.761956492843169, 7.813896512793237, 7.868978190552527, 7.927170688685038, 7.988459716671647, 8.052815189208937, 8.120185209148516, 8.190533674126286, 8.263789131546632, 8.339920743973552, 8.418832238436117, 8.500451410206129, 8.584694020434018, 8.671444992834223, 8.76058849898859, 8.851999684887941, 8.945539406003979, 9.041059116151086, 9.13837604711101, 9.237296900809302, 9.337623866376003, 9.439108363991636, 9.541517984687314, 9.644276594902617, 9.746405294077647, 9.847446785600479, 9.946788833546597, 10.043827099383629, 10.134581485502778, 10.212777701876627, 10.272638430083559, 10.306358038152277, 10.3110092260604, 10.285983518546336, 10.230664542956351, 10.14451001169637, 10.027071653745494, 9.878117060134834, 9.708762562443475, 9.537534690515027, 9.371518345270827, 9.234098645023028, 9.13920658951844, 9.084943420045212, 9.06938574554933, 9.090660379826849, 9.146919521148584, 9.236460905440628, 9.357619687225194, 9.508812251343716, 9.688465700526978, 9.89437572220031, 10.124899282730906, 10.377815146560927, 10.65004615125035, 10.939448342864491, 11.244768292449773, 11.564909390696684, 11.898878822592494, 12.245759738515599, 12.604683425329858, 12.974820280793576, 13.355361762377443, 13.745588455263526, 14.144841491307012],
[13.117835283377321, 13.382112862042417, 13.640213175783611, 13.89197301182987, 14.137207345565187, 14.375737169434235, 14.60738648441203, 14.831977035075834, 15.04930800202534, 15.259162018943373, 15.461362334672359, 15.655694591425462, 15.841945935681009, 16.019900505386993, 16.189334165032967, 16.34996185636906, 16.50152559791848, 16.64376289540892, 16.775861445648147, 16.897290243030874, 17.007802588069076, 17.107145764214074, 17.195096388087993, 17.271223487719404, 17.335229970737068, 17.386844317277657, 17.427573801042314, 17.45940856485202, 17.4821560627649, 17.49456023336334, 17.494271414450544, 17.479493513412578, 17.446933693790548, 17.395551004086265, 17.3170877806446, 17.209090566971973, 17.072401751564037, 16.907624544754288, 16.715265131772693, 16.49579886041323, 16.249756736281192, 15.977627645557092, 15.680162968693741, 15.360692396646716, 15.020599101240567, 14.660319319374828, 14.280271236766984, 13.883357333064097, 13.47258012154709, 13.048390881767205, 12.611271730711682, 12.161662665942984, 11.700266931428459, 11.22810141042357, 10.745838509459666, 10.254023148594872, 9.753028949690986, 9.243209415063143, 8.724903756507363, 8.198458331079223, 7.66422607699442, 7.122515744679264, 6.5736278111016215, 6.017837368754603, 5.4554200742307595],
[9.889957507360515, 10.91874363705356, 11.938656482570329, 12.949383156755298, 13.950590464873134, 14.941932049868257, 15.923050648762606, 16.893547631285912, 17.852814522176594, 18.79905974161622, 19.73105054429731, 20.647833978234175, 21.544756599120927, 22.416042028701685, 23.259017187767935, 24.065429677592643, 24.827892051922007, 25.546804445291453, 26.219445642007003, 26.844619751257866, 27.422958564415783, 27.9550261809198, 28.441309982685258, 28.88224018955101, 29.278161278241168, 29.627116199768516, 29.928318223448724, 30.1834054940528, 30.392909017186035, 30.55711309896571, 30.676136576340323, 30.747565103710727, 30.77053824139923, 30.745122177073345, 30.67095889562242, 30.54846056570325, 30.378257474422416, 30.16062189377582, 29.89625330706789, 29.585637591948796, 29.22897030289074, 28.82648911379071, 28.378351972491643, 27.88460250268349, 27.345291849382058, 26.76236051465768, 26.137096149496568, 25.469429557700863, 24.759259953504124, 24.006420363472387, 23.213416893380213, 22.390756835445657, 21.550390072185586, 20.699582204374966, 19.83987359954871, 18.971709520197365, 18.09552319469013, 17.21173581727483, 16.32075654807792, 15.42292911169093, 14.51858369431885, 13.608103883580238, 12.692303862998786, 11.771588279905444, 10.846250466008517]]
y4 = [[8.032097, 8.735841, 9.444112, 10.224908, 10.974344, 11.705074, 12.177707, 12.665221, 13.151246, 13.578028, 13.948596, 14.209378, 14.487204, 14.711556, 14.750396, 14.768389, 14.786932, 14.818327, 14.784561, 14.69032, 14.582246, 14.423483, 14.248031, 14.05462, 13.76896, 13.470395, 13.170047, 12.885495, 12.574884, 12.279283, 11.99319, 11.705953, 11.417492, 11.073726, 10.729903, 10.387326, 10.052514, 9.729682, 9.413056, 9.133272, 8.860497, 8.615723, 8.411274, 8.241976, 8.08986, 7.9471927, 7.826347, 7.7812176, 7.7511697, 7.7214727, 7.845902, 8.026198, 8.215667, 8.410281, 8.6104965, 9.020629, 9.490735, 9.983645, 10.529898, 11.132846, 11.8424635, 12.665301, 13.503483, 14.332845, 15.073415], 
[[6.8472347, 7.376135, 7.9731836, 8.574817, 9.088793, 9.619316, 10.320796, 10.819872, 11.272069, 11.772446, 12.283967, 12.795445, 13.279771, 13.720281, 14.160762, 14.622903, 15.030082, 15.384303, 15.71363, 16.022793, 16.311874, 16.60223, 16.86365, 17.087252, 17.301626, 17.518173, 17.70851, 17.79841, 17.892012, 18.02624, 18.150871, 18.147074, 18.131716, 18.117525, 18.085907, 18.01841, 17.90585, 17.793253, 17.68015, 17.548534, 17.307673, 17.056824, 16.790539, 16.505219, 16.228815, 15.921081, 15.61237, 15.293966, 14.893072, 14.442056, 13.972039, 13.52631, 13.089016, 12.655929, 12.2043085, 11.779246, 11.338997, 10.909665, 10.41186, 9.938121, 9.43746, 8.923979, 8.484108, 8.234106, 8.225737]], 
[[10.018796, 10.575811, 11.312051, 12.160485, 13.198487, 14.228258, 15.353373, 16.53209, 17.54001, 18.54263, 19.526402, 20.470276, 21.397009, 22.297655, 23.105259, 23.887775, 24.603756, 25.260538, 25.900059, 26.496464, 27.06532, 27.554428, 27.938541, 28.32125, 28.702623, 29.08074, 29.36602, 29.63119, 29.89052, 29.993046, 30.05685, 30.11472, 30.165253, 30.213022, 30.12928, 29.969486, 29.783611, 29.589394, 29.31937, 28.991768, 28.660082, 28.325, 27.988052, 27.510881, 26.977192, 26.408213, 25.827034, 25.237953, 24.636944, 23.866356, 23.05534, 22.223234, 21.372982, 20.518791, 19.689333, 18.771105, 17.78654, 16.783915, 15.694862, 14.517206, 13.407857, 12.323341, 11.281536, 10.479175, 9.846892]]]

x = np.linspace(np.finfo(float).eps, 1, 65)
for i in range(3):
    plt.figure(figsize=(5, 5))
    plt.plot(x, y0[i], marker='', ls='-', label='Truth', linewidth=4, color='gold')
    plt.scatter(x, y1[i], marker='x', label='Drnet (D)', alpha=0.7, zorder=2, color='#2b83ba', s=15)
    plt.scatter(x, y2[i], marker='h', label='Vcnet (D)', alpha=0.7, zorder=2, color='grey', s=15)
    plt.scatter(x, y3[i], marker='H', label='SCIGAN', alpha=0.7, zorder=2, color='#abdda4', s=15)
    plt.scatter(x, y4[i], marker='*', label='TransTEE', alpha=0.9, zorder=3, color='#d7191c', s=15)

    plt.grid()
    plt.legend(prop=font_legend, loc='best')
    plt.xlabel('Treatment', font1)
    plt.ylabel('Response', font1)

    plt.savefig("tr"+str(i)+"curve.pdf", bbox_inches='tight')
    plt.close()