import argparse
import pickle
import numpy as np
from Algos import BUCB, UCBImproved, BKL, GLUE
import time

parser=argparse.ArgumentParser()
parser.add_argument('--Case_number','-cn',help="Case number",type=int)
parser.add_argument('--Reward_type','-rew',help="Reward Distribution",type=str)
args=parser.parse_args()
case_number = args.Case_number
rew_type = args.Reward_type

def get_rew_vec_bern(avg):
    rew = np.zeros(avg.size)
    for i in range(avg.size):
        temp = np.random.rand()
        if temp<=avg[i]: rew[i] = 1
    return rew

def get_rew_vec_clippedgaussian(avg, var = 0.01):
    rew = np.zeros(avg.size)
    for i in range(avg.size):
        temp = avg[i] + np.sqrt(var)*np.random.randn()
        rew[i] = max(0, min(temp, 1))
    return rew

def get_rew_vec_clippedunif(avg):
    rew = np.zeros(avg.size)
    bound = 0.1*np.ones(avg.size)
    for i in range(avg.size):
        while avg[i]+bound[i]>1 or avg[i]-bound[i]<0:
            bound[i] /= 2.0
        rew[i] = np.random.uniform(avg[i]-bound[i], avg[i]+bound[i])
    return rew

def run_algos(ub, lb, avg, num_iter, num_inst, rew_type):
    
    reg1, reg2 = np.zeros((num_inst, num_iter)), np.zeros((num_inst, num_iter))
    reg3, reg4 = np.zeros((num_inst, num_iter)), np.zeros((num_inst, num_iter))
    
    algo1 = BUCB(ub, lb, avg)
    algo2 = UCBImproved(ub, lb, avg)
    algo3 = BKL(ub, lb, avg)
    algo4 = GLUE(ub, lb, avg)
    
    for k in range(num_inst):
        algo1.restart()
        algo2.restart()
        algo3.restart()
        algo4.restart()
        
        #if (k+1)%10 == 0:
        #    print('Instance number = ', k+1)
        
        for t in range(num_iter-1):
            if rew_type == 'bernoulli':
                rew_vec = get_rew_vec_bern(avg)
            elif rew_type == 'clippedgaussian':
                rew_vec = get_rew_vec_clippedgaussian(avg)
            elif rew_type == 'clippedunif':
                rew_vec = get_rew_vec_clippedunif(avg)
            else:
                print('Reward type not defined')
                return None
            

            algo1.iterate(rew_vec)
            algo2.iterate(rew_vec)
            algo3.iterate(rew_vec)
            algo4.iterate(rew_vec)
      
        
        reg1[k,:], reg2[k,:] = np.asarray(algo1.cum_reg), np.asarray(algo2.cum_reg)
        reg3[k,:], reg4[k,:] = np.asarray(algo3.cum_reg), np.asarray(algo4.cum_reg)
        
    return reg1, reg2, reg3, reg4

if case_number == 1:
    avg = np.asarray([0.96, 0.2, 0.5])
    ub = np.asarray([1,1,0.6])
    lb = np.asarray([0.95,0,0.4])
elif case_number == 2:
    avg = np.asarray([0.96, 0.2, 0.5 ])
    ub = np.asarray([0.98,1,0.6])
    lb = np.asarray([0.95,0,0.4])
elif case_number == 3:
    avg = np.asarray([0.5, 0.2, 0.3])
    ub = np.asarray([0.75,0.5,0.55])
    lb = np.asarray([0.3,0.1,0.15])
elif case_number == 4:
    avg = np.asarray([0.08, 0.02, 0.05 ])
    ub = np.asarray([0.11,0.09,0.075])
    lb = np.asarray([0,0,0])
elif case_number == 5:
    avg = np.asarray([0.08, 0.02, 0.05])
    ub = np.asarray([0.1,0.11,0.075])
    lb = np.asarray([0,0,0])
elif case_number == 6:
    avg = np.asarray([0.96, 0.7, 0.5 ])
    ub = np.asarray([1,1,1])
    lb = np.asarray([0.1,0.2,0.3])
elif case_number == 7:
    avg = np.asarray([0.9, 0.65, 0.2])
    ub = np.asarray([1.0, 1.0, 0.5])
    lb = np.asarray([0.85, 0.6, 0])
elif case_number == 8:
    avg = np.asarray([0.92, 0.2, 0.5 ])
    ub = np.asarray([1,1,0.6])
    lb = np.asarray([0.9,0,0.4])
  
num_iter,num_inst = int(5e4),500

reg1,reg2,reg3,reg4 = run_algos(ub, lb, avg, num_iter, num_inst, rew_type)

reg = dict()
reg[1] = reg1
reg[2] = reg2
reg[3] = reg3
reg[4] = reg4
reg[6] = avg
reg[7] = ub
reg[8] = lb
reg[9] = rew_type


path = './Data/NewCase'+str(case_number)+str(rew_type)+'.p'
with open(path, 'wb') as fp:
    pickle.dump(reg, fp, protocol=pickle.HIGHEST_PROTOCOL)
