'''
testing a trained model for scheduling,
and comparing its performance with the 
Recovering Bandits problem.
'''

import os 
import torch
import random
import numpy as np 
import scipy.special
import itertools 
import pandas as pd 
import matplotlib.pyplot as plt
import sys
sys.path.insert(0,'../')
from neurwin import fcnn 
from envs.recoveringBanditsEnv import recoveringBanditsEnv
from qlearning import qLearningAgent 

import operator
from reinforce import reinforceFcnn


###########################-CONSTANT VALUES-########################################
SEED = 30
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
STATESIZE = 1
numEpisodes = 1
filesSeed = 30 # the seed value in the filename 
################################-PARAMETERS-########################################
SCHEDULE = 25
TIMELIMIT = 3000
NOISEVAR = 0
MAXZ = 20
BATCHSIZE = 5
EPISODESEND = 50000
EPISODERANGE = 100
BETA = 0.999
REINFORCELR = 0.001
RUNS = 1
d = 1
CASE = 5  # 1, 2, 3, 4, 5 # case 1 is also for REINFORCE

NOISY = True

if NOISY:
    directory = (f'../testResults/recovering_env/noisy_results/')
    if not os.path.exists(directory):
        os.makedirs(directory)

    WINNMODELDIR = (f'../trainResults/neurwin/recovering_bandits_env/noisy_version/')
else:
    directory = (f'../testResults/recovering_env/')
    if not os.path.exists(directory):
        os.makedirs(directory)

    WINNMODELDIR = (f'../trainResults/neurwin/recovering_bandits_env/')

REINFORCEDIR = (f'../trainResults/reinforce/recovering_env/')

modelDirs = []
THETA = []

classATheta = [10., 0.2, 0.0]
classBTheta = [8.5, 0.4, 0.0]
classCTheta = [7., 0.6, 0.0]
classDTheta = [5.5, 0.8, 0.0]

if CASE == 1: # 4 different recovery functions: class A, B, C, D respectively.
    ARMS = 4 
    THETA = [classATheta, classBTheta, classCTheta, classDTheta]
    modelDirs.append(WINNMODELDIR+f'recovery_function_A/')
    modelDirs.append(WINNMODELDIR+f'recovery_function_B/')
    modelDirs.append(WINNMODELDIR+f'recovery_function_C/')
    modelDirs.append(WINNMODELDIR+f'recovery_function_D/')

elif CASE == 2: # 6 arms: three class A, three class D.
    ARMS = 6 
    for i in range(3):
        modelDirs.append(WINNMODELDIR+f'recovery_function_A/')
        THETA.append(classATheta)
        modelDirs.append(WINNMODELDIR+f'recovery_function_D/')
        THETA.append(classDTheta)

elif CASE == 3: # 12 arms: three for each class
    ARMS = 12
    for i in range(3):
        modelDirs.append(WINNMODELDIR+f'recovery_function_A/')
        THETA.append(classATheta)
        modelDirs.append(WINNMODELDIR+f'recovery_function_B/')
        THETA.append(classBTheta)
        modelDirs.append(WINNMODELDIR+f'recovery_function_C/')
        THETA.append(classCTheta)
        modelDirs.append(WINNMODELDIR+f'recovery_function_D/')
        THETA.append(classDTheta)

elif CASE == 4: # 12 arms: six class A, six class D.
    ARMS = 12
    for i in range(6):
        modelDirs.append(WINNMODELDIR+f'recovery_function_A/')
        THETA.append(classATheta)
        modelDirs.append(WINNMODELDIR+f'recovery_function_D/')
        THETA.append(classDTheta)

elif CASE == 5: # 100 arms schedule 25: 25 class A, 25 class B, 25 class C, 25 class D.
    ARMS = 100
    for i in range(25):
        modelDirs.append(WINNMODELDIR+f'recovery_function_A/')
        THETA.append(classATheta)
        modelDirs.append(WINNMODELDIR+f'recovery_function_B/')
        THETA.append(classBTheta)
        modelDirs.append(WINNMODELDIR+f'recovery_function_C/')
        THETA.append(classCTheta)
        modelDirs.append(WINNMODELDIR+f'recovery_function_D/')
        THETA.append(classDTheta)

