# Class for tabular Q-learning with arbitrary reward function

import numpy as np

class QLearningAgent:
    def __init__(self, eval_env, reward_fn, gamma=0.99, alpha=0.1, n_iters=1000):
        self.eval_env = eval_env
        self.reward_fn = reward_fn  # Should take obs or (y, x) and return reward
        self.gamma = gamma
        self.alpha = alpha
        self.n_iters = n_iters
        self.grid = eval_env._layout
        self.n = self.grid.shape[0]
        self.m = self.grid.shape[1]
        self.num_actions = 5  # up, right, down, left, stay
        self.Q = np.zeros((self.n, self.m, self.num_actions))

    def get_next_state(self, state, action):
        return self.eval_env.get_single_transition_state(state, action)

    def solve(self):
        for it in range(self.n_iters):
            for y in range(self.n):
                for x in range(self.m):
                    if self.grid[y, x] == -1:
                        continue
                    for a in range(self.num_actions):
                        next_state = self.get_next_state((y, x), a)
                        reward = self.reward_fn(next_state)
                        max_next_q = np.max(self.Q[next_state[0], next_state[1]])
                        self.Q[y, x, a] = reward + self.gamma * max_next_q

        # Extract all possible optimal actions from each state
        self.actions = {}
        for y in range(self.n):
            for x in range(self.m):
                if self.grid[y, x] == -1:
                    continue
                max_q = np.max(self.Q[y, x])
                # state_index = y * self.m + x
                # self.actions[state_index] = [a for a in range(self.num_actions) if self.Q[y, x, a] == max_q]
                self.actions[(y, x)] = [a for a in range(self.num_actions) if self.Q[y, x, a] == max_q]

    def get_policy(self, state):
        return self.actions.get(state, [])

    def plot_q_function(self, work_dir, task_str="", step=0):
        state_list = self.eval_env.get_state_list()
        # v_list = np.max(self.Q, axis=2).flatten()
        v_list = {}
        # a_list = [self.get_policy(state) for state in state_list]
        a_list = {}
        # print(a_list)
        # create a_list with a_list[(y, x)] = [list of optimal actions]
        for y in range(self.n):
            for x in range(self.m):
                # state_index = y * self.m + x
                a_list[(y, x)] = self.get_policy((y, x))
                v_list[(y, x)] = np.max(self.Q[y, x])
        self.eval_env.plot_bf_function(work_dir, state_list, v_list, a_list, f"q_step_{task_str}_{step}_v_function")

if __name__ == "__main__":
    from url_benchmark.gridworld.env import build_gridworld_task, ObservationType

    env = build_gridworld_task('fourroom', observation_type=ObservationType.AGENT_ONEHOT)
    env.reset()
    pos, neg, arr, reward_func = env.setup_eval('rni')

    agent = QLearningAgent(env, reward_func, n_iters=500)
    agent.solve()
    agent.plot_q_function(work_dir='.', step=500)

