"""
testing a trained model for scheduling,
and comparing its performance with the 
size-aware whittle index heuristic.
"""
import os 
import torch
import random
import operator
import itertools 
import numpy as np
import pandas as pd 
import scipy.special
import sys
sys.path.insert(0,'../')
from neurwin import fcnn 
from qlearning import qLearningAgent
import matplotlib.pyplot as plt
from envs.sizeAwareIndexEnv import sizeAwareIndexEnv
from  reinforce import reinforceFcnn, REINFORCE # import REINFORCE

###########################-CONSTANT VALUES-########################################
STATESIZE = 2
numEpisodes = 1 # basically one iteration per run for testing
SEED = 30
filesSeed = 30
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
################################-PARAMETERS-########################################
CASE = 1
SCHEDULE = 25  # number of arms to activate\schedule in each timestep
REINFORCELR = 0.001 # selecting the REINFORCE learning rate it was trained on
BATCHSIZE = 5
ARMS = 100     # number of arms to test
BETA = 0.999

numClass1 = 50
numClass2 = 50
EPISODESEND = 1000000
EPISODERANGE = 10000
RUNS = 25
noiseVar = 0
NOISY = True # NOISY means test control policy from the noisy version 

assert(numClass1+numClass2 == ARMS)

if CASE == 1:
    HOLDINGCOST1 = 1
    HOLDINGCOST2 = 1
    MAXLOAD1 = 1000000
    MAXLOAD2 = 1000000
    TESTLOAD1 = np.random.randint(1, MAXLOAD1, size=ARMS) # random loads per run
    TESTLOAD2 = np.random.randint(1, MAXLOAD2, size=ARMS)

    GOODTRANS1 = 33600
    BADTRANS1 = 8400

    GOODTRANS2 = 33600
    BADTRANS2 = 8400

    GOODPROB1 = 0.75
    GOODPROB2 = 0.1

if CASE == 2: 
    HOLDINGCOST1 = 5
    HOLDINGCOST2 = 1
    MAXLOAD1 = 1000000
    MAXLOAD2 = 1000000
    TESTLOAD1 = np.random.randint(1, MAXLOAD1, size=ARMS)
    TESTLOAD2 = np.random.randint(1, MAXLOAD2, size=ARMS)
    GOODTRANS1 = 33600
    BADTRANS1 = 8400

    GOODTRANS2 = 33600
    BADTRANS2 = 8400

    GOODPROB1 = 0.5
    GOODPROB2 = 0.5


if NOISY:
    directory = (f'../testResults/size_aware_env/noisy_results/case_{CASE}/')
    WINNMODEL1DIR = (f'../trainResults/neurwin/size_aware_env/noisy_version/case_{CASE}/class_1/') 
    WINNMODEL2DIR = (f'../trainResults/neurwin/size_aware_env/noisy_version/case_{CASE}/class_2/') 
    if not os.path.exists(directory):
        os.makedirs(directory)   

else:
    directory = (f'../testResults/size_aware_env/case_{CASE}/')
    WINNMODEL1DIR = (f'../trainResults/neurwin/size_aware_env/case_{CASE}/class_1/') 
    WINNMODEL2DIR = (f'../trainResults/neurwin/size_aware_env/case_{CASE}/class_2/') 
    if not os.path.exists(directory):
        os.makedirs(directory)

readMeFileName = (f'{directory}'+'readme.txt')
readMeFile = open(readMeFileName, 'a')
readMeFile.write(f'\nSelected case: {CASE}\nNumber of arms: {ARMS} \nNumber of class 1 arms: {numClass1} \nNumber of class 2 arms: {numClass2}\n-----------')
readMeFile.close()


REINFORCEDIR = (f'../trainResults/reinforce/size_aware_env/case_{CASE}/')
##########################-- TESTING FUNCTIONS --#########################################

