# utility functions for manipulating MJCF XML models

import xml.etree.ElementTree as ET
import os
import numpy as np
from collections.abc import Iterable
from PIL import Image
from pathlib import Path
from copy import deepcopy

import robosuite

RED = [1, 0, 0, 1]
GREEN = [0, 1, 0, 1]
BLUE = [0, 0, 1, 1]
CYAN = [0, 1, 1, 1]
ROBOT_COLLISION_COLOR = [0, 0.5, 0, 1]
MOUNT_COLLISION_COLOR = [0.5, 0.5, 0, 1]
GRIPPER_COLLISION_COLOR = [0, 0, 0.5, 1]
OBJECT_COLLISION_COLOR = [0.5, 0, 0, 1]
ENVIRONMENT_COLLISION_COLOR = [0.5, 0.5, 0, 1]
SENSOR_TYPES = {
    "touch",
    "accelerometer",
    "velocimeter",
    "gyro",
    "force",
    "torque",
    "magnetometer",
    "rangefinder",
    "jointpos",
    "jointvel",
    "tendonpos",
    "tendonvel",
    "actuatorpos",
    "actuatorvel",
    "actuatorfrc",
    "ballangvel",
    "jointlimitpos",
    "jointlimitvel",
    "jointlimitfrc",
    "tendonlimitpos",
    "tendonlimitvel",
    "tendonlimitfrc",
    "framepos",
    "framequat",
    "framexaxis",
    "frameyaxis",
    "framezaxis",
    "framelinvel",
    "frameangvel",
    "framelinacc",
    "frameangacc",
    "subtreecom",
    "subtreelinvel",
    "subtreeangmom",
    "user",
}

MUJOCO_NAMED_ATTRIBUTES = {
    "class", "childclass", "name", "objname", "material", "texture",
    "joint", "joint1", "joint2", "jointinparent", "geom", "geom1", "geom2",
    "mesh", "fixed", "actuator", "objname", "tendon", "tendon1", "tendon2",
    "slidesite", "cranksite", "body", "body1", "body2", "hfield", "target",
    "prefix", "site",
}

IMAGE_CONVENTION_MAPPING = {
    "opengl": 1,
    "opencv": -1,
}

TEXTURES = {
    "WoodRed": "red-wood.png",
    "WoodGreen": "green-wood.png",
    "WoodBlue": "blue-wood.png",
    "WoodLight": "light-wood.png",
    "WoodDark": "dark-wood.png",
    "WoodTiles": "wood-tiles.png",
    "WoodPanels": "wood-varnished-panels.png",
    "WoodgrainGray": "gray-woodgrain.png",
    "PlasterCream": "cream-plaster.png",
    "PlasterPink": "pink-plaster.png",
    "PlasterYellow": "yellow-plaster.png",
    "PlasterGray": "gray-plaster.png",
    "PlasterWhite": "white-plaster.png",
    "BricksWhite": "white-bricks.png",
    "Metal": "metal.png",
    "SteelBrushed": "steel-brushed.png",
    "SteelScratched": "steel-scratched.png",
    "Brass": "brass-ambra.png",
    "Bread": "bread.png",
    "Can": "can.png",
    "Ceramic": "ceramic.png",
    "Cereal": "cereal.png",
    "Clay": "clay.png",
    "Dirt": "dirt.png",
    "Glass": "glass.png",
    "FeltGray": "gray-felt.png",
    "Lemon": "lemon.png",
}

ALL_TEXTURES = TEXTURES.keys()


