"""
URDF/SRDF Parser for Robot Ontology

Parses robot description files without ROS dependencies.
Implements heuristic joint name matching between URDF and SRDF.
Computes end-effector offsets from kinematic chains.
"""

import xml.etree.ElementTree as ET
import json
import re
import difflib
import math
import numpy as np
from typing import Dict, List, Tuple, Optional, Any, Set
from pathlib import Path

from ontology import (
    TripleGraph, Predicates,
    robot_id, joint_id, link_id, group_id
)


def parse_urdf(urdf_path: str) -> Dict[str, Any]:
    """
    Parse URDF file and extract robot structure.

    Returns dict with:
    - robot_name: str
    - joints: list of joint dicts (including origin xyz/rpy)
    - links: list of link names
    """
    tree = ET.parse(urdf_path)
    root = tree.getroot()

    robot_name = root.attrib.get('name', 'unknown')

    joints = []
    for joint_elem in root.findall('.//joint'):
        joint_info = {
            'name': joint_elem.attrib.get('name', ''),
            'type': joint_elem.attrib.get('type', 'fixed'),
        }

        # Parent/child links
        parent = joint_elem.find('parent')
        child = joint_elem.find('child')
        if parent is not None:
            joint_info['parent_link'] = parent.attrib.get('link', '')
        if child is not None:
            joint_info['child_link'] = child.attrib.get('link', '')

        # Origin (xyz and rpy)
        origin = joint_elem.find('origin')
        if origin is not None:
            xyz_str = origin.attrib.get('xyz', '0 0 0')
            rpy_str = origin.attrib.get('rpy', '0 0 0')
            joint_info['origin_xyz'] = [float(v) for v in xyz_str.split()]
            joint_info['origin_rpy'] = [float(v) for v in rpy_str.split()]
        else:
            joint_info['origin_xyz'] = [0.0, 0.0, 0.0]
            joint_info['origin_rpy'] = [0.0, 0.0, 0.0]

        # Limits
        limit = joint_elem.find('limit')
        if limit is not None:
            joint_info['limits'] = {
                'lower': float(limit.attrib.get('lower', 0)),
                'upper': float(limit.attrib.get('upper', 0)),
                'effort': float(limit.attrib.get('effort', 0)),
                'velocity': float(limit.attrib.get('velocity', 0))
            }

        # Axis
        axis = joint_elem.find('axis')
        if axis is not None:
            joint_info['axis'] = axis.attrib.get('xyz', '0 0 1')

        joints.append(joint_info)

    # Extract links
    links = []
    for link_elem in root.findall('.//link'):
        link_name = link_elem.attrib.get('name', '')
        if link_name:
            links.append(link_name)

    return {
        'robot_name': robot_name,
        'joints': joints,
        'links': links
    }


