from types import SimpleNamespace
import numpy as np
class SimpleMaze():
    '''
    A grid world example with two objectives where primary objective should be 
    sacrificed both early and late in the episode; both not in between.
    
    Primary Objective: Reach goal in K steps without stepping on bad tiles.
    
    Secondary Objective: Avoid secondary bad tiles as much as possible.
    
    Primary Reward Function: R = bad_tile_punishment on bad tiles, goal_reward on goal,  
    
    Actions:
        0
     3__|__1
        |
        2    
    '''
    
    '''
    Combined Maze 
    __________
    |__|G_|__| 4
    |AA|AA|__| 3
    |__|__|__| 2
    |__|aa|aa| 1
    |__|S_|__| 0
     0  1  2
        
    '''
    
    def __init__(self, intrinsic_reward=True, flip_rewards=False):
        
        self.intrinsic = intrinsic_reward
        self.position = None
        self.flip_rewards = flip_rewards
#         if flip_rewards:
#             spec = {"reward2_threshold": -14.0 if self.intrinsic else 0.9, "reward1_threshold": 10.0}
#         else:
#             spec = {"reward1_threshold": -14.0 if self.intrinsic else 0.9, "reward2_threshold": -2.05}
        spec = {"reward1_threshold": -1.05, "reward2_threshold": -5.1}
        self.spec = SimpleNamespace(**spec)
    
    def reset(self):
        self.position = [1,0]
        return onehot_state_maker(self.position)
    
    
    def step(self,a):
        if a == 0:
            self.position[1] +=  0 if (self.position[1] == 4) else 1
        elif a == 1:
            self.position[0] +=  0 if (self.position[0] == 2) else 1
        elif a == 2:
            self.position[1] += 0  if (self.position[1] == 0) else -1
        elif a == 3:
            self.position[0] +=  0 if (self.position[0] == 0) else -1            
        else:
            raise Exception("Invalid action")
            
        done = self._terminal(self.position)
        reward = (0 if done else -1) if self.intrinsic else (1 if done else 0)
        secondary_reward = self._secondary_reward(self.position)
        
#         return np.array(self.position), reward, done, {}        
#         reward_tuple = (secondary_reward, reward) if self.flip_rewards else (reward, secondary_reward)
    
        return onehot_state_maker(self.position), secondary_reward, reward, done, {} 
        
    def _secondary_reward(self,position):   
        if (position[0] < 2) and (position[1] in [3]):
            return -20
        elif (position[0] > 0) and (position[1] in [1]):
            return -2 
        elif self._terminal(position):
            return 1
        else:
            return 0
        
    def _terminal(self,position):   
        if  position == [1,4]:
            return True
        else:
            return False      
        
def state_enumerator(state):
    return state[1]*3 + state[0]

# def state_maker(position):
#     return np.array([(position[0]-1)/2, (position[1]-5)/5])

def onehot_state_maker(position):
    onehot = np.zeros((3*5))
    onehot[state_enumerator(position)] = 1.
    return onehot
    
# def action_mapper(a):
#     if a == 0:
#         return [0,1]
#     elif a == 1:
#         return [1,0]
#     elif a == 2:
#         return [0,-1]
#     elif a == 3:
#         return [-1,0]
#     else:
#         raise Exception("Invalid action")    
        
def action_stringer(a):
    if a == [0,1]:
        return "U"
    elif a == [1,0]:
        return "R"
    elif a == [0,-1]:
        return "D"
    elif a == [-1,0]:
        return "L"
    else:
        raise Exception("Invalid action")           