'''
Environment to calculate the Whittle index values as a deep reinforcement 
learning environment modelled after the OpenAi Gym API.
Same mini-batch episodes have the same trajectory values for comparing their returns.
'''

import gym
import math
import time
import random
import datetime 
import numpy as np
import pandas as pd
from gym import spaces
from numpy.random import RandomState
#from stable_baselines.common.env_checker import check_env #this test throws errors and needs tensorflow 1.x with python 3.6. it's normal

class sizeAwareIndexEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, seed, numEpisodes, HOLDINGCOST, Training,r1, r2, q, 
        case, classVal, batchSize,load, maxLoad, episodeLimit, fixedSizeMDP, noiseVar):

        super(sizeAwareIndexEnv, self).__init__()
        self.seed = seed
        random.seed(seed)
        np.random.seed(seed)
        
        assert(case in [1,2]) # case 1 is homogenous users. Case 2 is heterogenous users.
        assert(classVal in [1,2])
        self.time = 0
        self.numEpisodes = numEpisodes
        self.episodeTime = 0      
        self.currentEpisode = 0  
        self.holdingCost = float(HOLDINGCOST)
        self.case = case
        self.classVal = classVal
        self.noiseVar = noiseVar

        self.goodTransVal = r2
        self.badTransVal = r1
        self.goodProb = q
        self.arm = {0:[1, 1]} # initialize arm to max state of (1,1)
        self.maxLoad = maxLoad

        self.train = Training
        self.load = load
        self.batchSize = batchSize
        self.miniBatchCounter = 0
        self.episodeLimit = episodeLimit
        self.fixedSizeMDP = fixedSizeMDP

        loadVals = []
        if Training: # Each batch has the same initial load 
            for x in range(int(np.ceil(self.numEpisodes/self.batchSize))):
                loadVal = (np.ceil(np.random.randint(1, self.load))).astype(np.float32)
                for i in range(batchSize):
                    loadVals.append(loadVal)
            loadVals = np.array(loadVals)
        else:
            loadVals = np.tile(load, numEpisodes)
        self.initialLoad = loadVals
        self.currentLoad = loadVals

        self.observationSize = 2

        lowState = np.zeros(self.observationSize, dtype=np.float32)
        highState = np.full(self.observationSize, 1.0, dtype=np.float32)

        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Box(lowState, highState, dtype=np.float32)
        #self.noiseVector = np.random.normal(0, self.noiseVar, np.shape(self.stateArray)[0])
        #print(f'This is the noise: {self.noiseVector}')
    '''
    def _makeStateTable(self):
        self.stateArray = []

        for x in range(self.maxLoad+1):
            self.stateArray.append([x/self.maxLoad,0])
            self.stateArray.append([x/self.maxLoad,1])

        self.stateArray = np.array(self.stateArray, dtype=np.float32)
    '''
    def _findStateIndex(self, state):
        
        stateLocation = np.where((self.stateArray == state).all(axis=1))[0][0]
        #print(f'state is: {state}. Location is: {stateLocation}. noise value is: {self.noiseVector[stateLocation]}')
        return stateLocation

    def _calReward(self, action):
        ''' function to calculate next state, and reward given the action'''

        #print(f'state is: {self.arm[0][0]/self.maxLoad}')
        #print(f'state is: {int(self.arm[0][0])}. noise value is: {noise}')
        if action == 1:
            prng = RandomState(int(self.arm[0][0])) 
            noise = prng.normal(0, self.noiseVar)
            if self.train:
                reward = self.holdingCost
            else:
                reward = self.holdingCost

            if self.channelState[self.episodeTime] == 1:
                self.arm[0][0] -= self.goodTransVal
                self.arm[0][1] = 1.0
            else:
                self.arm[0][0] -= self.badTransVal
                self.arm[0][1] = 0.0

            if self.arm[0][0] <= 0:
                self.arm[0][0] = 0
   
        elif action == 0:
            prng = RandomState(int(self.arm[0][0] + 1)) 
            noise = prng.normal(0, self.noiseVar)
            if self.train:
                reward = self.holdingCost
            else:
                reward = self.holdingCost

            if self.channelState[self.episodeTime] == 1:
                self.arm[0][1] = 1.0
            else:
                self.arm[0][1] = 0.0

        nextState = np.array([self.arm[0][0], self.arm[0][1]], dtype=np.float32)
        #print(f'noise is: {noise}')

        reward = reward + noise*reward
        return nextState, -1*reward

    def _normalizeState(self, state):
        ''' Function for normalizing the remaining load against the max load value'''
        state[0] = state[0] / self.maxLoad
        return state
    
    def step(self, action):
        ''' Standard Gym function for taking an action. Supplies nextstate, reward, and episode termination signal.'''
        assert self.action_space.contains(action)
        assert action in [0,1]
        self.time += 1
        self.episodeTime += 1
        
        nextState, reward = self._calReward(action)
        
        if self.train:
            done = bool((nextState[0] == 0)) or (self.episodeTime == self.episodeLimit)
            if self.fixedSizeMDP:  # fixed number of arms setting ( wrapping this env with other arms.)
                done = False
                if nextState[0] == 0:
                    reward = 0
        else:
            done = bool((nextState[0] == 0))
        
        nextState = self._normalizeState(nextState)

        if done:
            self.currentEpisode += 1
            self.episodeTime = 0
            if self.train == False:
                self.currentEpisode = 0

        info = {}
        return nextState, reward, done, info # reward here is a negative cost

    def reset(self):
        ''' Standard Gym function for supplying initial episode state.'''

        if self.miniBatchCounter % self.batchSize == 0:
            self.channelState = np.random.choice([1,0], self.episodeLimit+1, p=[self.goodProb, 1 - self.goodProb])

            self.arm[0][0] = self.initialLoad[self.currentEpisode]
            if self.channelState[self.episodeTime] == 1:
                self.arm[0][1] = 1.0
            else:
                self.arm[0][1] = 0.0      

            self.arm[0][0] = self.initialLoad[self.currentEpisode]

            initialState = np.array([self.arm[0][0], self.arm[0][1]], dtype=np.float32)
            initialState = self._normalizeState(initialState)
            self.miniBatchCounter = 0
        else:  
            self.arm[0][0] = self.initialLoad[self.currentEpisode]
            self.arm[0][1] = self.channelState[self.episodeTime]
            initialState = np.array([self.arm[0][0], self.arm[0][1]], dtype=np.float32)
            initialState = self._normalizeState(initialState)

        self.miniBatchCounter += 1
        #print(f'initial state: {initialState}')

        return initialState