class CustomMaterial(object):
    """
    Simple class to instantiate the necessary parameters to define an appropriate texture / material combo

    Instantiates a nested dict holding necessary components for procedurally generating a texture / material combo

    Please see http://www.mujoco.org/book/XMLreference.html#asset for specific details on
        attributes expected for Mujoco texture / material tags, respectively

    Note that the values in @tex_attrib and @mat_attrib can be in string or array / numerical form.

    Args:
        texture (None or str or 4-array): Name of texture file to be imported. If a string, should be part of
            ALL_TEXTURES. If texture is a 4-array, then this argument will be interpreted as an rgba tuple value and
            a template png will be procedurally generated during object instantiation, with any additional
            texture / material attributes specified. If None, no file will be linked and no rgba value will be set
            Note, if specified, the RGBA values are expected to be floats between 0 and 1

        tex_name (str): Name to reference the imported texture

        mat_name (str): Name to reference the imported material

        tex_attrib (dict): Any other optional mujoco texture specifications.

        mat_attrib (dict): Any other optional mujoco material specifications.

        shared (bool): If True, this material should not have any naming prefixes added to all names

    Raises:
        AssertionError: [Invalid texture]
    """

    def __init__(
            self,
            texture,
            tex_name,
            mat_name,
            tex_attrib=None,
            mat_attrib=None,
            shared=False,
    ):
        # Check if the desired texture is an rgba value
        if type(texture) is str:
            default = False
            # Verify that requested texture is valid
            assert texture in ALL_TEXTURES, "Error: Requested invalid texture. Got {}. Valid options are:\n{}".format(
                texture, ALL_TEXTURES)
        else:
            default = True
            # If specified, this is an rgba value and a default texture is desired; make sure length of rgba array is 4
            if texture is not None:
                assert len(texture) == 4, "Error: Requested default texture. Got array of length {}." \
                                          "Expected rgba array of length 4.".format(len(texture))

        # Setup the texture and material attributes
        self.tex_attrib = {} if tex_attrib is None else tex_attrib.copy()
        self.mat_attrib = {} if mat_attrib is None else mat_attrib.copy()

        # Add in name values
        self.name = mat_name
        self.shared = shared
        self.tex_attrib["name"] = tex_name
        self.mat_attrib["name"] = mat_name
        self.mat_attrib["texture"] = tex_name

        # Loop through all attributes and convert all non-string values into strings
        for attrib in (self.tex_attrib, self.mat_attrib):
            for k, v in attrib.items():
                if type(v) is not str:
                    if isinstance(v, Iterable):
                        attrib[k] = array_to_string(v)
                    else:
                        attrib[k] = str(v)

        # Handle default and non-default cases separately for linking texture patch file locations
        if not default:
            # Add in the filepath to texture patch
            self.tex_attrib["file"] = xml_path_completion("textures/" + TEXTURES[texture])
        else:
            if texture is not None:
                # Create a texture patch
                tex = Image.new('RGBA', (100, 100), tuple((np.array(texture)*255).astype('int')))
                # Create temp directory if it does not exist
                save_dir = "/tmp/robosuite_temp_tex"
                Path(save_dir).mkdir(parents=True, exist_ok=True)
                # Save this texture patch to the temp directory on disk (MacOS / Linux)
                fpath = save_dir + "/{}.png".format(tex_name)
                tex.save(fpath, "PNG")
                # Link this texture file to the default texture dict
                self.tex_attrib["file"] = fpath

class CustomMaterialFromPNG(object):
    """
    Simple class to instantiate the necessary parameters to define an appropriate texture / material combo

    Instantiates a nested dict holding necessary components for procedurally generating a texture / material combo

    Please see http://www.mujoco.org/book/XMLreference.html#asset for specific details on
        attributes expected for Mujoco texture / material tags, respectively

    Note that the values in @tex_attrib and @mat_attrib can be in string or array / numerical form.

    Args:
        texture_png (str): Path to png file

        tex_name (str): Name to reference the imported texture

        mat_name (str): Name to reference the imported material

        tex_attrib (dict): Any other optional mujoco texture specifications.

        mat_attrib (dict): Any other optional mujoco material specifications.

        shared (bool): If True, this material should not have any naming prefixes added to all names

    Raises:
        AssertionError: [Invalid texture]
    """

    def __init__(
            self,
            texture_png,
            tex_name,
            mat_name,
            tex_attrib=None,
            mat_attrib=None,
            shared=False,
    ):
        # Setup the texture and material attributes
        self.tex_attrib = {} if tex_attrib is None else tex_attrib.copy()
        self.mat_attrib = {} if mat_attrib is None else mat_attrib.copy()

        # Add in name values
        self.name = mat_name
        self.shared = shared
        self.tex_attrib["name"] = tex_name
        self.mat_attrib["name"] = mat_name
        self.mat_attrib["texture"] = tex_name

        # Loop through all attributes and convert all non-string values into strings
        for attrib in (self.tex_attrib, self.mat_attrib):
            for k, v in attrib.items():
                if type(v) is not str:
                    if isinstance(v, Iterable):
                        attrib[k] = array_to_string(v)
                    else:
                        attrib[k] = str(v)

        # Handle default and non-default cases separately for linking texture patch file locations
        # Add in the filepath to texture patch
        self.tex_attrib["file"] = xml_path_completion(texture_png)
        

