import numpy as np
import random

class Environment(object):
    '''General RL environment'''

    def __init__(self):
        pass

    def reset(self):
        pass

    def advance(self, action):
        '''
        Moves one step in the environment.
        Args:
            action
        Returns:
            reward - double - reward
            newState - int - new state
            pContinue - 0/1 - flag for end of the episode
        '''
        return 0, 0, 0
    
    
def make_riverSwim(epLen=20, nState=5):
    '''
    Makes the benchmark RiverSwim MDP.
    Args:
        NULL - works for default implementation
    Returns:
        riverSwim - Tabular MDP environment '''
    nAction = 2
    riverSwim = TabularMDP(nState, nAction, epLen)

    # Rewards
    riverSwim.R[0, 0] = 5/1000
    riverSwim.R[nState - 1, 1] = 1

    # Transitions for left action
    for s in range(nState):
        riverSwim.P[s, 0, max(0, s-1)] = 1 # Go left

    # Transitions for right action
    for s in range(1, nState - 1):
        # Go right 1step    
        riverSwim.P[s, 1, s + 1] = 0.35
        riverSwim.P[s, 1, s] = 0.6 
        riverSwim.P[s, 1, s - 1] = 0.05 

    riverSwim.P[0, 1, 0] = 0.4
    riverSwim.P[0, 1, 1] = 0.6
    riverSwim.P[nState - 1, 1, nState - 2] = 0.4
    riverSwim.P[nState - 1, 1, nState - 1] = 0.6

    riverSwim.reset()

    return riverSwim

class TabularMDP(Environment):
    '''
    Tabular MDP
    R - dict by (s,a) - each R[s,a] = (meanReward, sdReward)
    P - dict by (s,a) - each P[s,a] = transition vector size S
    '''

    def __init__(self, nState, nAction, epLen):
        '''
        Initialize a tabular episodic MDP
        Args:
            nState  - int - number of states
            nAction - int - number of actions
            epLen   - int - episode length
        Returns:
            Environment object
        '''

        self.nState = nState
        self.nAction = nAction
        self.epLen = epLen

        self.timestep = 0
        self.state = 0

        # Now initialize R and P
        self.R = np.zeros((self.nState, self.nAction))
        self.P = np.zeros((self.nState, self.nAction, self.nState))
                
    def reset(self):
        "Resets the Environment"
        self.timestep = 0
        self.state = 0
        
    def advance(self, action):
        '''
        Move one step in the environment
        Args:
        action - int - chosen action
        Returns:
        reward - double - reward
        newState - int - new state
        episodeEnd - 0/1 - flag for end of the episode
        '''
        reward = self.R[self.state, action]
        newState = random.choices(list(range(self.nState)), weights = self.P[self.state, action])[0]
        
        # Update the environment
        self.state = newState
        self.timestep += 1

        episodeEnd = 0
        if self.timestep == self.epLen:
            episodeEnd = 1
            self.reset()

        return reward, newState, episodeEnd

    def calculateOptimalValue(self):
        V = np.zeros((self.epLen + 1, self.nState))
        for h in range(self.epLen-1, -1, -1):
            V[h] = np.max(self.R + self.P @ V[h+1], axis = -1)
        return V
    
    def evaluate(self, policy):
        V = np.zeros((self.epLen + 1, self.nState))
        for h in range(self.epLen-1, -1, -1):
            Q = self.R + self.P @ V[h+1]
            V[h] = np.take_along_axis(Q, policy[h].reshape(-1, 1), axis = 1).squeeze(1)
        return V
