'''
Main file for running training experiments for both NeurWIN and REINFORCE.
For selecting the training case, uncomment its portion in the file.
'''

import os 
import gym
import time
import torch
import random
import numpy as np 
import pandas as pd
from reinforce import REINFORCE
from neurwin import NEURWIN, fcnn
from envs.sizeAwareIndexEnv import sizeAwareIndexEnv
from envs.recoveringBanditsEnv import recoveringBanditsEnv
from envs.deadlineSchedulingEnv import deadlineSchedulingEnv
from envs.sizeAwareIndexMultipleArmsEnv import sizeAwareIndexMultipleArmsEnv
from envs.recoveringBanditsMultipleArmsEnv import recoveringBanditsMultipleArmsEnv
from envs.deadlineSchedulingMultipleArmsEnv import deadlineSchedulingMultipleArmsEnv


###########################PARAMETERS########################################
STATESIZE = 2
SEED = 30
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

sigmoidParam = 5       # sigmoid function sensitivity parameter. 1 for deadline. 5 for recovering. 0.01 for size-aware
BATCHSIZE = 5          # every n episodes
learningRate = 1e-03   
numEpisodes = 50000
discountFactor = 0.999

TRAIN = True
EPISODELIMIT = 100 # 3000 for size-aware and deadline. 100 for recovering 
noiseVar = 0.05
#####################################################
# for size-aware index's cases
CASE = 1
CLASSVAL = 1

if CASE == 1 and CLASSVAL == 1:
    HOLDINGCOST = 1
    GOODTRANS = 33600 
    BADTRANS = 8400
    GOODPROB = 0.5
    LOAD = 1000000

elif CASE == 2 and CLASSVAL == 1: # old case 3
    HOLDINGCOST = 5
    GOODTRANS = 33600 
    BADTRANS = 8400
    GOODPROB = 0.5
    LOAD = 1000000
elif CASE == 3 and CLASSVAL == 2: # old case 4
    HOLDINGCOST = 1
    GOODTRANS = 16800  
    BADTRANS = 8400
    GOODPROB = 0.5
    LOAD = 1000000

elif CASE == 4:  # test case for debugging 
    HOLDINGCOST = 1
    GOODTRANS = 4
    BADTRANS = 1
    GOODPROB = 0.5
    LOAD = 120
else:
    print(f' entered case not in list. Exiting...')
    exit(1)

print(f'selected case: {CASE}. selected class: {CLASSVAL}')


########################################----TRAINING SETTINGS----#########################################################

########################################## NEURWIN DEADLINE SCHEDULING ###################################################
'''

if noiseVar > 0: 
    deadlineDirectory = (f'trainResults/neurwin/deadline_env/noisy_version/')
else:
    deadlineDirectory = (f'trainResults/neurwin/deadline_env/')

deadlineEnv = deadlineSchedulingEnv(seed=SEED, numEpisodes=numEpisodes, episodeLimit=EPISODELIMIT, maxDeadline=12,
maxLoad=9, newJobProb=0.7, processingCost=0.5, train=TRAIN, batchSize=BATCHSIZE, noiseVar=noiseVar)

agent = NEURWIN(stateSize=STATESIZE,lr=learningRate, env=deadlineEnv, sigmoidParam=sigmoidParam, numEpisodes=numEpisodes,
seed=SEED, batchSize=BATCHSIZE, discountFactor=discountFactor, saveDir = deadlineDirectory, episodeSaveInterval=5)
agent.learn()

'''
########################################## NEURWIN SIZE-AWARE SCHEDULING #################################################
'''
if noiseVar > 0:
    sizeAwareDirectory = (f'trainResults/neurwin/size_aware_env/noisy_version/case_{CASE}/class_{CLASSVAL}/')
else:
    sizeAwareDirectory = (f'trainResults/neurwin/size_aware_env/case_{CASE}/class_{CLASSVAL}/')

sizeAwareEnv = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST, seed=SEED, Training=TRAIN, r1=BADTRANS,
r2=GOODTRANS, q=GOODPROB, case=CASE, classVal=CLASSVAL, noiseVar = noiseVar, 
load=LOAD, batchSize = BATCHSIZE, maxLoad = LOAD, episodeLimit=EPISODELIMIT, fixedSizeMDP=False)

agent = NEURWIN(stateSize=STATESIZE,lr=learningRate, env=sizeAwareEnv, sigmoidParam=sigmoidParam, numEpisodes=numEpisodes,
seed=SEED, batchSize=BATCHSIZE, discountFactor=discountFactor, saveDir = sizeAwareDirectory, episodeSaveInterval=1000)
agent.learn()
'''

