import mujoco
import pinocchio as pin

def create_g1_models(model_path: str):
    """
    Create Pinocchio and MuJoCo models for the G1 robot from the provided XML file.

    Args:
        model_path (str): Path to the MuJoCo XML model file.

    Returns:
        tuple: A tuple containing:
            - pin_model (pinocchio.Model): The Pinocchio model of the robot.
            - mj_model (mujoco.MjModel): The MuJoCo model of the robot.
    """

    mj_model = mujoco.MjModel.from_xml_path(model_path)

    pin_model = pin.buildModelFromMJCF(model_path, pin.JointModelFreeFlyer())

    pin_model.lowerPositionLimit[:3] = [-1e10]*3
    pin_model.upperPositionLimit[:3] = [1e10]*3 

    def add_mujoco_site_as_frames(pin_model, mj_model, site_name, parent_name):
        # Get the parent joint id from the Pinocchio model
        parent_joint_id = pin_model.getJointId(parent_name)
        
        # Iterate over all sites in the MuJoCo model
        for site_id in range(mj_model.nsite):
            # Check if the current site's name matches the provided site_name
            if mj_model.site(site_id).name == site_name:
                # Get the site position and quaternion (MuJoCo quaternion: [w, x, y, z])
                site_pos = mj_model.site(site_id).pos.tolist()
                site_quat = mj_model.site(site_id).quat.tolist()
                # Convert to Pinocchio quaternion order: [x, y, z, w]
                pin_quat = [site_quat[1], site_quat[2], site_quat[3], site_quat[0]]
                
                # Convert the position and quaternion to a SE3 placement
                se3 = pin.XYZQUATToSE3(site_pos + pin_quat)
                
                # Add the frame to the Pinocchio model with the proper parent joint
                pin_model.addFrame(pin.Frame(
                    name=mj_model.site(site_id).name,
                    parent_joint=parent_joint_id,
                    placement=se3,
                    type=pin.FrameType.OP_FRAME
                ))

    for site, parent in [("right_sole", "right_ankle_roll_joint"),
                        ("right_sole_p1", "right_ankle_roll_joint"),
                        ("right_sole_p2", "right_ankle_roll_joint"), 
                        ("right_sole_p3", "right_ankle_roll_joint"), 
                        ("right_sole_p4", "right_ankle_roll_joint"), 
                        ("left_sole", "left_ankle_roll_joint"),
                        ("left_sole_p1", "left_ankle_roll_joint"),
                        ("left_sole_p2", "left_ankle_roll_joint"),
                        ("left_sole_p3", "left_ankle_roll_joint"),
                        ("left_sole_p4", "left_ankle_roll_joint"),
                        ("right_ee", "right_elbow_roll_joint"),
                        ("left_ee", "left_elbow_roll_joint")]:
        add_mujoco_site_as_frames(pin_model, mj_model, site, parent)
    
    return pin_model, mj_model