import random

import numpy as np

from d4rl_alt.pointmaze import q_iteration
from d4rl_alt.pointmaze.gridcraft import grid_env, grid_spec

MAZE = (
    "###################\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOOOOOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "####O#########O####\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "#OOOOOOOOOOOOOOOOO#\\"
    + "#OOOOOOOO#OOOOOOOO#\\"
    + "###################\\"
)


# NLUDR -> RDLU
TRANSLATE_DIRECTION = {
    0: None,
    1: 3,  # 3,
    2: 1,  # 1,
    3: 2,  # 2,
    4: 0,  # 0,
}

RIGHT = 1
LEFT = 0
FORWARD = 2


class FourRoomController(object):
    def __init__(self):
        self.env = grid_env.GridEnv(grid_spec.spec_from_string(MAZE))
        self.reset_locations = list(zip(*np.where(self.env.gs.spec == grid_spec.EMPTY)))

    def sample_target(self):
        return random.choice(self.reset_locations)

    def set_target(self, target):
        self.target = target
        self.env.gs[target] = grid_spec.REWARD
        self.q_values = q_iteration.q_iteration(
            env=self.env, num_itrs=32, discount=0.99
        )
        self.env.gs[target] = grid_spec.EMPTY

    def get_action(self, pos, orientation):
        if tuple(pos) == tuple(self.target):
            done = True
        else:
            done = False
        env_pos_idx = self.env.gs.xy_to_idx(pos)
        qvalues = self.q_values[env_pos_idx]
        direction = TRANSLATE_DIRECTION[np.argmax(qvalues)]
        # tgt_pos, _ = self.env.step_stateless(env_pos_idx, np.argmax(qvalues))
        # tgt_pos = self.env.gs.idx_to_xy(tgt_pos)
        # print('\tcmd_dir:', direction, np.argmax(qvalues), qvalues, tgt_pos)
        # infos = {}
        # infos['tgt_pos'] = tgt_pos
        if orientation == direction or direction == None:
            return FORWARD, done
        else:
            return get_turn(orientation, direction), done


# RDLU
TURN_DIRS = [
    [None, RIGHT, RIGHT, LEFT],  # R
    [LEFT, None, RIGHT, RIGHT],  # D
    [RIGHT, LEFT, None, RIGHT],  # L
    [RIGHT, RIGHT, LEFT, None],  # U
]


def get_turn(ori, tgt_ori):
    return TURN_DIRS[ori][tgt_ori]