############################################################################
'''
For environment validation purposes, the below code checks if the nextstate, reward matches
what is expected given a dummy action.
'''
'''
SEED = 20
np.random.seed(SEED)
random.seed(SEED)
numEpisodes = 10
HOLDINGCOST = 1

BADTRANS = 1
GOODTRANS = 4
GOODPROB = 0.5
CASE = 1
CLASSVAL = 1
noiseVar = 0.05
LOAD = 100
BATCHSIZE = 5

EPISODELIMIT = 200
FIXEDSIZEMDP = False

env = sizeAwareIndexEnv(numEpisodes=numEpisodes, HOLDINGCOST=HOLDINGCOST, seed=SEED, maxLoad = LOAD,
Training=True, r1 =BADTRANS, r2 = GOODTRANS, q=GOODPROB, noiseVar = noiseVar, 
case=CASE, classVal=CLASSVAL, load=LOAD, batchSize = BATCHSIZE, episodeLimit=EPISODELIMIT, fixedSizeMDP=FIXEDSIZEMDP)
observation = env.reset()

#check_env(env, warn=True)

#print(f'the loads are: {env.initialLoad}')
x = np.array([0,1,1]) # dummy actions
x = np.tile(x, 100000)
n_steps = np.size(x)

start = time.time()
for step in range(n_steps):
    nextState, reward, done, info = env.step(x[step])
    print(f'action: {x[step]} next state: {nextState}  reward: {reward} done: {done}')
    if done:
        print(f'Finished episode {env.currentEpisode}/{env.numEpisodes}')
        if env.currentEpisode < env.numEpisodes:
            nextState = env.reset()
        if env.currentEpisode == env.numEpisodes:
            break

print('------------------------------------------')
print(f'DONE. Time taken: {time.time() - start:.4f} seconds')
#print(f'the initial loads are: {env.initialLoad}')
#print(env.noiseVector)
#print(env.stateArray)
'''