def parse_srdf(srdf_path: str) -> Dict[str, Any]:
    """
    Parse SRDF file and extract semantic information.

    Returns dict with:
    - robot_name: str
    - groups: list of group dicts (name, joints, links, chains)
    - group_states: list of named configurations
    - end_effectors: list of end effector dicts
    - disable_collisions: list of collision pairs
    - virtual_joints: list of virtual joint dicts
    """
    tree = ET.parse(srdf_path)
    root = tree.getroot()

    robot_name = root.attrib.get('name', 'unknown')

    # Parse groups
    groups = []
    for group_elem in root.findall('.//group'):
        group_info = {
            'name': group_elem.attrib.get('name', ''),
            'joints': [],
            'links': [],
            'chains': []
        }

        for joint_elem in group_elem.findall('joint'):
            joint_name = joint_elem.attrib.get('name', '')
            if joint_name:
                group_info['joints'].append(joint_name)

        for link_elem in group_elem.findall('link'):
            link_name = link_elem.attrib.get('name', '')
            if link_name:
                group_info['links'].append(link_name)

        for chain_elem in group_elem.findall('chain'):
            chain_info = {
                'base_link': chain_elem.attrib.get('base_link', ''),
                'tip_link': chain_elem.attrib.get('tip_link', '')
            }
            group_info['chains'].append(chain_info)

        groups.append(group_info)

    # Parse group states
    group_states = []
    for state_elem in root.findall('.//group_state'):
        state_info = {
            'name': state_elem.attrib.get('name', ''),
            'group': state_elem.attrib.get('group', ''),
            'joint_values': {}
        }
        for joint_elem in state_elem.findall('joint'):
            jname = joint_elem.attrib.get('name', '')
            jval = joint_elem.attrib.get('value', '0')
            if jname:
                state_info['joint_values'][jname] = float(jval)
        group_states.append(state_info)

    # Parse end effectors
    end_effectors = []
    for ee_elem in root.findall('.//end_effector'):
        ee_info = {
            'name': ee_elem.attrib.get('name', ''),
            'group': ee_elem.attrib.get('group', ''),
            'parent_link': ee_elem.attrib.get('parent_link', ''),
            'parent_group': ee_elem.attrib.get('parent_group', '')
        }
        end_effectors.append(ee_info)

    # Parse disable_collisions
    disable_collisions = []
    for dc_elem in root.findall('.//disable_collisions'):
        dc_info = {
            'link1': dc_elem.attrib.get('link1', ''),
            'link2': dc_elem.attrib.get('link2', ''),
            'reason': dc_elem.attrib.get('reason', '')
        }
        disable_collisions.append(dc_info)

    # Parse virtual joints
    virtual_joints = []
    for vj_elem in root.findall('.//virtual_joint'):
        vj_info = {
            'name': vj_elem.attrib.get('name', ''),
            'type': vj_elem.attrib.get('type', 'fixed'),
            'parent_frame': vj_elem.attrib.get('parent_frame', ''),
            'child_link': vj_elem.attrib.get('child_link', '')
        }
        virtual_joints.append(vj_info)

    return {
        'robot_name': robot_name,
        'groups': groups,
        'group_states': group_states,
        'end_effectors': end_effectors,
        'disable_collisions': disable_collisions,
        'virtual_joints': virtual_joints
    }


def normalize_joint_name(name: str, robot_prefixes: List[str] = None) -> str:
    """
    Normalize a joint name for matching.

    Steps:
    1. Lowercase
    2. Remove common separators (_, -)
    3. Remove robot prefixes (panda, j2n6s300, etc.)
    4. Remove common tokens (joint, link)
    5. Normalize synonyms (finger <-> gripper)
    """
    if robot_prefixes is None:
        robot_prefixes = ['panda', 'j2n6s300', 'kinova', 'franka']

    normalized = name.lower()

    # Remove separators
    normalized = normalized.replace('_', '').replace('-', '')

    # Remove robot prefixes
    for prefix in robot_prefixes:
        if normalized.startswith(prefix):
            normalized = normalized[len(prefix):]

    # Remove common tokens
    for token in ['joint', 'respondable']:
        normalized = normalized.replace(token, '')

    # Normalize synonyms
    normalized = normalized.replace('gripper', 'finger')

    return normalized


def extract_number(name: str) -> Optional[int]:
    """Extract numeric suffix from a joint name."""
    match = re.search(r'(\d+)$', name)
    if match:
        return int(match.group(1))
    # Also try finding number anywhere
    match = re.search(r'(\d+)', name)
    if match:
        return int(match.group(1))
    return None


def compute_token_overlap(name1: str, name2: str) -> float:
    """Compute token overlap score between two names."""
    # Extract meaningful tokens
    tokens1 = set(re.findall(r'[a-z]+|\d+', name1.lower()))
    tokens2 = set(re.findall(r'[a-z]+|\d+', name2.lower()))

    if not tokens1 or not tokens2:
        return 0.0

    intersection = tokens1 & tokens2
    union = tokens1 | tokens2

    return len(intersection) / len(union) if union else 0.0


