from typing import Callable, Optional, List

import gym
import tree
from gym import Env, spaces
import numpy as np
import popgym
from gym.core import ObsType
from popgym.core.popgym_env import POPGymEnv
from popgym import Observability
import tensorflow as tf
from belief_learner.utils.costs import get_cost_fn


def losses(space) -> Callable:
    if isinstance(space, gym.spaces.Box):
        assert len(space.shape) == 1
        return get_cost_fn('l22')
    elif isinstance(space, gym.spaces.Discrete):
        return tf.nn.softmax_cross_entropy_with_logits
    elif isinstance(space, gym.spaces.MultiDiscrete):
        return lambda x, y: sum(tf.nn.softmax_cross_entropy_with_logits(x_, y_)
                                for x_, y_ in zip(tf.split(x, space.nvec, axis=-1),
                                                  tf.split(y, space.nvec, axis=-1)))
    elif isinstance(space, gym.spaces.MultiBinary):
        return tf.nn.sigmoid_cross_entropy_with_logits
    elif isinstance(space, gym.spaces.Tuple):
        rec = recover_tf_vect(space, recursive=False)
        raise NotImplementedError
    elif isinstance(space, gym.spaces.Dict):
        raise NotImplementedError


class POPGymWrapper(gym.Wrapper, POPGymEnv):

    def __init__(self, env: Env):
        super().__init__(env)
        assert isinstance(env, POPGymEnv)
        self._state_space: Optional[spaces.Space] = None

    @property
    def state_space(self) -> spaces.Space:
        """Returns the observation space of the environment."""
        if self._state_space is None:
            return self.env.state_space
        return self._state_space

    @state_space.setter
    def state_space(self, space: spaces.Space):
        self._state_space = space

    def get_state(self):
        return self.env.get_state()


class POPGymObservationWrapper(POPGymWrapper):
    def __init__(self, env: Env):
        super().__init__(env)
        assert isinstance(env.unwrapped, POPGymEnv)
        self.observation_space = self.map_space(self.observation_space)
        self.preprocess_fn = self.get_obs_mapper(self.env.observation_space)
        self.state_space = self.map_space(self.state_space)
        self.state_preprocess_fn = self.get_obs_mapper(self.env.state_space)

    def observation(self, observation):
        return self.preprocess_fn(observation)

    def get_state(self):
        state = super(POPGymObservationWrapper, self).get_state()
        return self.state_preprocess_fn(state)

    def reset(self, **kwargs):
        if kwargs.get("return_info", False):
            obs, info = self.env.reset(**kwargs)
            return self.observation(obs), info
        else:
            return self.observation(self.env.reset(**kwargs))

    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        return self.observation(observation), reward, done, info

    @staticmethod
    def map_space(space: gym.Space) -> gym.Space:
        raise NotImplementedError

    @staticmethod
    def get_obs_mapper(space: gym.Space) -> Callable[[ObsType], ObsType]:
        raise NotImplementedError


class FlattenObservationSpaceWrapper(POPGymObservationWrapper):
    @staticmethod
    def map_space(space: gym.Space) -> gym.Space:
        if isinstance(space, gym.spaces.Box):
            space = gym.spaces.Box(space.low.flatten(), space.high.flatten(), (np.prod(space.shape).item(),))
        elif isinstance(space, gym.spaces.Discrete):
            pass
        elif isinstance(space, gym.spaces.MultiDiscrete):
            space = gym.spaces.MultiDiscrete(nvec=space.nvec.flatten())
        elif isinstance(space, gym.spaces.MultiBinary):
            space = gym.spaces.MultiBinary(n=np.prod(space.shape))
        elif isinstance(space, (gym.spaces.Tuple, gym.spaces.Dict)):
            space = gym.spaces.Tuple(tuple(map(FlattenObservationSpaceWrapper.map_space, tree.flatten(space))))
        else:
            raise ValueError
        return space

    @staticmethod
    def get_obs_mapper(space: gym.Space):
        if isinstance(space, (gym.spaces.Box, gym.spaces.MultiBinary, gym.spaces.MultiDiscrete)):
            def flatten(obs: ObsType):
                return obs.flatten()

            return flatten
        elif isinstance(space, gym.spaces.Discrete):
            def identity(obs: ObsType):
                return obs

            return identity
        elif isinstance(space, (gym.spaces.Tuple, gym.spaces.Dict)):
            sub_fun = list(map(FlattenObservationSpaceWrapper.get_obs_mapper, tree.flatten(space)))

            def flatten(obs: ObsType):
                return tuple(f(o) for f, o in zip(sub_fun, tree.flatten(obs)))

            return flatten


