'''
Environment to calculate the Whittle index values as a deep reinforcement 
learning environment modelled after the OpenAi Gym API.
From the paper: 
"Deadline Scheduling as Restless Bandits"
'''

import gym
import math
import time
import torch 
import random
import datetime 
import numpy as np
import pandas as pd
from gym import spaces
#from stable_baselines.common.env_checker import check_env #this package throws errors. it's normal. requires python 3.6.

class deadlineSchedulingEnv(gym.Env):
    metadata = {'render.modes': ['human']}
    '''
    Custom Gym environment modelled after "deadline scheduling as restless bandits" paper RMAB description.
    The environment represents one position in the N-length queue. 
    '''

    def __init__(self, seed, numEpisodes, episodeLimit, maxDeadline, maxLoad, newJobProb, 
        processingCost, train, batchSize, noiseVar):
        super(deadlineSchedulingEnv, self).__init__()
        self.seed = seed
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)

        self.observationSize = 2
        self.arm = {0:[1, 1, 1]}  # first: laxity T. Second: load B. Third: deadline d. initalized to all ones

        self.newJobProb = newJobProb
        self.noiseVar = noiseVar
        self.numEpisodes = numEpisodes
        self.currentEpisode = 0
        self.episodeTime = 0
        self.episodeLimit = episodeLimit
        self.train = train
        self.processingCost = processingCost
 
        self.maxDeadline = maxDeadline
        self.maxLoad = maxLoad
        self.batchSize = batchSize
        self.miniBatchCounter = 0
        self.loadIndex = 0

        lowState = np.zeros(self.observationSize, dtype=np.float32)
        highState = np.full(self.observationSize, 1.0, dtype=np.float32)

        self.action_space = spaces.Discrete(2) # only 0 and 1 (passivness or activation)
        self.state_space = spaces.Box(lowState, highState, dtype=np.float32)
        self.createStateTable()
        # gives the added noise value for each state sampled from a Gaussian distribution
        self.noiseVector = np.random.normal(0, self.noiseVar, np.shape(self.stateArray)[0]*2)
    
    def _calRewardState(self, action):
        ''' function to calculate the reward and next state. '''
        currentState = np.array([self.arm[0][0], self.arm[0][1]], dtype=np.float32)
        #print(f'current state: {currentState}. action: {action}')
        if action == 1:
            noise = self.noiseVector[self._findStateIndex(currentState)]
            if (self.arm[0][1] == 0) and (self.arm[0][0] == 0): # if load B is zero
                reward = 0 
                nextState = self._newArrival()
            elif (self.arm[0][1] >= 0) and (self.arm[0][0] > 1): # if B > 0 and T > 1
                reward = (1 - self.processingCost)
                self.arm[0][0] -= 1
                self.arm[0][1] -= 1
                if self.arm[0][1] < 0:
                    self.arm[0][1] = 0
                    reward = 0
                nextState = np.array([self.arm[0][0], self.arm[0][1]], dtype=np.float32)
            elif (self.arm[0][1] >= 0) and (self.arm[0][0] == 1):  # if B > 0 and T = 1
                reward = ((1 - self.processingCost) - 0.2*(((self.arm[0][1]) - 1)**2)) # (1 - c(t)) - F(B - a(t))
                self.arm[0][1] = 0
                self.arm[0][0] = 0
                nextState = np.array([self.arm[0][0], self.arm[0][1]], dtype=np.float32)

        elif action == 0:
            noise = self.noiseVector[self._findStateIndex(currentState)+np.shape(self.stateArray)[0]]
            if (self.arm[0][1] == 0)  and (self.arm[0][0] == 0):
                reward = 0
                nextState = self._newArrival()
            elif (self.arm[0][1] >= 0) and (self.arm[0][0] > 1): # if B > 0 and T > 1
                reward = 0
                self.arm[0][0] -= 1
                nextState = np.array([self.arm[0][0], self.arm[0][1]], dtype=np.float32)
            elif (self.arm[0][1] >= 0) and (self.arm[0][0] == 1):  # if B > 0 and T = 1
                reward =  -0.2*(((self.arm[0][1]))**2)  # (1 - c(t)) - F(B - a(t))
                self.arm[0][1] = 0
                self.arm[0][0] = 0
                nextState = np.array([self.arm[0][0], self.arm[0][1]], dtype=np.float32)
         
        #print(f'noise: {noise}')
        reward = reward + noise*reward
        return nextState, reward 

    def _findStateIndex(self, state):
        #print(state)
        stateLocation = np.where((self.stateArray == state).all(axis=1))[0][0]
        return stateLocation

    def createStateTable(self):
        stateArray = []

        for B in range(self.maxLoad+1): # 9 is max load
            for T in range(self.maxDeadline+1): # 12 is max deadline
                state = [T,B]
                stateArray.append(state)  # remaining load, channel state, action

        self.stateArray = np.array(stateArray, dtype=np.float32)
        #print(f'number of states: {np.shape(self.stateArray)[0]}')

    def step(self, action):
        ''' standard Gym function for taking an action. Provides the next state, reward, and episode termination signal.'''
        assert self.action_space.contains(action)
        assert action in [0,1]
        self.episodeTime += 1

        nextState, reward = self._calRewardState(action)
        
        if self.train:
            done = bool(self.episodeTime == self.episodeLimit) # end episode if we reach the time limit
        else:
            done = False # for testing scenario

        if done:
            self.currentEpisode += 1
            self.episodeTime = 0
            if self.train == False:
                self.currentEpisode = 0

        info = {}

        return nextState, reward, done, info 

    def _newArrival(self): 
        ''' function for new load arrivals during an episode.'''
        job = np.random.choice([1,0], p=[self.newJobProb, 1 - self.newJobProb])  

        if job == 1:
            self.arm[0][2] = self.deadline[self.loadIndex]
            self.arm[0][0] = self.timeUntilDeadline[self.loadIndex]
            self.arm[0][1] = self.load[self.loadIndex]
            self.loadIndex += 1

        elif job == 0:
            self.arm[0][2] = 0
            self.arm[0][0] = 0
            self.arm[0][1] = 0
        else:
            print('ERROR. Value not in range...')
            exit(1)

        state = np.array([self.arm[0][0], self.arm[0][1]], dtype=np.float32)

        return state

    def reset(self):
        ''' standard Gym function for reseting the state for a new episode.'''
        self.loadIndex = 0
        
        # if 1, a job will arrive. otherwise, state is (T,B) = (0,0)
        if self.miniBatchCounter % self.batchSize == 0:

            # always starting the episode with an initial load
            self.deadline = np.random.randint(1, self.maxDeadline, size=self.episodeLimit) 
            self.timeUntilDeadline = self.deadline.copy()
            self.load = np.random.randint(1, self.maxLoad, size=self.episodeLimit) # at least a load of 1

            self.arm[0][2] = self.deadline[0]
            self.arm[0][0] = self.timeUntilDeadline[0]
            self.arm[0][1] = self.load[0]

            self.miniBatchCounter = 0
            
        else:
            self.arm[0][2] = self.deadline[0]
            self.arm[0][0] = self.timeUntilDeadline[0]
            self.arm[0][1] = self.load[0]

        initialState = np.array([self.arm[0][0], self.arm[0][1]], dtype=np.float32)
        self.loadIndex += 1    
        self.miniBatchCounter += 1

        return initialState

#########################################################################################
'''
For environment validation purposes, the below code checks if the nextstate, reward matches
what is expected given a dummy action.
'''

'''
SEED = 20

env = deadlineSchedulingEnv(seed = SEED, numEpisodes=10, episodeLimit=40, maxDeadline=12,
maxLoad=9, newJobProb=0.7, train=True, processingCost = 0.5, batchSize = 5, noiseVar=0.05)

observation = env.reset()

#check_env(env, warn=True)

x = np.array([0,1])
x = np.tile(x, 10000)
n_steps = np.size(x)

start = time.time()
for step in range(n_steps):
    nextState, reward, done, info = env.step(x[step])
    print(f'action: {x[step]} nextstate: {nextState}  reward: {reward} done: {done}')
    print("---------------------------------------------------------")
    if done:
        print(f'Finished episode {env.currentEpisode}/{env.numEpisodes}')
        if env.currentEpisode < env.numEpisodes:
            nextState = env.reset()
        if env.currentEpisode == env.numEpisodes:
            break
  

print(f'-------------------------------------\nDone. Time taken: {time.time() - start:.4f} seconds')
'''