
import numpy as np
import pickle
import argparse
import matplotlib.pyplot as plt
from Env import FiniteStateFiniteActionLinearMDP

def smoother(l, n=20):
    l = np.array(l)
    ll = np.zeros_like(l)
    for i in range(len(l)):
        left = max(0, i-n)
        right = min(len(l), i+n)
        ll[i] = np.mean(l[left:right])
    return ll

parser = argparse.ArgumentParser(description='Draw the reward vs. epoch')
parser.add_argument('-H', type=int, help='the length of horizon')
parser.add_argument('-S', type=int, help='the number of states')
parser.add_argument('-A', type=int, help='the number of actions')
parser.add_argument('-d', type=int, help='the number of dimensions of feature')
parser.add_argument('-env', type=str, help='the name of environment')
parser.add_argument('-k', type=float, help='the k to generate baseline')
parser.add_argument('-alpha', type=float, help='the alpha used in constraint')
parser.add_argument('-lbeta', type=float, help='the beta of LCB')
parser.add_argument('-ubeta', type=float, help='the beta of UCB')
parser.add_argument('-N1', type=int, help='the number of offline trajectories')
parser.add_argument('-N2', type=int, help='the number of online episodes')
parser.add_argument('-M', type=int, help='the number of total trials')
args = parser.parse_args()

l3 = ['hist_expectation', 'hist_random', 'h0_list', 'rho_list']

env = FiniteStateFiniteActionLinearMDP(H=args.H, S=args.S, A=args.A, d=args.d)
env.load_env(args.env)

perf, actions = env.baseline_gen(temprature_k=args.k)
print('k=', args.k)
print('boltzmann policy with temprature k average total rewards: ', perf)

temp = []
for i in range(args.M):
     with open('Results/SafeRLwithOffline_env' + args.env + '_alpha' + str(args.alpha) + '_Lbeta'+str(args.lbeta)+'_Ubeta' + str(args.ubeta) + '_N1'+str(args.N1)+'_N2'+str(args.N2)+'_tempk'+str(args.k)+'_seed'+str(i)+'.pkl', 'rb') as f:
         temp.append(pickle.load(f))


tmp = []
for i in range(3):
    temp2 = []
    for j in range(args.M):
        temp2.append(np.array(temp[j][l3[1]][i]))
    temp2 = np.array(temp2)
    tmp.append(temp2)
tmp = np.array(tmp) # tmp.shape = [3, 10, 10000]

plt.figure(dpi=200)
count = np.sum(tmp[0] < (perf * (1-args.alpha)))
plt.plot(smoother(np.mean(tmp[0], axis=0)), label='StepMix, # of violation:'+str(count), alpha = 0.8, zorder = 10, color='steelblue')
plt.plot(np.mean(tmp[0], axis=0), alpha = 0.2, color='steelblue')
        
count = np.sum(tmp[1] < (perf * (1-args.alpha)))
plt.plot(smoother(np.mean(tmp[1], axis=0)), label='EpsMix, # of violation:'+str(count), alpha = 0.8, zorder = 9, color='mediumpurple')
plt.plot(np.mean(tmp[1], axis=0), alpha = 0.2, color='mediumpurple')
        
count = np.sum(tmp[2] < (perf * (1-args.alpha)))
plt.plot(smoother(np.mean(tmp[2], axis=0)), label='LSVI-UCB, # of violation:'+str(count), alpha = 0.8, zorder = 7, color='tab:orange')
plt.plot(np.mean(tmp[2], axis=0), alpha = 0.2, color='tab:orange')
        
plt.plot(np.ones(args.N2) * perf, '-.', label='baseline: '+str(round(perf, 2)), alpha = 0.5, zorder = 4)
plt.xlabel('online epochs')
plt.ylabel('total reward')
plt.legend()
plt.savefig('figure/env' + args.env + '_alpha' + str(args.alpha) + '_Lbeta'+str(args.lbeta)+'_Ubeta' + str(args.ubeta) + '_N1'+str(args.N1)+'_N2'+str(args.N2)+'_tempk'+str(args.k)+'.png')