def match_joint_names(
    srdf_joint_names: List[str],
    urdf_joint_names: List[str],
    robot_prefixes: List[str] = None,
    threshold: float = 0.72
) -> Tuple[Dict[str, str], Dict[str, float], List[str], List[str]]:
    """
    Match SRDF joint names to URDF joint names using heuristics.

    Algorithm:
    1. Normalize names (lowercase, remove separators/prefixes)
    2. Try exact normalized match
    3. Compute similarity scores (difflib + token overlap + digit matching)
    4. Greedy one-to-one matching with threshold

    Returns:
    - mapping: dict of srdf_name -> urdf_name
    - confidence: dict of srdf_name -> confidence score
    - unmapped_srdf: list of unmatched SRDF names
    - unmapped_urdf: list of unmatched URDF names
    """
    if robot_prefixes is None:
        robot_prefixes = ['panda', 'j2n6s300', 'kinova', 'franka']

    # Normalize all names
    srdf_normalized = {name: normalize_joint_name(name, robot_prefixes) for name in srdf_joint_names}
    urdf_normalized = {name: normalize_joint_name(name, robot_prefixes) for name in urdf_joint_names}

    # Track which URDF joints are still available
    available_urdf = set(urdf_joint_names)

    mapping = {}
    confidence = {}

    # Compute all pairwise scores
    scores: List[Tuple[float, str, str]] = []

    for srdf_name in srdf_joint_names:
        srdf_norm = srdf_normalized[srdf_name]
        srdf_num = extract_number(srdf_name)

        for urdf_name in urdf_joint_names:
            urdf_norm = urdf_normalized[urdf_name]
            urdf_num = extract_number(urdf_name)

            # Compute various similarity metrics

            # 1. Exact normalized match
            if srdf_norm == urdf_norm:
                score = 1.0
            else:
                # 2. SequenceMatcher ratio
                seq_ratio = difflib.SequenceMatcher(None, srdf_norm, urdf_norm).ratio()

                # 3. Token overlap
                token_score = compute_token_overlap(srdf_name, urdf_name)

                # 4. Number matching bonus
                num_bonus = 0.0
                if srdf_num is not None and urdf_num is not None:
                    if srdf_num == urdf_num:
                        num_bonus = 0.3
                    else:
                        # Penalize number mismatch
                        num_bonus = -0.2

                # 5. Keyword matching (finger, gripper, etc.)
                keyword_bonus = 0.0
                keywords = ['finger', 'gripper', 'hand', 'tip', 'sensor']
                for kw in keywords:
                    if (kw in srdf_name.lower()) == (kw in urdf_name.lower()):
                        if kw in srdf_name.lower():
                            keyword_bonus += 0.1
                    else:
                        keyword_bonus -= 0.15

                # Combine scores
                score = (seq_ratio * 0.4 + token_score * 0.3 + 0.3) + num_bonus + keyword_bonus
                score = max(0.0, min(1.0, score))

            scores.append((score, srdf_name, urdf_name))

    # Sort by score descending
    scores.sort(reverse=True, key=lambda x: x[0])

    # Greedy one-to-one matching
    matched_srdf = set()

    for score, srdf_name, urdf_name in scores:
        if srdf_name in matched_srdf:
            continue
        if urdf_name not in available_urdf:
            continue

        if score >= threshold:
            mapping[srdf_name] = urdf_name
            confidence[srdf_name] = round(score, 4)
            matched_srdf.add(srdf_name)
            available_urdf.remove(urdf_name)

    # Find unmapped
    unmapped_srdf = [n for n in srdf_joint_names if n not in mapping]
    unmapped_urdf = list(available_urdf)

    return mapping, confidence, unmapped_srdf, unmapped_urdf


def collect_srdf_joints(srdf_data: Dict[str, Any]) -> Set[str]:
    """Collect all joint names referenced in SRDF."""
    joints = set()

    # From groups
    for group in srdf_data.get('groups', []):
        for jname in group.get('joints', []):
            joints.add(jname)

    # From group states
    for state in srdf_data.get('group_states', []):
        for jname in state.get('joint_values', {}).keys():
            joints.add(jname)

    return joints


# ============================================================================
# End-Effector Offset Computation from URDF Kinematic Chain
# ============================================================================

