#!/usr/bin/python

import numpy as np
import IPython

from IPython.core.debugger import set_trace

NORTH = np.array([0, 1])
SOUTH = np.array([0, -1])
WEST = np.array([-1, 0])
EAST = np.array([1, 0])
STAY = np.array([0, 0])

TRANSLATION_TABLE = [
    # [left, intended_direction, right]
    [WEST,  NORTH, EAST],
    [EAST,  SOUTH, WEST],
    [SOUTH, WEST,  NORTH],
    [NORTH, EAST,  SOUTH],
    [STAY,  STAY,  STAY]
]

DIRECTION = np.array([[0.0, 1.0],
                      [0.0, -1.0],
                      [-1.0, 0.0],
                      [1.0, 0.0],
                      [0.0, 0.0]])

class Agent(object):

    """agent movement path is generated by using astar alg"""
    
    def __init__(self, idx, grid_dim, agent_trans_noise=0.1):
        self.idx = idx
        self.grid_dim = grid_dim
        self.x_len, self.y_len = self.grid_dim

        self.position = self.rand_position(*self.grid_dim)
        self.agt_trans_noise = agent_trans_noise

        self.cur_action = None
        self.cur_action_time_left = 0.0
        self.cur_action_done = True

    def step(self, action, goal):
        raise NotImplementedError
    
    def astar_move(self, goal):
        moves = self.wrap_positions(DIRECTION + self.position)
        h = np.linalg.norm(goal-moves, axis=1)
        dest_idx = np.random.choice(np.where(h == h.min())[0], size=1)[0]
        trans = TRANSLATION_TABLE[dest_idx][np.random.choice(3, p=[self.agt_trans_noise/2, 1-self.agt_trans_noise, self.agt_trans_noise/2])]
        self.position = (self.position+trans) % self.x_len

        dist = np.linalg.norm(goal - self.position)
        if dist < 0.1:
            self.cur_action_done = True
            self.cur_action_time_left = 0.0
    
    ############################################################################
    # helper functions

    def _get_position_from_one_hot(self, goal):
        index = goal.nonzero()[0]
        X = index % self.x_len
        Y = index // self.x_len
        return np.concatenate([X,Y])

    def _get_position_from_normalized(self, goal):
        if all(goal[2:] == -1):
            return goal[2:]
        else:
            return goal[2:] * self.x_len

    @staticmethod
    def rand_position(x_range, y_range):
        return np.array([np.random.randint(x_range), np.random.randint(y_range)])

    def wrap_positions(self, positions):
        X, Y = np.split(positions,2,axis=1)
        return np.concatenate([X%self.x_len, Y%self.y_len], axis=1)


class Agent_v0(Agent):

    """Move_To_Target macro-action is terminated by either reaching the goal or
       not seeing target. The low level controller automatically set the latest 
       obvserved tagrget's position as the goal."""

    def __init__(self, idx, grid_dim, agent_trans_noise=0.1):
        super(Agent_v0, self).__init__(idx, grid_dim, agent_trans_noise=agent_trans_noise)
        
    def step(self, action, goal):

        if self.cur_action_done:
            self.cur_action = action
        else:
            action = self.cur_action
        self.cur_action_done = False
        self.cur_action_time_left = -1.0

        if action == 1:
            self.cur_action_done = True
            self.cur_action_time_left = 0.0
        else:
            if len(goal) > len(self.grid_dim) * 2:
                goal = self._get_position_from_one_hot(goal[self.x_len*self.y_len:])
            else:
                goal = self._get_position_from_normalized(goal)

            self.astar_move(goal)
   
class Agent_v1(Agent):

    """Move_To_Target macro-action is terminated by either reaching the goal. 
       The low level controller automatically set the latest 
       obvserved tagrget's position as the goal. If the target is flicked, the 
       previous target's location is continuely implemented."""

    def __init__(self, idx, grid_dim, agent_trans_noise=0.1):
        super(Agent_v1, self).__init__(idx, grid_dim, agent_trans_noise=agent_trans_noise)
        self.pre_goal = np.array([-1,-1])

    def step(self, action, goal):

        if self.cur_action_done:
            self.cur_action = action
        else:
            action = self.cur_action
        self.cur_action_done = False
        self.cur_action_time_left = -1.0

        if action == 1:
            self.cur_action_done = True
            self.cur_action_time_left = 0.0
        else:
            if len(goal) > len(self.grid_dim) * 2:
                goal = self._get_position_from_one_hot(goal[self.x_len*self.y_len:])
            else:
                goal = self._get_position_from_normalized(goal)

            # target is flicked, then move towards the target position in previous obs
            if all(goal==-1):
                if all(self.pre_goal==-1):
                    self.cur_action_done = True
                    self.cur_action_time_left = 0.0
                else:
                    self.astar_move(self.pre_goal)
            else:
                self.astar_move(goal)
                self.pre_goal = goal

class Agent_v2(Agent):

    """Move_To_Target macro-action is terminated by either reaching the goal. 
       The low level controller does not automatically set the latest 
       obvserved tagrget's position as the goal."""

    def __init__(self, idx, grid_dim, agent_trans_noise=0.1):
        super(Agent_v2, self).__init__(idx, grid_dim, agent_trans_noise=agent_trans_noise)
        
    def step(self, action, goal):

        if self.cur_action_done:
            self.cur_action = action
        else:
            action = self.cur_action
        self.cur_action_done = False
        self.cur_action_time_left = -1.0

        if action == 1:
            self.cur_action_done = True
            self.cur_action_time_left = 0.0
        else:
            if len(goal) > len(self.grid_dim) * 2:
                goal = self._get_position_from_one_hot(goal[self.x_len*self.y_len:])
            else:
                goal = self._get_position_from_normalized(goal)

            # target is flicked, then move towards the target position in previous obs
            if all(goal==-1):
                self.cur_action_done = True
                self.cur_action_time_left = 0.0
            else:
                self.astar_move(goal)