def calculateSecondaryIndex():
    global goodEnvs, goodIndex
    for i in goodEnvs:
        nuem = envs[i].holdingCost * envs[i].goodTransVal
        denom =  envs[i].arm[0][0]
        goodIndex[i] = nuem / denom
    
def getSelectionSizeAware(goodIndex, badIndex):
    result = []
    copyGoodIndex = goodIndex.copy()
    copyBadIndex = badIndex.copy()
    if len(copyGoodIndex) + len(copyBadIndex) == SCHEDULE:
        armsToActivate = SCHEDULE - len(copyGoodIndex)
    else:
        armsToActivate = len(copyBadIndex)

    armsToActivate = min(SCHEDULE, len(copyGoodIndex) + len(copyBadIndex))


    for i in range(armsToActivate):
        if len(copyGoodIndex) != 0:
            result.append(max(copyGoodIndex.items(), key=operator.itemgetter(1))[0])
            del copyGoodIndex[result[-1]]   
        else:
            result.append(max(copyBadIndex.items(), key=operator.itemgetter(1))[0])
            del copyBadIndex[result[-1]]                 

    return result

def getSelection(index):
    result = []
    copyIndex = index.copy()

    if len(copyIndex) < SCHEDULE:
        for i in range(len(copyIndex)):
            result.append(max(copyIndex.items(), key=operator.itemgetter(1))[0])
            del copyIndex[result[i]]    
    else:   
        for i in range(SCHEDULE):
            result.append(max(copyIndex.items(), key=operator.itemgetter(1))[0])
            del copyIndex[result[i]]
 
    choice = result
    return choice 

def calculatePrimaryIndex():
    global badEnvs, badIndex
    for i in badEnvs:
        nuem = envs[i].holdingCost 
        denom = envs[i].goodProb*((envs[i].goodTransVal/envs[i].badTransVal) - 1)
        badIndex[i] = nuem / denom

def initialize(): 
    global numClass1, numClass2, TESTLOAD1, TESTLOAD2, envSeeds
    num1 = numClass1
    num2 = numClass2 
    load1Index = 0
    load2Index = 0

    for i in range(ARMS):
        if num1 != 0 and num2 != 0:
            choice = np.random.choice([1,2]) # 1 is class1, 2 is class2
            if choice == 1:
                env = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST1, seed=envSeeds[i],Training=False,
                r1=BADTRANS1, r2=GOODTRANS1, q=GOODPROB1, case=CASE, classVal=choice, load=TESTLOAD1[load1Index], noiseVar = noiseVar,
                maxLoad = MAXLOAD1, batchSize=5, episodeLimit=1000000, fixedSizeMDP=False) # episode limit in testing doesn't matter
                num1 -= 1
                load1Index += 1
            else:
                env = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST2, seed=envSeeds[i],Training=False,
                r1=BADTRANS2, r2=GOODTRANS2, q=GOODPROB2, case=CASE, classVal=choice, load=TESTLOAD2[load2Index], noiseVar = noiseVar,
                maxLoad = MAXLOAD2, batchSize=5, episodeLimit=1000000, fixedSizeMDP=False)
                num2 -= 1
                load2Index += 1
        else:
            if num1 != 0:
                choice = 1 
                env = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST1, seed=envSeeds[i], Training=False,
                r1=BADTRANS1, r2=GOODTRANS1, q=GOODPROB1, case=CASE, classVal=choice, load=TESTLOAD1[load1Index], noiseVar = noiseVar,
                maxLoad = MAXLOAD1, batchSize=5, episodeLimit=1000000, fixedSizeMDP=False)
                load1Index += 1

            else:
                choice = 2
                env = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST2, seed=envSeeds[i],Training=False,
                r1=BADTRANS2, r2=GOODTRANS2, q=GOODPROB2, case=CASE, classVal=choice, load=TESTLOAD2[load2Index], noiseVar = noiseVar,
                maxLoad = MAXLOAD2, batchSize=5, episodeLimit=1000000, fixedSizeMDP=False)
                load2Index += 1

        envs[i] = env