def rpy_to_rotation_matrix(rpy: List[float]) -> np.ndarray:
    """
    Convert roll-pitch-yaw angles to rotation matrix.
    Uses the standard URDF convention: R = Rz(yaw) * Ry(pitch) * Rx(roll)
    """
    roll, pitch, yaw = rpy

    cr, sr = math.cos(roll), math.sin(roll)
    cp, sp = math.cos(pitch), math.sin(pitch)
    cy, sy = math.cos(yaw), math.sin(yaw)

    # Rotation matrix: Rz * Ry * Rx
    R = np.array([
        [cy*cp, cy*sp*sr - sy*cr, cy*sp*cr + sy*sr],
        [sy*cp, sy*sp*sr + cy*cr, sy*sp*cr - cy*sr],
        [-sp,   cp*sr,            cp*cr]
    ])
    return R


def build_kinematic_tree(joints: List[Dict]) -> Tuple[Dict[str, Dict], Dict[str, str]]:
    """
    Build kinematic tree from joint list.

    Returns:
    - parent_to_children: dict mapping parent_link -> list of (joint_info, child_link)
    - child_to_parent: dict mapping child_link -> parent_link
    """
    parent_to_children = {}
    child_to_parent = {}

    for joint in joints:
        parent = joint.get('parent_link', '')
        child = joint.get('child_link', '')
        if parent and child:
            if parent not in parent_to_children:
                parent_to_children[parent] = []
            parent_to_children[parent].append((joint, child))
            child_to_parent[child] = parent

    return parent_to_children, child_to_parent


def find_path_to_tip(
    start_link: str,
    joints: List[Dict],
    tip_keywords: List[str] = None
) -> Tuple[Optional[str], List[Dict]]:
    """
    Find the path from start_link to a tip link (finger tip, gripper tip, etc.).

    Args:
        start_link: Starting link name (e.g., last arm link)
        joints: List of joint dicts from URDF
        tip_keywords: Keywords to identify tip links

    Returns:
        (tip_link_name, list_of_joints_in_path)
    """
    if tip_keywords is None:
        tip_keywords = [
            # 
            'tip', 'fingertip', 'finger_tip', 'tcp', 'tool',
            #  (gripper/robotiq)
            'finger', 'gripper', 'hand', 'ee', 'end_effector',
            'robotiq', 'knuckle', 'pad',
            #  (suction)
            'suction', 'vacuum', 'cup', 'nozzle'
        ]

    parent_to_children, _ = build_kinematic_tree(joints)

    # BFS to find tip link
    from collections import deque
    queue = deque([(start_link, [])])
    visited = set()

    best_tip = None
    best_path = []
    best_depth = 0

    while queue:
        current_link, path = queue.popleft()

        if current_link in visited:
            continue
        visited.add(current_link)

        # Check if this is a tip link
        link_lower = current_link.lower()
        is_tip = any(kw in link_lower for kw in tip_keywords)

        if is_tip and len(path) > best_depth:
            best_tip = current_link
            best_path = path
            best_depth = len(path)

        # Explore children
        if current_link in parent_to_children:
            for joint, child_link in parent_to_children[current_link]:
                if child_link not in visited:
                    queue.append((child_link, path + [joint]))

    # If no tip found, use the deepest leaf
    if best_tip is None:
        queue = deque([(start_link, [])])
        visited = set()

        while queue:
            current_link, path = queue.popleft()

            if current_link in visited:
                continue
            visited.add(current_link)

            if current_link not in parent_to_children or not parent_to_children[current_link]:
                # Leaf node
                if len(path) > best_depth:
                    best_tip = current_link
                    best_path = path
                    best_depth = len(path)
            else:
                for joint, child_link in parent_to_children[current_link]:
                    if child_link not in visited:
                        queue.append((child_link, path + [joint]))

    return best_tip, best_path


