from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from gymnasium.spaces import Discrete, Box, Dict

from envs import ConditionalActionEnv


class MNISTHyperGrid(ConditionalActionEnv):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, dimensions=(4, 5), eps=0.0, max_episode_steps=20,
                 goal_conditioned=False, use_initiation_vector=False, render_mode=None):
        """
        Create a HyperGrid Walk domain with given size and specify whether noise is applied to the observations.
        Observations are simply the node at which the agent is currently at, and actions allow the agent to increment
        and decrement each dimension of the grid. If the agent is at the upper or lower bound of the grid, it can only
        move in one direction.
        :param dimensions: dimensions of the grid
        :param eps: probability of teleporting to a random state
        :param max_episode_steps: maximum number of steps in a single episode
        :param goal_conditioned: if `True`, the observation is a dictionary with keys `observation`, `achieved_goal`,
                                 and `desired_goal`. This is compatible with Stable Baselines3.
        :param use_initiation_vector: if `True`, null actions are filtered out in `sample_action`.
        :param render_mode: `rgb_array` or `human`
        """
        self._dimensions = dimensions
        self._eps = eps
        self._ndim = len(dimensions)
        self._max_episode_steps = max_episode_steps
        self._gc = goal_conditioned
        self._use_init_vec = use_initiation_vector

        self._iter = 0
        self.action_space = Discrete(2 * self._ndim)  # increment/decrement for each dimension
        self.action_names = []
        for i in range(self._ndim):
            self.action_names.append(f"d{i}_down")
            self.action_names.append(f"d{i}_up")
        dataset = torchvision.datasets.MNIST(root=".", train=True, download=True)
        self._data = dataset.data.flatten(1, -1)
        self._labels = {i: torch.where(dataset.targets == i)[0] for i in range(max(dimensions))}

        # a digit for each dimension
        if self._gc:
            self.observation_space = Dict({
                "observation": Box(low=0, high=1, shape=(784*self._ndim,), dtype=np.float64),
                "achieved_goal": Box(low=0, high=1, shape=(784*self._ndim,), dtype=np.float64),
                "desired_goal": Box(low=0, high=1, shape=(784*self._ndim,), dtype=np.float64),
            })
        else:
            self.observation_space = Box(low=0, high=1, shape=(784*self._ndim,), dtype=np.float64)
        self._observation = None
        self._state = np.zeros(self._ndim, dtype=int)  # this is the underlying node at which the agent is currently
        self._goal = np.zeros(self._ndim, dtype=int)
        for i in range(self._ndim):
            self._goal[i] = self._dimensions[i]-1

        # rendering stuff
        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode
        self._viewer = None

    @property
    def available_mask(self) -> Tuple:
        """
        Return a binary array specifying which options can be run at the current state
        """
        mask = ()
        for i, dim in enumerate(self._dimensions):
            # all actions are available if we are not using the initiation vector
            if not self._use_init_vec:
                mask += (1, 1)
                continue

            if dim == 1:
                mask += (0, 0)
            elif self._state[i] == 0:
                mask += (0, 1)
            elif self._state[i] == dim - 1:
                mask += (1, 0)
            else:
                mask += (1, 1)
        return mask

    @property
    def observation(self):
        return self._observation

    @property
    def reward(self):
        return float(self.terminated)

    @property
    def terminated(self):
        for i in range(self._ndim):
            if self._state[i] != self._goal[i]:
                return False
        return True

    @property
    def truncated(self):
        return self._iter >= self._max_episode_steps

    @property
    def info(self):
        x = {}
        x["steps"] = 0
        x["position"] = self._state.copy() + np.random.randn(self._ndim).clip(-3, 3)*0.2
        x["state"] = self._state.copy()
        x["goal"] = self._goal.copy()
        return x

    def compute_reward(self, achieved_goal, desired_goal, info):
        if isinstance(info, dict):
            return float(np.all(info["state"] == info["goal"]))
        else:
            return np.array([float(np.all(x["state"] == x["goal"])) for x in info])

    def _compute_obs(self, state):
        img = np.zeros((784*self._ndim,))
        for i, s in enumerate(state):
            idx = np.random.choice(self._labels[s])
            img[i*784:(i+1)*784] = self._data[idx].float().numpy() / 255
        if self._gc:
            goal_img = np.zeros((784*self._ndim,))
            for i, s in enumerate(self._goal):
                idx = np.random.choice(self._labels[s])
                goal_img[i*784:(i+1)*784] = self._data[idx].float().numpy() / 255
            obs = {
                "observation": img.copy(),
                "achieved_goal": img.copy(),
                "desired_goal": goal_img.copy(),
            }
            return obs
        return img

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self._state = np.zeros(self._ndim, dtype=int)
        self._iter = 0
        for i in range(self._ndim):
            self._state[i] = np.random.randint(0, self._dimensions[i])
        if self._gc:
            for i in range(self._ndim):
                self._goal[i] = np.random.randint(0, self._dimensions[i])

        self._observation = self._compute_obs(self._state)

        if self.render_mode == "human":
            self._render_frame()

        return self._observation, self.info

    def step(self, action):

        assert self.action_space.contains(action)

        action_dim = action // 2
        action_dir = action % 2

        if (self._state[action_dim] == 0) and (action_dir == 0):
            pass
        elif (self._state[action_dim] == (self._dimensions[action_dim] - 1)) and (action_dir == 1):
            pass
        else:
            if np.random.rand() < self._eps:
                self._state = np.random.randint(0, self._dimensions, size=self._ndim)
            else:
                if action_dir == 0:
                    # decrement
                    self._state[action_dim] -= 1
                else:
                    # increment
                    self._state[action_dim] += 1

        self._iter += 1
        self._observation = self._compute_obs(self._state)

        if self.render_mode == "human":
            self._render_frame()

        return self._observation, self.reward, self.terminated, self.truncated, self.info

    def render(self):
        if self.render_mode == "rgb_array":
            return self._render_frame()

    def _render_frame(self):

        if not self._viewer and self.render_mode == "human":
            import matplotlib
            matplotlib.use('TkAgg')  # interactive mode

        plt.clf()

        if self._gc:
            img = self._observation["observation"].reshape(self._ndim, 28, 28)
            img = np.transpose(img, (1, 0, 2)).reshape(28, 28*self._ndim)
            goal = self._observation["desired_goal"].reshape(self._ndim, 28, 28)
            goal = np.transpose(goal, (1, 0, 2)).reshape(28, 28*self._ndim)
            img = np.concatenate([img, goal], axis=1)
        else:
            img = self._observation.reshape(self._ndim, 28, 28)
            img = np.transpose(img, (1, 0, 2)).reshape(28, 28*self._ndim)

        img = torchvision.transforms.ToPILImage()(img)
        plt.imshow(img)
        plt.axis('off')
        if self._gc:
            plt.title(f"State: {self._state}, Goal: {self._goal}")
        else:
            plt.title(f"State: {self._state}")

        plt.draw()
        if self.render_mode == "human":
            plt.pause(1 / self.metadata["render_fps"])

        else:  # rgb_array
            plt.draw()
            canvas = plt.gcf().canvas
            data = np.array(canvas.renderer.buffer_rgba(), dtype=np.uint8)
            w, h = canvas.get_width_height()
            im = data.reshape((int(h), int(w), -1))
            return im


if __name__ == '__main__':

    # env = MNISTHyperGrid(dimensions=(4,), render_mode="human")
    env = MNISTHyperGrid(dimensions=(3, 5), render_mode="human")
    # env = MNISTHyperGrid(dimensions=(6, 2, 3, 1), render_mode="human")
    observation, info = env.reset()

    for i in range(1000):
        action = env.sample_action()
        observation, reward, terminated, truncated, info = env.step(action)

        # frame = env.render() # render if mode is rgb array. Can then save as image

        if terminated or truncated:
            observation, info = env.reset()

    env.close()