else:
    print(f'case not list. exiting...')
    exit(1)


readMeFileName = (f'{directory}'+'readme.txt')
readMeFile = open(readMeFileName, 'a')
readMeFile.write(f'\nNumber of arms: {ARMS}\n')
readMeFile.close()


def initialize():
    global envSeeds, envs
    for i in range(ARMS):
        env = recoveringBanditsEnv(seed=envSeeds[i], numEpisodes=numEpisodes, episodeLimit=TIMELIMIT, train=False, 
batchSize=BATCHSIZE, thetaVals=THETA[i], noiseVar=NOISEVAR, maxWait = MAXZ)

        envs[i] = env

def initializeNN():
    global envsSeeds, envs, agents

    for i in range(ARMS):
        env = recoveringBanditsEnv(seed=envSeeds[i], numEpisodes=numEpisodes, episodeLimit=TIMELIMIT, train=False, 
batchSize=BATCHSIZE, thetaVals=THETA[i], noiseVar=NOISEVAR, maxWait = MAXZ)           
        agent = fcnn(stateSize=1)
        agent.load_state_dict(torch.load(modelDirs[i]+currentTrainedModel))
        envs[i] = env
        agents[i] = agent 


def takeActionAndRecordNN(arms):
    global rewards, time, states, envs
    
    finalReward = 0
    for arm in arms:
        nextState, reward, done, info = envs[arm].step(1)
        finalReward += reward
        states[arm] = nextState
    
    for key in envs:
        if key in arms:
            pass
        else:
            nextState, redundantReward, done, info = envs[key].step(0)
            states[key] = nextState

    rewards.append((BETA**time)*finalReward)


def getSelection(index):
 
    result = []
    copyIndex = index.copy()

    for i in range(SCHEDULE):
        result.append(max(copyIndex.items(), key=operator.itemgetter(1))[0])
        del copyIndex[result[i]]

    choice = result
    return choice 

def resetEnvs():
    global states, envs
    for key in envs:
        state = envs[key].reset()
        states[key] = state


def calculateIndexNeuralNetwork():
    global indexNN, states
    for key in agents:
        indexNN[key] = agents[key].forward(states[key]).detach().numpy()[0]

    choice = getSelection(indexNN)
    indexNN = {}
    return choice

def selectArmsWhittleIndex():
    global states, index, envs
    # index is: (z+1)*f(z) - z*f(z+1)
    for key in envs:
        if envs[key].arm[0] != MAXZ:
            firstVal = (envs[key].arm[0]+1) * (envs[key]._calReward(1, envs[key].arm[0])) # reward when we activate the arm
            secondVal = envs[key].arm[0]*(envs[key]._calReward(1, envs[key].arm[0]+1))
            index[key] = firstVal - secondVal
        else:
            index[key] = envs[key]._calReward(1, envs[key].arm[0])

    choice = getSelection(index)
    index = {}

    return choice 


def activateArmsAndRecordReward(arms):
    global rewards, t, time, states, envs 
    
    finalReward = 0
    for arm in arms:
        nextState, reward, done, info = envs[arm].step(1)
        finalReward += reward
        states[arm] = nextState

    for key in envs:
        if key in arms:
            pass
        else:
            nextState, redundantReward, done, info = envs[key].step(0)
            states[key] = nextState
    rewards.append((BETA**t)*finalReward)

def lookAhead1SelectArms(): 
    global envs, states
    rewards = {}
    stateVals = [envs[i].arm[0] for i in range(len(envs))]

    for q in envs:
        rewards[q] = envs[q]._calReward(1, envs[q].arm[0])

    choice = getSelection(rewards)

    return choice

def takeAction1LookAhead(arms):
    global rewards, envs, t
    
    reward = 0

    for arm in arms:
        nextState, activationReward, done, info = envs[arm].step(1)
        reward += activationReward

    for key in envs:
        if key in arms:
            pass
        else:
            nextState, redundantReward, done, info = envs[key].step(0)

    rewards.append((BETA**t)*reward) # discounting the reward