def compute_chain_offset(joints_in_path: List[Dict]) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute the cumulative offset along a kinematic chain.

    Args:
        joints_in_path: List of joint dicts in order from base to tip

    Returns:
        (position_offset, cumulative_rotation_matrix)
    """
    position = np.array([0.0, 0.0, 0.0])
    rotation = np.eye(3)

    for joint in joints_in_path:
        xyz = np.array(joint.get('origin_xyz', [0, 0, 0]))
        rpy = joint.get('origin_rpy', [0, 0, 0])

        # Transform the local offset by current rotation
        position = position + rotation @ xyz

        # Update rotation
        local_rotation = rpy_to_rotation_matrix(rpy)
        rotation = rotation @ local_rotation

    return position, rotation


def identify_gripper_type(joints: List[Dict], links: List[str]) -> str:
    """
    Identify the gripper type from URDF joint/link names.

    Returns one of: "gripper", "robotiq85", "robotiq140", "suction", "baxter_gripper", "unknown"
    """
    all_names = ' '.join([j.get('name', '') for j in joints] + links).lower()

    if 'robotiq_85' in all_names or 'robotiq85' in all_names:
        return 'robotiq85'
    elif 'robotiq_140' in all_names or 'robotiq140' in all_names:
        return 'robotiq140'
    elif 'suction' in all_names or 'vacuum' in all_names:
        return 'suction'
    elif 'baxter' in all_names:
        return 'baxter_gripper'
    elif 'panda' in all_names and ('gripper' in all_names or 'finger' in all_names):
        return 'gripper'
    elif 'gripper' in all_names or 'finger' in all_names:
        return 'gripper'
    else:
        return 'unknown'


def find_ee_base_link(
    srdf_data: Dict[str, Any],
    urdf_data: Dict[str, Any]
) -> Optional[str]:
    """
    Find the end-effector base link from SRDF or infer from URDF.

    Returns the link name that connects the arm to the gripper.
    We need a link that has children (gripper links) to traverse.
    """
    joints = urdf_data.get('joints', [])
    parent_to_children, _ = build_kinematic_tree(joints)

    def has_gripper_descendants(link: str) -> bool:
        """Check if link has gripper-related descendants."""
        visited = set()
        stack = [link]

        gripper_keywords = [
            #  
            'gripper', 'finger', 'tip', 'hand',
            #  (robotiq)
            'robotiq', 'knuckle', 'pad',
            #  (suction)
            'suction', 'vacuum', 'cup', 'nozzle',
            #  EE 
            'tool', 'tcp', 'ee', 'end_effector', 'attachment'
        ]

        while stack:
            curr = stack.pop()
            if curr in visited:
                continue
            visited.add(curr)
            curr_lower = curr.lower()
            if any(kw in curr_lower for kw in gripper_keywords):
                return True
            if curr in parent_to_children:
                for _, child in parent_to_children[curr]:
                    stack.append(child)
        return False

    # Try to find from SRDF end_effector definitions
    for ee in srdf_data.get('end_effectors', []):
        parent_link = ee.get('parent_link', '')
        if parent_link and parent_link in parent_to_children:
            # Check if it has children to traverse
            if has_gripper_descendants(parent_link):
                return parent_link

    # Try to find from SRDF group chains - look for arm tip that has gripper children
    for group in srdf_data.get('groups', []):
        for chain in group.get('chains', []):
            tip_link = chain.get('tip_link', '')
            if tip_link and tip_link in parent_to_children:
                if has_gripper_descendants(tip_link):
                    return tip_link

    # Look for links with gripper-related children
    links = urdf_data.get('links', [])
    for pattern in ['link7', 'link8', 'wrist_3', 'wrist3', 'ee_link', 'right_hand', 'left_hand']:
        for link in links:
            if pattern in link.lower():
                if link in parent_to_children and has_gripper_descendants(link):
                    return link

    # Last resort: find any link that has gripper/finger children
    for link in links:
        if link in parent_to_children:
            for _, child in parent_to_children[link]:
                child_lower = child.lower()
                if any(kw in child_lower for kw in [
                    'gripper','finger','hand','attachment',
                    'robotiq','suction','vacuum','cup','tool','tcp','ee'
                ]):
                    return link

    return None


def compute_ee_offset(
    urdf_data: Dict[str, Any],
    srdf_data: Dict[str, Any],
    ee_base_link: Optional[str] = None
) -> Dict[str, Any]:
    """
    Compute end-effector offset from URDF kinematic chain.

    Args:
        urdf_data: Parsed URDF data
        srdf_data: Parsed SRDF data
        ee_base_link: Optional starting link (auto-detected if None)

    Returns:
        dict with:
        - ee_type: gripper type string
        - ee_base_link: the starting link
        - tip_link: the tip link found
        - offset_down: [x, y, z] offset when pointing down (default)
        - offset_left: [x, y, z] offset when pointing left
        - offset_right: [x, y, z] offset when pointing right
        - raw_offset: [x, y, z] raw computed offset
        - raw_rotation: 3x3 rotation matrix
    """
    joints = urdf_data.get('joints', [])
    links = urdf_data.get('links', [])

    # Find EE base link
    if ee_base_link is None:
        ee_base_link = find_ee_base_link(srdf_data, urdf_data)

    if ee_base_link is None:
        return {
            'ee_type': 'unknown',
            'ee_base_link': None,
            'tip_link': None,
            'offset_down': [0, 0, 0],
            'offset_left': [0, 0, 0],
            'offset_right': [0, 0, 0],
            'raw_offset': [0, 0, 0],
            'raw_rotation': np.eye(3).tolist(),
            'error': 'Could not find EE base link'
        }

    # Identify gripper type
    ee_type = identify_gripper_type(joints, links)

    # Find path to tip
    tip_link, joints_in_path = find_path_to_tip(ee_base_link, joints)

    if tip_link is None or not joints_in_path:
        return {
            'ee_type': ee_type,
            'ee_base_link': ee_base_link,
            'tip_link': None,
            'offset_down': [0, 0, 0],
            'offset_left': [0, 0, 0],
            'offset_right': [0, 0, 0],
            'raw_offset': [0, 0, 0],
            'raw_rotation': np.eye(3).tolist(),
            'error': f'Could not find tip from {ee_base_link}'
        }

    # Compute offset
    position, rotation = compute_chain_offset(joints_in_path)

    # Convert to different pointing directions
    # "down" is the default (z-axis pointing down in world frame)
    offset_down = position.tolist()

    # "left" - rotate offset so z becomes y
    offset_left = [position[0], position[2], position[1]]

    # "right" - rotate offset so z becomes -y
    offset_right = [position[0], -position[2], position[1]]

    return {
        'ee_type': ee_type,
        'ee_base_link': ee_base_link,
        'tip_link': tip_link,
        'offset_down': [round(v, 6) for v in offset_down],
        'offset_left': [round(v, 6) for v in offset_left],
        'offset_right': [round(v, 6) for v in offset_right],
        'raw_offset': [round(v, 6) for v in position.tolist()],
        'raw_rotation': [[round(v, 6) for v in row] for row in rotation.tolist()],
        'joints_in_path': [j['name'] for j in joints_in_path]
    }


def parse_robot(
    robot_name: str,
    urdf_path: str,
    srdf_path: str,
    output_dir: str = None
) -> Tuple[Dict[str, Any], TripleGraph]:
    """
    Parse a robot's URDF and SRDF, match joints, and build a triple graph.

    Returns:
    - robot_summary: dict with all parsed info
    - graph: TripleGraph with robot structure
    """
    # Parse files
    urdf_data = parse_urdf(urdf_path)
    srdf_data = parse_srdf(srdf_path)

    # Collect joint names
    urdf_joint_names = [j['name'] for j in urdf_data['joints']]
    srdf_joint_names = list(collect_srdf_joints(srdf_data))

    # Match joints
    robot_prefixes = [robot_name.lower(), urdf_data['robot_name'].lower(), srdf_data['robot_name'].lower()]
    robot_prefixes = list(set(p for p in robot_prefixes if p))

    mapping, confidence, unmapped_srdf, unmapped_urdf = match_joint_names(
        srdf_joint_names,
        urdf_joint_names,
        robot_prefixes=robot_prefixes
    )

    # Build summary
    notes = []
    if unmapped_srdf:
        notes.append(f"SRDF joints not matched to URDF: {unmapped_srdf}")
    if unmapped_urdf and len(unmapped_urdf) < 20:  # Don't list if too many
        notes.append(f"URDF joints not referenced in SRDF: {unmapped_urdf}")

    robot_summary = {
        'robot': robot_name,
        'urdf': {
            'path': urdf_path,
            'robot_name': urdf_data['robot_name'],
            'joint_count': len(urdf_data['joints']),
            'link_count': len(urdf_data['links']),
            'joints': urdf_data['joints'],
            'links': urdf_data['links']
        },
        'srdf': {
            'path': srdf_path,
            'robot_name': srdf_data['robot_name'],
            'groups': srdf_data['groups'],
            'group_states': srdf_data['group_states'],
            'end_effectors': srdf_data['end_effectors'],
            'virtual_joints': srdf_data['virtual_joints'],
            'disable_collision_count': len(srdf_data['disable_collisions'])
        },
        'joint_name_mapping': mapping,
        'joint_mapping_confidence': confidence,
        'unmapped_srdf_joints': unmapped_srdf,
        'unmapped_urdf_joints': unmapped_urdf,
        'notes': notes
    }

    # Save joint mapping
    if output_dir:
        mapping_path = Path(output_dir) / f"joint_mapping_{robot_name}.json"
        with open(mapping_path, 'w', encoding='utf-8') as f:
            json.dump({
                'mapping': mapping,
                'confidence': confidence,
                'unmapped_srdf': unmapped_srdf,
                'unmapped_urdf': unmapped_urdf
            }, f, indent=2)

    # Build triple graph
    graph = TripleGraph()
    rid = robot_id(robot_name)

    # Add robot
    graph.add(rid, Predicates.HAS_TYPE, "Robot")

    # Add joints from URDF
    for joint in urdf_data['joints']:
        jid = joint_id(robot_name, joint['name'])
        graph.add(rid, Predicates.HAS_JOINT, jid)
        graph.add(jid, Predicates.JOINT_TYPE, joint['type'])

        if 'parent_link' in joint:
            graph.add(jid, Predicates.PARENT_LINK, link_id(robot_name, joint['parent_link']))
        if 'child_link' in joint:
            graph.add(jid, Predicates.CHILD_LINK, link_id(robot_name, joint['child_link']))

        if 'limits' in joint:
            limits = joint['limits']
            graph.add(jid, Predicates.HAS_LIMIT_LOWER, str(limits['lower']))
            graph.add(jid, Predicates.HAS_LIMIT_UPPER, str(limits['upper']))
            graph.add(jid, Predicates.HAS_LIMIT_EFFORT, str(limits['effort']))
            graph.add(jid, Predicates.HAS_LIMIT_VELOCITY, str(limits['velocity']))

    # Add links
    for link_name in urdf_data['links']:
        lid = link_id(robot_name, link_name)
        graph.add(rid, Predicates.HAS_LINK, lid)

    # Add groups from SRDF
    for group in srdf_data['groups']:
        gid = group_id(robot_name, group['name'])
        graph.add(rid, Predicates.HAS_GROUP, gid)

        for jname in group['joints']:
            # Try to use URDF joint name via mapping, or use SRDF name directly
            urdf_jname = mapping.get(jname, jname)
            graph.add(gid, Predicates.GROUP_HAS_JOINT, joint_id(robot_name, urdf_jname))

        for lname in group['links']:
            graph.add(gid, Predicates.GROUP_HAS_LINK, link_id(robot_name, lname))

        for chain in group['chains']:
            graph.add(gid, Predicates.CHAIN_BASE, link_id(robot_name, chain['base_link']))
            graph.add(gid, Predicates.CHAIN_TIP, link_id(robot_name, chain['tip_link']))

    # Add end effectors
    for ee in srdf_data['end_effectors']:
        ee_name = ee['name']
        graph.add(rid, Predicates.HAS_END_EFFECTOR, f"ee:{robot_name}:{ee_name}")
        if ee['group']:
            graph.add(f"ee:{robot_name}:{ee_name}", Predicates.HAS_GROUP, group_id(robot_name, ee['group']))
        if ee['parent_link']:
            graph.add(f"ee:{robot_name}:{ee_name}", Predicates.PARENT_LINK, link_id(robot_name, ee['parent_link']))

    # Count DoF (non-fixed joints)
    dof_count = sum(1 for j in urdf_data['joints'] if j['type'] != 'fixed')
    graph.add(rid, Predicates.HAS_DOF, str(dof_count))

    # Compute end-effector offset from kinematic chain
    ee_offset_info = compute_ee_offset(urdf_data, srdf_data)
    robot_summary['ee_offset'] = ee_offset_info

    # Add EE offset info to graph
    if ee_offset_info.get('ee_type'):
        graph.add(rid, Predicates.HAS_EE_TYPE, ee_offset_info['ee_type'])
    if ee_offset_info.get('offset_down'):
        graph.add(rid, Predicates.HAS_EE_OFFSET_DOWN, json.dumps(ee_offset_info['offset_down']))
    if ee_offset_info.get('offset_left'):
        graph.add(rid, Predicates.HAS_EE_OFFSET_LEFT, json.dumps(ee_offset_info['offset_left']))
    if ee_offset_info.get('offset_right'):
        graph.add(rid, Predicates.HAS_EE_OFFSET_RIGHT, json.dumps(ee_offset_info['offset_right']))

    # Save EE offset info to file
    if output_dir:
        ee_offset_path = Path(output_dir) / f"ee_offset_{robot_name}.json"
        with open(ee_offset_path, 'w', encoding='utf-8') as f:
            json.dump(ee_offset_info, f, indent=2)

    return robot_summary, graph


if __name__ == "__main__":
    # Test with multiple robots
    import sys

    BASE_DIR = Path(__file__).parent.parent

    robots_to_test = [
        ("panda", BASE_DIR / "robots/panda/panda.urdf", BASE_DIR / "robots/panda/panda.srdf"),
        ("ur5", BASE_DIR / "robots/ur5/ur5_robotiq85.urdf", BASE_DIR / "robots/ur5/ur5_robotiq85.srdf"),
        ("sawyer", BASE_DIR / "robots/sawyer/sawyer_with_baxter_gripper.urdf", BASE_DIR / "robots/sawyer/sawyer_with_baxter_gripper.srdf"),
    ]

    for robot_name, urdf_path, srdf_path in robots_to_test:
        if urdf_path.exists() and srdf_path.exists():
            print(f"\n{'='*60}")
            print(f"Testing {robot_name.upper()}")
            print(f"{'='*60}")

            summary, graph = parse_robot(robot_name, str(urdf_path), str(srdf_path))
            print(f"  Joints: {summary['urdf']['joint_count']}, Links: {summary['urdf']['link_count']}")
            print(f"  Joint mapping: {len(summary['joint_name_mapping'])} matched")

            ee_info = summary.get('ee_offset', {})
            print(f"\n  End-Effector Info:")
            print(f"    Type: {ee_info.get('ee_type', 'N/A')}")
            print(f"    Base link: {ee_info.get('ee_base_link', 'N/A')}")
            print(f"    Tip link: {ee_info.get('tip_link', 'N/A')}")
            print(f"    Offset (down): {ee_info.get('offset_down', 'N/A')}")
            print(f"    Offset (left): {ee_info.get('offset_left', 'N/A')}")
            print(f"    Offset (right): {ee_info.get('offset_right', 'N/A')}")
            print(f"    Joints in path: {ee_info.get('joints_in_path', 'N/A')}")

            if 'error' in ee_info:
                print(f"    Error: {ee_info['error']}")
        else:
            print(f"\nSkipping {robot_name}: URDF or SRDF not found")
