import numpy as np

move = [(-1,0),(1,0),(0,-1),(0,1)]

def get_road_blocks(w, h, difficulty):

    # assuming 1 is the lane width for each direction.
    road_blocks = {
        'easy':   [ np.s_[h//2, :],
                    np.s_[:, w//2]],

        'medium': [ np.s_[h//2 - 1 : h//2 + 1, :],
                    np.s_[:, w//2 - 1 : w//2 + 1]],

        'hard':   [ np.s_[h//3-2: h//3, :],
                    np.s_[2* h//3: 2* h//3 + 2, :],

                    np.s_[:, w//3-2: w//3],
                    np.s_[:, 2* h//3: 2* h//3 + 2]],
    }

    return road_blocks[difficulty]

def goal_reached(place_i, curr, finish_points):
    return curr in finish_points[:place_i] + finish_points[place_i+1:]


def get_add_mat(dims, grid, difficulty):
    h,w = dims

    road_dir = grid.copy()
    junction = np.zeros_like(grid)

    if difficulty == 'medium':
        arrival_points = [  (0, w//2-1),    # TOP
                            (h-1,w//2),     # BOTTOM
                            (h//2, 0),      # LEFT
                            (h//2-1,w-1)]   # RIGHT

        finish_points =  [  (0, w//2),      # TOP
                            (h-1,w//2-1),   # BOTTOM
                            (h//2-1, 0),    # LEFT
                            (h//2,w-1)]     # RIGHT

        # mark road direction
        road_dir[h//2, :] = 2
        road_dir[h//2 - 1, :] = 3
        road_dir[:, w//2 ] = 4

        # mark the Junction
        junction[h//2-1:h//2+1,w//2-1:w//2+1 ] =1

    elif difficulty =='hard':
        arrival_points = [  (0, w//3-2),    # TOP-left
                            (0,2*w//3),     # TOP-right

                            (h//3-1, 0),    # LEFT-top
                            (2*h//3+1,0),   # LEFT-bottom

                            (h-1,w//3-1),   # BOTTOM-left
                            (h-1,2*w//3+1), # BOTTOM-right

                            (h//3-2, w-1),  # RIGHT-top
                            (2*h//3,w-1)]   # RIGHT-bottom


        finish_points = [  (0, w//3-1),     # TOP-left
                            (0,2*w//3+1),   # TOP-right

                            (h//3-2, 0),    # LEFT-top
                            (2*h//3,0),     # LEFT-bottom

                            (h-1,w//3-2),   # BOTTOM-left
                            (h-1,2*w//3),   # BOTTOM-right

                            (h//3-1, w-1),  # RIGHT-top
                            (2*h//3+1,w-1)] # RIGHT-bottom

        # mark road direction
        road_dir[h//3-1, :] = 2
        road_dir[2*h//3, :] = 3
        road_dir[2*h//3 + 1, :] = 4

        road_dir[:, w//3-2 ] = 5
        road_dir[:, w//3-1 ] = 6
        road_dir[:, 2*w//3 ] = 7
        road_dir[:, 2*w//3 +1] = 8

        # mark the Junctions
        junction[h//3-2:h//3, w//3-2:w//3 ] = 1
        junction[2*h//3:2*h//3+2, w//3-2:w//3 ] = 1

        junction[h//3-2:h//3, 2*w//3:2*w//3+2 ] = 1
        junction[2*h//3:2*h//3+2, 2*w//3:2*w//3+2 ] = 1

    return arrival_points, finish_points, road_dir, junction


def next_move(curr, turn, turn_step, start, grid, road_dir, junction, visited):
    h,w = grid.shape
    turn_completed = False
    turn_prog = False
    neigh =[]
    for m in move:
        # check lane while taking left turn
        n = (curr[0] + m[0], curr[1] + m[1])
        if 0 <= n[0] <= h-1 and 0 <= n[1] <= w-1 and grid[n] and n not in visited:
            # On Junction, use turns
            if junction[n] == junction[curr] == 1:
                if (turn == 0 or turn == 2) and ((n[0] == start[0]) or (n[1] == start[1])):
                    # Straight on junction for either left or straight
                    neigh.append(n)
                    if turn == 2:
                        turn_prog = True

                # left from junction
                elif turn == 2 and turn_step ==1:
                    neigh.append(n)
                    turn_prog = True

                else:
                    # End of path
                    pass

            # Completing left turn on junction
            elif junction[curr] and not junction[n] and turn ==2 and turn_step==2 \
                and (abs(start[0] - n[0]) ==2 or abs(start[1] - n[1]) ==2):
                neigh.append(n)
                turn_completed =True

            # junction seen, get onto it;
            elif (junction[n] and not junction[curr]):
                neigh.append(n)

            # right from junction
            elif turn == 1 and not junction[n] and junction[curr]:
                neigh.append(n)
                turn_completed =True

            # Straight from jucntion
            elif turn == 0 and junction[curr] and road_dir[n] == road_dir[start]:
                neigh.append(n)
                turn_completed = True

            # keep going no decision to make;
            elif road_dir[n] == road_dir[curr] and not junction[curr]:
                neigh.append(n)

    if neigh:
        return neigh[0], turn_prog, turn_completed
    if len(neigh) != 1:
        raise RuntimeError("next move should be of len 1. Reached ambiguous situation.")



def get_routes(dims, grid, difficulty):
    '''
    returns
        - routes: type list of list
        list for each arrival point of list of routes from that arrival point.
    '''
    grid.dtype = int
    h,w = dims

    assert difficulty == 'medium' or difficulty == 'hard'

    arrival_points, finish_points, road_dir, junction = get_add_mat(dims, grid, difficulty)

    n_turn1 = 3 # 0 - straight, 1-right, 2-left
    n_turn2 = 1 if difficulty == 'medium' else 3


    routes=[]
    # routes for each arrival point
    for i in range(len(arrival_points)):
        paths = []
        # turn 1
        for turn_1 in range(n_turn1):
            # turn 2
            for turn_2 in range(n_turn2):
                total_turns = 0
                curr_turn = turn_1
                path = []
                visited = set()
                current = arrival_points[i]
                path.append(current)
                start = current
                turn_step = 0
                # "start"
                while not goal_reached(i, current, finish_points):
                    visited.add(current)
                    current, turn_prog, turn_completed = next_move(current, curr_turn, turn_step, start, grid, road_dir, junction, visited)
                    if curr_turn == 2 and turn_prog:
                        turn_step +=1
                    if turn_completed:
                        total_turns += 1
                        curr_turn = turn_2
                        turn_step = 0
                        start = current
                    # keep going straight till the exit if 2 turns made already.
                    if total_turns == 2:
                        curr_turn = 0
                    path.append(current)
                paths.append(path)
                # early stopping, if first turn leads to exit
                if total_turns == 1:
                    break
        routes.append(paths)
    return routes