'''
testing a trained model for scheduling,
and comparing its performance with the 
Deadline Scheduling problem.
'''


import os 
import torch
import random
import operator
import itertools
import numpy as np 
import pandas as pd 
import scipy.special
import sys
sys.path.insert(0,'../')
from neurwin import fcnn 
import matplotlib.pyplot as plt
from qlearning import qLearningAgent
from reinforce import reinforceFcnn
from envs.deadlineSchedulingEnv import deadlineSchedulingEnv


###########################-CONSTANT VALUES-########################################
STATESIZE = 2
filesSeed = 30
TIMELIMIT = 3000
SEED = 30
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
################################-PARAMETERS-########################################
numEpisodes = 1
ARMS = 100  # number of positions in the queue
SCHEDULE = 25
PROCESSINGCOST = 0.5
BATCHSIZE = 5
EPISODESEND = 1000
EPISODERANGE = 50
BETA = 0.999
REINFORCELR = 0.001
RUNS = 25

NOISY = False

if NOISY:
    directory = (f'../testResults/deadline_env/noisy_results/')
    WINNMODELDIR = (f'../trainResults/neurwin/deadline_env/noisy_version/')
    if not os.path.exists(directory):
        os.makedirs(directory)
else:
    directory = (f'../testResults/deadline_env/')
    WINNMODELDIR = (f'../trainResults/neurwin/deadline_env/')
    REINFORCEDIR = (f'../trainResults/reinforce/deadline_env/')
    if not os.path.exists(directory):
        os.makedirs(directory)  

readMeFileName = (f'{directory}'+'readme.txt')
readMeFile = open(readMeFileName, 'a')
readMeFile.write(f'\nNumber of arms: {ARMS}\n-----------')
readMeFile.close()


############################################################################
def initialize(): 

    global envSeeds, envs
    jobProbs = 0.7
    
    for i in range(ARMS):
        env = deadlineSchedulingEnv(seed=envSeeds[i], numEpisodes=numEpisodes, episodeLimit=TIMELIMIT, maxDeadline=12,
maxLoad=9, newJobProb=jobProbs, processingCost=PROCESSINGCOST, train=False, batchSize=BATCHSIZE, noiseVar=0)
        envs[i] = env

def initializeNN(): 

    global envSeeds, envs, agents
    jobProbs = 0.7

    for i in range(ARMS):
        env = deadlineSchedulingEnv(seed=envSeeds[i], numEpisodes=numEpisodes, episodeLimit=TIMELIMIT, maxDeadline=12,
maxLoad=9, newJobProb=jobProbs, processingCost=PROCESSINGCOST, train=False, batchSize=BATCHSIZE, noiseVar=0)
        agent = fcnn(stateSize=STATESIZE)
        agent.load_state_dict(torch.load(MODELNAME))
        envs[i] = env
        agents[i] = agent

def initializeREINFORCE(): 

    global envSeeds, envs
    jobProbs = 0.7

    for i in range(ARMS):
        env = deadlineSchedulingEnv(seed=envSeeds[i], numEpisodes=numEpisodes, episodeLimit=TIMELIMIT, maxDeadline=12,
maxLoad=9, newJobProb=jobProbs, processingCost=PROCESSINGCOST, train=False, batchSize=BATCHSIZE, noiseVar=0)
        envs[i] = env

def resetEnvs():
    global states, envs
    for key in envs:
        state = envs[key].reset()
        states[key] = state


def getSelection(index):

    result = []
    copyIndex = index.copy()

    for i in range(SCHEDULE):
        result.append(max(copyIndex.items(), key=operator.itemgetter(1))[0])
        del copyIndex[result[i]]

    choice = result

    return choice 

def calculateIndexNeuralNetwork():
    global indexNN, states, whittleIndex
    for key in agents:
        indexNN[key] = agents[key].forward(states[key]).detach().numpy()[0]

    choice = getSelection(indexNN)

    indexNN = {}
    return choice

def takeActionAndRecordNN(arms):
    global rewards, time, states, remainingLoad, envs

    cumReward = 0 
    for arm in arms:
        nextState, reward, done, info = envs[arm].step(1)
        cumReward += reward 
        states[arm] = nextState

    for key in envs:
        if key in arms:
            pass
        else:
            nextState, reward, done, info = envs[key].step(0)
            states[key] = nextState
            cumReward += reward

    rewards.append((BETA**time)*cumReward)
    
    for key in envs:
        if key == 0:
            remainingLoad.append(states[0][0])


