import matplotlib.pyplot as plt

def draw_result(lst_iter, loss_gnn, loss_egnn, loss_baseline, title):
    plt.plot(lst_iter, loss_gnn, '-b', label='gnn')
    plt.plot(lst_iter, loss_egnn, '-r', label='egnn')
    plt.plot(lst_iter, loss_baseline, '--g', label='baseline')

    plt.xlabel("n iteration")
    plt.legend(loc='upper left')
    plt.title(title)

    # save image
    # plt.savefig(title+".png")  # should before show method

    # show
    plt.show()



if __name__ == "__main__":


    gnn = [0.7161,0.4581,0.4035,0.0858,0.0301,0.0267,0.0257,0.0243,0.0235,0.0226,0.0218,0.0212,0.0209,0.0201,0.0199,0.0199,0.0198,0.0199,0.0196,0.0193,0.0192,0.0188,0.0184,0.0180,0.0183,0.0180,0.0179,0.0183,0.0175,0.0173,0.0182,0.0196,0.0171,0.0185,0.0187,0.0173,0.0155,0.0155,0.0183,0.0164,0.0146,0.0173,0.0172,0.0154,0.0181,0.0189,0.0186,0.0162,0.0183,0.0158,0.0156,0.0138,0.0206,0.0148,0.0141,0.0155,0.0155,0.0163,0.0127,0.0167,0.0133,0.0146,0.0194,0.0169,0.0145,0.0144,0.0129,0.0160,0.0165,0.0136,0.0123,0.0123,0.0124,0.0129,0.0122,0.0186,0.0133,0.0138,0.0123,0.0114,0.0131,0.0128,0.0123,0.0114,0.0116,0.0121,0.0121,0.0115,0.0119,0.0109,0.0112,0.0117,0.0110,0.0108,0.0110,0.0119,0.0114,0.0108,0.0121,0.0111,0.0111,0.0109,0.0100,0.0117,0.0106,0.0113,0.0118,0.0128,0.0101,0.0114,0.0106,0.0112,0.0104,0.0114,0.0097,0.0106,0.0105,0.0117,0.0103,0.0100,0.0097,0.0107,0.0097,0.0129,0.0102,0.0118,0.0119,0.0099,0.0104,0.0111,0.0095,0.0104,0.0122,0.0090,0.0114,0.0089,0.0095,0.0097,0.0088,0.0116,0.0103,0.0106,0.0100,0.0170,0.0085,0.0094,0.0107,0.0100,0.0092,0.0106,0.0084,0.0096,0.0093,0.0092,0.0102,0.0091,0.0105,0.0088,0.0110,0.0092,0.0124,0.0092,0.0099,0.0091,0.0096,0.0093,0.0085,0.0091,0.0086,0.0085,0.0078,0.0086,0.0085,0.0093,0.0084,0.0084,0.0093,0.0082,0.0073,0.0091,0.0090,0.0089,0.0095,0.0076,0.0078,0.0078,0.0077,0.0085,0.0084,0.0084,0.0085,0.0078,0.0079,0.0095,0.0072,0.0075,0.0075,0.0094,0.0081,0.0088,0.0077,0.0083,0.0080,0.0075,0.0069,0.0088,0.0080,0.0079,0.0076,0.0122,0.0084,0.0083,0.0082,0.0082,0.0078,0.0074,0.0092,0.0069,0.0089,0.0071,0.0079,0.0099,0.0079,0.0070,0.0086,0.0070,0.0093,0.0075,0.0078,0.0079,0.0085,0.0088,0.0104,0.0075,0.0079,0.0078,0.0080,0.0091,0.0089,0.0080,0.0079,0.0079,0.0076,0.0077,0.0071,0.0080,0.0073,0.0065,0.0082,0.0116,0.0075,0.0074,0.0073,0.0086,0.0068,0.0082,0.0079,0.0095,0.0071,0.0088,0.0104,0.0088,0.0072,0.0084,0.0068,0.0076,0.0080,0.0074,0.0077,0.0076,0.0070,0.0079,0.0079,0.0068,0.0084,0.0077,0.0100,0.0078,0.0084,0.0071,0.0072,0.0078,0.0067,0.0077,0.0075,0.0065,0.0076,0.0091,0.0074,0.0082,0.0071,0.0074,0.0072,0.0085,0.0073,0.0069,0.0067,0.0091,0.0085,0.0086,0.0077,0.0067,0.0077,0.0086,0.0074,0.0066,0.0075,0.0071,0.0078,0.0070,0.0074,0.0093,0.0074,0.0070,0.0092,0.0070,0.0071,0.0114,0.0074,0.0074,0.0071,0.0076,0.0090,0.0075,0.0073,0.0079,0.0103,0.0075,0.0081,0.0078,0.0089,0.0103,0.0071,0.0069,0.0076,0.0069,0.0067,0.0074,0.0082,0.0079,0.0081,0.0069,0.0070,0.0062,0.0072,0.0121,0.0066,0.0065,0.0075,0.0074,0.0076,0.0071,0.0078,0.0066,0.0080,0.0069,0.0066,0.0070,0.0072,0.0080,0.0078,0.0074,0.0084,0.0070,0.0078,0.0081,0.0071,0.0065,0.0071,0.0081,0.0094,0.0085,0.0070,0.0068,0.0080,0.0085,0.0070,0.0069,0.0075,0.0065,0.0071,0.0076,0.0077,0.0071,0.0066,0.0075,0.0086,0.0068,0.0065,0.0072,0.0073,0.0097,0.0066,0.0073,0.0075,0.0073,0.0067,0.0064,0.0065,0.0068]
    egnn = [0.1020,0.1055,0.1058,0.1029,0.0986,0.0948,0.0913,0.0883,0.0857,0.0836,0.0818,0.0804,0.0794,0.0786,0.0779,0.0773,0.0769,0.0767,0.0766,0.0765,0.0766,0.0770,0.0775,0.0784,0.0791,0.0798,0.0808,0.0823,0.0829,0.0846,0.0855,0.0866,0.0879,0.0893,0.0907,0.0923,0.0936,0.0947,0.0957,0.0963,0.0967,0.0974,0.0982,0.0989,0.0998,0.1008,0.1022,0.1051,0.1075,0.1093,0.1115,0.1142,0.1177,0.1219,0.1259,0.1295,0.1346,0.1396,0.1433,0.1462,0.1497,0.1530,0.1564,0.1591,0.1617,0.1643,0.1667,0.1688,0.1710,0.1730,0.1754,0.1780,0.1802,0.1824,0.1851,0.1879,0.1907,0.1933,0.1959,0.1984,0.2014,0.2045,0.2073,0.2101,0.2124,0.2148,0.2173,0.2196,0.2212,0.2227,0.2239,0.2252,0.2267,0.2283,0.2299,0.2314,0.2332,0.2352,0.2372,0.2391,0.2408,0.2424,0.2441,0.2455,0.2471,0.2489,0.2507,0.2525,0.2543,0.2561,0.2578,0.2598,0.2617,0.2635,0.2651,0.2667,0.2683,0.2694,0.2706]
    gnn = gnn[1:100]
    egnn = egnn[1:100]
    num_epochs = len(gnn)
    baseline = [0.1075] * num_epochs
    epochs = list(range(num_epochs))
    draw_result(epochs, gnn, egnn, baseline, title="Comparison")
