import numpy as np
import gym
from gym.envs.toy_text import discrete
from gym import spaces

# NOISE = 0.05


class GridWorld(gym.Env):

    def __init__(self, noise: float, size=5, max_episode_steps=10):
        super().__init__()
        self.noise = noise
        self.size = size
        self.observation_space = spaces.Discrete((self.size+1)**2)
        self.num_act = 4
        self.action_space = spaces.Discrete(self.num_act)
        self.end_point = {(self.size, self.size): 10}
        self.state = None
        self.move_dir = {0: (0, 1), 1: (1, 0), 2: (0, -1), 3: (-1, 0)}
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = None

    def seed(self, seed=None):
        super().seed(seed)
        np.random.seed(seed)

    def get_reward(self, obs, next_obs):

        """ Return the reward given observation and next observation.
        reward = 20 if obs is not ending point and next_obs is the ending point, reward = 0 otherwise.
        keep done = False when the episode length does not achieve the max_episode_length.
        """
        done = False
        reward = -1
        if obs in self.end_point.keys():
            # done = True
            reward = 1
        elif obs not in self.end_point.keys() and next_obs in self.end_point.keys():
            # done = True
            reward = 1

        return reward, done

    def get_next_obs(self, obs, act, is_deterministic=False):

        """ Output the next state given current state and action.
        With probability of 1-noise, the next state is result of taking action in current state.
        With probability of noise, the next state is result of taking random action in current state.
        If the obs is ending point, then the agent keeps still.
        """
        if obs in self.end_point.keys():
            next_obs = obs
        else:
            if not is_deterministic:
                if np.random.random() >= self.noise:
                    next_obs = (np.clip(obs[0] + self.move_dir[act][0], 0, self.size),
                            np.clip(obs[1] + self.move_dir[act][1], 0, self.size))
                else:
                    act = np.random.choice(self.num_act)
                    next_obs = (np.clip(obs[0] + self.move_dir[act][0], 0, self.size),
                            np.clip(obs[1] + self.move_dir[act][1], 0, self.size))
            else:
                next_obs = (np.clip(obs[0] + self.move_dir[act][0], 0, self.size),
                            np.clip(obs[1] + self.move_dir[act][1], 0, self.size))
        return next_obs

    def reset(self):

        self.state = (0, 0)
        self._elapsed_steps = 0

        return self.state

    def step(self, act, is_deterministic=False):

        next_state = self.get_next_obs(self.state, act, is_deterministic)
        reward, done = self.get_reward(obs=self.state, next_obs=next_state)
        self.state = next_state
        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            done = True

        return self.state, reward, done, {}

    def population_step(self, obs, act):

        res = []
        for i in range(self.num_act):
            if i == act:
                prob = 1.0
            else:
                prob = 0.0
            next_obs = self.get_next_obs(obs, i)
            reward, done = self.get_reward(obs=obs, next_obs=next_obs)
            res.append(dict(prob=prob, reward=reward, next_obs=next_obs, done=done))

        return res

    def population_step_with_reward(self, obs, act, reward_matrix: np.ndarray):

        """
        Args:
            obs: current obs (x, y)
            act: current action
            reward_matrix: the estimate reward matrix with shape = (num_states, num_action)
        """

        res = []
        for i in range(self.num_act):
            if i == act:
                prob = 1.0
            else:
                prob = 0.0
            next_obs = self.get_next_obs(obs, i)
            true_reward, done = self.get_reward(obs=obs, next_obs=next_obs)
            state = obs[0] * (self.size+1) + obs[1]
            reward = reward_matrix[state, i]
            res.append(dict(prob=prob, reward=reward, next_obs=next_obs, done=done))

        return res