import gym
from gym import spaces
import numpy as np
from collections import defaultdict 
from gym.utils import seeding
from tqdm import tqdm


STANDARD_MAP = np.array([
    [1,1,1,1,1,1,1,1,1,1],
    [1,0,0,0,0,0,0,0,0,1],
    [1,0,1,1,1,1,1,1,0,1],
    [1,0,0,0,0,0,0,0,0,1],
    [1,0,1,1,1,1,0,1,0,1],
    [1,0,1,0,0,0,0,1,0,1],
    [1,1,1,1,1,1,1,1,1,1]])


def fill_matrix(nrow, ncol, n_actions, n_ghosts, n_directions, food_x, food_y):

    RANDOM_CHANCE = 0.6

    grid = np.arange(nrow*ncol).reshape(nrow, ncol)

    assert n_ghosts == 1

    # make action map
    action_map = {0: (0, -1), # left
                  1: (0, 1), # right
                  2: (1, 0), # up
                  3: (-1, 0), # down
                  4: (0, 0), # stay
                  }

    direction_map = {0: (0, -1), # left
                     1: (0, 1), # right
                     2: (1, 0), # up
                     3: (-1, 0), # down
                    }

    reverse_map = {0: 1,
                   1: 0,
                   2: 3,
                   3: 2,}

    assert n_actions <= len(action_map)

    assert n_directions <= len(direction_map)

    probabilities = defaultdict(list)
    successor_states = defaultdict(list)

    state_map = {}

    state_number = 0

    reverse_state_map = {}

    for agent_y in range(nrow):
        for agent_x in range(ncol):
            for agent_direction in range(n_directions):
                
                infront_agent_y = int(np.clip(agent_y + direction_map[agent_direction][0], 0, nrow-1))
                infront_agent_x = int(np.clip(agent_x + direction_map[agent_direction][1], 0, ncol-1))

                behind_agent_y = int(np.clip(agent_y - direction_map[agent_direction][0], 0, nrow-1))
                behind_agent_x = int(np.clip(agent_x - direction_map[agent_direction][1], 0, ncol-1))

                if STANDARD_MAP[infront_agent_y, infront_agent_x] == 1 and STANDARD_MAP[behind_agent_y, behind_agent_x] == 1:
                    continue

                for ghost_y in range(nrow):
                    for ghost_x in range(ncol):
                        for ghost_direction in range(n_directions):
                            for food in [0, 1]:

                                ghost_loc = grid[ghost_y, ghost_x]
                                infront_ghost_y = int(np.clip(ghost_y + direction_map[ghost_direction][0], 0, nrow-1))
                                infront_ghost_x = int(np.clip(ghost_x + direction_map[ghost_direction][1], 0, ncol-1))

                                behind_ghost_y = int(np.clip(ghost_y - direction_map[ghost_direction][0], 0, nrow-1))
                                behind_ghost_x = int(np.clip(ghost_x - direction_map[ghost_direction][1], 0, ncol-1))

                                if STANDARD_MAP[infront_ghost_y, infront_ghost_x] == 1 and STANDARD_MAP[behind_ghost_y, behind_ghost_x] == 1:
                                    continue

                                if STANDARD_MAP[agent_y, agent_x] == 0 and STANDARD_MAP[ghost_y, ghost_x] == 0:
                                    state_map[(agent_y, agent_x, agent_direction, ghost_y, ghost_x, ghost_direction, food)] = state_number
                                    reverse_state_map[state_number] = (agent_y, agent_x, agent_direction, ghost_y, ghost_x, ghost_direction, food)
                                    state_number += 1

    n_states = state_number

    matrix = np.zeros((n_states, n_states, n_actions), dtype=np.float32)

    print("Computing successor states and probabilities ... ")

    for tup, state_number in tqdm(state_map.items()):

        agent_y, agent_x, agent_direction, ghost_y, ghost_x, ghost_direction, food = tup

        if (agent_x == food_x) and (agent_y == food_y):
            next_food = 0
        else:
            next_food = food

        next_ghost_y = int(np.clip(ghost_y + direction_map[ghost_direction][0], 0, nrow-1))
        next_ghost_x = int(np.clip(ghost_x + direction_map[ghost_direction][1], 0, ncol-1))

        next_loc_free = (STANDARD_MAP[next_ghost_y, next_ghost_x] == 0)

        x = np.array([agent_y - ghost_y, agent_x - ghost_x], dtype=np.float32)
        norm = np.linalg.norm(x)
        if norm == 0.0:
            x_norm = x
        else:
            x_norm = x * (1/norm)
        prods = np.array([-np.inf for _ in range(n_actions)], dtype=np.float32)

        for ghost_act in range(n_actions):
            if next_loc_free and ghost_act == reverse_map[ghost_direction]:
                continue
            next_ghost_direction = ghost_act if ghost_act <= 3 else ghost_direction
            next_ghost_y = int(np.clip(ghost_y + direction_map[next_ghost_direction][0], 0, nrow-1))
            next_ghost_x = int(np.clip(ghost_x + direction_map[next_ghost_direction][1], 0, ncol-1))
            if STANDARD_MAP[next_ghost_y, next_ghost_x] == 0:
                prods[ghost_act] = np.dot(x, np.array(direction_map[next_ghost_direction]))

        ghost_act_probs = np.zeros(n_actions, dtype=np.float32)
        available_ghost_acts = np.where(prods != -np.inf)[0]
        n = len(available_ghost_acts)
        ghost_act_probs[available_ghost_acts] = RANDOM_CHANCE / float(max(1, n-1))
        ghost_act_probs[np.argmax(prods)] = 1.0 - RANDOM_CHANCE if n > 1 else 1.0
        
        assert np.any(prods != np.array([-np.inf for _ in range(n_actions)]))
        assert np.sum(ghost_act_probs) == 1.0, f"sum: {np.sum(ghost_act_probs)}, probs: {ghost_act_probs}"
        
        next_agent_y = int(np.clip(agent_y + direction_map[agent_direction][0], 0, nrow-1))
        next_agent_x = int(np.clip(agent_x + direction_map[agent_direction][1], 0, ncol-1))
        next_loc_free = (STANDARD_MAP[next_agent_y, next_agent_x] == 0)

        # next_agent_y next_agent_x agent_direction ghost_y ghost_x 

        for agent_act in range(n_actions):
            if next_loc_free and agent_act == reverse_map[agent_direction]:
                next_agent_direction = agent_direction
            else:
                next_agent_direction = agent_act if agent_act <= 3 else agent_direction
            next_agent_y = int(np.clip(agent_y + direction_map[next_agent_direction][0], 0, nrow-1))
            next_agent_x = int(np.clip(agent_x + direction_map[next_agent_direction][1], 0, ncol-1))
            if STANDARD_MAP[next_agent_y, next_agent_x] == 1:
                next_agent_y = int(np.clip(agent_y + direction_map[agent_direction][0], 0, nrow-1))
                next_agent_x = int(np.clip(agent_x + direction_map[agent_direction][1], 0, ncol-1))
                if STANDARD_MAP[next_agent_y, next_agent_x] == 1:
                    next_agent_y = agent_y
                    next_agent_x = agent_x
                next_agent_direction = agent_direction

            for ghost_act in available_ghost_acts:

                next_ghost_y = int(np.clip(ghost_y + action_map[ghost_act][0], 0, nrow-1))
                next_ghost_x = int(np.clip(ghost_x + action_map[ghost_act][1], 0, ncol-1))
                next_ghost_direction = ghost_act if ghost_act <= 3 else ghost_direction

                if (next_agent_y, next_agent_x) == (ghost_y, ghost_x):
                    matrix[state_map[(next_agent_y, next_agent_x, next_agent_direction, ghost_y, ghost_x, ghost_direction, next_food)], state_number, agent_act] += ghost_act_probs[ghost_act]
                else:
                    matrix[state_map[(next_agent_y, next_agent_x, next_agent_direction, next_ghost_y, next_ghost_x, next_ghost_direction, next_food)], state_number, agent_act] += ghost_act_probs[ghost_act]

    return matrix, state_map, n_states, reverse_state_map

