
import random 
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import sys
sys.path.insert(0,'../')
from neurwin import fcnn 
import torch

plt.rcParams['figure.figsize'] = (9,6)
plt.rcParams['font.size'] = 14
plt.rcParams['legend.fontsize'] = 22
#plt.rcParams['text.usetex'] = True
#plt.rcParams['font.family'] = 'serif'


maxLoad = 9
maxTimeBeforeDeadline = 12
PROCESSINGCOST = 0.5
BETA = 0.999
MAXTRAINGEPISODE = 1000
INTERVAL = 1000
NOISY = True 

############### CONSTANT VALUES
BATCHSIZE = 5
filesSeed = 30
###############

TRAININGEPISODES = np.arange(0, MAXTRAINGEPISODE+INTERVAL, INTERVAL)
TVALS = np.arange(1, maxTimeBeforeDeadline+1)

agent = fcnn(stateSize=2)

def testNN(trainEpisode, state):
    global agent
    if NOISY:
        MODELNAME = (f'../trainResults/neurwin/deadline_env/noisy_version/\
seed_{filesSeed}_lr_0.001_batchSize_{BATCHSIZE}_trainedNumEpisodes_{trainEpisode}/trained_model.pt')
    else:
        MODELNAME = (f'../trainResults/neurwin/deadline_env/\
seed_{filesSeed}_lr_0.001_batchSize_{BATCHSIZE}_trainedNumEpisodes_{trainEpisode}/trained_model.pt')

    agent.load_state_dict(torch.load(MODELNAME))
    return agent.forward(state)

def selectDeadlineIndex(state):
    if state[1] == 0:   # if B = 0
        index = 0
    elif (state[1] <= 1) or (state[1] <= state[0] - 1): # if 1 <= B <= T - 1
        index = 1 - PROCESSINGCOST
    elif (state[0] <= state[1]):  # if T <= B
        firstVal = (BETA**(state[0]-1))*(0.2*(state[1] - state[0] + 1)**2) 
        secondVal = (BETA**(state[0]-1))*(0.2*(state[1] - state[0])**2) 
        index = firstVal - secondVal + 1 - PROCESSINGCOST

    return index


linestyles = ['dashed', 'dashdot']

for T in TVALS:
    plt.figure()
    for trainEpisode in TRAININGEPISODES:
        indices = []
        nnIndex = []
        for B in range(maxLoad+1):
            state = [T,B]
            deadlineIndex = selectDeadlineIndex(state)
            nnindex = testNN(trainEpisode, state)
            nnIndex.append(nnindex)
            indices.append(deadlineIndex)
        style = np.where(TRAININGEPISODES == trainEpisode)[0][0]
        
        plt.plot(range(maxLoad+1), nnIndex, label=f'Trained Episodes: {trainEpisode}', marker='.', linestyle=linestyles[style])
    
    plt.plot(range(maxLoad+1), indices, label=f'Deadline Whittle Index', color='r', marker='.', linestyle='solid')
    plt.legend()
    plt.xlabel('Load size B', fontsize=22)
    plt.ylabel('Index Value', fontsize=22)
    #plt.title(f'indices for state [T,B] with T = {T}. B values between 0 and 9.')
    plt.xticks(range(maxLoad+1))
    if NOISY:
    	plt.savefig(f'../plotResults/deadline_results/deadline_index_B_noisy/deadline_B_{B}_T_{T}.pdf')
    else:
        plt.savefig(f'../plotResults/deadline_results/deadline_index_B_clean/deadline_B_{B}_T_{T}.pdf')
    #plt.show()    


