import sys
import numpy as np
sys.path.append('./agents')
sys.path.append('./..')
sys.path.append('.')
from base import HL

# per room optimal policy
#dist_to_exit = np.array([[6, 5, 4, 3],
#                         [5, 4, 3, 2],
#                         [4, 3, 2, 1],
#                         [5, 4, 3, 2],
#                         [6, 5, 4, 3]])


class QlearningHL(HL):

    def __init__(self, env, num_rooms, alpha=0.3,
                 gamma=1, eps=0.5, mode='qlearning', bonus=0.0,
                 action_mask=True):
        self.env = env
        self.n_actions = 4
        self.doors = {"N": 0, "S": 1, "W": 2, "E": 3}
        # the high-level state space also contain the doors through
        # which the agent enters the room
        self.Q = np.ones((num_rooms, len(self.doors), self.n_actions)) * 0

        self.alpha = alpha
        self.eps = eps
        self.gamma = gamma
        self.mode = mode
        self.bonus = bonus
        self.action_mask = action_mask

    def select_subgoal(self, context, greedy=False):
        room, door = context
        door = self.doors[door]
        door = 0
        if self.action_mask:
            actions = self.env.get_available_action(room)
        else:
            actions = np.arange(self.n_actions)
        if np.random.uniform(0, 1) < self.eps and not greedy:
            sg = np.random.choice(actions)
        else:
            tmp = np.argmax(self.Q[room, door, actions])
            sg = actions[tmp]
        return sg

    def update(self, transition):
        if self.mode == 'qlearning':
            self._qlearning_update(transition)

        elif self.mode == 'stationary':
            _, _, _, _, _, low_done = transition
            if low_done:
                self._qlearning_update(transition)

        elif self.mode == 'optimistic':
            # TODO: add bonus proportional to diff with opt. value fct
            # not sure what the opt. value function should look like...
            x, sg, r, nxt_x, done, sub_done = transition
            if not sub_done:
                r += self.bonus
            transition = (x, sg, r, nxt_x, done, sub_done)
            self._qlearning_update(transition)

        else:
            print(f"Unknow algorithm: {self.mode}")
            exit()

    def _qlearning_update(self, transition):
        x, sg, r, nxt_x, done, _ = transition
        q_sg = self.Q[x[0], self.doors[x[1]], sg]
        if self.action_mask:
            mask = self.env.get_available_action(nxt_x[0])
            max_nxt_q = np.max(self.Q[nxt_x[0], self.doors[nxt_x[1]], mask])
        else:
            max_nxt_q = np.max(self.Q[nxt_x[0], self.doors[nxt_x[1]], :])
        new_q_sg = q_sg + self.alpha * (r + (1 - int(done)) * self.gamma * max_nxt_q - q_sg)
        self.Q[x[0], self.doors[x[1]], sg] = new_q_sg
        # self.print_policy()
        # import time
        # time.sleep(1)

    def get_V(self):
        pass

    def print_policy(self):
        print(self.Q[:, 0, :])

