import os

from collections import OrderedDict
from numbers import Number

import numpy as np
from gym.spaces import Box

ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets')

class ImageandProprio(Box):
    def __init__(self, image_shape, proprio_shape):
        self.image_shape = image_shape
        self.proprio_shape = proprio_shape
        example = self.to_flat(np.zeros(self.image_shape), np.zeros(self.proprio_shape))
        super(ImageandProprio, self).__init__(0, 1, shape=example.shape)
    def to_flat(self, image, proprio):
        image = image.reshape(*image.shape[:-1 * len(self.image_shape)], -1)
        proprio = proprio.reshape(*proprio.shape[:-1 * len(self.proprio_shape)], -1)
        return np.concatenate([image, proprio], axis=-1)

    def from_flat(self, s):
        image_size = np.prod(self.image_shape)
        image = s[..., :image_size]
        image = image.reshape(*image.shape[:-1], *self.image_shape)
        proprio = s[..., image_size:]
        proprio = proprio.reshape(*proprio.shape[:-1], *self.proprio_shape)
        return image, proprio


def create_stats_ordered_dict(
        name,
        data,
        stat_prefix=None,
        always_show_all_stats=True,
        exclude_max_min=False,
):
    if stat_prefix is not None:
        name = "{} {}".format(stat_prefix, name)
    if isinstance(data, Number):
        return OrderedDict({name: data})

    if len(data) == 0:
        return OrderedDict()

    if isinstance(data, tuple):
        ordered_dict = OrderedDict()
        for number, d in enumerate(data):
            sub_dict = create_stats_ordered_dict(
                "{0}_{1}".format(name, number),
                d,
            )
            ordered_dict.update(sub_dict)
        return ordered_dict

    if isinstance(data, list):
        try:
            iter(data[0])
        except TypeError:
            pass
        else:
            data = np.concatenate(data)

    if (isinstance(data, np.ndarray) and data.size == 1
            and not always_show_all_stats):
        return OrderedDict({name: float(data)})

    stats = OrderedDict([
        (name + ' Mean', np.mean(data)),
        (name + ' Std', np.std(data)),
    ])
    if not exclude_max_min:
        stats[name + ' Max'] = np.max(data)
        stats[name + ' Min'] = np.min(data)
    return stats


def get_generic_path_information(paths, stat_prefix=''):
    """
    Get an OrderedDict with a bunch of statistic names and values.
    """
    statistics = OrderedDict()
    returns = [sum(path["rewards"]) for path in paths]

    rewards = np.vstack([path["rewards"] for path in paths])
    statistics.update(create_stats_ordered_dict('Rewards', rewards,
                                                stat_prefix=stat_prefix))
    statistics.update(create_stats_ordered_dict('Returns', returns,
                                                stat_prefix=stat_prefix))
    actions = [path["actions"] for path in paths]
    if len(actions[0].shape) == 1:
        actions = np.hstack([path["actions"] for path in paths])
    else:
        actions = np.vstack([path["actions"] for path in paths])
    statistics.update(create_stats_ordered_dict(
        'Actions', actions, stat_prefix=stat_prefix
    ))
    statistics['Num Paths'] = len(paths)

    return statistics


def get_average_returns(paths):
    returns = [sum(path["rewards"]) for path in paths]
    return np.mean(returns)


def get_path_lengths(paths):
    return [len(path['observations']) for path in paths]


def get_stat_in_paths(paths, dict_name, scalar_name):
    if len(paths) == 0:
        return np.array([[]])

    if type(paths[0][dict_name]) == dict:
        # Support rllab interface
        return [path[dict_name][scalar_name] for path in paths]

    return [
        [info[scalar_name] for info in path[dict_name]]
        for path in paths
    ]


def get_asset_full_path(file_name):
    return os.path.join(ENV_ASSET_DIR, file_name)

def concatenate_box_spaces(*spaces):
    """
    Assumes dtypes of all spaces are the of the same type
    """
    low = np.concatenate([space.low for space in spaces])
    high = np.concatenate([space.high for space in spaces])
    return Box(low=low, high=high, dtype=np.float32)
