import os
import json
import logging
from utils.child_offset import compute_child_visual_origin_offset_with_link_offset
import numpy as np
from scipy.spatial.transform import Rotation as R

def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )

def load_articulation(base_dir):
    """Load articulation info from configs/articulation.json or legacy locations."""
    # Try new location first
    configs_dir = os.path.join(base_dir, 'configs')
    articulation_path = os.path.join(configs_dir, 'articulation.json')
    
    # Try legacy location if configs doesn't exist
    if not os.path.exists(articulation_path):
        articulation_path = os.path.join(base_dir, 'articulation.json')
    
    workflow_path = os.path.join(configs_dir, 'workflow.json')
    if not os.path.exists(workflow_path):
        workflow_path = os.path.join(base_dir, 'workflow.json')
    if os.path.exists(articulation_path):
        try:
            with open(articulation_path, 'r', encoding='utf-8') as f:
                return json.load(f)
        except Exception as e:
            logging.error(f"Failed to load articulation.json: {e}")
    elif os.path.exists(workflow_path):
        try:
            with open(workflow_path, 'r', encoding='utf-8') as f:
                workflow = json.load(f)
                return workflow.get('articulation', None)
        except Exception as e:
            logging.error(f"Failed to load workflow.json: {e}")
    else:
        logging.warning(f"No articulation.json or workflow.json found in {base_dir}.")
    return None

def build_child_joint_origins(articulation):
    """Build a map from link name to joint origin (xyz, rpy) for revolute/continuous joints."""
    child_joint_origins = {}
    if articulation and isinstance(articulation, list):
        for joint in articulation:
            child = joint.get('child', '')
            joint_type = joint.get('type', '')
            origin = joint.get('origin', None)
            if child and joint_type in ['revolute', 'continuous'] and origin is not None:
                xyz = origin.get('xyz', [0, 0, 0])
                rpy = origin.get('rpy', [0, 0, 0])
                child_joint_origins[child] = {'xyz': xyz, 'rpy': rpy}
    return child_joint_origins

def format_xyz(xyz):
    return f"{xyz[0]} {xyz[1]} {xyz[2]}"

def format_rpy(rpy):
    return f"{rpy[0]} {rpy[1]} {rpy[2]}" if rpy else "0 0 0"

def rpy_xyz_to_matrix(rpy, xyz):
    rotation_matrix = R.from_euler('xyz', rpy, degrees=False).as_matrix()
    transform_matrix = np.eye(4)
    transform_matrix[:3, :3] = rotation_matrix
    transform_matrix[:3, 3] = xyz
    return transform_matrix

def matrix_to_rpy_xyz(matrix):
    rotation_matrix = matrix[:3, :3]
    xyz = matrix[:3, 3]
    rpy = R.from_matrix(rotation_matrix).as_euler('xyz', degrees=False)
    return rpy, xyz

def inverse_rpy_xyz(rpy, xyz):
    T = rpy_xyz_to_matrix(rpy, xyz)
    T_inv = np.linalg.inv(T)
    rpy_inv, xyz_inv = matrix_to_rpy_xyz(T_inv)
    return xyz_inv, rpy_inv

