import numpy as np
import pickle
import matplotlib.pyplot as plt
import argparse

parser=argparse.ArgumentParser()
parser.add_argument('--Case_number','-cn', help= "Case number", type=int)
parser.add_argument('--Reward_type', '-rew', help = 'bern, clippedunif or clippedgaussian', type = str)
args=parser.parse_args()
case_number = args.Case_number
rew_type = args.Reward_type
if rew_type == 'Composite':
    path1 = './Data/NewCase'+str(case_number)+'clippedunif.p'
    path2 = './Data/NewCase'+str(case_number)+'bernoulli.p' 
else:
    path = './Data/Case'+str(case_number)+str(rew_type)+'.p' 

def get_files(path):
    reg = []
    with open(path, 'rb') as fp:
            data = pickle.load(fp)
    reg += [data[1], data[2], data[3], data[4]]
    avg, ub, lb, rew_type = data[6], data[7], data[8], data[9]
    return reg, avg, ub, lb, rew_type 

def comp_reg(reg_array, avg, ub, lb, label_array, save_path, per = 20, save = False):
    plt.figure(figsize = (12,4))
    plt.rcParams["font.family"] = "sans-serif"
    plt.subplot(1,2,1)
    x = len(avg)+1
    plt.plot(range(1,x,1),avg,'ro',label='Mean')
    plt.plot(range(1,x,1),ub,'gs',label='Upper bound')
    plt.plot(range(1,x,1),lb,'bs',label='Lower bound')
    plt.vlines(range(1,x,1), ymin=ub, ymax=lb, color = 'k')
    plt.xticks(range(1,x,1))
    plt.xlabel('Arms', fontsize=14)
    plt.ylabel('Reward', fontsize=14)
    plt.grid()
    plt.legend(loc='lower left')
    
    num = len(reg_array)
    reg_median, reg_std, reg_mean,reg_lower, reg_upper = dict(),dict(),dict(),dict(),dict()
    for i in range(num):
        reg_median[i] = np.median(reg_array[i], axis = 0)
        reg_mean[i] = np.mean(np.array(reg_array[i]),axis = 0)
        reg_std[i] = np.std(np.array(reg_array[i]), axis=0)
        reg_upper[i] = np.percentile(np.array(reg_array[i]), min(100-per,per), axis = 0) 
        reg_lower[i] = np.percentile(np.array(reg_array[i]), max(per,100-per), axis = 0)
    col_mean = ['lightcoral','limegreen','tan','royalblue']
    col_med = ['lightcoral','limegreen','tan','royalblue']
    col_bet = ['mistyrose','palegreen','blanchedalmond','lavender']
    plt.subplot(1,2,2)
    for i in [0,1,2,3]:
        plt.errorbar(np.arange(len(reg_std[i])), reg_mean[i], yerr=reg_std[i], errorevery=10000, capsize=10, fmt=col_mean[i], linewidth=3.0, label=label_array[i])

    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Regret', fontsize=14)
    plt.legend(loc = 'lower right',fontsize=14)
    plt.grid()
    if save:
        plt.savefig(save_path,bbox_inches = 'tight',pad_inches = 0)
    plt.show()
    plt.close()
    return None

def get_files(path):
    reg = []
    with open(path, 'rb') as fp:
            data = pickle.load(fp)
    reg += [data[1], data[2], data[3], data[4]]
    avg, ub, lb, rew_type = data[6], data[7], data[8], data[9]
    return reg, avg, ub, lb, rew_type 