def getActionTableLength():
    global SCHEDULE
    scheduleArms = SCHEDULE
    actionTable = np.zeros(int(scipy.special.binom(ARMS, scheduleArms)))
    n = int(ARMS)
    actionTable  = list(itertools.product([0, 1], repeat=n))
    actionTable = [x for x in actionTable if not sum(x) != scheduleArms]
    
    return actionTable

def resetREINFORCEEnvs():
    global envs, state
    for key in envs:
        val = envs[key].reset()[0]
        state.append(val)
    state = np.array(state, dtype=np.float32)

def REINFORCETakeActionAndRecordReward():
    global rewards, state, reinforceAgent, envs, actionTable

    cumReward = 0
    stateVals = []

    action_probs = reinforceAgent.forward(state).detach().numpy()
    action = np.random.choice(np.arange(len(actionTable)), p=action_probs)
    actionVector = actionTable[action]

    for i in range(len(actionVector)):
        if actionVector[i] == 1:
            nextState, reward, done, info = envs[i].step(1)
            stateVals.append(nextState[0])
            cumReward += reward
        else:
            nextState, redundantReward, done, info = envs[i].step(0)
            stateVals.append(nextState[0])

    state = stateVals
    state = np.array(state, dtype=np.float32)

    rewards.append((BETA**time)*cumReward)

def initializeREINFORCE(): 
    global envSeeds, envs

    for i in range(ARMS):
        env = recoveringBanditsEnv(seed=envSeeds[i], numEpisodes=numEpisodes, episodeLimit=TIMELIMIT, train=False, 
batchSize=BATCHSIZE, thetaVals=THETA[i], noiseVar=NOISEVAR, maxWait = MAXZ)        
        envs[i] = env

###########################-    TESTING SETTINGS    -######################################
############################- 1-lookAhead Scheduling -#####################################

# the 1-lookahead scheduling policy 
'''

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

cumReward = []

for i in range(RUNS):
    envSeeds = np.random.randint(0, 10000, size=ARMS)

    time = 0
    envs = {}
    rewards = []
    states = {}

    print(f'doing the d-lookahead rewards\' value for run {i+1}')
    initialize()
    resetEnvs()
    selectedSequence = []
    
    for t in range(0, TIMELIMIT):

        arms = lookAhead1SelectArms()
        takeAction1LookAhead(arms)

    total_reward = (np.cumsum(rewards))[-1]
    print(f'total reward: {total_reward}')
    cumReward.append(total_reward)

data = {'run': range(RUNS), 'cumulative_reward':cumReward}

df = pd.DataFrame(data=data)
dLookAheadFileName = (f'{directory}'+f'dLookAheadResults_arms_{ARMS}_timeLimit_{TIMELIMIT}_d_{d}_case_{CASE}_schedule_{SCHEDULE}.csv')
df.to_csv(dLookAheadFileName, index=False)

############################- NeurWIN Scheduling -#####################################
'''
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

for i in range(RUNS):
    envSeeds = np.random.randint(0, 10000, size=ARMS)
    ALLEPISODES = np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE) 
    zeroEnvs = {}
    total_reward = []
    for x in ALLEPISODES:
        EPISODESTRAINED = x
        print(f'doing for episode count : {x}')
        currentTrainedModel = (f'seed_{filesSeed}_lr_0.001_batchSize_{BATCHSIZE}_trainedNumEpisodes_{EPISODESTRAINED}/trained_model.pt')
        nnFileName = directory+(f'nnIndexResults_arms_{ARMS}_batchSize_{BATCHSIZE}_case_{CASE}_timeLimit_{TIMELIMIT}_schedule_{SCHEDULE}.csv')


        time = 0
        envs = {}
        rewards = [] # records the rewards from all arms at each timestep
        index = {}
        agents = {}
        states = {}
        indexNN = {}
        initializeNN()

        if x == 0:
            zeroEnvs = envs.copy()
        else:
            envs = zeroEnvs.copy()

        resetEnvs()
        
        while True:
            arm = calculateIndexNeuralNetwork()
            takeActionAndRecordNN(arm)
            time += 1
            if time == TIMELIMIT: # all arms done
                break
            
        total_reward.append((np.cumsum(rewards))[-1])

        print(f'finished NN for trained episodes: {x}')
        print(f'NeurWIN for episodes: {x}. rewards: {total_reward[-1]}')

    data = {'episode': np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE), 'cumulative_reward':total_reward}
    df = pd.DataFrame(data=data)
    df.to_csv(nnFileName, index=False)
    print(f'finished NN recovering bandits scheduling for run {i+1}') #case {CASE} for number of episodes: {EPISODESTRAINED}')

