from types import SimpleNamespace
import numpy as np
class SimpleMazeConcave():
    '''
    A grid world example with two objectives where primary objective should be 
    sacrificed both early and late in the episode; both not in between.
    
    Primary Objective: Reach goal in K steps without stepping on bad tiles.
    
    Secondary Objective: Avoid secondary bad tiles as much as possible.
    
    Primary Reward Function: R = bad_tile_punishment on bad tiles, goal_reward on goal,  
    
    Actions:
        0
     3__|__1
        |
        2    
    '''
    
    '''
    Combined Maze 
    ____________
    |__|G_|__|__| 4
    |__|aa|aa|aa| 3
    |__|__|__|__| 2
    |AA|AA|AA|__| 1
    |S_|__|__|__| 0
     0  1  2  3
        
    '''
    
    def __init__(self, intrinsic_reward=True, flip_rewards=False):
        
        self.intrinsic = intrinsic_reward
        self.position = None
        self.flip_rewards = flip_rewards
        if flip_rewards:
            spec = {"reward1_threshold": -3.05, "reward2_threshold": -6.1 if self.intrinsic else 0.95}            
        else:
            spec = {"reward1_threshold": -6.1 if intrinsic_reward else 0.95, "reward2_threshold": 0.95}            
        self.spec = SimpleNamespace(**spec)
    
    def reset(self):
        self.position = [0,0]
        return onehot_state_maker(self.position)
    
    
    def step(self,a):
        if a == 0:
            self.position[1] +=  0 if (self.position[1] == 4) else 1
        elif a == 1:
            self.position[0] +=  0 if (self.position[0] == 3) else 1
        elif a == 2:
            self.position[1] += 0  if (self.position[1] == 0) else -1
        elif a == 3:
            self.position[0] +=  0 if (self.position[0] == 0) else -1            
        else:
            raise Exception("Invalid action")
            
        done = self._terminal(self.position)
        reward = (0 if done else -1) if self.intrinsic else (1 if done else 0)
        secondary_reward = self._secondary_reward(self.position)
        
        if self.flip_rewards:
            return onehot_state_maker(self.position), secondary_reward, reward, done, {} 
        else:
            return onehot_state_maker(self.position), reward, secondary_reward, done, {} 
        
    def _secondary_reward(self,position):   
        if (position[0] < 3) and (position[1] == 1):
            return -5
        elif (position[0] > 0) and (position[1] == 3):
            return -4
        elif self._terminal(position):
            return 1
        else:
            return 0
        
    def _terminal(self,position):   
        if  position == [1,4]:
            return True
        else:
            return False      
        
def state_enumerator(state):
    return state[1]*4 + state[0]

# def state_maker(position):
#     return np.array([(position[0]-1)/2, (position[1]-5)/5])

def onehot_state_maker(position):
    onehot = np.zeros((4*5))
    onehot[state_enumerator(position)] = 1.
    return onehot
    
# def action_mapper(a):
#     if a == 0:
#         return [0,1]
#     elif a == 1:
#         return [1,0]
#     elif a == 2:
#         return [0,-1]
#     elif a == 3:
#         return [-1,0]
#     else:
#         raise Exception("Invalid action")    
        
def action_stringer(a):
    if a == [0,1]:
        return "U"
    elif a == [1,0]:
        return "R"
    elif a == [0,-1]:
        return "D"
    elif a == [-1,0]:
        return "L"
    else:
        raise Exception("Invalid action")           