def xml_path_completion(xml_path):
    """
    Takes in a local xml path and returns a full path.
        if @xml_path is absolute, do nothing
        if @xml_path is not absolute, load xml that is shipped by the package

    Args:
        xml_path (str): local xml path

    Returns:
        str: Full (absolute) xml path
    """
    if xml_path.startswith("/"):
        full_path = xml_path
    else:
        full_path = os.path.join(robosuite.models.assets_root, xml_path)
    return full_path


def array_to_string(array):
    """
    Converts a numeric array into the string format in mujoco.

    Examples:
        [0, 1, 2] => "0 1 2"

    Args:
        array (n-array): Array to convert to a string

    Returns:
        str: String equivalent of @array
    """
    return " ".join(["{}".format(x) for x in array])


def string_to_array(string):
    """
    Converts a array string in mujoco xml to np.array.

    Examples:
        "0 1 2" => [0, 1, 2]

    Args:
        string (str): String to convert to an array

    Returns:
        np.array: Numerical array equivalent of @string
    """
    return np.array([float(x) for x in string.split(" ")])


def convert_to_string(inp):
    """
    Converts any type of {bool, int, float, list, tuple, array, string, np.str_} into an mujoco-xml compatible string.
        Note that an input string / np.str_ results in a no-op action.

    Args:
        inp: Input to convert to string

    Returns:
        str: String equivalent of @inp
    """
    if type(inp) in {list, tuple, np.ndarray}:
        return array_to_string(inp)
    elif type(inp) in {int, float, bool}:
        return str(inp).lower()
    elif type(inp) in {str, np.str_}:
        return inp
    else:
        raise ValueError("Unsupported type received: got {}".format(type(inp)))


def set_alpha(node, alpha=0.1):
    """
    Sets all a(lpha) field of the rgba attribute to be @alpha
    for @node and all subnodes
    used for managing display

    Args:
        node (ET.Element): Specific node element within XML tree
        alpha (float): Value to set alpha value of rgba tuple
    """
    for child_node in node.findall(".//*[@rgba]"):
        rgba_orig = string_to_array(child_node.get("rgba"))
        child_node.set("rgba", array_to_string(list(rgba_orig[0:3]) + [alpha]))


def new_element(tag, name, **kwargs):
    """
    Creates a new @tag element with attributes specified by @**kwargs.

    Args:
        tag (str): Type of element to create
        name (None or str): Name for this element. Should only be None for elements that do not have an explicit
            name attribute (e.g.: inertial elements)
        **kwargs: Specified attributes for the new joint

    Returns:
        ET.Element: new specified xml element
    """
    # Name will be set if it's not None
    if name is not None:
        kwargs["name"] = name
    # Loop through all attributes and pop any that are None, otherwise convert them to strings
    for k, v in kwargs.copy().items():
        if v is None:
            kwargs.pop(k)
        else:
            kwargs[k] = convert_to_string(v)
    element = ET.Element(tag, attrib=kwargs)
    return element


def new_joint(name, **kwargs):
    """
    Creates a joint tag with attributes specified by @**kwargs.

    Args:
        name (str): Name for this joint
        **kwargs: Specified attributes for the new joint

    Returns:
        ET.Element: new joint xml element
    """
    return new_element(tag="joint", name=name, **kwargs)


def new_actuator(name, joint, act_type="actuator", **kwargs):
    """
    Creates an actuator tag with attributes specified by @**kwargs.

    Args:
        name (str): Name for this actuator
        joint (str): type of actuator transmission.
            see all types here: http://mujoco.org/book/modeling.html#actuator
        act_type (str): actuator type. Defaults to "actuator"
        **kwargs: Any additional specified attributes for the new joint

    Returns:
        ET.Element: new actuator xml element
    """
    element = new_element(tag=act_type, name=name, **kwargs)
    element.set("joint", joint)
    return element