###################################### NEURWIN RECOVERING BANDITS SCHEDULING ##############################################
'''
maxWait = 20 # maximum time before refreshing the arm
STATESIZE = 1
CASE = 'D'  # A,B,C,D different recovery functions

if CASE == 'A':
	THETA = [10., 0.2, 0.0]
elif CASE == 'B':
	THETA = [8.5, 0.4, 0.0]
elif CASE == 'C':
	THETA = [7., 0.6, 0.0]
elif CASE == 'D':
	THETA = [5.5, 0.8, 0.0]


if noiseVar > 0:
    recoveringDirectory = (f'trainResults/neurwin/recovering_bandits_env/noisy_version/recovery_function_{CASE}/')
else:
    recoveringDirectory = (f'trainResults/neurwin/recovering_bandits_env/recovery_function_{CASE}/')

os.makedirs(recoveringDirectory)
file = open(recoveringDirectory+'used_parameters.txt', 'w+')
file.write(f'Theta0, Theta1, Theta2: {THETA}\n')
file.write(f'max wait for recovery function: {maxWait}\n')
file.close()

print(f'selected theta: {THETA}')
recoveringEnv = recoveringBanditsEnv(seed=SEED, numEpisodes=numEpisodes, episodeLimit=EPISODELIMIT, train=TRAIN, 
batchSize=BATCHSIZE,thetaVals=THETA, noiseVar=noiseVar, maxWait = maxWait)

agent = NEURWIN(stateSize=STATESIZE,lr=learningRate, env=recoveringEnv, 
sigmoidParam=sigmoidParam, numEpisodes=numEpisodes,seed=SEED, batchSize=BATCHSIZE, 
discountFactor=discountFactor, saveDir = recoveringDirectory,episodeSaveInterval=100)
agent.learn()

'''
###################################### REINFORCE SIZE-AWARE SCHEDULING ##############################################

'''

reinforceSizeAwareDirectory = (f'trainResults/reinforce/size_aware_env/case_{CASE}/')

activateArms = 1

sizeAwareIndexMultipleArmsEnv = sizeAwareIndexMultipleArmsEnv(seed=SEED, numEpisodes=numEpisodes,train=TRAIN, noiseVar=0,
batchSize = BATCHSIZE, class1Arms=2, class2Arms=2, numArms=4, scheduleArms=1, case=CASE, episodeLimit=EPISODELIMIT)

reinforceAgent = REINFORCE(lr=learningRate, env=sizeAwareIndexMultipleArmsEnv, seed=SEED, activateArms = activateArms,
numEpisodes=numEpisodes, batchSize=BATCHSIZE, discountFactor=discountFactor, saveDir = reinforceSizeAwareDirectory,episodeSaveInterval=1000)


reinforceAgent.learn()
'''
###################################### REINFORCE DEADLINE SCHEDULING ##############################################
'''
reinforceDeadlineDirectory = (f'trainResults/reinforce/deadline_env/')

newJobProb = 0.7
numArms = 4
activateArms = 1
PROCESSINGCOST = 0.5
MAXDEADLINE = 12
MAXLOAD = 9
noiseVar = 0

deadlineMultipleArmsEnv = deadlineSchedulingMultipleArmsEnv(seed=SEED, numEpisodes=numEpisodes, batchSize=BATCHSIZE, 
train=True, numArms=numArms, processingCost=PROCESSINGCOST, maxDeadline=MAXDEADLINE, 
maxLoad=MAXLOAD, newJobProb=newJobProb, episodeLimit=EPISODELIMIT, scheduleArms=activateArms, noiseVar=noiseVar)

reinforceAgent = REINFORCE(lr=learningRate, env=deadlineMultipleArmsEnv, seed=SEED, activateArms = activateArms,
numEpisodes=numEpisodes, batchSize=BATCHSIZE, discountFactor=discountFactor, saveDir = reinforceDeadlineDirectory,episodeSaveInterval=5)


reinforceAgent.learn()
'''
################################# REINFORCE RECOVERING BANDITS SCHEDULING #########################################
'''
reinforceDeadlineDirectory = (f'trainResults/reinforce/recovering_env/')


NUMARMS = 4
SCHEDULEARMS = 1
NOISEVAR = 0.05
MAXWAIT = 20

recoveringMultipleArmsEnv = recoveringBanditsMultipleArmsEnv(seed=SEED, numEpisodes=numEpisodes, batchSize=BATCHSIZE,
train = True, numArms=NUMARMS, scheduleArms=SCHEDULEARMS, noiseVar=NOISEVAR, maxWait=MAXWAIT, episodeLimit=EPISODELIMIT)

reinforceAgent = REINFORCE(lr=learningRate, env=recoveringMultipleArmsEnv, seed=SEED, activateArms = SCHEDULEARMS,
numEpisodes=numEpisodes, batchSize=BATCHSIZE, discountFactor=discountFactor, saveDir = reinforceDeadlineDirectory,episodeSaveInterval=100)


reinforceAgent.learn()

'''
