'''
Gym scheduling environment for n arms.
To be used with the REINFORCE algorithm.
Can be used with any OpenAI baseline agent.
This environment acts as a wrapper environment for N recoveringBandits Env.
'''


import gym
import time
import random
import itertools 
import numpy as np 
import pandas as pd 
import scipy.special
from gym import spaces
from envs.recoveringBanditsEnv import recoveringBanditsEnv
#from stable_baselines.common.env_checker import check_env #this package throws errors. it's normal

class recoveringBanditsMultipleArmsEnv(gym.Env):
    metadata = {'render.modes':['human']}

    def __init__(self, seed, numEpisodes, batchSize, train, numArms, 
        scheduleArms, noiseVar, maxWait, episodeLimit):
        super(recoveringBanditsMultipleArmsEnv, self).__init__()
        random.seed(seed)
        np.random.seed(seed)

        self.seed = seed
        self.time = 0
        self.numEpisodes = numEpisodes
        self.episodeLimit = episodeLimit
        self.observationSize = numArms # each arm has state-size = 1.
        self.numArms = numArms
        self.batchSize = batchSize
        self.train = train  
        self.maxWait = maxWait # maxWait is the maximum state value. Set the same for all arms
        self.noiseVar = noiseVar
        self.episodeTime = 0
        self.currentEpisode = 0
        self.scheduleArms = scheduleArms
        self._createActionTable()
        self.envs = {}

        lowState = np.ones(self.observationSize, dtype=np.float32)
        highState = np.array(np.tile(self.maxWait, self.observationSize), dtype=np.float32)

        self.action_space = spaces.Discrete(len(self.actionTable))
        self.state_space = spaces.Box(lowState, highState, dtype=np.float32)

        self.envSeeds = np.random.randint(0, 10000, size=self.numArms)
        classATheta = [10., 0.2, 0.0]
        classBTheta = [8.5, 0.4, 0.0]
        classCTheta = [7., 0.6, 0.0]
        classDTheta = [5.5, 0.8, 0.0]
        self.THETA = []

        for i in range(self.numArms): # environment length is a fixed N=4.
            self.THETA.append(classATheta)
            self.THETA.append(classBTheta)
            self.THETA.append(classCTheta)
            self.THETA.append(classDTheta)

        self._setTheArms()
    
    def _createActionTable(self):
        ''' function that creates a mapping of actions to take. Will be mapped with the action taken from the agent.'''
        self.actionTable = np.zeros(int(scipy.special.binom(self.numArms, self.scheduleArms)))
        n = int(self.numArms)
        self.actionTable  = list(itertools.product([0, 1], repeat=n))
        self.actionTable = [x for x in self.actionTable if not sum(x) != self.scheduleArms]

    def _setTheArms(self):
        ''' function that sets the N arms for training'''

        for i in range(self.numArms):

            self.envs[i] = recoveringBanditsEnv(seed=self.envSeeds[i], numEpisodes=1, episodeLimit=self.episodeLimit, train=True, 
        batchSize=self.batchSize, thetaVals=self.THETA[i], noiseVar=self.noiseVar, maxWait = self.maxWait)

    def _calculateReward(self, action):
        '''Function to calculate recovery function's reward based on supplied state.'''
        actionVector = self.actionTable[action]
        cumReward = 0
        state = []
        
        envCounter = 0
        for i in range(len(actionVector)):
            if actionVector[i] == 1:
                nextState, reward, done, info = self.envs[envCounter].step(1)
                state.append(nextState[0])
                cumReward += reward # only obtain reward from activated arm(s)
            elif actionVector[i] == 0:
                nextState, reward, done, info = self.envs[envCounter].step(0)
                state.append(nextState[0])
            envCounter += 1

        state = np.array(state, dtype=np.float32)

        return state, cumReward

    def step(self, action):
        ''' Standard Gym function for taking an action. Supplies nextstate, reward, and episode termination signal.'''
        assert self.action_space.contains(action)
        self.time += 1
        self.episodeTime += 1

        nextState, reward = self._calculateReward(action)

        done = bool(self.episodeTime == self.episodeLimit)

        if done:
            self.currentEpisode += 1
            self.episodeTime = 0

        info = {}

        return nextState, reward, done, info

    def reset(self):
        ''' Standard Gym function for supplying initial episode state.
        Episodes in the same mini-batch have the same trajectory for valid return comparison.'''
        self.state = []

        for i in self.envs:
            val = self.envs[i].reset()
            self.state.append(val[0])

        self.state = np.array(self.state, dtype=np.float32)
        #print(f'intial state: {self.state}')

        return self.state

##########################################################################################
'''
For environment validation purposes, the below code checks if the nextstate, reward matches
what is expected given a dummy action.
'''
'''
SEED = 25
x = np.array([0,1,2,3])
x = np.tile(x, 1000)
n_steps = np.size(x)

howManyToActivate = 1

env = recoveringBanditsMultipleArmsEnv(seed=SEED, numEpisodes=20, batchSize=5, train=True, numArms=4, 
        scheduleArms=1, noiseVar=0.0, maxWait=20, episodeLimit=10)
#env = deadlineMultipleArmsEnv(seed=SEED, numEpisodes=10, batchSize=5, train=True, numArms=4, processingCost=0.5, 
#        maxDeadline=12, maxLoad=9, newJobProb=0.7, episodeLimit=1, scheduleArms=howManyToActivate)

observation = env.reset()
print(env.actionTable)
#print(observation)

start = time.time()
for step in range(n_steps):
    nextState, reward, done, info = env.step(x[step])
    print(f'action: {x[step]} nextstate: {nextState}  reward: {reward} done: {done}')
    print("---------------------------------------------------------")
    if done:
        print(f'Finished episode {env.currentEpisode}/{env.numEpisodes}')
        if env.currentEpisode < env.numEpisodes:
            nextState = env.reset()
        if env.currentEpisode == env.numEpisodes:
            break
  

print(f'-------------------------------------\nDone. Time taken: {time.time() - start:.4f} seconds')
print(f'Action table: {env.actionTable}')
'''