def selectDeadlineIndex():
    global states, PROCESSINGCOST, index
    for key in states:
        if states[key][1] == 0:   # if B = 0
            index[key] = 0
        elif (states[key][1] <= 1) or (states[key][1] <= states[key][0] - 1): # if 1 <= B <= T - 1
            index[key] = 1 - PROCESSINGCOST
        elif (states[key][0] <= states[key][1]):  # if T <= B
            firstVal = (BETA**(states[key][0]-1))*(0.2*(states[key][1] - states[key][0] + 1)**2) 
            secondVal = (BETA**(states[key][0]-1))*(0.2*(states[key][1] - states[key][0])**2) 
            index[key] = firstVal - secondVal + 1 - PROCESSINGCOST
    
    choice = getSelection(index)

    return choice

def takeActionAndRecordDeadlineIndex(arms):
    global rewards, time, envs, states

    cumReward = 0
    # activating the selected arm
    for arm in arms:
        nextState, reward, done, info = envs[arm].step(1)
        states[arm][0] = nextState[0]
        states[arm][1] = nextState[1]
        cumReward += reward 

    for key in envs:
        if key in arms:
            pass
        else:
            nextState, reward, done, info = envs[key].step(0)
            cumReward += reward        
            states[key][0] = nextState[0]
            states[key][1] = nextState[1]

    rewards.append((BETA**time)*cumReward)


def getActionTableLength():
    global SCHEDULE
    scheduleArms = SCHEDULE
    actionTable = np.zeros(int(scipy.special.binom(ARMS, scheduleArms)))
    n = int(ARMS)
    actionTable  = list(itertools.product([0, 1], repeat=n))
    actionTable = [x for x in actionTable if not sum(x) != scheduleArms]
    
    return actionTable

def resetREINFORCEEnvs():
    global envs, state
    for key in envs:
        vals = envs[key].reset()
        val1 = vals[0]
        val2 = vals[1]
        state.append(val1)
        state.append(val2)
    state = np.array(state, dtype=np.float32)

def REINFORCETakeActionAndRecordCost():
    global rewards, state, reinforceAgent, envs, actionTable

    cumReward = 0
    stateVals = []

    action_probs = reinforceAgent.forward(state).detach().numpy()
    action = np.random.choice(np.arange(len(actionTable)), p=action_probs)
    actionVector = actionTable[action]
    #print(actionVector)
    for i in range(len(actionVector)):
        if actionVector[i] == 1:
            nextState, cost, done, info = envs[i].step(1)
            stateVals.append(nextState[0])
            stateVals.append(nextState[1])
            cumReward += cost
        else:
            nextState, cost, done, info = envs[i].step(0)
            stateVals.append(nextState[0])
            stateVals.append(nextState[1])
            cumReward += cost 

    state = stateVals
    state = np.array(state, dtype=np.float32)

    rewards.append((BETA**time)*cumReward)

remainingLoad = []

'''
###########################-    TESTING SETTINGS    -######################################
#######################- DEADLINE CLOSED-FORM INDEX SCHEDULING -#######################
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

cumReward = []
for i in range(RUNS):
    envSeeds = np.random.randint(0, 10000, size=ARMS)

    time = 0
    envs = {}
    rewards = [] # records the holding rewards from all arms at each timestep
    index = {}
    states = {}

    initialize()
    resetEnvs()

    print(f'doing the closed-form deadline index value for run {i+1}')
    while True:
        # size-aware index functions
        arm = selectDeadlineIndex()
        takeActionAndRecordDeadlineIndex(arm)
        ###############################################
        time += 1
        if time == TIMELIMIT: # all arms done
            break

    total_reward = (np.cumsum(rewards))[-1]
    cumReward.append(total_reward)

data = {'run': range(RUNS), 'cumulative_reward':cumReward}
df = pd.DataFrame(data=data)
deadlineFileName = (f'{directory}'+f'deadlineIndexResults_arms_{ARMS}_timeLimit_{TIMELIMIT}_schedule_{SCHEDULE}.csv')
df.to_csv(deadlineFileName, index=False)

'''
##########################- NEURWIN INDEX SCHEDULING -###############################

