'''
Gym scheduling environment for n arms.
To be used with the REINFORCE algorithm.
Can also be used with any OpenAI baseline agent.
This environment acts as a wrapper environment for N sizeAwareIndex enviornments.
'''

import gym
import random
import itertools 
import numpy as np
import pandas as pd
import scipy.special
from gym import spaces
from envs.sizeAwareIndexEnv import sizeAwareIndexEnv
#from stable_baselines.common.env_checker import check_env #this package throws errors. it's normal

class sizeAwareIndexMultipleArmsEnv(gym.Env):
    metadata = {'render.modes':['human']}

    def __init__(self, seed, numEpisodes, batchSize, train, noiseVar,
    class1Arms, class2Arms, numArms, scheduleArms, case, episodeLimit):
        super(sizeAwareIndexMultipleArmsEnv, self).__init__()
        random.seed(seed)
        np.random.seed(seed)
        assert(case in [1,2]) # first case is homogenous user (same parameters). Case 2 is heterogenous user (different holding costs).
        assert(class1Arms+class2Arms == numArms)
        self.time = 0
        self.numEpisodes = numEpisodes
        self.episodeTime = 0
        self.currentEpisode = 0
        self.numArms = numArms
        self.class1Arms = class1Arms
        self.class2Arms = class2Arms
        self.noiseVar = self.noiseVar

        self.scheduleArms = scheduleArms
        self.case = case
        self.episodeLimit = episodeLimit
        self.batchSize = batchSize
        self.state = []
        self.envs = {}
        self.q1 = 0.75 # good channel prob for class 1
        self.q2 = 0.1 
        self.train = train

        if case == 1:
            self.goodTrans1 = 33600.0
            self.badTrans1 = 8400.0
            self.goodTrans2 = 33600.0
            self.badTrans2 = 8400.0
            self.holdingCost1 = 1.0
            self.holdingCost2 = 1.0
            self.loadClass1 = 1000000.0
            self.loadClass2 = 1000000.0
        elif case == 2:
            self.goodTrans1 = 33600.0
            self.badTrans1 = 8400.0
            self.goodTrans2 = 33600.0
            self.badTrans2 = 8400.0
            self.holdingCost1 = 5.0
            self.holdingCost2 = 1.0
            self.loadClass1 = 1000000.0
            self.loadClass2 = 1000000.0
        else:
            print('Case not in case list. exiting...')
            exit(1)
        
        self._createActionTable()
        self.envSeeds = np.random.randint(0, 10000, size=self.numArms)

        self.observationSize = numArms*2 # each arm has two elements in the state definition
        self.low = np.zeros(self.observationSize, dtype=np.float32)
        self.high = np.full(self.observationSize, 1.0, dtype=np.float32)

        self.action_space = spaces.Discrete(len(self.actionTable))
        self.observation_space = spaces.Box(self.low, self.high, dtype=np.float32)

        self.miniBatchCounter = 0
        self.prevLoads = []
        self.prevChannelState = []
        self._setTheArms()
    
    def _createActionTable(self):
        '''function to create a mapping of actions to take. Will be mapped with the action taken from the agent.
        Set of permissible actions is mapped to binomial factor N choose M'''
        self.actionTable = np.zeros(int(scipy.special.binom(self.numArms, self.scheduleArms)))
        n = int(self.numArms)
        self.actionTable  = list(itertools.product([0, 1], repeat=n))
        self.actionTable = [x for x in self.actionTable if not sum(x) != self.scheduleArms]

    def _calReward(self, action):
        '''function to calculate the total reward from taking action x. Reward is collected from all arms.'''
        actionVector = self.actionTable[action]

        cumulativeReward = 0
        state = []
        remaining = 0 # remaining total load for episode termination signal 
        counter = []

        for i in self.envs:
            if actionVector[i] == 1:
                nextState, reward, done, info = self.envs[i].step(1)
                state.append(nextState[0])
                state.append(nextState[1])
                remaining += nextState[0]
                cumulativeReward += reward
                if (nextState[0] == 0):
                    counter.append(self.envs[i].classVal)

            elif actionVector[i] == 0:
                nextState, reward, done, info = self.envs[i].step(0)
                state.append(nextState[0])
                state.append(nextState[1])
                remaining += nextState[0]
                cumulativeReward += reward
                if nextState[0] == 0:
                    counter.append(self.envs[i].classVal)

        state = np.array(state, dtype=np.float32)

        if len(counter) == self.numArms:
            val = counter[-1]
            if val == 1:
                cumulativeReward += -1*self.holdingCost1
            else:
                cumulativeReward += -1*self.holdingCost2

        return state, cumulativeReward, remaining

    def step(self, action):
        ''' standard Gym function for taking an action. Provides the next state, reward, and episode termination signal.'''
        assert self.action_space.contains(action)
        self.time += 1
        self.episodeTime += 1
 
        nextState, reward, remainingLoad = self._calReward(action)
        done = bool((remainingLoad == 0)) or (self.episodeTime == self.episodeLimit)

        if done:
            self.currentEpisode += 1
            self.episodeTime = 0
            for x in self.envs:
                self.envs[x].currentEpisode += 1
                self.envs[x].episodeTime = 0

        info = {}
        return nextState, reward, done, info 

    def _setTheArms(self):
        ''' initialization of all arms that define the MDP. The learning algorithm is trained on this set of arms.'''
        num1 = self.class1Arms
        num2 = self.class2Arms

        goodTrans1 = self.goodTrans1
        badTrans1 = self.badTrans1
        goodTrans2 = self.goodTrans2
        badTrans2 = self.badTrans2
        loadClass1 = self.loadClass1
        loadClass2 = self.loadClass2

        i = 0
        for i in range(self.numArms):
            if num1 != 0 and num2 != 0:
                classChoice = np.random.choice([1,2])
                if classChoice == 1:
                    self.envs[i] = sizeAwareIndexEnv(seed=self.envSeeds[i], numEpisodes=self.numEpisodes, HOLDINGCOST=self.holdingCost1, Training=self.train,
r1=badTrans1, r2=goodTrans1, q=self.q1, case=self.case, classVal=1, noiseVar=self.noiseVar,
batchSize=self.batchSize,load=loadClass1, maxLoad=loadClass1, episodeLimit=self.episodeLimit, fixedSizeMDP=True)
                    num1 -= 1
                elif classChoice == 2:
                    self.envs[i] = sizeAwareIndexEnv(seed=self.envSeeds[i], numEpisodes=self.numEpisodes, HOLDINGCOST=self.holdingCost2, Training=self.train,
r1=badTrans2, r2=goodTrans2, q=self.q2, case=self.case, classVal=2, noiseVar=self.noiseVar,
batchSize=self.batchSize,load=loadClass2, maxLoad=loadClass2, episodeLimit= self.episodeLimit, fixedSizeMDP=True)
                    num2 -= 1
            else:
                if num1 != 0:
                    self.envs[i] = sizeAwareIndexEnv(seed=self.envSeeds[i], numEpisodes=self.numEpisodes, HOLDINGCOST=self.holdingCost1, Training=self.train,
r1=badTrans1, r2=goodTrans1, q=self.q1, case=self.case, classVal=1, noiseVar=self.noiseVar,
batchSize=self.batchSize,load=loadClass1, maxLoad=loadClass1, episodeLimit= self.episodeLimit, fixedSizeMDP=True)
                else:
                    self.envs[i] = sizeAwareIndexEnv(seed=self.envSeeds[i], numEpisodes=self.numEpisodes, HOLDINGCOST=self.holdingCost2, Training=self.train,
r1=badTrans2, r2=goodTrans2, q=self.q2, case=self.case, classVal=2, noiseVar=self.noiseVar,
batchSize=self.batchSize,load=loadClass2, maxLoad=loadClass2, episodeLimit= self.episodeLimit, fixedSizeMDP=True)

            i += 1
        for x in self.envs:
            print(self.envs[x].classVal)

    def reset(self):
        ''' Standard Gym function for supplying initial episode state.'''
        self.state = []
        for i in self.envs:
            vals = self.envs[i].reset()
            val1 = vals[0]
            val2 = vals[1]
            self.state.append(val1)
            self.state.append(val2)
        
        self.state = np.array(self.state, dtype=np.float32)

        return self.state

