import xml.etree.ElementTree as ET
import xml.dom.minidom
import numpy as np
from transforms3d.euler import euler2mat, mat2euler
from transforms3d.affines import compose, decompose
import os

class UrdfEditor:

    def _parse_xyz(self, xyz_str: str) -> np.ndarray:
        if xyz_str is None:
            return np.zeros(3)
        return np.array([float(s) for s in xyz_str.split()])

    def _parse_rpy(self, rpy_str: str) -> np.ndarray:
        if rpy_str is None:
            return np.zeros(3)
        return np.array([float(s) for s in rpy_str.split()])

    def _format_xyz(self, xyz_arr: np.ndarray) -> str:
        return ' '.join(map(str, xyz_arr))

    def _format_rpy(self, rpy_arr: np.ndarray) -> str:
        return ' '.join(map(str, rpy_arr))

    def _to_transform_matrix(self, xyz: np.ndarray, rpy: np.ndarray) -> np.ndarray:
        rotation_matrix = euler2mat(rpy[0], rpy[1], rpy[2], axes='sxyz')
        return compose(xyz, rotation_matrix, np.ones(3))

    def _from_transform_matrix(self, matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        xyz, rotation_matrix, _, _ = decompose(matrix)
        rpy = mat2euler(rotation_matrix, axes='sxyz')
        return xyz, rpy

    def modify_urdf(self, input_filename: str, output_filename: str):
        tree = ET.parse(input_filename)
        root = tree.getroot()

        link_parents = {}
        joint_origins = {}
        link_connecting_joint = {}
        joints_map = {}

        for joint in root.findall('joint'):
            joint_name = joint.get('name')
            parent_link = joint.find('parent').get('link')
            child_link = joint.find('child').get('link')
            link_parents[child_link] = parent_link
            link_connecting_joint[child_link] = joint_name
            joints_map[joint_name] = joint
            origin_tag = joint.find('origin')
            xyz = self._parse_xyz(origin_tag.get('xyz') if origin_tag is not None else None)
            rpy = self._parse_rpy(origin_tag.get('rpy') if origin_tag is not None else None)
            joint_origins[joint_name] = self._to_transform_matrix(xyz, rpy)

        base_to_link_transforms = {'base': np.identity(4)}
        queue = ['base']
        processed_links = {'base'}

        while queue:
            current_parent_name = queue.pop(0)
            T_base_parent = base_to_link_transforms[current_parent_name]
            for child_name, parent_name in link_parents.items():
                if parent_name == current_parent_name and child_name not in processed_links:
                    joint_name = link_connecting_joint[child_name]
                    T_parent_child = joint_origins[joint_name]
                    T_base_child = T_base_parent @ T_parent_child
                    base_to_link_transforms[child_name] = T_base_child
                    processed_links.add(child_name)
                    queue.append(child_name)

        for joint_element in joints_map.values():
            if joint_element.find('parent').get('link') != 'base':
                child_name = joint_element.find('child').get('link')
                T_base_child = base_to_link_transforms.get(child_name)
                if T_base_child is None:
                    continue
                joint_element.find('parent').set('link', 'base')
                new_xyz, new_rpy = self._from_transform_matrix(T_base_child)
                origin_tag = joint_element.find('origin')
                if origin_tag is None:
                    origin_tag = ET.SubElement(joint_element, 'origin')
                origin_tag.set('xyz', self._format_xyz(new_xyz))
                origin_tag.set('rpy', self._format_rpy(new_rpy))

        for link in root.findall('link'):
            link_name = link.get('name')
            if not link_name or link_name == 'base':
                continue
            new_mesh_filename = f"{link_name}_combined_mesh.obj"
            self._consolidate_meshes(link, 'visual', new_mesh_filename)
            self._consolidate_meshes(link, 'collision', new_mesh_filename)

        rough_string = ET.tostring(root, 'utf-8')
        reparsed = xml.dom.minidom.parseString(rough_string)
        pretty_xml_as_string = reparsed.toprettyxml(indent="  ")
        lines = [line for line in pretty_xml_as_string.split('\n') if line.strip()]
        final_xml = '\n'.join(lines)
        with open(output_filename, "w", encoding="utf-8") as f:
            f.write(final_xml)

    def _consolidate_meshes(self, link: ET.Element, tag_name: str, new_mesh_filename: str):
        elements = link.findall(tag_name)
        first_origin_elem = None
        has_mesh = False
        if elements:
            origin_tag = elements[0].find('origin')
            if origin_tag is not None:
                first_origin_elem = ET.Element('origin', attrib=origin_tag.attrib)
            for elem in elements:
                if elem.find('.//mesh') is not None:
                    has_mesh = True
                link.remove(elem)
        if has_mesh:
            new_element = ET.Element(tag_name)
            if first_origin_elem is not None:
                new_element.append(first_origin_elem)
            geometry = ET.SubElement(new_element, 'geometry')
            mesh = ET.SubElement(geometry, 'mesh')
            mesh.set('filename', new_mesh_filename)
            link.append(new_element)

def process_dataset(base_path: str):
    editor = UrdfEditor()
    for folder_name in os.listdir(base_path):
        folder_path = os.path.join(base_path, folder_name)
        if os.path.isdir(folder_path):
            input_file = os.path.join(folder_path, 'mobility.urdf')
            output_file = os.path.join(folder_path, 'new.urdf')
            if os.path.exists(input_file):
                editor.modify_urdf(input_file, output_file)