def new_site(name, rgba=RED, pos=(0, 0, 0), size=(0.005,), **kwargs):
    """
    Creates a site element with attributes specified by @**kwargs.

    NOTE: With the exception of @name, @pos, and @size, if any arg is set to
        None, the value will automatically be popped before passing the values
        to create the appropriate XML

    Args:
        name (str): Name for this site
        rgba (4-array): (r,g,b,a) color and transparency. Defaults to solid red.
        pos (3-array): (x,y,z) 3d position of the site.
        size (n-array of float): site size (sites are spherical by default).
        **kwargs: Any additional specified attributes for the new site

    Returns:
        ET.Element: new site xml element
    """
    kwargs["pos"] = pos
    kwargs["size"] = size
    kwargs["rgba"] = rgba if rgba is not None else None
    return new_element(tag="site", name=name, **kwargs)


def new_geom(name, type, size, pos=(0, 0, 0), group=0, **kwargs):
    """
    Creates a geom element with attributes specified by @**kwargs.

    NOTE: With the exception of @geom_type, @size, and @pos, if any arg is set to
        None, the value will automatically be popped before passing the values
        to create the appropriate XML

    Args:
        name (str): Name for this geom
        type (str): type of the geom.
            see all types here: http://mujoco.org/book/modeling.html#geom
        size (n-array of float): geom size parameters.
        pos (3-array): (x,y,z) 3d position of the site.
        group (int): the integrer group that the geom belongs to. useful for
            separating visual and physical elements.
        **kwargs: Any additional specified attributes for the new geom

    Returns:
        ET.Element: new geom xml element
    """
    kwargs["type"] = type
    kwargs["size"] = size
    kwargs["pos"] = pos
    kwargs["group"] = group if group is not None else None
    return new_element(tag="geom", name=name, **kwargs)


def new_body(name, pos=(0, 0, 0), **kwargs):
    """
    Creates a body element with attributes specified by @**kwargs.

    Args:
        name (str): Name for this body
        pos (3-array): (x,y,z) 3d position of the body frame.
        **kwargs: Any additional specified attributes for the new body

    Returns:
        ET.Element: new body xml element
    """
    kwargs["pos"] = pos
    return new_element(tag="body", name=name, **kwargs)


def new_inertial(pos=(0, 0, 0), mass=None, **kwargs):
    """
    Creates a inertial element with attributes specified by @**kwargs.

    Args:
        pos (3-array): (x,y,z) 3d position of the inertial frame.
        mass (float): The mass of inertial
        **kwargs: Any additional specified attributes for the new inertial element

    Returns:
        ET.Element: new inertial xml element
    """
    kwargs["mass"] = mass if mass is not None else None
    kwargs["pos"] = pos
    return new_element(tag="inertial", name=None, **kwargs)


def get_size(size,
             size_max,
             size_min,
             default_max,
             default_min):
    """
    Helper method for providing a size, or a range to randomize from

    Args:
        size (n-array): Array of numbers that explicitly define the size
        size_max (n-array): Array of numbers that define the custom max size from which to randomly sample
        size_min (n-array): Array of numbers that define the custom min size from which to randomly sample
        default_max (n-array): Array of numbers that define the default max size from which to randomly sample
        default_min (n-array): Array of numbers that define the default min size from which to randomly sample

    Returns:
        np.array: size generated

    Raises:
        ValueError: [Inconsistent array sizes]
    """
    if len(default_max) != len(default_min):
        raise ValueError('default_max = {} and default_min = {}'
                         .format(str(default_max), str(default_min)) +
                         ' have different lengths')
    if size is not None:
        if (size_max is not None) or (size_min is not None):
            raise ValueError('size = {} overrides size_max = {}, size_min = {}'
                             .format(size, size_max, size_min))
    else:
        if size_max is None:
            size_max = default_max
        if size_min is None:
            size_min = default_min
        size = np.array([np.random.uniform(size_min[i], size_max[i])
                         for i in range(len(default_max))])
    return np.array(size)