'''
############################- REINFORCE Scheduling -#####################################

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

actionTable = getActionTableLength()

for i in range(RUNS):
    envSeeds = np.random.randint(0, 10000, size=ARMS)
    ALLEPISODES = np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE) 
    zeroEnvs = {}
    total_reward = []
    for x in ALLEPISODES:

        EPISODESTRAINED = x
        REINFORCEMODELDIR = REINFORCEDIR+(f'seed_{filesSeed}_lr_{REINFORCELR}_batchSize_{BATCHSIZE}_trainedNumEpisodes_{EPISODESTRAINED}/trained_model.pt')
        reinforceFileName = directory+(f'reinforceResults_arms_{ARMS}_batchSize_{BATCHSIZE}_lr_{REINFORCELR}_run_{i}_schedule_{SCHEDULE}.csv')
        time = 0
        
        envs = {}
        rewards = []
        state = []
        initializeREINFORCE()
        if x == 0:
            zeroEnvs = envs.copy()
        else:
            envs = zeroEnvs.copy()

        resetREINFORCEEnvs()
        reinforceAgent = reinforceFcnn() # fixed for now. 4 arms with state size 2: 2x4 = 8
        reinforceAgent.load_state_dict(torch.load(REINFORCEMODELDIR))
        
        while True:
            REINFORCETakeActionAndRecordReward()
            time += 1
            if time == TIMELIMIT: # all arms done
                break

        total_reward.append((np.cumsum(rewards))[-1])
        print(f'finished for trained episodes: {x}. cumulative_reward: {total_reward[-1]}')

    data = {'episode': np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE), 'cumulative_reward':total_reward}
    df = pd.DataFrame(data=data)
    df.to_csv(reinforceFileName, index=False)
    print(f'finished REINFOCE scheduling for run {i+1}') #case {CASE} for number of episodes: {EPISODESTRAINED}')



############################### Q-learning offline training ########################################
'''

def initializeQLearningTraining():
    global trainingEnvSeeds, trainingEnvs
    for i in range(ARMS):
        env = recoveringBanditsEnv(seed=trainingEnvSeeds[i], numEpisodes=EPISODESEND, episodeLimit=TIMELIMIT, train=True, 
batchSize=BATCHSIZE, thetaVals=THETA[i], noiseVar=NOISEVAR, maxWait = MAXZ)

        trainingEnvs[i] = env

def createAgents():
    global agents, trainingEnvs, stateArray

    for key in trainingEnvs:
        agents[key] = qLearningAgent(trainingEnvs[key], stateArray)

def updateAgentEnvs():
    global agents, envs 
    for key in envs:
        agents[key].env = envs[key]


def qLearningChooseArmsTraining():
    global states, trainingEnvs
    for key in trainingEnvs:
        index[key] = agents[key]._getLamda(states[key])

    choice = getSelection(index)

    return choice

def TakeActionAndRecordQLearning(arms, episode):
    global rewards, time, trainingEnvs

    cumReward = 0

    for arm in arms:
        nextState, reward = agents[arm]._takeAction(1, episode)
        states[arm] = nextState

    for key in trainingEnvs:
        if key in arms:
            pass
        else:
            nextState, reward = agents[key]._takeAction(0, episode)
            states[key] = nextState


def resetQLearningEnvs():
    global states, trainingEnvs
    for key in trainingEnvs:
        state = trainingEnvs[key].reset()
        states[key] = state


def createStateTable():
    global stateArray 
    for x in range(MAXZ): # MAXZ is max recovering function
            state = x+1
            stateArray.append(state)  

    stateArray = np.array(stateArray, dtype=np.uint32)