class Pacman(gym.Env):

    metadata = {"render_modes": ["ascii"]}

    def __init__(self, seed=0, episode_length=100, render_mode=None):

        self.np_random, _ = seeding.np_random(seed)

        self.ncol = 10
        self.nrow = 7
        self.n_ghosts = 1
        self.n_directions = 4

        self.n_actions = 5

        self.episode_length = episode_length

        self._food_x = 7
        self._food_y = 3

        self.transition_matrix, self.state_map, self.n_states, self.reverse_state_map = fill_matrix(self.nrow, self.ncol, self.n_actions, self.n_ghosts, self.n_directions, self._food_x, self._food_y)

        print(f"total number of states {self.n_states}")

        # sanity check to make sure every state action pair sums to probability 1
        for state in range(self.n_states):
            for action in range(self.n_actions):
                assert np.sum(self.transition_matrix[:, state, action])

        self.action_space = spaces.Discrete(self.n_actions)
        self.observation_space = spaces.Discrete(self.n_states)

        self._safe_state_dict = defaultdict(bool)
        self._unsafe_state_dict = defaultdict(bool)

        self._agent_start_x = 4
        self._agent_start_y = 1
        self._agent_start_direction = 1
        self._ghost_start_x = 3
        self._ghost_start_y = 5
        self._ghost_start_direction = 1

        self._start_state = self.state_map[(self._agent_start_y, self._agent_start_x, self._agent_start_direction, self._ghost_start_y, self._ghost_start_x, self._ghost_start_direction, 1)]

        self.atomic_predicates = {"unsafe"}
        def empty_set():
            return {}

        self.labelling_fn = defaultdict(empty_set) 
        self.labelling_fn[self._start_state] = ({"start"})
        self.reward_fn = defaultdict(float)


        for agent_x in range(self.ncol):
            for agent_y in range(self.nrow):
                for agent_direction in range(self.n_directions):
                    for ghost_x in range(self.ncol):
                        for ghost_y in range(self.nrow):
                            for ghost_direction in range(self.n_directions):
                                for food in [0, 1]:
                                    try:
                                        state_number = self.state_map[(agent_y, agent_x, agent_direction, ghost_y, ghost_x, ghost_direction, food)]
                                        # safe states are goal states where the agent is in the goal position and the ghost is not in the agent position
                                        if (agent_y == self._food_y) and (agent_x == self._food_x) and (not ((agent_y, agent_x) == (ghost_y, ghost_x))) and food:
                                            self.labelling_fn[state_number] = ({"food"})
                                            self.reward_fn[state_number] = 1.0
                                        if (agent_y, agent_x) == (ghost_y, ghost_x):
                                            self.labelling_fn[state_number] = ({"ghost"})
                                    except KeyError:
                                        continue

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        self._step_counter = 0

    def _transition(self, action):
        try:
            return self.np_random.choice(self.n_states, p=self.transition_matrix[:, self._state, action])
        except:
            assert np.sum(self.transition_matrix[:, self._state, action])==1.0, f"sum: {np.sum(self.transition_matrix[:, self._state, action])}, probs: {self.transition_matrix[:, self._state, action]} state: {self._state}, action: {action}"

    def _get_labels(self):
        """return the labels for the current state"""
        return self.labelling_fn[self._state]

    def _get_obs(self):
        return self._state

    def _get_info(self):
        """return the info for the current state"""
        return {"labels": self._get_labels()}

    def _get_reward(self):
        """return the reward for the current state"""
        return self.reward_fn[self._state]

    def _get_terminated(self):
        """check if the termination condition is satisfied"""
        return False

    def _get_truncated(self):
        return True if self._step_counter >= self.episode_length else False

    def reset(self, seed=None, options=None):
        """reset the environment and return the start obs"""
        self._state = self._start_state
        observation = self._get_obs()
        info = self._get_info()
        self._step_counter = 0

        if self.render_mode == "ascii":
            self._render_frame()

        return observation, info

    def step(self, action):
        """play a given action in the environment"""
        next_state = self._transition(action)
        self._state = next_state

        # increment step counter
        self._step_counter += 1

        terminated = self._get_terminated()
        truncated = self._get_truncated()
        done = terminated or truncated
        reward = self._get_reward()
        observation = self._get_obs()
        info = self._get_info()
        info["is_truncated"] = truncated
        info["is_terminated"] = terminated

        if self.render_mode == "ascii":
            self._render_frame()
            
        return observation, reward, done, info

    def _render_frame(self):
        grid = STANDARD_MAP
        agent_y, agent_x, agent_direction, ghost_y, ghost_x, ghost_direction, food = self.reverse_state_map[self._state]
        for y in range(grid.shape[0]):
            row = []
            for x in range(grid.shape[1]):
                if agent_y == y and agent_x == x:
                    row.append("A")
                elif ghost_y == y and ghost_x == x:
                    row.append("G")
                elif food and self._food_y == y and self._food_x ==x:
                    row.append("X")
                elif STANDARD_MAP[y, x] == 1:
                    row.append("☐")
                elif STANDARD_MAP[y, x] == 0:
                    row.append("-")
            print(row)

    @property
    def _agent_location(self):
        return self._state