def postprocess_model_xml(xml_str):
    """
    This function postprocesses the model.xml collected from a MuJoCo demonstration
    in order to make sure that the STL files can be found.

    Args:
        xml_str (str): Mujoco sim demonstration XML file as string

    Returns:
        str: Post-processed xml file as string
    """

    path = os.path.split(robosuite.__file__)[0]
    path_split = path.split("/")

    # replace mesh and texture file paths
    tree = ET.fromstring(xml_str)
    root = tree
    asset = root.find("asset")
    meshes = asset.findall("mesh")
    textures = asset.findall("texture")
    all_elements = meshes + textures

    for elem in all_elements:
        old_path = elem.get("file")
        if old_path is None:
            continue
        old_path_split = old_path.split("/")
        ind = max(
            loc for loc, val in enumerate(old_path_split) if val == "robosuite"
        )  # last occurrence index
        new_path_split = path_split + old_path_split[ind + 1 :]
        new_path = "/".join(new_path_split)
        elem.set("file", new_path)

    return ET.tostring(root, encoding="utf8").decode("utf8")


def add_to_dict(dic, fill_in_defaults=True, default_value=None, **kwargs):
    """
    Helper function to add key-values to dictionary @dic where each entry is its own array (list).
    Args:
        dic (dict): Dictionary to which new key / value pairs will be added. If the key already exists,
            will append the value to that key entry
        fill_in_defaults (bool): If True, will automatically add @default_value to all dictionary entries that are
            not explicitly specified in @kwargs
        default_value (any): Default value to fill (None by default)

    Returns:
        dict: Modified dictionary
    """
    # Get keys and length of array for a given entry in dic
    keys = set(dic.keys())
    n = len(list(keys)[0]) if keys else 0
    for k, v in kwargs.items():
        if k in dic:
            dic[k].append(v)
            keys.remove(k)
        else:
            dic[k] = [default_value] * n + [v] if fill_in_defaults else [v]
    # If filling in defaults, fill in remaining default values
    if fill_in_defaults:
        for k in keys:
            dic[k].append(default_value)
    return dic


def add_prefix(
        root,
        prefix,
        tags="default",
        attribs="default",
        exclude=None,
):
    """
    Find all element(s) matching the requested @tag, and appends @prefix to all @attributes if they exist.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through.
        prefix (str): Prefix to add to all specified attributes
        tags (str or list of str or set): Tag(s) to search for in this ElementTree. "Default" corresponds to all tags
        attribs (str or list of str or set): Element attribute(s) to append prefix to. "Default" corresponds
            to all attributes that reference names
        exclude (None or function): Filtering function that should take in an ET.Element or a string (attribute) and
            return True if we should exclude the given element / attribute from having any prefixes added
    """
    # Standardize tags and attributes to be a set
    if tags != "default":
        tags = {tags} if type(tags) is str else set(tags)
    if attribs == "default":
        attribs = MUJOCO_NAMED_ATTRIBUTES
    attribs = {attribs} if type(attribs) is str else set(attribs)

    # Check the current element for matching conditions
    if (tags == "default" or root.tag in tags) and (exclude is None or not exclude(root)):
        for attrib in attribs:
            v = root.get(attrib, None)
            # Only add prefix if the attribute exist, the current attribute doesn't already begin with prefix,
            # and the @exclude filter is either None or returns False
            if v is not None and not v.startswith(prefix) and (exclude is None or not exclude(v)):
                root.set(attrib, prefix + v)
    # Continue recursively searching through the element tree
    for r in root:
        add_prefix(root=r, prefix=prefix, tags=tags, attribs=attribs, exclude=exclude)