def initializeNN(): 
    global numClass1, numClass2, TESTLOAD1, TESTLOAD2, envSeeds, envs
    num1 = numClass1
    num2 = numClass2
    load1Index = 0
    load2Index = 0
    agentSequnece = []

    for i in range(ARMS):
        
        if num1 != 0 and num2 != 0:
            choice = np.random.choice([1,2]) # 1 is class1, 2 is class2

            if choice == 1:
                env = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST1, seed=envSeeds[i], Training=False,
                r1=BADTRANS1, r2=GOODTRANS1, q=GOODPROB1, case=CASE, classVal=choice, load=TESTLOAD1[load1Index], noiseVar = noiseVar,
                maxLoad = MAXLOAD1, batchSize=5, episodeLimit=1000000, fixedSizeMDP=False)
                num1 -= 1
                load1Index += 1
                agentSequnece.append(1)
            else:
                env = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST2, seed=envSeeds[i], Training=False,
                r1=BADTRANS2, r2=GOODTRANS2, q=GOODPROB2, case=CASE, classVal=choice, load=TESTLOAD2[load2Index], noiseVar = noiseVar,
                maxLoad = MAXLOAD2, batchSize=5, episodeLimit=1000000, fixedSizeMDP=False)
                num2 -= 1
                load2Index += 1
                agentSequnece.append(2)
        else:

            if num1 != 0:
                choice = 1 
                env = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST1, seed=envSeeds[i], Training=False,
                r1=BADTRANS1, r2=GOODTRANS1, q=GOODPROB1, case=CASE, classVal=choice, load=TESTLOAD1[load1Index], noiseVar = noiseVar,
                maxLoad = MAXLOAD1, batchSize=5, episodeLimit=1000000, fixedSizeMDP=False)
                load1Index += 1
                agentSequnece.append(1)
            else:
                choice = 2
                env = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST2, seed=envSeeds[i], Training=False,
                r1=BADTRANS2, r2=GOODTRANS2, q=GOODPROB2, case=CASE, classVal=choice, load=TESTLOAD2[load2Index], noiseVar = noiseVar,
                maxLoad = MAXLOAD2, batchSize=5, episodeLimit=1000000, fixedSizeMDP=False)
                load2Index += 1
                agentSequnece.append(2)

        envs[i] = env
    return agentSequnece #agentSequence is to make sure the trained agents are linked with their respective environment

def initializeAgents(seq):
    global MODELNAME1, MODELNAME2, agents
   
    for i in range(len(seq)):
        if seq[i] == 1:
            agent = fcnn(stateSize=STATESIZE)
            agent.load_state_dict(torch.load(MODELNAME1))
            agents[i] = agent
        elif seq[i] == 2:
            agent = fcnn(stateSize=STATESIZE)
            agent.load_state_dict(torch.load(MODELNAME2))
            agents[i] = agent            

def resetEnvs():
    global states, envs
    for key in envs:
        state = envs[key].reset()
        states[key] = state

def calculateIndexNeuralNetwork():
    global indexNN, states
    for key in agents:
        indexNN[key] = agents[key].forward(states[key]).detach().numpy()[0]

    choice = getSelection(indexNN)
    
    indexNN = {}
    return choice

def selectArmSizeAwareIndex():
    global goodEnvs, badEnvs, goodIndex, badIndex, time
    for key in envs:
        if envs[key].channelState[time] == 1:
            goodEnvs.append(key)
        else:
            badEnvs.append(key)

    calculateSecondaryIndex() # calculate indices of good channels
    calculatePrimaryIndex() # calculate indices of bad channels
    arms = getSelectionSizeAware(goodIndex, badIndex)

    goodEnvs = []
    badEnvs = []
    goodIndex = {}
    badIndex = {}

    return arms