class BoxifyWrapper(POPGymObservationWrapper):
    def __init__(self, env: Env):
        super().__init__(env)
        self.pre_boxify_obs_space = self.env.observation_space
        self.pre_boxify_state_space = self.env.state_space
        if not isinstance(self.pre_boxify_obs_space, gym.spaces.Tuple):
            self.pre_boxify_obs_space_shapes = [BoxifyWrapper.space_box_shape(self.pre_boxify_obs_space)]
        else:
            self.pre_boxify_obs_space_shapes = list(map(BoxifyWrapper.space_box_shape, self.pre_boxify_obs_space))
        if not isinstance(self.pre_boxify_state_space, gym.spaces.Tuple):
            self.pre_boxify_state_space_shapes = [BoxifyWrapper.space_box_shape(self.pre_boxify_state_space)]
        else:
            self.pre_boxify_state_space_shapes = list(map(BoxifyWrapper.space_box_shape, self.pre_boxify_state_space))

    @staticmethod
    def space_box_shape(space: gym.Space) -> int:
        if isinstance(space, gym.spaces.Box):
            n = np.prod(space.shape).item()
        elif isinstance(space, gym.spaces.Discrete):
            n = space.n
        elif isinstance(space, gym.spaces.MultiDiscrete):
            n = sum(space.nvec)
        elif isinstance(space, gym.spaces.MultiBinary):
            n = np.prod(space.shape).item()
        else:
            raise ValueError
        return n

    @staticmethod
    def map_space(space: gym.Space) -> gym.Space:
        if isinstance(space, gym.spaces.Box):
            space = gym.spaces.Box(space.low.flatten(), space.high.flatten(), (BoxifyWrapper.space_box_shape(space),))
        elif isinstance(space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete, gym.spaces.MultiBinary)):
            space = gym.spaces.Box(0, 1, (BoxifyWrapper.space_box_shape(space),))
        elif isinstance(space, gym.spaces.Tuple):
            spaces = tuple(map(BoxifyWrapper.map_space, space))
            dim = sum(map(lambda x: x.shape[0], spaces))
            low = np.concatenate([space.low for space in spaces], 0)
            high = np.concatenate([space.high for space in spaces], 0)
            space = gym.spaces.Box(low, high, (dim,))
        else:
            raise ValueError
        return space

    @staticmethod
    def get_obs_mapper(space: gym.Space):
        def flatten(obs: ObsType):
            return obs.flatten().astype(np.float32)

        def get_one_hot(n):
            array = np.eye(n, dtype=np.float32)

            def one_hot(i):
                return array[i]

            return one_hot

        def multi_discrete_one_hot(nvec):
            one_hot_fns = [get_one_hot(n) for n in nvec]

            def multi_one_hot(array):
                array = [one_hot(a) for one_hot, a in zip(one_hot_fns, array)]
                return np.concatenate(array, axis=0)

            return multi_one_hot

        def concat_tuple(space_tulpe):
            fns = tuple(map(BoxifyWrapper.get_obs_mapper, space_tulpe.spaces))

            def apply_and_concat(obs_tuple):
                obs_list = [fn(obs) for fn, obs in zip(fns, obs_tuple)]
                return np.concatenate(obs_list, axis=0)

            return apply_and_concat

        if isinstance(space, gym.spaces.Box):
            return flatten
        elif isinstance(space, gym.spaces.Discrete):
            return get_one_hot(space.n)
        elif isinstance(space, gym.spaces.MultiDiscrete):
            return multi_discrete_one_hot(space.nvec)
        elif isinstance(space, gym.spaces.MultiBinary):
            return flatten
        elif isinstance(space, gym.spaces.Tuple):
            return concat_tuple(space)
        else:
            raise ValueError


def split_boxify_np(array: np.ndarray, pre_boxify_obs_space_shapes: List[int]):
    splits = np.cumsum(pre_boxify_obs_space_shapes)[:-1]
    split_arrays = np.split(array, splits, -1)
    return split_arrays


def split_boxify_tf(tensor: tf.Tensor, pre_boxify_obs_space_shapes: List[int]):
    return tf.split(tensor, pre_boxify_obs_space_shapes, axis=-1)
