import itertools
import numpy as np

import gymnasium
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple, Dict, MultiBinary
import gym
from gym.spaces import Box as gBox, Discrete as gDiscrete, MultiDiscrete as gMultiDiscrete, Tuple as gTuple, Dict as gDict, MultiBinary as gMultiBinary
from .task import Task, EnvType


def parse_gymnasium_environment(env: gymnasium.Env, env_type: EnvType, name: str = None) -> Task:
    '''
    Generates a regym.environments.Task by extracting information from the
    already built :param: env.

    This function makes the following Assumptions from :param: env:
        - Observation / Action space (it's geometry, dimensionality) are identical for all agents

    :param env: Environment following Gymnasium interface
    :param env_type: Determines whether the parameter is (single/multi)-agent
                     and how are the environment processes these actions
                     (i.e all actions simultaneously, or sequentially)
    :param name: Str defining the name to give to the task, if needs be.
    :returns: Task created from :param: env named :param: name
    '''
    if name is None:    name = env.spec.id
    action_dims, action_type = get_action_dimensions_and_type(env)
    observation_shape, observation_type = get_observation_dimensions_and_type(env)
    state_space_size = env.state_space_size if hasattr(env, 'state_space_size') else None
    action_space_size = env.action_space_size if hasattr(env, 'action_space_size') else None
    hash_function = env.hash_state if hasattr(env, 'hash_state') else None
    goal_shape, goal_type = get_goal_dimensions_and_type(env)
    
    # TODO: find a better condition...
    #check_env_compatibility_with_env_type(env, env_type)

    return Task(name, 
                env, 
                env_type, 
                None, 
                state_space_size, 
                action_space_size, 
                observation_shape, 
                observation_type, 
                action_dims, 
                action_type, 
                hash_function,
                goal_shape,
                goal_type)


def parse_dimension_space(space, key="observation"):
    if isinstance(space, Discrete) or isinstance(space, gDiscrete): return space.n, 'Discrete' # One neuron is enough to take any Discrete space
    if isinstance(space, MultiDiscrete) or isinstance(space, gMultiDiscrete): return space.nvec, 'MultiDiscrete' # One neuron is enough to take any Discrete space
    if isinstance(space, MultiBinary) or isinstance(space, gMultiBinary): return space.n, 'MultiBinary' # One neuron is enough to take any Discrete space
    elif isinstance(space, Box) or isinstance(space, gBox): return space.shape, 'Continuous'
    elif isinstance(space, Tuple) or isinstance(space, gTuple): return sum([parse_dimension_space(s)[0] for s in space.spaces]), parse_dimension_space(space.spaces[0])[1]
    elif isinstance(space, Dict) or isinstance(space, gDict):
        if key in space.spaces.keys():
            return parse_dimension_space(space.spaces[key])
        else:
            raise ValueError(f"Wrongly formatted observation space: {space}")
    # Below space refers to OneHotEncoding space from 'https://github.com/Danielhp95/gym-rock-paper-scissors'
    elif hasattr(space, 'size'): return space.size, 'Discrete'
    raise ValueError('Unknown observation space: {}'.format(space))


def get_observation_dimensions_and_type(env):
    # ASSUMPTION: Multi agent environment. Symmetrical observation space --> Dict space
    return parse_dimension_space(env.observation_space, key="observation") # Single agent environment


def get_goal_dimensions_and_type(env):
    return parse_dimension_space(env.observation_space, key="desired_goal") # Single agent environment


def get_action_dimensions_and_type(env):
    def parse_dimension_space(space):
        if isinstance(space, Discrete) or isinstance(space, gDiscrete): return space.n, 'Discrete'
        elif isinstance(space, MultiDiscrete) or isinstance(space, gMultiDiscrete): return compute_multidiscrete_space_size(space.nvec), 'Discrete'
        elif isinstance(space, Box) or isinstance(space, gBox): return space.shape[0], 'Continuous'
        else: raise ValueError('Unknown action space: {}'.format(space))

    if hasattr(env.action_space, 'spaces'): return parse_dimension_space(env.action_space.spaces[0]) # Multi agent environment
    else: return parse_dimension_space(env.action_space) # Single agent environment


def compute_multidiscrete_space_size(flattened_multidiscrete_space):
    """
    Computes size of the combinatorial space generated by :param: flattened_multidiscrete_space

    :param multidiscrete_action_space: gymnasium.spaces.MultiDiscrete space
    :returns: Size of 'flattened' :param: flattened_multidiscrete_space
    """
    possible_vals = [range(_num) for _num in flattened_multidiscrete_space]
    return len([list(_action) for _action in itertools.product(*possible_vals)])


def check_env_compatibility_with_env_type(env, env_type):
    # Environment is multiagent but it has been declared single agent
    if hasattr(env.observation_space, 'spaces') \
            and env_type == EnvType.SINGLE_AGENT:
                error_msg = \
f'''
The environment ({env.spec.id}) appears to be multiagent (it has multiple observation spaces).
But parameter \'env_type\' was set to EnvType.SINGLE_AGENT (default value).
Suggestion: Change to a multiagent EnvType.
'''
                raise ValueError(error_msg)
    # Environment is single agent but it has been declared multiagent
    if not hasattr(env.observation_space, 'spaces') \
            and env_type != EnvType.SINGLE_AGENT:
                error_msg = \
f'''
The environment ({env.spec.id}) appears to be single agent
But parameter \'env_type\' was set to {env_type}
Suggestion: Change to a EnvType.SINGLE_AGENT
'''
                raise ValueError(error_msg)