def takeActionAndRecordNN(arms):
    global rewards, time, states, remainingLoad, envs

    cumReward = 0 
    for arm in arms:
        nextState, reward, done, info = envs[arm].step(1)
        cumReward += reward 
        states[arm] = nextState

        if done:
            del envs[arm]
            del agents[arm]

    for key in envs:
        if key in arms:
            pass
        else:
            nextState, reward, done, info = envs[key].step(0)
            states[key] = nextState
            cumReward += reward

    rewards.append((BETA**time)*cumReward)
    
    for key in envs:
        if key == 0:
            remainingLoad.append(states[0][0])

def takeActionAndRecordRewardSizeAwareIndex(arms):
    global rewards, time, envs

    cumReward = 0
    # activating the selected arm(s)
    for arm in arms:
        nextState, reward, done, info = envs[arm].step(1)
        cumReward += reward 
        if done:
            del envs[arm] # remove the arm if its job terminates

    for key in envs:
        if key in arms:
            pass
        else:
            nextState, reward, done, info = envs[key].step(0)
            cumReward += reward

    rewards.append((BETA**time)*cumReward)

def getActionTableLength():
    scheduleArms = SCHEDULE
    actionTable = np.zeros(int(scipy.special.binom(ARMS, scheduleArms)))
    n = int(ARMS)
    actionTable  = list(itertools.product([0, 1], repeat=n))
    actionTable = [x for x in actionTable if not sum(x) != scheduleArms]
    
    return actionTable

def initializeREINFORCE():

    global TESTLOAD1, TESTLOAD2, envSeeds, envs, armsSeqeunce
    load1Index = 0
    load2Index = 0
    agentSequnece = []

    
    for i in range(ARMS):
        if armsSeqeunce[i] == 1:
            env = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST1, seed=envSeeds[i], noiseVar = noiseVar,
            Training=False, fixedSizeMDP=True, r1=BADTRANS1, r2=GOODTRANS1, q=GOODPROB1, case=CASE, classVal=1, 
            load=TESTLOAD1[load1Index], maxLoad = MAXLOAD1, batchSize=5, episodeLimit=1000000)
            load1Index += 1

        elif armsSeqeunce[i] == 2:
            env = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST2, seed=envSeeds[i], noiseVar = noiseVar,
            Training=False, fixedSizeMDP=True, r1=BADTRANS2, r2=GOODTRANS2, q=GOODPROB2, case=CASE, classVal=2, 
            load=TESTLOAD2[load2Index], maxLoad = MAXLOAD2, batchSize=5, episodeLimit=1000000)
            load2Index += 1

        envs[i] = env

def REINFORCETakeActionAndRecordReward():
    global rewards, state, reinforceAgent, envs, actionTable

    cumReward = 0
    stateVals = []

    action_probs = reinforceAgent.forward(state).detach().numpy()
    action = np.random.choice(np.arange(len(actionTable)), p=action_probs)
    actionVector = actionTable[action]
    for i in range(len(actionVector)):
        if actionVector[i] == 1:
            nextState, reward, done, info = envs[i].step(1)
            stateVals.append(nextState[0])
            stateVals.append(nextState[1])
            cumReward += reward
        else:
            nextState, reward, done, info = envs[i].step(0)
            stateVals.append(nextState[0])
            stateVals.append(nextState[1])
            cumReward += reward 

    state = stateVals
    state = np.array(state, dtype=np.float32)

    rewards.append(cumReward)

def resetREINFORCEEnvs():
    global envs, state
    for key in envs:
        vals = envs[key].reset()
        val1 = vals[0]
        val2 = vals[1]
        state.append(val1)
        state.append(val2)
    state = np.array(state, dtype=np.float32)


