from baselines.load_data import load_mnist_1d, Bandit_multi, load_mnist_adv,synthetic
from NeuralADv2 import NeuralADv2

import numpy as np
import os



if __name__ == '__main__':
    #dataset = ['covertype','MagicTelescope','shuttle','adult','mushroom','fashion']
    #dataset = ['covertype']
    #dataset = ['mnist']
    dataset = ['cos']
    dataset = ['covertype','MagicTelescope','shuttle','mushroom','fashion','Plants']
    #dataset = ['leaf', 'eucalyptus']
    #dataset = ['Plants']
    dataset = ['shuttle','fashion']
    dataset = ['covertype','MagicTelescope','shuttle','mushroom','fashion','Plants']
    dataset = ['shuttle']
    #dataset = ['quad', 'cos','square' ]
    for d in dataset:
   
        runing_times = 20
        regrets_all = []
        for i in range(runing_times):  

            #  #
            if d == 'mnist':
                b = load_mnist_adv()
            elif d == 'cos' or d == 'square' or d == 'quad':
                b = synthetic(d)
            else:
                b = Bandit_multi(d)
            regrets = []
            sum_regret = 0
            neuralad = NeuralADv2(b.dim, b.n_arm,gamma = b.n_arm)
            block = 500
            error = np.zeros(b.n_arm)
            count = np.zeros(b.n_arm)
            for t in range(5000):
                '''Draw input sample'''
                if t < block:
                    context, rwd, arm = b.step(-1)
                elif t%block == 0:
                    k = np.argmax(error/count)
                    print(error,count,error/count,k)
                    context, rwd, arm = b.step(k)
                else:
                    context, rwd, arm = b.step(k)
                #print(b.dim, context.shape)
                #print(arm)
                arm_select = neuralad.select(context,t)
                reward = rwd[arm_select]
#                 if t%500 in range(0,50):
#                     print(k, arm)
                count[arm] +=1
                if reward==0:
                    error[arm] += 1
                

                neuralad.update(context[arm_select], reward)
                if t<1000:
                    if t%10 == 0:
                        loss = neuralad.train(t)
                else:
                    if t%100 == 0:
                        loss = neuralad.train(t)

                regret = np.max(rwd) - reward
                sum_regret+=regret
                regrets.append(sum_regret)
                if t % 50 == 0:
                    print('{}: {:}, {:.4f}'.format(t, sum_regret, sum_regret/(t+1)))

            print("run:", i, "; ", "regret:", sum_regret)
            regrets_all.append(regrets)
        path = os.getcwd()
        np.save('{}/results/neuraladv2_1_by_t_results_{}.npy'.format(path,d), regrets_all)
