from typing import Tuple, List

import gym
import numpy as np
from torchvision import datasets

from util import check_valid_split, get_split_shape, crop, crop_batch

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec


class PartialMNISTEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, x_splits=0, y_splits=0, train=True, data_root="data", normalize_obs=True, max_steps=2,
                 num_envs=1, sample_all_sequentially=False):
        """
        A MNIST classification environment with multiple agents and partial observability.

        :param x_splits: The number of splits along the x axis
        :param y_splits: The number of splits along the y axis
        :param train: Whether to use train or test data
        :param data_root: The root directory where the data is/will be stored
        :param normalize_obs: Whether to normalize observations
        :param max_steps: The maximum number of steps before the episode ends
                          (reward is returned at the end of the episode)
        :param num_envs: The number of parallel environments
        :param sample_all_sequentially: Whether to sample sequentially from the given data (warning: will result in
                                        dynamic batch sizes, reset returns an empty observation when done)
        """

        self.x_splits = x_splits
        self.y_splits = y_splits
        self.max_steps = max_steps
        self.data_root = data_root
        self.normalize_obs = normalize_obs

        self.train = train
        self.dataset = datasets.MNIST(data_root, train=train, download=True)
        self.data = self.dataset.data.numpy()
        self.image_shape = self.data.shape[1:]
        self.targets = self.dataset.targets.numpy()

        self.num_envs = num_envs
        self.sample_all_sequentially = sample_all_sequentially

        if normalize_obs:
            self.data = (self.data - self.get_data_mean()) / self.get_data_std()

        check_valid_split(x_splits, y_splits, self.image_shape[1], self.image_shape[0])
        self.split_shape = get_split_shape(x_splits, y_splits, self.image_shape[1], self.image_shape[0])

        # crop all observations once to make the resets faster
        self.img_channels = 1
        self.num_agents = (x_splits + 1) * (y_splits + 1)
        self.num_actions = 10

        self.single_obs_shape = (self.img_channels, *self.split_shape)
        self.all_obs = np.empty((len(self.data), self.num_agents, *self.single_obs_shape))

        self.agents_view_rect = []
        idx = 0
        for y in range(0, self.image_shape[0], self.split_shape[0]):
            for x in range(0, self.image_shape[1], self.split_shape[1]):
                cropped = crop_batch(self.data, x, y, self.split_shape[1], self.split_shape[0])
                expanded = np.expand_dims(cropped, 1)

                self.agents_view_rect.append((x, y, self.split_shape[1], self.split_shape[0]))

                self.all_obs[:, idx] = expanded
                idx += 1

        self.indices = None
        self.reset_index()
        self.step_count = 0
        self.target = None
        self.done = False
        self.obs = None

        obs_shape = (self.num_agents, self.num_envs, *self.single_obs_shape)
        self.observation_space = gym.spaces.Box(
            self.data.min(), self.data.max(), shape=obs_shape, dtype=self.data.dtype
        )

        agent_action_space = gym.spaces.Tuple([gym.spaces.Discrete(10) for _ in range(0, self.num_envs)])
        self.action_space = gym.spaces.Tuple([agent_action_space for _ in range(0, self.num_agents)])

        self.fig = None

    def reset_index(self):
        self.indices = None

    def eval(self, test=True):
        """
        Create a copy of this environment with sample_all_sequentially = True and the given value for test.

        :param test: Whether to use test data
        :return: A new PartialMNISTEnv instance
        """

        return PartialMNISTEnv(x_splits=self.x_splits, y_splits=self.y_splits, train=(not test), data_root=self.data_root,
                               normalize_obs=self.normalize_obs, max_steps=self.max_steps, num_envs=self.num_envs,
                               sample_all_sequentially=True)

    def get_data_mean(self):
        return np.mean(self.data, axis=(1, 2)).mean()

    def get_data_std(self):
        return np.std(self.data, axis=(1, 2)).mean()

    def get_next_sample_index(self) -> int:
        """
        Get the next sample index for sequential sampling.

        :return: the next index
        """
        if self.indices is None:
            # start at zero
            return 0

        return self.indices[-1] + 1

    def reset(self):
        if self.sample_all_sequentially:
            next_index = self.get_next_sample_index()
            # create a sequence of indices with max self.num_envs elements
            self.indices = np.arange(next_index, min(len(self.data), next_index + self.num_envs))
        else:
            # choose random indices
            self.indices = np.random.randint(0, len(self.data), size=self.num_envs)

        # create observation with shape (num_agents, num_envs, *self.single_obs_shape)
        self.obs = self.all_obs[self.indices].swapaxes(0, 1)
        self.target = self.targets[self.indices]

        self.step_count = 0
        self.done = False

        return self.obs

    @staticmethod
    def reward_to_accuracy(reward):
        return (reward + 1) / 2

    @staticmethod
    def get_view_rects(image_shape, x_splits, y_splits):
        split_shape = get_split_shape(x_splits, y_splits, image_shape[1], image_shape[0])
        view_rects = []
        for y in range(0, image_shape[0], split_shape[0]):
            for x in range(0, image_shape[1], split_shape[1]):
                view_rects.append((x, y, split_shape[1], split_shape[0]))

        return view_rects

    def step(self, actions) -> Tuple[np.ndarray, np.ndarray, bool, dict]:
        """
        Execute a step in all environments.

        :param actions: The actions of all agents for each environment. Expected shape: (num_agents, num_envs)
        :return: observations, rewards, done (singe value for all envs), info
        """
        assert len(self.target) > 0, "Cannot execute step on empty target"

        self.step_count += 1

        if self.step_count >= self.max_steps:
            self.done = True

        if self.done:
            rewards = (np.array(actions) == self.target) * 2 - 1
        else:
            rewards = np.zeros_like(actions)

        return self.obs, rewards, self.done, dict(target=self.target)

    def render(self, mode='human', env_index=0, interactive=True):
        if self.fig is None:
            self.fig = plt.figure(figsize=(10, 5))
        else:
            # self.fig.clear()
            # hotfix
            plt.close(self.fig)
            self.fig = plt.figure(figsize=(10, 5))

        self.fig.suptitle(f"Environment index {env_index}/{self.num_envs - 1}")

        outer = gridspec.GridSpec(1, 2)

        inner = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=outer[0])
        ax = plt.Subplot(self.fig, inner[0])
        ax.set_title(f"Image: Digit {self.target[env_index]}")
        ax.imshow(self.data[self.indices[env_index]])
        self.fig.add_subplot(ax)

        inner = gridspec.GridSpecFromSubplotSpec(self.y_splits + 1, self.x_splits + 1, subplot_spec=outer[1])
        for a in range(0, self.num_agents):
            ax = plt.Subplot(self.fig, inner[a])
            ax.set_title(f"Obs {a}")
            ax.imshow(self.obs[a, env_index, 0])
            self.fig.add_subplot(ax)

        if interactive:
            plt.show()
        else:
            plt.draw()
            plt.pause(0.001)