##########################TESTING-STEP######################################
#######################  SIZE-AWARE INDEX ##################################
'''
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

cumReward = []
for i in range(RUNS):
    envSeeds = np.random.randint(0, 10000, size=ARMS)

    time = 0
    envs = {}
    rewards = [] # records the negative total holding costs from all arms at each timestep
    goodEnvs = []
    badEnvs = []
    #index = {}
    goodIndex = {}
    badIndex = {}
    agents = {}
    states = {}
    indexNN = {}
    LOADS = []

    initialize()
    resetEnvs()

    while True:
        # size-aware index functions
        arms = selectArmSizeAwareIndex()
        takeActionAndRecordRewardSizeAwareIndex(arms)
        ###############################################
        time += 1
        if len(envs) == 0: # all arms done
            break

    total_reward = (np.cumsum(rewards))[-1]
    cumReward.append(total_reward)
    #print(time)
    print(f'Finished size aware index value for run {i+1} scheduling {SCHEDULE} arms')

data = {'run': range(RUNS), 'cumulative_reward':cumReward}
df = pd.DataFrame(data=data)
sizeAwareFileName = (f'{directory}'+f'sizeAwareIndexResults_arms_{ARMS}_schedule_{SCHEDULE}_arms.csv')
df.to_csv(sizeAwareFileName, index=False)
'''

#############################################################################
# NeurWIN neural network index testing


random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

for i in range(RUNS):
    envSeeds = np.random.randint(0, 10000, size=ARMS)
    ALLEPISODES = np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE) 
    zeroEnvs = {}
    total_reward = []
    for x in ALLEPISODES:

        EPISODESTRAINED = x
        MODELNAME1 = WINNMODEL1DIR+(f'seed_{filesSeed}_lr_0.001_batchSize_{BATCHSIZE}_trainedNumEpisodes_{EPISODESTRAINED}/trained_model.pt')
        MODELNAME2 = WINNMODEL2DIR+(f'seed_{filesSeed}_lr_0.001_batchSize_{BATCHSIZE}_trainedNumEpisodes_{EPISODESTRAINED}/trained_model.pt')
        nnFileName = directory+(f'nnIndexResults_arms_{ARMS}_batchSize_{BATCHSIZE}_run_{i}_schedule_{SCHEDULE}_arms.csv')

        ############################## NN INDEX TEST ####################################
        #print(f'doing the neural network index value for {EPISODESTRAINED} episodes')
        time = 0
        envs = {}
        rewards = [] # records the negative total holding costs from all arms at each timestep
        goodEnvs = []
        badEnvs = []
        index = {}
        agents = {}
        states = {}
        indexNN = {}
        remainingLoad = []
        
        agentSequence = initializeNN()
        initializeAgents(agentSequence)
        #print(f'agents: {agents}')

        if x == 0:
            zeroEnvs = envs.copy()
        else:
            envs = zeroEnvs.copy()

        resetEnvs()
        #remainingLoad.append(states[0][0]) # append the fist arm's initial load
        
        while True:
            # neural network index functions
            arms = calculateIndexNeuralNetwork()
            takeActionAndRecordNN(arms)
            ###############################################
            #print(f'current timestep: {time}')
            time += 1
            if len(envs) == 0: # all arms done
                break
       
        total_reward.append((np.cumsum(rewards))[-1])
        print(f'finished NN scheduling for episode {x}')
    data = {'episode': np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE), 'cumulative_reward':total_reward}
    df = pd.DataFrame(data=data)
    df.to_csv(nnFileName, index=False)
    print(f'finished NN scheduling for run {i+1}') #case {CASE} for number of episodes: {EPISODESTRAINED}')