##############################################################################################
'''
For environment validation purposes, the below code checks if the nextstate, reward matches
what is expected given a dummy action.
'''
'''
numEpisodes = 10
SEED = 50
training = True
class1Arms = 2
class2Arms = 2
numArms = 4
case = 2
BATCHSIZE = 5
howManyToSchedule = 1
EPISODELIMIT = 1000
env = sizeAwareIndexMultipleArmsEnv(numEpisodes=numEpisodes, seed=SEED, batchSize = BATCHSIZE, train=training,
    class1Arms=class1Arms, class2Arms=class2Arms, numArms=numArms, scheduleArms = howManyToSchedule, case=case, episodeLimit = EPISODELIMIT)
#check_env(env, warn=True)
observation = env.reset()

x = np.array([0,1,2,3]) # dummy actions for N = 4, M = 1.
x = np.tile(x, 2000000)

n_steps = np.size(x)
for step in range(n_steps):
    nextState, reward, done, info = env.step(x[step])
    print(f'action: {x[step]} next state: {nextState}  reward: {reward} done: {done}')
    print("---------------------------------------------------------")
    if done:
        print(f'Finished episode {env.currentEpisode}/{env.numEpisodes}')
        if env.currentEpisode < env.numEpisodes:
            nextState = env.reset()
        if env.currentEpisode == env.numEpisodes:
            break

#print(f'The action table: {env.actionTable}')
'''