def comp_reg_composite(reg_array1, reg_array2, avg, ub, lb, label_array, save_path, per = 20, save = False):
    plt.figure(figsize = (16,4))
    plt.rcParams["font.family"] = "sans-serif"
    plt.subplot(1,3,1)
    x = len(avg)+1
    plt.plot(range(1,x,1),avg,'ro',label='Mean')
    plt.plot(range(1,x,1),ub,'gs',label='Upper bound')
    plt.plot(range(1,x,1),lb,'bs',label='Lower bound')
    plt.vlines(range(1,x,1), ymin=ub, ymax=lb, color = 'k')
    plt.xticks(range(1,x,1))
    plt.xlabel('Arms', fontsize=14)
    plt.ylabel('Reward', fontsize=14)
    plt.title('Arm Specifications',fontsize=15)
    plt.grid()
    plt.legend()
    
    num = len(reg_array1)
    reg_median, reg_mean, reg_std, reg_lower, reg_upper = dict(),dict(),dict(),dict(),dict()
    for i in range(num):
        reg_median[i] = np.median(reg_array1[i], axis = 0)
        reg_std[i] = np.std(np.array(reg_array1[i]), axis=0)
        reg_mean[i] = np.mean(np.array(reg_array1[i]),axis = 0)
        reg_upper[i] = np.percentile(np.array(reg_array1[i]), min(100-per,per), axis = 0) 
        reg_lower[i] = np.percentile(np.array(reg_array1[i]), max(per,100-per), axis = 0)
    col_mean = ['lightcoral','limegreen','tan','royalblue']
    col_med = ['lightcoral','limegreen','tan','royalblue']
    col_bet = ['mistyrose','palegreen','blanchedalmond','lavender']
    plt.subplot(1,3,2)
    for i in [0,1,2,3]:
        #plt.plot(np.arange(len(reg_median[i])), reg_mean[i], col_mean[i], linewidth = 3.0, label = label_array[i])
        #plt.plot(np.arange(len(reg_median[i])), reg_lower[i], color = col_med[i], alpha = 0.7)
        #plt.plot(np.arange(len(reg_median[i])), reg_upper[i], color = col_med[i], alpha = 0.7)
        #plt.fill_between(np.arange(len(reg_median[i])), reg_lower[i], reg_upper[i], color = col_bet[i], alpha = 0.5)
        plt.errorbar(np.arange(len(reg_std[i])), reg_mean[i], yerr=reg_std[i], errorevery=10000, capsize=10, fmt=col_mean[i], linewidth=3.0, label=label_array[i])
    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Regret', fontsize=14)
    plt.legend(loc = 'lower right',fontsize=14)
    plt.grid()
    plt.title('Clipped Uniform Rewards',fontsize=15)

    num = len(reg_array2)
    reg_median, reg_mean,reg_lower, reg_upper = dict(),dict(),dict(),dict()
    for i in range(num):
        reg_median[i] = np.median(reg_array2[i], axis = 0)
        reg_std[i] = np.std(np.array(reg_array2[i]), axis=0)
        reg_mean[i] = np.mean(np.array(reg_array2[i]),axis = 0)
        reg_upper[i] = np.percentile(np.array(reg_array2[i]), min(100-per,per), axis = 0) 
        reg_lower[i] = np.percentile(np.array(reg_array2[i]), max(per,100-per), axis = 0)
    col_mean = ['lightcoral','limegreen','tan','royalblue']
    col_med = ['lightcoral','limegreen','tan','royalblue']
    col_bet = ['mistyrose','palegreen','blanchedalmond','lavender']
    plt.subplot(1,3,3)
    for i in [0,1,2,3]:
        #plt.plot(np.arange(len(reg_median[i])), reg_mean[i], col_mean[i], linewidth = 3.0, label = label_array[i])
        #plt.plot(np.arange(len(reg_median[i])), reg_lower[i], color = col_med[i], alpha = 0.7)
        #plt.plot(np.arange(len(reg_median[i])), reg_upper[i], color = col_med[i], alpha = 0.7)
        #plt.fill_between(np.arange(len(reg_median[i])), reg_lower[i], reg_upper[i], color = col_bet[i], alpha = 0.5)
        plt.errorbar(np.arange(len(reg_std[i])), reg_mean[i], yerr=reg_std[i], errorevery=10000, capsize=10, fmt=col_mean[i], linewidth=3.0, label=label_array[i])
    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Regret', fontsize=14)
    plt.legend(loc = 'lower right',fontsize=14)
    plt.title('Bernoulli Rewards',fontsize=15)
    plt.grid()
    if save:
        plt.savefig(save_path,bbox_inches = 'tight',pad_inches = 0)
    plt.show()
    plt.close()
    return None

if rew_type == 'Composite':
    reg1, avg, ub, lb, rew_type = get_files(path1)
    reg2, _, _, _, _ = get_files(path2)
    labels = ['B-UCB', 'ImprovedUCB', 'B-KL-UCB','GLUE']
    save_path = './Results/NewCase'+str(case_number)+'Composite.png'
    comp_reg_composite(reg1, reg2, avg, ub, lb, labels, save_path, per=15, save=True)
else:
    reg, avg, ub, lb, rew_type = get_files(path)
    labels = ['B-UCB', 'ImprovedUCB', 'B-KL-UCB','GLUE']
    save_path = './Results/Case'+str(case_number)+rew_type+'.png'
    comp_reg(reg, avg, ub, lb, labels, save_path, per=15, save=True)
    
