import gym
from gym import spaces
import numpy as np
from collections import defaultdict 
from gym.utils import seeding

def fill_matrix(grid_size, n_states, n_actions, prob=0.0):
    """construct the transition matrix for the environment"""
    assert n_states == grid_size**2 

    # make grid
    grid = np.zeros((grid_size, grid_size), dtype=int)
    for y in range(grid_size):
        grid[y] = np.arange(grid_size) + y*grid_size

    # make action map
    action_map = {0: (0, -1), # left
                  1: (0, 1), # right
                  2: (1, 0), # up
                  3: (-1, 0), # down
                  4: (0, 0), # stay
                  5: (-1, -1), # left up
                  6: (-1, 1), # left down
                  7: (-1, 1), # right up
                  8: (1, 1), # right down
                  }

    assert n_actions < len(action_map.keys())

    # make matrix
    # next state, state, action
    matrix = np.zeros((n_states, n_states, n_actions))

    for y in range(grid_size):
        for x in range(grid_size):
            for a in range(n_actions):
                state = grid[y][x]
                next_y = int(np.clip(y + action_map[a][0], 0, grid_size-1))
                next_x = int(np.clip(x + action_map[a][1], 0, grid_size-1))
                next_state = grid[next_y, next_x]
                matrix[next_state, state, a] = 1.0 - prob
                
                # add random actions to the transition matrix
                if not prob:
                    continue

                rand_prob = prob * 1 / (n_actions - 1)
                for rand_a in range(n_actions):
                    if rand_a == a:
                        continue
                    next_y = int(np.clip(y + action_map[rand_a][0], 0, grid_size-1))
                    next_x = int(np.clip(x + action_map[rand_a][1], 0, grid_size-1))
                    next_state = grid[next_y, next_x]
                    matrix[next_state, state, a] += rand_prob

    return matrix

class ColourGridWorld(gym.Env):
    """
    Colour Grid World environment

    Input attributes:
        random_action_probability: probability of a random actions being selected
        episode_length: length of the episode until termination
        render_mode: how to render the environment [currently not implemented]

    Other attributes:
        grid_size: size of the grid world
        ncol: number of columns
        nrow: number of rows
        n_states: number of states (grid_size^2)
        n_actions: number of actions
        observation_space: gym spaces object
        action_space: gym spaces object
        transition_matrix: the full state action tarsnition matrix of the environment
        atomic_predicates: set of atomic predicates associated with the environment
        labelling_fn: the labelling function of the environment
        reward_fn: the reward function of the environment
        _step_counter: total number of steps in the environment

    """

    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, seed=0, random_action_probability=0.0, episode_length=1000, render_mode=None):
        
        self.np_random, _ = seeding.np_random(seed)
        
        self.grid_size = 9
        self.ncol = self.grid_size
        self.nrow = self.grid_size

        self.n_states = self.grid_size**2
        self.n_actions = 5

        self.random_action_probability=random_action_probability
        self.episode_length = episode_length

        self.observation_space = spaces.Discrete(self.n_states)
        self.action_space = spaces.Discrete(self.n_actions)

        # modify these if you want to change the location of the coloured states
        self._start_state = 0
        self._goal_state = self.n_states - 1
        self._blue_state = 4*self.grid_size
        self._green_state = self.n_states//2
        self._purple_state = 4

        self.transition_matrix = fill_matrix(self.grid_size, self.n_states, self.n_actions, prob=random_action_probability)

        # the goal state sends the agent back to square 0
        self.transition_matrix[:, self._goal_state, :] = np.zeros_like(self.transition_matrix[:, self.n_states-1, :])
        self.transition_matrix[0, self._goal_state, :] = np.ones_like(self.transition_matrix[0, self.n_states-1, :])

        self.atomic_predicates = {"blue", "goal", "green", "purple", "start", "colour"}

        def empty_set():
            return {}
        self.labelling_fn = defaultdict(empty_set) 
        self.labelling_fn[self._start_state] = ({"start"})
        self.labelling_fn[self._goal_state] = ({"goal"})
        self.labelling_fn[self._green_state] = ({"green", "colour"})
        self.labelling_fn[self._purple_state] = ({"purple", "colour"})
        self.labelling_fn[self._blue_state] = ({"blue", "colour"})

        self.reward_fn = defaultdict(float)
        self.reward_fn[self.n_states - 1] = 1.0

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        self._step_counter = 0

    def _transition(self, action):
        """sample a next state randomly from the transition matrix"""
        return self.np_random.choice(self.n_states, p=self.transition_matrix[:, self._agent_location, action])

    def _get_labels(self):
        """return the labels for the current state"""
        return self.labelling_fn[self._agent_location]

    def _get_obs(self):
        """return the observation for the current state"""
        return self._agent_location

    def _get_info(self):
        """return the info for the current state"""
        return {"labels": self._get_labels()}

    def _get_reward(self):
        """return the reward for the current state"""
        return self.reward_fn[self._agent_location]

    def _get_terminated(self):
        return False

    def _get_truncated(self):
        """check if the termination condition is satisfied"""
        return True if self._step_counter >= self.episode_length else False

    def reset(self, seed=None, options=None):
        """reset the environment and return the start obs"""
        self._agent_location = self._start_state
        observation = self._get_obs()
        info = self._get_info()
        self._step_counter = 0

        if self.render_mode == "human":
            self._render_frame()

        return observation, info

    def step(self, action):
        """play a given action in the environment"""
        next_state = self._transition(action)
        self._agent_location = next_state

        # increment step counter
        self._step_counter += 1

        terminated = self._get_terminated()
        truncated=self._get_truncated()
        done = terminated or truncated
        reward = self._get_reward()
        observation = self._get_obs()
        info = self._get_info()
        info["is_truncated"] = truncated
        info["is_terminated"] = terminated

        if self.render_mode == "human":
            self._render_frame()
            
        return observation, reward, done, info

    def _render_frame():
        raise NotImplementedError
