from load_data import synthetic, Bandit_multi
from ConservativeSquareCB import ConservativeSquareCB

import numpy as np
import os


if __name__ == '__main__':

    dataset = ['covertype','MagicTelescope','shuttle','mushroom','fashion','Plants']

    for d in dataset:
   
        runing_times = 10
        regrets_all = []
        baseline_rew = []
        attained_rew = []
        for i in range(runing_times):  

            if d == 'cos' or d == 'square' or d == 'quad':
                b = synthetic(d)
            else:
                b = Bandit_multi(d)
            regrets = []
            sum_regret = 0
            neuralad = ConservativeSquareCB(b.dim, b.n_arm,alpha= 0.1)
            block = 500
            error = np.zeros(b.n_arm)
            count = np.zeros(b.n_arm)
            for t in range(5000):
                '''Draw input sample'''

                context, rwd, true_rwd, baseline = b.step()
                arm_select = neuralad.select(context,rwd,true_rwd,baseline,t)
                if arm_select == 'baseline':
                    reward = true_rwd[baseline]
                else:
                    reward = rwd[arm_select]
                    neuralad.update(context[arm_select], reward)
#                 if t%500 in range(0,50):
#                     print(k, arm)
                

                
                if t<1000:
                    if t%10 == 0:
                        loss = neuralad.train(t)
                else:
                    if t%10 == 0:
                        loss = neuralad.train(t)

                regret = reward - np.min(true_rwd)
                sum_regret+=regret
                regrets.append(sum_regret)
                if t % 50 == 0:
                    print('{}: {:}, {:.4f}'.format(t, sum_regret, sum_regret/(t+1)))

            print("run:", i, "; ", "regret:", sum_regret)
            regrets_all.append(regrets)


            baseline_rew.append(true_rwd[baseline])
            attained_rew.append(reward)
        path = os.getcwd()
        np.save('{}/results/ConservativeSquareCB_{}-new.npy'.format(path,d), regrets_all)
