from matplotlib import pyplot as plt
import numpy as np
import os


imnet_wd_small = [20.807,
20.807,
34.143,
34.143,
40.178,
40.178,
44.301,
44.301,
47.181,
47.181,
49.946,
49.946,
51.721,
51.721,
53.327,
53.327,
54.706,
54.706,
55.058,
55.058,
56.443,
56.443,
56.383,
56.383,
56.893,
56.893,
58.296,
58.296,
58.537,
58.537,
59.337,
59.337,
60.212,
60.212,
60.210,
60.210,
60.896,
60.896,
60.612,
60.612,
60.825,
60.825,
61.470,
61.470,
61.707,
61.707,
62.314,
62.314,
61.854,
61.854,
62.250,
62.250,
62.289,
62.289,
62.409,
62.409,
62.575,
62.575,
63.073,
63.073,
63.002,
63.002,
63.400,
63.400,
63.301,
63.301,
63.330,
63.330,
63.178,
63.178,
62.685,
62.685,
63.668,
63.668,
63.904,
63.904,
64.080,
64.080,
64.014,
64.014,
63.900,
63.900,
64.020,
64.020,
64.393,
64.393,
64.026,
64.026,
64.319,
64.319,
64.729,
64.729,
64.200,
64.200,
64.656,
64.656,
64.690,
64.690,
64.735,
64.735,
64.613,
64.613,
64.914,
64.914,
64.837,
64.837,
64.947,
64.947,
64.880,
64.880,
64.708,
64.708,
64.457,
64.457,
64.789,
64.789,
65.417,
65.417,
65.380,
65.380,
65.397,
65.397,
65.270,
65.270,
65.044,
65.044,
65.490,
65.490,
65.088,
65.088,
65.162,
65.162,
65.500,
65.500,
65.137,
65.137,
65.204,
65.204,
65.463,
65.463,
67.445,
67.445,
67.472,
67.472,
67.495,
67.495,
67.488,
67.488,
67.478,
67.478,
67.553,
67.553,
67.532,
67.532,
67.401,
67.401,
67.470,
67.470,
67.561,
67.561,
67.706,
67.706,
67.646,
67.646,
67.749,
67.749,
67.704,
67.704,
67.582,
67.582,
67.548,
67.548,
67.604,
67.604,
67.718,
67.718,
67.760,
67.760,
67.629,
67.629]
imnet_wd_large = [21.108,
21.108,
33.476,
33.476,
38.957,
38.957,
44.283,
44.283,
47.361,
47.361,
48.727,
48.727,
51.617,
51.617,
53.186,
53.186,
54.038,
54.038,
54.845,
54.845,
55.189,
55.189,
56.167,
56.167,
57.046,
57.046,
57.509,
57.509,
56.536,
56.536,
58.203,
58.203,
58.813,
58.813,
58.118,
58.118,
59.092,
59.092,
59.837,
59.837,
59.660,
59.660,
59.694,
59.694,
59.835,
59.835,
60.423,
60.423,
60.492,
60.492,
60.438,
60.438,
60.500,
60.500,
60.535,
60.535,
60.432,
60.432,
61.153,
61.153,
61.045,
61.045,
61.534,
61.534,
61.773,
61.773,
61.416,
61.416,
61.151,
61.151,
61.617,
61.617,
61.265,
61.265,
60.840,
60.840,
62.322,
62.322,
61.767,
61.767,
61.217,
61.217,
61.825,
61.825,
61.953,
61.953,
61.541,
61.541,
62.158,
62.158,
62.442,
62.442,
61.814,
61.814,
62.795,
62.795,
62.144,
62.144,
61.806,
61.806,
61.914,
61.914,
62.113,
62.113,
62.227,
62.227,
62.196,
62.196,
61.385,
61.385,
61.559,
61.559,
62.279,
62.279,
62.561,
62.561,
62.217,
62.217,
62.977,
62.977,
62.814,
62.814,
63.064,
63.064,
62.747,
62.747,
62.776,
62.776,
61.022,
61.022,
62.123,
62.123,
62.509,
62.509,
63.494,
63.494,
63.027,
63.027,
63.510,
63.510,
68.174,
68.174,
68.585,
68.585,
68.535,
68.535,
68.757,
68.757,
68.939,
68.939,
69.128,
69.128,
68.919,
68.919,
69.008,
69.008,
69.161,
69.161,
69.213,
69.213,
69.590,
69.590,
69.597,
69.597,
69.586,
69.586,
69.632,
69.632,
69.738,
69.738,
69.570,
69.570,
69.642,
69.642,
69.721,
69.721,
69.667,
69.667,
69.760,
69.760]

imnet_wd_small_mod = []
imnet_wd_large_mod = []

ctr = 0
for small_wd, large_wd in zip(imnet_wd_small, imnet_wd_large):
    if ctr%2 == 0:
        imnet_wd_large_mod.append(large_wd)
        imnet_wd_small_mod.append(small_wd)
    ctr = ctr + 1


imnet_small = np.array(imnet_wd_small_mod)
imnet_large = np.asarray(imnet_wd_large_mod)

plt.figure(figsize=(7, 5))

plt.plot(imnet_large, '-', label='wd-1e-2', linewidth=1.5)
plt.plot(imnet_small, '-', label='wd-1e-4', linewidth=1.5)
plt.grid()
plt.legend(fontsize=14, loc='lower right', ncol=2, prop={'size': 9})
# plt.title('Effect of weight decay on Adabelief')

y_axis_list = np.arange(35, 80, 5)
x_axis_list = range(0, 91, 10)

plt.yticks(y_axis_list)
plt.xticks(x_axis_list)
plt.ylim((35, 75))

plt.ylabel('Top-1 Accuracy')
plt.xlabel('Training epochs')
plt.savefig('imagenet_wd_effect' +'.png', bbox_inches='tight',pad_inches = 0.1, dpi = 200)