############################### REINFORCE TESTING ########################################
'''
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

armsSeqeunce = [1,1,2,2] # sequence of class 1 and class 2 arms 
actionTable = getActionTableLength()

for i in range(RUNS):
    envSeeds = np.random.randint(0, 10000, size=ARMS)
    ALLEPISODES = np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE) 
    zeroEnvs = {}
    total_reward = []
    remainingLoad = 0
    for x in ALLEPISODES:

        EPISODESTRAINED = x
        REINFORCEMODELDIR = REINFORCEDIR+(f'seed_{filesSeed}_lr_{REINFORCELR}_batchSize_{BATCHSIZE}_trainedNumEpisodes_{EPISODESTRAINED}/trained_model.pt')
        reinforceFileName = directory+(f'reinforceResults_arms_{ARMS}_batchSize_{BATCHSIZE}_lr_{REINFORCELR}_run_{i}_schedule_{SCHEDULE}.csv')
        time = 0
        
        envs = {}
        rewards = []
        state = []
        initializeREINFORCE()
        if x == 0:
            zeroEnvs = envs.copy()
        else:
            envs = zeroEnvs.copy()

        resetREINFORCEEnvs()
        reinforceAgent = reinforceFcnn() 
        reinforceAgent.load_state_dict(torch.load(REINFORCEMODELDIR))
        
        while True:
            REINFORCETakeActionAndRecordReward()
            time += 1
            for b in envs:
                remainingLoad += envs[b].arm[0][0]
            if remainingLoad == 0:
                break
            remainingLoad = 0

        total_reward.append((np.cumsum(rewards))[-1])
        print(f'finished REINFORCE scheduling for episode {x}')

    data = {'episode': np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE), 'cumulative_reward':total_reward}
    df = pd.DataFrame(data=data)
    df.to_csv(reinforceFileName, index=False)
    print(f'finished REINFOCE scheduling for run {i+1}') #case {CASE} for number of episodes: {EPISODESTRAINED}')

'''
############################### Q-learning offline training ########################################
'''

def createAgents():
    global agents, envs

    for key in range(arms):
        agents[key] = qLearningAgent(envs[key])

def updateAgentEnvs():
    global agents, envs 
    for key in range(arms):
        agents[key].env = envs[key]


def qLearningChooseArms():
    global states, time

    if time == 0:
        for key in agents:
            index[key] = agents[key]._getLamda(states[key])

    else:


    choice = getSelection(index)

    return choice

def TakeActionAndRecordQLearning(arms):
    global rewards, time, envs

    cumReward = 0
    # activating the selected arm(s)
    for arm in arms:
        nextState, reward, done, info = envs[arm].step(1)
        new_state[key] = nextState
        agents[key].currentReward = reward
        agents[arm].currentAction = 1

        cumReward += reward

        if done:
            del envs[arm] # remove the arm if its job terminates

    for key in envs:
        if key in arms:
            pass
        else:
            nextState, reward, done, info = envs[key].step(0)
            new_state[key] = nextState
            agents[key].currentAction = 0
            agents[key].currentReward = reward
            cumReward += reward

    rewards.append((BETA**time)*cumReward)

    for key in envs:
        agents[key]._updateQTable(states[key], new_state[key])


random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

for i in range(RUNS):
    time = 0
    envs = {}
    agents = {}
    index = {}
    total_rewards = [] # records the negative total holding costs from all arms at each timestep
    states = {}
    new_states = {}
    envSeeds = np.random.randint(0, 10000, size=ARMS)
    
    initialize()
    createAgents()

    for episode in range(EPISODESEND):
        episode_rewards = []
        
        initialize()
        updateAgentEnvs()
        resetEnvs()

        for step in range(EPISODERANGE):

            arms = qLearningChooseArms()
            TakeActionAndRecordQLearning(arms)

            if len(envs) == 0: # all arms done (all jobs size equal zero)
                break

        print(f'finished q learning for episode {episode}')
        total_rewards.append((np.cumsum(episode_rewards))[-1])
        
    data = {'episode': np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE), 'cumulative_reward':total_rewards}
    df = pd.DataFrame(data=data)
    qLearningFileName = directory+(f'qLearningResults_arms_{ARMS}_batchSize_{BATCHSIZE}_run_{i}_schedule_{SCHEDULE}_arms.csv')
    df.to_csv(qLearningFileName, index=False)
    print(f'finished Q-learning scheduling for run {i+1}') #case {CASE} for number of episodes: {EPISODESTRAINED}')




'''