# Class for Bellman-Ford algorithm that takes the eval_env which is a gridworld

import numpy as np

class BellmanFordAgent:
    def __init__(self, eval_env):
        self.eval_env = eval_env
        self.grid = eval_env._layout
        self.n = self.grid.shape[0]
        self.m = self.grid.shape[1]
        self.inf = 1e2
        self.dist = np.zeros((self.n, self.m))
        self.prev = {}
        self.vis = np.zeros((self.n, self.m))

    def initialize(self):
        for i in range(self.n):
            for j in range(self.m):
                self.dist[i][j] = self.inf
                self.prev[(i, j)] = []
                self.vis[i][j] = 0

    def solve(self):
        self.initialize()
        print("Solving using Bellman-Ford")
        print(self.grid)
        goal = self.eval_env.get_goal_obs()
        goal = (int(goal[0] * self.n + 0.1), int(goal[1] * self.m + 0.1))
        print(goal)

        self.dist[goal[0]][goal[1]] = 0
        for k in range(self.n * self.m - 1):
            for i in range(self.n):
                for j in range(self.m):
                    if self.grid[i][j] == -1:
                        continue
                    for dx, dy in zip([-1, 0, 1, 0, 0], [0, 1, 0, -1, 0]):
                        if dx == 0 and dy == 0:
                            continue
                        if i + dx >= 0 and i + dx < self.n and j + dy >= 0 and j + dy < self.m:
                            if self.dist[i][j] > self.dist[i + dx][j + dy] + 1:
                                self.dist[i][j] = self.dist[i + dx][j + dy] + 1
        
        # Extact all possible actions from each state
        for i in range(self.n):
            for j in range(self.m):
                if self.grid[i][j] == -1:
                    continue
                if i == goal[0] and j == goal[1]:
                    print("Goal")
                    for act_id, (dx, dy) in enumerate(zip([-1, 0, 1, 0, 0], [0, 1, 0, -1, 0])):
                        if dx == 0 and dy == 0:
                            print(f"goal {i} {j} {dx} {dy} {act_id}")
                            self.prev[(i, j)].append(act_id)
                        elif self.dist[i + dx][j + dy] == self.inf:
                            print(f"goal {i} {j} {dx} {dy} {act_id}")
                            self.prev[(i, j)].append(act_id)
                    continue
                for act_id, (dx, dy) in enumerate(zip([-1, 0, 1, 0, 0], [0, 1, 0, -1, 0])):
                    if dx == 0 and dy == 0:
                        continue
                    # if i + dx >= 0 and i + dx < self.n and j + dy >= 0 and j + dy < self.m:
                    if self.dist[i + dx][j + dy] == self.dist[i][j] - 1:
                        print(f"{i} {j} {dx} {dy} {act_id}")
                        self.prev[(i, j)].append(act_id)

                

        print("Distances")
        print(self.dist)
        # Plot value function
        # self.eval_env.plot_v_function_bf('')
        # self.eval_env.plot_grid()

    def plot_bf_function(self, work_dir, step):
        state_list = self.eval_env.get_state_list()
        print(state_list)
        # state_list = [(int(state[0] * self.n + 0.1), int(state[1] * self.m + 0.1)) for state in state_list]
        print('in plot_bf_function')
        # print(state_list)
        # obs_list = [self.eval_env.get_obs_from_state(state) for state in state_list] # implement this function
        # obs_list = np.concatenate(obs_list, axis=0)
        # print(obs_list)
        # print(len(state_list))
        v_list = -1.0 * self.dist
        a_list = self.prev
        # print(v_list, a_list)
        self.eval_env.plot_bf_function(work_dir, state_list, v_list, a_list, f"bf_step_{step}_v_function") # write this function