'''
# training process 
TIMELIMIT = 100 # for training, it's 100 
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

stateArray = []

episodeRange = np.arange(0, (EPISODESEND + EPISODERANGE), EPISODERANGE) 
trainingEnvs = {}
states = {}
new_states = {}
agents = {}
index = {} 
trainingEnvSeeds = np.random.randint(0, 10000, size=ARMS)

createStateTable()
initializeQLearningTraining()
createAgents()


for episode in range(EPISODESEND):
    resetQLearningEnvs()
    print(f'current episode: {episode+1}')

    if (episode == 0):
        index_vals = []
        for key in agents: 
            index_vals.append(agents[key].init_lamda_index)
        index_vals = np.array(index_vals)
        
        saveDir = (f'../trainResults/qLearning/recovering_bandits_env/arms_{ARMS}_schedule_{SCHEDULE}/episode_{episode}')
        if not os.path.exists(saveDir):
            os.makedirs(saveDir)
        
        np.save(saveDir, index_vals)  # saving current mapped indices

    for step in range(TIMELIMIT):

        arms = qLearningChooseArmsTraining()
        TakeActionAndRecordQLearning(arms, episode)

    if (episode+1 in episodeRange) or (episode+1 == EPISODESEND): # save lamda table up until current episode.
        print(f'saving for episode {episode+1}')
        index_vals = []
        for key in agents:
            index_vals.append(agents[key].init_lamda_index)
        index_vals = np.array(index_vals)

        saveDir = (f'../trainResults/qLearning/recovering_bandits_env/arms_{ARMS}_schedule_{SCHEDULE}/episode_{episode+1}')
        if not os.path.exists(saveDir):
            os.makedirs(saveDir)
        
        np.save(saveDir, index_vals)  # saving current q-table

'''
######################################## Q-learning testing #################################################


def createAgentsTesting():
    global agents, trained_vals, envs, stateArray
    for key in envs:
        agents[key] = qLearningAgent(envs[key], stateArray)
        agents[key].init_lamda_index =  trained_vals[key]

        #print(agents[key].init_lamda_index)

def qLearningTestSelectArms():
    global states, envs, index, agents

    for key in envs:
        #print(agents[key][:,int(states[key]-1),1] - agents[key][:,int(states[key]-1),0])
        index[key] = agents[key]._getLamda(states[key])
        #index[key] = np.argmin(np.abs(agents[key][:,int(states[key]-1),1] - agents[key][:,int(states[key]-1),0]))

    #print(index)
    choice = getSelection(index)
    #print(f'choice is: {choice}')
    return choice
'''

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

stateArray = []
createStateTable()

for i in range(RUNS):
    envSeeds = np.random.randint(0, 10000, size=ARMS)
    zeroEnvs = {}
    total_reward = []
    ALLEPISODES = np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE) 
    for x in ALLEPISODES:
        EPISODESTRAINED = x
        print(f'doing for episode count : {x}')
        currentTrainedModel = (f'../trainResults/qLearning/recovering_bandits_env/arms_{ARMS}_schedule_{SCHEDULE}/episode_{x}.npy')
        qLearningFileName = directory+(f'qLearningResults_arms_{ARMS}_run_{i}_schedule_{SCHEDULE}.csv')

        time = 0
        trained_vals = np.load(currentTrainedModel)
        envs = {}
        rewards = [] # records the rewards from all arms at each timestep
        index = {}
        agents = {}
        states = {}

        initialize()

        if x == 0:
            zeroEnvs = envs.copy()
        else:
            envs = zeroEnvs.copy()

        createAgentsTesting()
        resetEnvs()
        
        while True:
            arm = qLearningTestSelectArms()
            takeActionAndRecordNN(arm) # same function as NeurWIN
            time += 1
            if time == TIMELIMIT: # all arms done
                break

        total_reward.append((np.cumsum(rewards))[-1])

        print(f'finished Q-learning for trained episodes: {x}')
        print(f'Q-learning for episodes: {x}. rewards: {total_reward[-1]}')

    data = {'episode': np.arange(0, EPISODESEND+EPISODERANGE, EPISODERANGE), 'cumulative_reward':total_reward}
    df = pd.DataFrame(data=data)
    df.to_csv(qLearningFileName, index=False)
    print(f'finished Q-learning recovering bandits scheduling for run {i+1}') #case {CASE} for number of episodes: {EPISODESTRAINED}')
'''