"""
This file implements a wrapper for visualizing important sites in a given environment.

By default, this visualizes all sites possible for the environment. Visualization options
for a given environment can be found by calling `get_visualization_settings()`, and can
be set individually by calling `set_visualization_setting(setting, visible)`.
"""
import numpy as np
from robosuite.wrappers import Wrapper
from robosuite.utils.mjcf_utils import new_site, new_geom, new_body
from copy import deepcopy

DEFAULT_INDICATOR_SITE_CONFIG = {
    "type": "sphere",
    "size": [0.03],
    "rgba": [1, 0, 0, 0.5],
}


class VisualizationWrapper(Wrapper):
    def __init__(self, env, indicator_configs=None):
        """
        Initializes the data collection wrapper. Note that this automatically conducts a (hard) reset initially to make
        sure indicators are properly added to the sim model.

        Args:
            env (MujocoEnv): The environment to visualize

            indicator_configs (None or str or dict or list): Configurations to use for indicator objects.

                If None, no indicator objects will be used

                If a string, this should be `'default'`, which corresponds to single default spherical indicator

                If a dict, should specify a single indicator object config

                If a list, should specify specific indicator object configs to use for multiple indicators (which in
                turn can either be `'default'` or a dict)

                As each indicator object is essentially a site element, each dict should map site attribute keywords to
                values. Note that, at the very minimum, the `'name'` attribute MUST be specified for each indicator. See
                http://www.mujoco.org/book/XMLreference.html#site for specific site attributes that can be specified.
        """
        super().__init__(env)
        # Standardize indicator configs
        self.indicator_configs = None
        if indicator_configs is not None:
            self.indicator_configs = []
            if type(indicator_configs) in {str, dict}:
                indicator_configs = [indicator_configs]
            for i, indicator_config in enumerate(indicator_configs):
                if indicator_config == "default":
                    indicator_config = deepcopy(DEFAULT_INDICATOR_SITE_CONFIG)
                    indicator_config["name"] = f"indicator{i}"
                # Make sure name attribute is specified
                assert "name" in indicator_config, "Name must be specified for all indicator object configurations!"
                # Add this configuration to the internal array
                self.indicator_configs.append(indicator_config)

        # Create internal dict to store visualization settings (set to True by default)
        self._vis_settings = {vis: True for vis in self.env._visualizations}

        # Add the post-processor to make sure indicator objects get added to model before it's actually loaded in sim
        self.env.set_model_postprocessor(postprocessor=self._add_indicators_to_model)

        # Conduct a (hard) reset to make sure visualization changes propagate
        reset_mode = self.env.hard_reset
        self.env.hard_reset = True
        self.reset()
        self.env.hard_reset = reset_mode

    def get_indicator_names(self):
        """
        Gets all indicator object names for this environment.

        Returns:
            list: Indicator names for this environment.
        """
        return [ind_config["name"] for ind_config in self.indicator_configs] if \
            self.indicator_configs is not None else []

    def set_indicator_pos(self, indicator, pos):
        """
        Sets the specified @indicator to the desired position @pos

        Args:
            indicator (str): Name of the indicator to set
            pos (3-array): (x, y, z) Cartesian world coordinates to set the specified indicator to
        """
        # Make sure indicator is valid
        indicator_names = set(self.get_indicator_names())
        assert indicator in indicator_names, "Invalid indicator name specified. Valid options are {}, got {}".\
            format(indicator_names, indicator)
        # Set the specified indicator
        self.env.sim.model.body_pos[self.env.sim.model.body_name2id(indicator + "_body")] = np.array(pos)

    def get_visualization_settings(self):
        """
        Gets all settings for visualizing this environment

        Returns:
            list: Visualization keywords for this environment.
        """
        return self._vis_settings.keys()

    def set_visualization_setting(self, setting, visible):
        """
        Sets the specified @setting to have visibility = @visible.

        Args:
            setting (str): Visualization keyword to set
            visible (bool): True if setting should be visualized.
        """
        assert setting in self._vis_settings, "Invalid visualization setting specified. Valid options are {}, got {}".\
            format(self._vis_settings.keys(), setting)
        self._vis_settings[setting] = visible

    def reset(self):
        """
        Extends vanilla reset() function call to accommodate visualization

        Returns:
            OrderedDict: Environment observation space after reset occurs
        """
        ret = super().reset()
        # Update any visualization
        self.env.visualize(vis_settings=self._vis_settings)
        return ret

    def step(self, action):
        """
        Extends vanilla step() function call to accommodate visualization

        Args:
            action (np.array): Action to take in environment

        Returns:
            4-tuple:

                - (OrderedDict) observations from the environment
                - (float) reward from the environment
                - (bool) whether the current episode is completed or not
                - (dict) misc information
        """
        ret = super().step(action)

        # Update any visualization
        self.env.visualize(vis_settings=self._vis_settings)

        return ret

    def _add_indicators_to_model(self, model):
        """
        Adds indicators to the mujoco simulation model

        Args:
            model (Task): Task instance including all mujoco models for the current simulation to be loaded
        """
        if self.indicator_configs is not None:
            for indicator_config in self.indicator_configs:
                config = deepcopy(indicator_config)
                indicator_body = new_body(name=config["name"] + "_body", pos=config.pop("pos", (0, 0, 0)))
                indicator_body.append(new_site(**config))
                model.worldbody.append(indicator_body)