'''
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

#indexDf = pd.DataFrame(data={'timesteps': range(TIMELIMIT)}) # data for storing the index

for i in range(RUNS):
    envSeeds = np.random.randint(0, 10000, size=ARMS)
    ALLEPISODES = np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE) 
    zeroEnvs = {}
    total_reward = []
    for x in ALLEPISODES:
        EPISODESTRAINED = x
        MODELNAME = WINNMODELDIR+(f'seed_{filesSeed}_lr_0.001_batchSize_{BATCHSIZE}_trainedNumEpisodes_{EPISODESTRAINED}/trained_model.pt')
        nnFileName = directory+(f'nnIndexResults_arms_{ARMS}_batchSize_{BATCHSIZE}_run_{i}_timeLimit_{TIMELIMIT}_schedule_{SCHEDULE}.csv')

        time = 0
        envs = {}
        rewards = [] 
        index = {}
        agents = {}
        states = {}
        indexNN = {}
        whittleIndex = []
        remainingLoad = []
        initializeNN()

        if x == 0:
            zeroEnvs = envs.copy()
        else:
            envs = zeroEnvs.copy()

        resetEnvs()
        
        while True:
            # neural network index functions
            arm = calculateIndexNeuralNetwork()
            takeActionAndRecordNN(arm)
            ###############################################
            time += 1
            if time == TIMELIMIT: # all arms done
                break
            
        total_reward.append((np.cumsum(rewards))[-1])
        print(f'finish NeurWIN for trained episodes: {x}')
    data = {'episode': np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE), 'cumulative_reward':total_reward}
    df = pd.DataFrame(data=data)
    df.to_csv(nnFileName, index=False)
    print(f'finished NeurWIN scheduling for run {i+1}')


'''
##############################- REINFORCE SCHEDULING -######################################

'''
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

actionTable = getActionTableLength()

for i in range(RUNS):
    envSeeds = np.random.randint(0, 10000, size=ARMS)
    ALLEPISODES = np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE) 
    zeroEnvs = {}
    total_reward = []
    for x in ALLEPISODES:

        EPISODESTRAINED = x
        REINFORCEMODELDIR = REINFORCEDIR+(f'seed_{filesSeed}_lr_0.001_batchSize_{BATCHSIZE}_trainedNumEpisodes_{EPISODESTRAINED}/trained_model.pt')
        reinforceFileName = directory+(f'reinforceResults_arms_{ARMS}_batchSize_{BATCHSIZE}_lr_{REINFORCELR}_run_{i}_schedule_{SCHEDULE}.csv')
        time = 0
        
        envs = {}
        rewards = []
        state = []
        initializeREINFORCE()
        if x == 0:
            zeroEnvs = envs.copy()
        else:
            envs = zeroEnvs.copy()

        resetREINFORCEEnvs()
        reinforceAgent = reinforceFcnn()
        reinforceAgent.load_state_dict(torch.load(REINFORCEMODELDIR))
        
        while True:
            REINFORCETakeActionAndRecordCost()
            time += 1
            if time == TIMELIMIT: # all arms done
                break

        total_reward.append((np.cumsum(rewards))[-1])
        print(f'finished for trained episodes: {x}')

    data = {'episode': np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE), 'cumulative_reward':total_reward}
    df = pd.DataFrame(data=data)
    df.to_csv(reinforceFileName, index=False)
    print(f'finished REINFOCE scheduling for run {i+1}')

'''
############################### Q-learning offline training ########################################


def initializeQLearningTraining():
    global trainingEnvSeeds, trainingEnvs

    jobProbs = 0.7

    for i in range(ARMS):
        env = deadlineSchedulingEnv(seed=trainingEnvSeeds[i], numEpisodes=EPISODESEND, episodeLimit=TIMELIMIT, maxDeadline=12,
maxLoad=9, newJobProb=jobProbs, processingCost=PROCESSINGCOST, train=True, batchSize=BATCHSIZE, noiseVar=0)
        trainingEnvs[i] = env

def createAgents():
    global agents, trainingEnvs, stateArray

    for key in trainingEnvs:
        agents[key] = qLearningAgent(trainingEnvs[key], stateArray)

def updateAgentEnvs():
    global agents, envs 
    for key in envs:
        agents[key].env = envs[key]


def qLearningChooseArmsTraining():
    #print(f'im here')
    global states, trainingEnvs
    #print(agents)
    for key in trainingEnvs:
        index[key] = agents[key]._getLamda(states[key])

    #print(index)
    choice = getSelection(index)

    return choice

def TakeActionAndRecordQLearning(arms, episode):
    global rewards, time, trainingEnvs

    cumReward = 0

    for arm in arms:
        nextState, reward = agents[arm]._takeAction(1, episode)
        states[arm] = nextState

    for key in trainingEnvs:
        if key in arms:
            pass
        else:
            nextState, reward = agents[key]._takeAction(0, episode)
            states[key] = nextState
    #rewards.append((BETA**time)*cumReward)