def add_material(root, naming_prefix="", custom_material=None):
    """
    Iterates through all element(s) in @root recursively and adds a material / texture to all visual geoms that don't
    already have a material specified.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through.
        naming_prefix (str): Adds this prefix to all material and texture names
        custom_material (None or CustomMaterial): If specified, will add this material to all visual geoms.
            Else, will add a default "no-change" material.

    Returns:
        4-tuple: (ET.Element, ET.Element, CustomMaterial, bool) (tex_element, mat_element, material, used)
            corresponding to the added material and whether the material was actually used or not.
    """
    # Initialize used as False
    used = False
    # First, make sure material is specified
    if custom_material is None:
        custom_material = CustomMaterial(
            texture=None,
            tex_name="default_tex",
            mat_name="default_mat",
            tex_attrib={
                "type": "cube",
                "builtin": "flat",
                "width": 100,
                "height": 100,
                "rgb1": np.ones(3),
                "rgb2": np.ones(3),
            },
        )
    # Else, check to make sure the custom material begins with the specified prefix and that it's unique
    if not custom_material.name.startswith(naming_prefix) and not custom_material.shared:
        custom_material.name = naming_prefix + custom_material.name
        custom_material.tex_attrib["name"] = naming_prefix + custom_material.tex_attrib["name"]
        custom_material.mat_attrib["name"] = naming_prefix + custom_material.mat_attrib["name"]
        custom_material.mat_attrib["texture"] = naming_prefix + custom_material.mat_attrib["texture"]

    # Check the current element for matching conditions
    if root.tag == "geom" and root.get("group", None) == "1" and root.get("material", None) is None:
        # Add a new material attribute to this geom
        root.set("material", custom_material.name)
        # Set used to True
        used = True
    # Continue recursively searching through the element tree
    for r in root:
        _, _, _, _used = add_material(root=r, naming_prefix=naming_prefix, custom_material=custom_material)
        # Update used
        used = used or _used
    # Lastly, return the new texture and material elements
    tex_element = new_element(tag="texture", **custom_material.tex_attrib)
    mat_element = new_element(tag="material", **custom_material.mat_attrib)
    return tex_element, mat_element, custom_material, used


def recolor_collision_geoms(root, rgba, exclude=None):
    """
    Iteratively searches through all elements starting with @root to find all geoms belonging to group 0 and set
    the corresponding rgba value to the specified @rgba argument. Note: also removes any material values for these
    elements.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through
        rgba (4-array): (R, G, B, A) values to assign to all geoms with this group.
        exclude (None or function): Filtering function that should take in an ET.Element and
            return True if we should exclude the given element / attribute from having its collision geom impacted.
    """
    # Check this body
    if root.tag == "geom" and root.get("group") in {None, "0"} and (exclude is None or not exclude(root)):
        root.set("rgba", array_to_string(rgba))
        root.attrib.pop("material", None)

    # Iterate through all children elements
    for r in root:
        recolor_collision_geoms(root=r, rgba=rgba, exclude=exclude)


def _element_filter(element, parent):
    """
    Default element filter to be used in sort_elements. This will filter for the following groups:

        :`'root_body'`: Top-level body element
        :`'bodies'`: Any body elements
        :`'joints'`: Any joint elements
        :`'actuators'`: Any actuator elements
        :`'sites'`: Any site elements
        :`'sensors'`: Any sensor elements
        :`'contact_geoms'`: Any geoms used for collision (as specified by group 0 (default group) geoms)
        :`'visual_geoms'`: Any geoms used for visual rendering (as specified by group 1 geoms)

    Args:
        element (ET.Element): Current XML element that we are filtering
        parent (ET.Element): Parent XML element for the current element

    Returns:
        str or None: Assigned filter key for this element. None if no matching filter is found.
    """
    # Check for actuator first since this is dependent on the parent element
    if parent is not None and parent.tag == "actuator":
        return "actuators"
    elif element.tag == "joint":
        # Make sure this is not a tendon (this should not have a "joint", "joint1", or "joint2" attribute specified)
        if element.get("joint") is None and element.get("joint1") is None:
            return "joints"
    elif element.tag == "body":
        # If the parent of this does not have a tag "body", then this is the top-level body element
        if parent is None or parent.tag != "body":
            return "root_body"
        return "bodies"
    elif element.tag == "site":
        return "sites"
    elif element.tag in SENSOR_TYPES:
        return "sensors"
    elif element.tag == "geom":
        # Only get collision and visual geoms (group 0 / None, or 1, respectively)
        group = element.get("group")
        if group in {None, "0", "1"}:
            return "visual_geoms" if group == "1" else "contact_geoms"
    else:
        # If no condition met, return None
        return None