def generate_urdf(mesh_folder, base_dir, output_path):
    # Use new structure (links folder) as standard
    mesh_group_output = os.path.join(base_dir, 'links')
    mesh_path_prefix = 'links'
    
    mesh_files = [f for f in os.listdir(mesh_group_output) if f.endswith('.obj')]

    robot_name = os.path.basename(os.path.abspath(base_dir))
    articulation = load_articulation(base_dir)

    # --- New recursive logic: recursively accumulate parent offset and handle fixed joint reverse offset ---
    # 1. Build parent->joint, child->joint, parent->children mappings
    parent_to_children = {}
    child_to_joint = {}
    all_links = set()
    if articulation and isinstance(articulation, list):
        for joint in articulation:
            parent = joint.get('parent', '')
            child = joint.get('child', '')
            if parent and child:
                if parent not in parent_to_children:
                    parent_to_children[parent] = []
                parent_to_children[parent].append(child)
                child_to_joint[child] = joint
                all_links.add(parent)
                all_links.add(child)
    # 2. Find root link (link that is not a child of any joint)
    root_links = [l for l in all_links if l not in child_to_joint]
    if not root_links:
        raise RuntimeError("No root link found in articulation tree.")
    root_link = root_links[0]  # Usually only one root

    # 3. Recursively calculate visual offset for each link
    child_visual_offsets = {}
    def compute_offsets_recursive(link, parent_offset_xyz, parent_offset_rpy):
        if link in child_to_joint:
            joint = child_to_joint[link]
            origin = joint.get('origin', {})
            joint_xyz = np.array(origin.get('xyz', [0.0, 0.0, 0.0]))
            joint_rpy = np.array(origin.get('rpy', [0.0, 0.0, 0.0]))
            joint_type = joint.get('type', '')

            # FIXED: Ignore prismatic joint origins in visual offset calculation
            # to match URDF generation behavior (where prismatic origins are set to zero)
            if joint_type == 'prismatic':
                # For prismatic joints, ignore origin to maintain consistency with URDF generation
                joint_xyz = np.array([0.0, 0.0, 0.0])
                joint_rpy = np.array([0.0, 0.0, 0.0])
                # No visual offset for prismatic joints
                child_visual_offsets[link] = {'xyz': [0, 0, 0], 'rpy': [0, 0, 0]}
                new_offset_xyz = parent_offset_xyz  # Don't accumulate offset
                new_offset_rpy = parent_offset_rpy
            elif joint_type in ['revolute', 'continuous']:
                xyz, rpy = compute_child_visual_origin_offset_with_link_offset(
                    joint_xyz, joint_rpy,
                    link_offset_xyz=np.array([0.0, 0.0, 0.0]), link_offset_rpy=np.array([0.0, 0.0, 0.0]),
                    parent_offset_xyz=parent_offset_xyz, parent_offset_rpy=parent_offset_rpy
                )
                child_visual_offsets[link] = {'xyz': xyz, 'rpy': rpy}
                new_offset_xyz = parent_offset_xyz + joint_xyz
                new_offset_rpy = parent_offset_rpy + joint_rpy
            elif joint_type == 'fixed':
                xyz_inv, rpy_inv = inverse_rpy_xyz(parent_offset_rpy, parent_offset_xyz)
                child_visual_offsets[link] = {'xyz': xyz_inv.tolist(), 'rpy': rpy_inv.tolist()}
                new_offset_xyz = parent_offset_xyz + joint_xyz
                new_offset_rpy = parent_offset_rpy + joint_rpy
            else:
                # Other joint types
                new_offset_xyz = parent_offset_xyz + joint_xyz
                new_offset_rpy = parent_offset_rpy + joint_rpy
        else:
            child_visual_offsets[link] = {'xyz': [0,0,0], 'rpy': [0,0,0]}
            new_offset_xyz = parent_offset_xyz
            new_offset_rpy = parent_offset_rpy
        for child in parent_to_children.get(link, []):
            compute_offsets_recursive(child, new_offset_xyz, new_offset_rpy)

    compute_offsets_recursive(root_link, np.array([0.0,0.0,0.0]), np.array([0.0,0.0,0.0]))

    # --- Generate final_visual_offsets ---
    final_visual_offsets = {}
    for mesh_file in mesh_files:
        link_name = os.path.splitext(mesh_file)[0]
        if link_name in child_visual_offsets:
            final_visual_offsets[link_name] = child_visual_offsets[link_name]
        else:
            final_visual_offsets[link_name] = {'xyz': [0, 0, 0], 'rpy': [0, 0, 0]}

    # Generate <link> elements
    urdf_links = []
    for mesh_file in mesh_files:
        link_name = os.path.splitext(mesh_file)[0]
        # Use the determined mesh path prefix
        mesh_rel_path = os.path.join(mesh_path_prefix, mesh_file) if 'mesh_path_prefix' in locals() else os.path.join('links', mesh_file)
        # Use the recursively computed offset for this link if available, otherwise default to zeros
        xyz = final_visual_offsets[link_name]['xyz']
        rpy = final_visual_offsets[link_name]['rpy']
        origin_tag = (
            f'\n      <origin xyz="{format_xyz(xyz)}" rpy="{format_rpy(rpy)}" />'
            if any(abs(v) > 1e-8 for v in xyz + rpy) else ''
        )
        # collision part is identical to visual
        collision = f'''    <collision>\n      <geometry>\n        <mesh filename=\"{mesh_rel_path}\" />\n      </geometry>{origin_tag}\n    </collision>'''
        link = f'''  <link name="{link_name}">\n    <visual>\n      <geometry>\n        <mesh filename=\"{mesh_rel_path}\" />\n      </geometry>{origin_tag}\n    </visual>\n{collision}\n  </link>\n'''
        urdf_links.append(link)

    # Generate <joint> elements
    urdf_joints = []
    if articulation and isinstance(articulation, list):
        for joint in articulation:
            joint_name = joint.get('joint_name', 'unnamed_joint')
            parent = joint.get('parent', '')
            child = joint.get('child', '')
            joint_type = joint.get('type', 'fixed')
            axis = joint.get('axis', None)
            limit = joint.get('limit', {})
            origin_xyz = [0, 0, 0]
            origin_rpy = [0, 0, 0]
            if joint_type == 'prismatic':
                pass  # prismatic joints default to zero origin
            elif 'origin' in joint and isinstance(joint['origin'], dict):
                origin_xyz = joint['origin'].get('xyz', [0, 0, 0])
                origin_rpy = joint['origin'].get('rpy', [0, 0, 0])
            axis_str = f'<axis xyz="{format_xyz(axis)}" />' if axis else ''
            limit_str = ''
            if joint_type in ['revolute', 'prismatic'] and limit:
                lower = limit.get('lower', 0)
                upper = limit.get('upper', 0)
                effort = limit.get('effort', 0)
                velocity = limit.get('velocity', 0)
                limit_str = f'<limit lower="{lower}" upper="{upper}" effort="{effort}" velocity="{velocity}" />'
            origin_tag = (
                f'\n    <origin xyz="{format_xyz(origin_xyz)}" rpy="{format_rpy(origin_rpy)}" />'
                if joint_type != 'fixed' and any(abs(v) > 1e-8 for v in origin_xyz + origin_rpy) else ''
            )
            joint_xml = f'''  <joint name="{joint_name}" type="{joint_type}">
    <parent link="{parent}" />
    <child link="{child}" />{origin_tag}
    {axis_str}
    {limit_str}
  </joint>\n'''
            urdf_joints.append(joint_xml)
    elif articulation is not None:
        logging.warning("Articulation info is not a list. No joints will be generated.")

    urdf_content = (
        f'<?xml version="1.0"?>\n<robot name="{robot_name}">\n'
        + ''.join(urdf_links)
        + ''.join(urdf_joints)
        + '</robot>\n'
    )

    try:
        with open(output_path, 'w') as f:
            f.write(urdf_content)
    except Exception as e:
        logging.error(f"Failed to write URDF file: {e}")
        raise

# Command-line interface
if __name__ == "__main__":
    import argparse
    setup_logging()
    parser = argparse.ArgumentParser(description='Generate a URDF file from mesh files in a folder and articulation info.')
    parser.add_argument('--input', type=str, default=os.path.join('obj_parts', 'group_output'),
                        help='Input folder or base directory (default: obj_parts/group_output or its parent)')
    parser.add_argument('--output', type=str, default='generated.urdf',
                        help='Output URDF file path (default: generated.urdf)')
    args = parser.parse_args()

    input_path = args.input
    if os.path.isdir(os.path.join(input_path, 'obj_parts', 'group_output')):
        base_dir = input_path
        mesh_folder = os.path.join(base_dir, 'obj_parts', 'group_output')
    else:
        mesh_folder = input_path
        base_dir = os.path.dirname(os.path.dirname(mesh_folder)) if mesh_folder.endswith('group_output') else os.path.dirname(mesh_folder)

    try:
        generate_urdf(mesh_folder, base_dir, args.output)
    except Exception as e:
        logging.error(f"URDF generation failed: {e}")