
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 




plt.rcParams['figure.figsize'] = (9,6)
plt.rcParams['font.size'] = 14
plt.rcParams['legend.fontsize'] = 24
#plt.rcParams['text.usetex'] = True
#plt.rcParams['font.family'] = 'serif'



ARMS = 100
SCHEDULE = 25
REINFORCELR = 0.001
TIMELIMIT = 3000
CASE = 5
BATCHSIZE = 5
EPISODEEND = 50000
INTERVAL = 100


NOISY = True

plotNumEpisodesXAxis = np.arange(0, (EPISODEEND + INTERVAL), INTERVAL) 


d1LookAheadFileName = (f'../testResults/recovering_env/dLookAheadResults_arms_{ARMS}_timeLimit_{TIMELIMIT}_d_{1}_case_{CASE}_schedule_{SCHEDULE}.csv')
df = pd.read_csv(d1LookAheadFileName)
d1rewards = list(df.iloc[:,1])

d15Percentile, d195Percentile = np.percentile(d1rewards, [5, 95])


nnRewards = []

if NOISY:
    neurwinFileName = (f'../testResults/recovering_env/noisy_results/nnIndexResults_arms_{ARMS}_batchSize_{BATCHSIZE}_case_{CASE}_timeLimit_{TIMELIMIT}_schedule_{SCHEDULE}.csv')
else:
    neurwinFileName = (f'../testResults/recovering_env/nnIndexResults_arms_{ARMS}_batchSize_{BATCHSIZE}_case_{CASE}_timeLimit_{TIMELIMIT}_schedule_{SCHEDULE}.csv')
df = pd.read_csv(neurwinFileName)
nnRewards = df.iloc[:, 1]

'''
reinforceRewards = []
reinforceFileName = (f'../testResults/recovering_env/reinforceResults_arms_{ARMS}_batchSize_{BATCHSIZE}\
_lr_{REINFORCELR}_run_{0}_schedule_{SCHEDULE}.csv')
df = pd.read_csv(reinforceFileName)
reinforceRewards = df.iloc[:,1]


qLearningRewards = []

qLearningFileName = (f'../testResults/recovering_env/qLearningResults_arms_{ARMS}_run_{0}_schedule_{SCHEDULE}.csv')
df = pd.read_csv(qLearningFileName)
qLearningRewards = df.iloc[:,1]
'''

plt.plot(plotNumEpisodesXAxis, nnRewards, label='NeurWIN', color = 'b', linewidth=3.0, linestyle='solid')

#plt.plot(plotNumEpisodesXAxis, qLearningRewards, label=f'QWIC', color='g', linewidth=3.0, linestyle='dashed')
#plt.plot(plotNumEpisodesXAxis, reinforceRewards, label=f'REINFORCE', color='k', linewidth=3.0, linestyle='dotted')

plt.hlines(xmin=0, xmax=EPISODEEND, y=d1rewards, label='1-Lookahead', color='r', linewidth=3.0, linestyle='dashdot')

plt.xlabel('Number of Training Episodes', fontsize=24)
plt.ylabel('Total Discounted Rewards', fontsize=24)
plt.yticks(rotation=60)
plt.legend()

if NOISY:
    plt.savefig(f'../plotResults/recovering_results/noisy_version/recovering_scheduling_arms_{ARMS}_activate_{SCHEDULE}.pdf')
else:
    plt.savefig(f'../plotResults/recovering_results/recovering_scheduling_arms_{ARMS}_activate_{SCHEDULE}.pdf')


plt.show()



#########################PLOTTING THE RECOVERING FUNCTIONS ##################################################
'''
maxWait = 20
STYLES = ['solid','dashed','dotted','dashdot']
LABELS = ['A','B','C','D']
THETAS = [[10., 0.2, 0.0],[8.5, 0.4, 0.0],[7., 0.6, 0.0],[5.5, 0.8, 0.0]]

for x in range(len(THETAS)):
    rewards = []
    THETA = THETAS[x]

    for i in range(1,maxWait+1):

        reward = THETA[0] * (1 - np.exp(-1*THETA[1] * i + THETA[2]))
        rewards.append(reward)
    plt.plot(range(1,maxWait+1), rewards, linewidth=3.0, label=f'Recovering function {LABELS[x]}',linestyle=STYLES[x])

plt.ylabel('$f(z)$', fontsize=24)
plt.xlabel('$z \in \{1, z_{max}\}$', fontsize=24)
plt.legend()
plt.xticks(np.arange(1,maxWait+1))
#plt.grid('on')
#plt.title(f'Class {CASE} Recovery Function')
plt.savefig(f'../plotResults/recovering_results/recovering_functions.pdf')
plt.show()





'''