def resetQLearningEnvs():
    global states, trainingEnvs
    for key in trainingEnvs:
        state = trainingEnvs[key].reset()
        states[key] = state


def createStateTable():
    global stateArray 

    for B in range(9+1): # 9 is max load
        for T in range(12+1): # 12 is max deadline
            state = [T,B]
            stateArray.append(state)  # remaining load, channel state, action

    stateArray = np.array(stateArray, dtype=np.float32)


'''
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


TIMELIMIT = 3000 # for training, it's 3000 
episodeRange = np.arange(0, (EPISODESEND + EPISODERANGE), EPISODERANGE) 
trainingEnvs = {}
states = {}
new_states = {}
agents = {}
index = {} 
stateArray = []
trainingEnvSeeds = np.random.randint(0, 10000, size=ARMS)

createStateTable()
initializeQLearningTraining()
createAgents()


for episode in range(EPISODESEND):
    resetQLearningEnvs()
    print(f'current episode: {episode+1}')

    if (episode == 0):
        index_vals = []
        for key in agents: 
            index_vals.append(agents[key].init_lamda_index)
        index_vals = np.array(index_vals)
        
        saveDir = (f'../trainResults/qLearning/deadline_env/arms_{ARMS}_schedule_{SCHEDULE}/episode_{episode}')
        if not os.path.exists(saveDir):
            os.makedirs(saveDir)
        
        np.save(saveDir, index_vals)  # saving current q-table
    for step in range(TIMELIMIT):

        arms = qLearningChooseArmsTraining()
        TakeActionAndRecordQLearning(arms, episode)

    if (episode+1 in episodeRange) or (episode+1 == EPISODESEND): # test lamda table up until current episode.
        print(f'saving for episode {episode+1}')
        index_vals = []
        for key in agents: 
            index_vals.append(agents[key].init_lamda_index)
        index_vals = np.array(index_vals)
        
        saveDir = (f'../trainResults/qLearning/deadline_env/arms_{ARMS}_schedule_{SCHEDULE}/episode_{episode+1}')
        if not os.path.exists(saveDir):
            os.makedirs(saveDir)
        
        np.save(saveDir, index_vals)  # saving current q-table
 
'''
######################################## Q-learning testing #################################################


def createAgentsTesting():
    global agents, trained_vals, envs, stateArray
    for key in envs:
        agents[key] = qLearningAgent(envs[key], stateArray)
        agents[key].init_lamda_index = trained_vals[key]


def qLearningTestSelectArms():
    global states, envs, index, agents
    
    for key in envs:
        index[key] = agents[key]._getLamda(states[key])

    choice = getSelection(index)
    #print(f'choice is: {choice}')
    return choice

'''
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

stateArray = []
createStateTable()

for i in range(RUNS):
    envSeeds = np.random.randint(0, 10000, size=ARMS)
    zeroEnvs = {}
    total_reward = []
    ALLEPISODES = np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE) 
    for x in ALLEPISODES:
        EPISODESTRAINED = x
        print(f'doing for episode count : {x}')
        currentTrainedModel = (f'../trainResults/qLearning/deadline_env/arms_{ARMS}_schedule_{SCHEDULE}/episode_{x}.npy')
        qLearningFileName = directory+(f'qLearningResults_arms_{ARMS}_run_{i}_schedule_{SCHEDULE}.csv')

        time = 0
        trained_vals = np.load(currentTrainedModel)
        envs = {}
        rewards = [] # records the rewards from all arms at each timestep
        index = {}
        agents = {}
        states = {}

        initialize()

        if x == 0:
            zeroEnvs = envs.copy()
        else:
            envs = zeroEnvs.copy()
        
        createAgentsTesting()
        resetEnvs()
        
        while True:
            arm = qLearningTestSelectArms()
            takeActionAndRecordNN(arm) # same function as NeurWIN
            time += 1
            if time == TIMELIMIT: # all arms done
                break
            
        total_reward.append((np.cumsum(rewards))[-1])

        print(f'finished Q-learning for trained episodes: {x}')
        print(f'Q-learning for episodes: {x}. rewards: {total_reward[-1]}')

    data = {'episode': np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE), 'cumulative_reward':total_reward}
    df = pd.DataFrame(data=data)
    df.to_csv(qLearningFileName, index=False)
    print(f'finished Q-learning deadline scheduling for run {i+1}') #case {CASE} for number of episodes: {EPISODESTRAINED}')
        
'''