import numpy as np
import pickle
from Algos import BUCB, UCBImproved, BKL, GLUE
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--Suffix', '-S', help='Suffix to save file',type=str)
parser.add_argument('--Num_Iter', '-Iter', help='Number of iterations',type=int)
parser.add_argument('--Num_Inst', '-Inst', help='Number of instances',type=int)
args =  parser.parse_args()
suffix = args.Suffix
num_iter = args.Num_Iter
num_inst = args.Num_Inst


def get_ub_lb_new(means,dist):
    max_rew_vec = np.zeros(means.shape[0]) # \mu^*_{z,u}
    opt_act = np.zeros(means.shape[0]) # k^*_{z,u}
    rew_k = np.zeros(means.shape[1]) # \mu_z(k)
    prob = np.zeros(means.shape[1]) # p_z(k)
    
    del_min = 10
    del_max = 0
    
    for i in range(means.shape[0]):
        max_rew_vec[i] = max(means[i,:])
        opt_act[i] = np.argmax(means[i,:])
        temp = np.sort(means[i,:])
        del_min = min(del_min, temp[-1] - temp[-2])
        del_max = max(del_max, temp[-1] - temp[0])
        
    max_rew = np.average(max_rew_vec, weights = dist) # \mu_z
    
    for i in range(means.shape[1]):
        temp1,temp2 = 0,0
        for j in range(means.shape[0]):
            if opt_act[j] == i:
                temp1 += max_rew_vec[j]*dist[j]
                temp2 += dist[j]
        rew_k[i], prob[i] = temp1, temp2
    
    ub, lb = np.zeros(means.shape[1]), np.zeros(means.shape[1])
    
    for i in range(means.shape[1]):
        ub[i] = max_rew - del_min*(1-prob[i])
        
        set_greater = []
        for j in range(means.shape[1]):
            if j != i and rew_k[j]> del_max*prob[j]: set_greater += [j]
        acc = rew_k[i]
        for j in set_greater:
            acc += rew_k[j] - del_max*prob[j] 
        lb[i] = acc
        
    return ub,lb

def get_rew(means,u_dist,var):
    inst = means[np.random.choice(np.arange(means.shape[0]), p = u_dist),:]
    rew = np.zeros(means.shape[1])
    for i in range(rew.size):
        temp = np.random.normal(inst[i], var)
        rew[i] = min(1, max(temp,0))
    return rew

def get_rew_bern(means, u_dist):
    
    inst = means[np.random.choice(np.arange(means.shape[0]), p = u_dist),:]
    rew = np.zeros(means.shape[1])
    for i in range(rew.size):
        temp = np.random.rand()
        rew[i] = int(temp<=inst[i])
    return rew

def get_rew_unif(means, u_dist):
    inst = means[np.random.choice(np.arange(means.shape[0]), p=u_dist),:]
    rew = np.zeros(means.shape[1])
    bound = 0.1*np.ones(rew.size)
    for i in range(rew.size):
        while inst[i]+bound[i]>1 or inst[i]-bound[i]<0:
            bound[i] /= 2.0
        rew[i] = np.random.uniform(inst[i]-bound[i], inst[i]+bound[i])
    return rew

def run_algos(ub, lb, true_means, u_dist, num_iter, num_inst, save_path, suffix):
    
    avg = np.average(true_means, 0, u_dist)
    
    reg1, reg2 = np.zeros((num_inst, num_iter)), np.zeros((num_inst, num_iter))
    reg3, reg4 = np.zeros((num_inst, num_iter)), np.zeros((num_inst, num_iter))
    
    algo1 = BUCB(ub, lb, avg)
    algo2 = UCBImproved(ub, lb, avg)
    algo3 = BKL(ub, lb, avg)
    algo4 = GLUE(ub, lb, avg)
    
    for k in range(num_inst):
        algo1.restart()
        algo2.restart()
        algo3.restart()
        algo4.restart()
        
        for t in range(num_iter-1):
            #rew_vec = get_rew(true_means,u_dist,0.0001)
            rew_vec = get_rew_unif(true_means, u_dist)
            algo1.iterate(rew_vec)
            algo2.iterate(rew_vec)
            algo3.iterate(rew_vec)
            algo4.iterate(rew_vec)
            
        reg1[k,:], reg2[k,:] = np.asarray(algo1.cum_reg), np.asarray(algo2.cum_reg)
        reg3[k,:], reg4[k,:] = np.asarray(algo3.cum_reg), np.asarray(algo4.cum_reg)
        
    reg = dict()
    reg[1] = reg1
    reg[2] = reg2
    reg[3] = reg3
    reg[4] = reg4
    #reg[5] = reg5
    reg[6] = avg
    reg[7] = ub
    reg[8] = lb
    reg[9]  = true_means
    reg[10] = u_dist
    with open(save_path+'_'+str(suffix)+'.p', 'wb') as fp:
        pickle.dump(reg, fp, protocol=pickle.HIGHEST_PROTOCOL)
        
    return None

def filter_data(data):
    for key in data.keys():
        rew_mat = data[key][0]
        dist = (1/sum(data[key][1]))*data[key][1]
        
        mean = np.zeros(rew_mat.shape[1])
        for i in range(len(mean)):
            mean[i] = np.average(rew_mat[:,i],weights = dist)
        sort_ind = np.argsort(mean)[::-1]
        #rew_mat = np.delete(rew_mat, sort_ind[2:-15],1)
        ids = []
        np.random.seed(200)
        #while sort_ind[0] not in ids:
        ids = np.random.choice(np.arange(rew_mat.shape[1]),  size=15, replace=False)
        del_ids = [i for i in range(rew_mat.shape[1]) if i not in ids]
        rew_mat = np.delete(rew_mat, del_ids, 1)
        np.random.seed()
        for i in range(rew_mat.shape[0]):
            sort_ind = np.argsort(rew_mat[i,:])[::-1]
            while rew_mat[i,sort_ind[0]]==rew_mat[i,sort_ind[1]]:
                rew_mat[i,sort_ind[1]] -= 0.001
                sort_ind = np.argsort(rew_mat[i,:])[::-1]
        ub, lb = get_ub_lb_new(rew_mat,dist)
        
        mean = np.zeros(rew_mat.shape[1])
        for i in range(len(mean)):
            mean[i] = np.average(rew_mat[:,i],weights = dist)
            
        data[key] = [rew_mat, mean, dist, ub, lb]
    return data
                
feature = 3
with open('./Files/Movielens_data_'+str(feature)+'_hidden.p','rb') as f:
    data_raw = pickle.load(f)
data = filter_data(data_raw)

key = 'academic'
true_means = data[key][0]
means = data[key][1]
dist = data[key][2]
ub = data[key][3]
lb = data[key][4]

savepath = './RegretRuns/regfile'

run_algos(ub, lb, true_means, dist, num_iter, num_inst, savepath, suffix)