"""
Based on 
https://github.com/aa14k/Exploration-in-RL.git
"""

import numpy as np

class Environment(object):
    '''General RL environment'''

    def __init__(self):
        pass

    def reset(self):
        pass

    def advance(self, action):
        '''
        Moves one step in the environment.
        Args:
            action
        Returns:
            reward - double - reward
            newState - int - new state
            pContinue - 0/1 - flag for end of the episode
        '''
        return 0, 0, 0
    
    
def block_make_riverSwim(epLen=20, nState=5):
    '''
    Makes the Block-RiverSwim MDP.
    Args:
        NULL - works for default implementation
    Returns:
        riverSwim - Tabular MDP environment '''
    nAction = 2
    R_true = {}
    P_true = {}
    states = {}
    eq_states = {}
    for s in range(nState):
        states[(s)] = 0.0
        ## for Equivalence Mapping ##
        if s == 0 or s == nState-1:
            eq_states[(s)] = s
        else:
            if s%2 == 1:
                eq_states[(s)] = int(1)
            else:
                eq_states[(s)] = int(2)
        #############################
        for a in range(nAction):
            R_true[s, a] = (0, 0)
            P_true[s, a] = np.zeros(nState)

    # Rewards
    R_true[0, 0] = (5/1000, 0)
    R_true[nState - 1, 1] = (1, 0)

    # To assert the MDP with sub-structures
    assert nState%3 == 2 and nState>=5, "nState should be 3n+2 and >= 5."

    # Transitions
    for s in range(nState):
        P_true[s, 0][max(0, s-1)] = 1.

    for s in range(1, nState - 1):
        if s%3 == 1: # Internal State, Exit State of previous subMDP
            P_true[s, 1][min(nState - 1, s + 1)] = 0.35
            P_true[s, 1][min(nState - 1, s + 2)] = 0.6  # To exit state
            P_true[s, 1][max(0, s-1)] = 0.05
        elif s%3 == 2: # Internal State
            P_true[s, 1][min(nState - 1, s + 1)] = 0.35
            P_true[s, 1][s] = 0.6
            P_true[s, 1][max(0, s-1)] = 0.05
        else: # Exit State
            P_true[s, 1][min(nState - 1, s + 1)] = 0.7 # Enter the next subMDP
            P_true[s, 1][s] = 0.3

    P_true[0, 1][0] = 0.4
    P_true[0, 1][1] = 0.6
    P_true[nState - 1, 1][nState - 1] = 0.6
    P_true[nState - 1, 1][nState - 2] = 0.4

    riverSwim = TabularMDP(nState, nAction, epLen)
    riverSwim.R = R_true
    riverSwim.P = P_true
    riverSwim.states = states
    riverSwim.eq_states = eq_states
    riverSwim.reset()

    return riverSwim

class TabularMDP(Environment):
    '''
    Tabular MDP
    R - dict by (s,a) - each R[s,a] = (meanReward, sdReward)
    P - dict by (s,a) - each P[s,a] = transition vector size S
    '''

    def __init__(self, nState, nAction, epLen):
        '''
        Initialize a tabular episodic MDP
        Args:
            nState  - int - number of states
            nAction - int - number of actions
            epLen   - int - episode length
        Returns:
            Environment object
        '''

        self.nState = nState
        self.nAction = nAction
        self.epLen = epLen
        self.number_of_eq_states = 3

        self.timestep = 0
        self.state = 0

        # Now initialize R and P
        self.R = {}
        self.P = {}
        self.states = {}
        for state in range(nState):
            for action in range(nAction):
                self.R[state, action] = (1, 1)
                self.P[state, action] = np.ones(nState) / nState
                
    def reset(self):
        "Resets the Environment"
        self.timestep = 0
        self.state = 0
        
    def advance(self,action):
        '''
        Move one step in the environment
        Args:
        action - int - chosen action
        Returns:
        reward - double - reward
        newState - int - new state
        episodeEnd - 0/1 - flag for end of the episode
        '''
        if self.R[self.state, action][1] < 1e-9:
            # Hack for no noise
            reward = self.R[self.state, action][0]
        else:
            reward = np.random.normal(loc=self.R[self.state, action][0],
                                      scale=self.R[self.state, action][1])
        #print(self.state, action, self.P[self.state, action])
        newState = np.random.choice(self.nState, p=self.P[self.state, action])
        
        # Update the environment
        self.state = newState
        self.timestep += 1

        episodeEnd = 0
        if self.timestep == self.epLen:
            episodeEnd = 1
            #newState = None
            self.reset()

        return reward, newState, episodeEnd
    
    def argmax(self,b):
        #print(b)
        return np.random.choice(np.where(b == b.max())[0])