import numpy as np
import matplotlib.pyplot as plt

# dataset = "MagicTelescope"
# dataset = "fashion"
# dataset = "mnist"
dataset = "mushroom"
plt.rcParams["figure.figsize"] = (7,4)
x = range(10000)

club1 = np.load("./regret/{}_club_regret.npy".format(dataset))
club2 = np.load("./regret/{}_club_regret_1.npy".format(dataset))
club3 = np.load("./regret/{}_club_regret_2.npy".format(dataset))
club = np.stack([club1,club2,club3],axis=0)
club_mean = np.mean(club, axis=0)
club_std = np.std(club, axis =0) 
plt.plot(x, club_mean, 'k-', color='yellow',linewidth=2.0,linestyle='-.')
plt.fill_between(x, club_mean-club_std, club_mean+club_std, facecolor='yellow', alpha=0.3)


cofiba1 = np.load("./regret/{}_{}_regret.npy".format(dataset,"cofiba"))
cofiba2 = np.load("./regret/{}_{}_regret_1.npy".format(dataset,"cofiba"))
cofiba3 = np.load("./regret/{}_{}_regret_2.npy".format(dataset,"cofiba"))
cofiba = np.stack([cofiba1,cofiba2,cofiba3],axis=0)
cofiba_mean = np.mean(cofiba, axis=0)
cofiba_std = np.std(cofiba, axis =0)
plt.plot(x, cofiba_mean, 'k-', color='m',linewidth=2.0,linestyle='-.')
plt.fill_between(x, cofiba_mean-cofiba_std, cofiba_mean+cofiba_std, facecolor='m', alpha=0.3)



sclub1 = np.load("./regret/{}_{}_regret.npy".format(dataset,"sclub"))
sclub2 = np.load("./regret/{}_{}_regret_1.npy".format(dataset,"sclub"))
sclub3 = np.load("./regret/{}_{}_regret_2.npy".format(dataset,"sclub"))
sclub = np.stack([sclub1,sclub2,sclub3],axis=0)
sclub_mean = np.mean(sclub, axis=0)
sclub_std = np.std(sclub, axis =0)
plt.plot(x, sclub_mean, 'k-', color='orange',linewidth=2.0,linestyle='-.')
plt.fill_between(x, sclub_mean-sclub_std, sclub_mean+sclub_std, facecolor='orange', alpha=0.3)



locb1 = np.load("./regret/{}_{}_regret.npy".format(dataset,"locb"))
locb2 = np.load("./regret/{}_{}_regret_1.npy".format(dataset,"locb"))
locb3 = np.load("./regret/{}_{}_regret_2.npy".format(dataset,"locb"))
locb = np.stack([locb1,locb2,locb3],axis=0)
locb_mean = np.mean(locb, axis=0)
locb_std = np.std(locb, axis =0)
plt.plot(x, locb_mean, 'k-', color='grey',linewidth=2.0,linestyle='-.')
plt.fill_between(x, locb_mean-locb_std, locb_mean+locb_std, facecolor='grey', alpha=0.3)



neuone1 = np.load("./regret/{}_{}_regret.npy".format(dataset,"neuucb_one"))
neuone2 = np.load("./regret/{}_{}_regret_1.npy".format(dataset,"neuucb_one"))
neuone3 = np.load("./regret/{}_{}_regret_2.npy".format(dataset,"neuucb_one"))
neuone = np.stack([neuone1,neuone2,neuone3],axis=0)
neuone_mean = np.mean(neuone, axis=0)
neuone_std = np.std(neuone, axis =0)

plt.plot(x, neuone_mean, 'k-', color='green',linewidth=2.0,linestyle=':')
plt.fill_between(x, neuone_mean-neuone_std, neuone_mean+neuone_std, facecolor='green', alpha=0.3)





neuind1 = np.load("./regret/{}_{}_regret.npy".format(dataset,"neuucb_ind"))
neuind2 = np.load("./regret/{}_{}_regret_1.npy".format(dataset,"neuucb_ind"))
neuind3 = np.load("./regret/{}_{}_regret_2.npy".format(dataset,"neuucb_ind"))
neuind = np.stack([neuind1,neuind2,neuind3],axis=0)
neuind_mean = np.mean(neuind, axis=0)
neuind_std = np.std(neuind, axis =0)
plt.plot(x, neuind_mean, 'k-', color='blue',linewidth=2.0,linestyle=':')
plt.fill_between(x, neuind_mean-neuind_std, neuind_mean+neuind_std, facecolor='blue', alpha=0.3)




neuind1 = np.load("./regret/{}_{}_regret.npy".format(dataset,"meta_ban"))
neuind2 = np.load("./regret/{}_{}_regret_1.npy".format(dataset,"meta_ban"))
neuind3 = np.load("./regret/{}_{}_regret_2.npy".format(dataset,"meta_ban"))
neuind = np.stack([neuind1,neuind2,neuind3],axis=0)

our = neuind
# our = np.delete(our, 2, axis=0)
#print(ee.shape)
our_mean = np.mean(our, axis=0)
our_std = np.std(our, axis =0)
plt.plot(x, our_mean, 'k-', color='red',linewidth=2.0,linestyle='-')
plt.fill_between(x, our_mean-our_std, our_mean+our_std, facecolor='red', alpha=0.3)

print(our_mean[-1], sclub_mean[-1],neuind_mean[-1])
print((-our_mean[-1]+sclub_mean[-1])/(sclub_mean[-1]))




plt.legend(["CLUB", "COFIBA",  "SCLUB", "LOCB",'NeuUCB-ONE','NeuUCB-IND',  "Meta-Ban (Ours)"])
plt.xlabel('Rounds')
plt.ylabel('Regret')
#plt.ylim((0,len(regrets)))
#plt.show()
# plt.title("Magictelescope")
# plt.title("Fashion-Mnist")
# plt.title("Mnist")
plt.title("Mushroom")

plt.savefig('regret_{}.jpg'.format(dataset), dpi=500,bbox_inches = 'tight')