def sort_elements(root, parent=None, element_filter=None, _elements_dict=None):
    """
    Utility method to iteratively sort all elements based on @tags. This XML ElementTree will be parsed such that
    all elements with the same key as returned by @element_filter will be grouped as a list entry in the returned
    dictionary.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through
        parent (ET.Element): Parent of the root node. Default is None (no parent node initially)
        element_filter (None or function): Function used to filter the incoming elements. Should take in two
            ET.Elements (current_element, parent_element) and return a string filter_key if the element
            should be added to the list of values sorted by filter_key, and return None if no value should be added.
            If no element_filter is specified, defaults to self._element_filter.
        _elements_dict (dict): Dictionary that gets passed to recursive calls. Should not be modified externally by
            top-level call.

    Returns:
        dict: Filtered key-specific lists of the corresponding elements
    """
    # Initialize dictionary and element filter if None is set
    if _elements_dict is None:
        _elements_dict = {}
    if element_filter is None:
        element_filter = _element_filter

    # Parse this element
    key = element_filter(root, parent)
    if key is not None:
        # Initialize new entry in the dict if this is the first time encountering this value, otherwise append
        if key not in _elements_dict:
            _elements_dict[key] = [root]
        else:
            _elements_dict[key].append(root)

    # Loop through all possible subtrees for this XML recurisvely
    for r in root:
        _elements_dict = sort_elements(
            root=r,
            parent=root,
            element_filter=element_filter,
            _elements_dict=_elements_dict
        )

    return _elements_dict


def find_parent(root, child):
    """
    Find the parent element of the specified @child node, recurisvely searching through @root.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through.
        child (ET.Element): Child element whose parent is to be found

    Returns:
        None or ET.Element: Matching parent if found, else None
    """
    # Iterate through children (DFS), if the correct child element is found, then return the current root as the parent
    for r in root:
        if r == child:
            return root
        parent = find_parent(root=r, child=child)
        if parent is not None:
            return parent
    # If we get here, we didn't find anything ):
    return None


def find_elements(root, tags, attribs=None, return_first=True):
    """
    Find all element(s) matching the requested @tag and @attributes. If @return_first is True, then will return the
    first element found matching the criteria specified. Otherwise, will return a list of elements that match the
    criteria.

    Args:
        root (ET.Element): Root of the xml element tree to start recursively searching through.
        tags (str or list of str or set): Tag(s) to search for in this ElementTree.
        attribs (None or dict of str): Element attribute(s) to check against for a filtered element. A match is
            considered found only if all attributes match. Each attribute key should have a corresponding value with
            which to compare against.
        return_first (bool): Whether to immediately return once the first matching element is found.

    Returns:
        None or ET.Element or list of ET.Element: Matching element(s) found. Returns None if there was no match.
    """
    # Initialize return value
    elements = None if return_first else []

    # Make sure tags is list
    tags = [tags] if type(tags) is str else tags

    # Check the current element for matching conditions
    if root.tag in tags:
        matching = True
        if attribs is not None:
            for k, v in attribs.items():
                if root.get(k) != v:
                    matching = False
                    break
        # If all criteria were matched, add this to the solution (or return immediately if specified)
        if matching:
            if return_first:
                return root
            else:
                elements.append(root)
    # Continue recursively searching through the element tree
    for r in root:
        if return_first:
            elements = find_elements(tags=tags, attribs=attribs, root=r, return_first=return_first)
            if elements is not None:
                return elements
        else:
            found_elements = find_elements(tags=tags, attribs=attribs, root=r, return_first=return_first)
            pre_elements = deepcopy(elements)
            if found_elements:
                elements += found_elements if type(found_elements) is list else [found_elements]

    return elements if elements else None


def save_sim_model(sim, fname):
    """
    Saves the current model xml from @sim at file location @fname.

    Args:
        sim (MjSim): XML file to save, in string form
        fname (str): Absolute filepath to the location to save the file
    """
    with open(fname, "w") as f:
        sim.save(file=f, format="xml")
