from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
from gymnasium.spaces import Discrete, Box

from envs import ConditionalActionEnv


class FourRooms(ConditionalActionEnv):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, noise=0.0, distractors=False, render_mode=None):
        """
        Create an environment similar to the four rooms domain, with options to move between room centres and doorways.
        The observations are just the xy-position of the agent and optionally additional noisy state variables appended
        :param noise: the amount of noise to apply to the observation (a Gaussian with mean zero and specified standard deviation)
        :param distractors: whether or not additional state variables consisting of just noise should be appended to the xy-position
        to form the observation returned
        """
        self.action_space = Discrete(4)  # northern door, centre north/left, centre south/right, southern door
        self.observation_space = Box(low=-100, high=100, shape=(10 if distractors else 2,),
                                     dtype=np.float32)  # (x, y) + potentially random stuff
        self._noise = noise
        self._distractors = distractors
        self.current_state = None
        self.in_doorway = False
        self.doorway = ''
        self.room = -1
        self._viewer = False

        # rendering stuff
        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

    def __noisy(self, x: List[float]) -> np.ndarray:
        if self._noise == 0:
            return x + np.zeros(shape=(len(x),))
        return x + np.random.normal(0, self._noise, size=(len(x),))

    @property
    def available_mask(self) -> Tuple:
        """
        Return a binary array specifying which options can be run at the current state
        """
        if self.in_doorway:
            mask = (0, 1, 1, 0)
        else:
            mask = (1, 0, 0, 1)
        return mask

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        self.current_state = self._create_observation(23, 23)

        self.in_doorway = False
        self.doorway = ''
        self.room = 0
        if self.render_mode == "human":
            self._render_frame()
        return self.current_state, {"steps": 0}

    def _create_observation(self, x, y):
        """
        Given the xy position of the agent, generate the observation vector. This could include noise as well as
        distractor noise state variables
        """
        if self._distractors:
            return self.__noisy([x, y] + [50] * 8)
        else:
            return self.__noisy([x, y])

    def step(self, action):

        assert self.action_space.contains(action)

        if self.in_doorway and action not in [1, 2]:
            raise ValueError("Cannot move into doorway because already there!")
        elif not self.in_doorway and action in [1, 2]:
            raise ValueError("Cannot move into room because already there!")

        target = None
        if action == 0:
            # move to north door
            if self.room == 0 or self.room == 1:
                target = (50, 25)
                self.doorway = 'N'
            elif self.room == 2:
                target = (25, 50)
                self.doorway = 'W'
            else:
                target = (75, 50)
                self.doorway = 'E'
        elif action == 3:
            # move to south door
            if self.room == 0:
                target = (25, 50)
                self.doorway = 'W'
            elif self.room == 1:
                target = (75, 50)
                self.doorway = 'E'
            else:
                target = (50, 75)
                self.doorway = 'S'
        elif action == 1:
            # north/left centre
            if self.doorway == 'N' or self.doorway == 'W':
                target = (25, 25)
                self.room = 0
            elif self.doorway == 'E':
                target = (75, 25)
                self.room = 1
            else:
                target = (25, 75)
                self.room = 2
        elif action == 2:
            # south/right centre
            if self.doorway == 'N':
                target = (75, 25)
                self.room = 1
            elif self.doorway == 'S' or self.doorway == 'E':
                target = (75, 75)
                self.room = 3
            else:
                target = (25, 75)
                self.room = 2

        new_state = self._create_observation(*target)
        rew = -np.linalg.norm(new_state[0:2] - self.current_state[0:2])
        self.current_state = new_state
        self.in_doorway = action in [0, 3]

        if self.render_mode == "human":
            self._render_frame()

        return self.current_state, rew, False, False, {'steps': int(round(-rew))}

    def render(self):
        if self.render_mode == "rgb_array":
            return self._render_frame()

    def _render_frame(self):

        if not self._viewer and self.render_mode == "human":
            import matplotlib
            matplotlib.use('TkAgg')  # interactive mode

        plt.clf()
        if self.in_doorway:
            plt.title("In doorway {}".format(self.doorway))
        else:
            plt.title("In room {}".format(self.room))
        x, y = self.current_state[0], self.current_state[1]
        plt.axhline(50, linestyle='--', color='k')
        plt.axvline(50, linestyle='--', color='k')
        plt.scatter(x, 100 - y, s=100)

        plt.xlim(0, 100)
        plt.ylim(0, 100)

        plt.draw()
        if self.render_mode == "human":
            plt.pause(1 / self.metadata["render_fps"])

        else:  # rgb_array
            plt.draw()
            canvas = plt.gcf().canvas
            data = np.array(canvas.renderer.buffer_rgba(), dtype=np.uint8)
            w, h = canvas.get_width_height()
            im = data.reshape((int(h), int(w), -1))
            return im


if __name__ == '__main__':

    env = FourRooms(noise=3, distractors=True, render_mode="human")
    observation, info = env.reset()

    for i in range(100):
        action = env.sample_action()
        observation, reward, terminated, truncated, info = env.step(action)
        print(observation)

        # frame = env.render() # render if mode is rgb array. Can then save as image

        if terminated or truncated:
            observation, info = env.reset()

    env.close()
