import torch


SMPLH_FOOT_JOINTS = [7, 8, 10, 11]  

def all_physics_metrics(verts, joints, tol=0.005, device="cuda"):
    
    
    
    initial_lowest = verts[:, 0].min(dim=1).values[:, 1]  # [batch_size]
    offset = -initial_lowest.unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1]
    
    
    verts = verts.clone()
    joints = joints.clone()
    verts[..., 1] += offset
    joints[..., 1] += offset
    
    
    lowest_vertex_y = verts.min(dim=2).values[:, :, 1]  
    
    penetration_dist = torch.abs(lowest_vertex_y[lowest_vertex_y < -tol])
    penetration_dist = penetration_dist.mean() if penetration_dist.numel() > 0 else torch.tensor(0.0).to(device)

    
    float_dist = torch.abs(lowest_vertex_y[lowest_vertex_y >= tol])
    float_dist = float_dist.mean() if float_dist.numel() > 0 else torch.tensor(0.0).to(device)

    
    skate_tol = tol + 0.01
    
    
    foot_joints = joints[:, :, SMPLH_FOOT_JOINTS, :]
    
    
    foot_joint_displacements_xz = torch.norm(
        torch.stack([
            foot_joints[:, 1:, :, 0] - foot_joints[:, :-1, :, 0],  
            foot_joints[:, 1:, :, 2] - foot_joints[:, :-1, :, 2]   
        ], dim=-1),
        dim=-1
    )
    
    
    foot_joints_in_contact = foot_joints[:, :, :, 1] < skate_tol
    
    
    foot_ground_contact = torch.logical_and(foot_joints_in_contact[:, :-1, :], foot_joints_in_contact[:, 1:, :])
    
    
    foot_ground_disp = torch.abs(foot_joint_displacements_xz * foot_ground_contact)
    
    
    foot_ground_disp_mean = foot_ground_disp.mean() if foot_ground_disp.numel() > 0 else torch.tensor(0.0).to(device)

    
    foot_ground_disp_std = foot_ground_disp.std(dim=(1,2)) if foot_ground_disp.numel() > 0 else torch.tensor(0.0).to(device)
    
    
    foot_ground_disp_sum = foot_ground_disp.sum(dim=(1,2)) if foot_ground_disp.numel() > 0 else torch.tensor(0.0).to(device)
    
    
    skate_th = 0.0
    foot_ground_disp_perc = (foot_ground_disp > skate_th).float().mean(dim=(1,2))

    return (
        penetration_dist, 
        float_dist, 
        foot_ground_disp_mean,
        foot_ground_disp_std,
        foot_ground_disp_sum,
        foot_ground_disp_perc
    )

def all_physics_metrics_fast(joints, tol=0.005, device="cuda"):
    
    
    foot_joints_y = joints[:, :, SMPLH_FOOT_JOINTS, 1] + 1.1618  
    lowest_joint_y = foot_joints_y.min(dim=2).values
    
    
    penetration_dist = torch.abs(lowest_joint_y[lowest_joint_y < -tol])
    penetration_dist = penetration_dist.mean() if penetration_dist.numel() > 0 else torch.tensor(0.0).to(device)
    
    
    float_dist = torch.abs(lowest_joint_y[lowest_joint_y >= tol])
    float_dist = float_dist.mean() if float_dist.numel() > 0 else torch.tensor(0.0).to(device)
    
    
    skate_tol = tol + 0.01
    
    
    foot_joints = joints[:, :, SMPLH_FOOT_JOINTS, :]
    
    
    foot_joint_displacements_xz = torch.norm(
        torch.stack([
            foot_joints[:, 1:, :, 0] - foot_joints[:, :-1, :, 0],  
            foot_joints[:, 1:, :, 2] - foot_joints[:, :-1, :, 2]   
        ], dim=-1),
        dim=-1
    )
    
    
    foot_joints_in_contact = foot_joints[:, :, :, 1] < skate_tol
    
    
    foot_ground_contact = torch.logical_and(foot_joints_in_contact[:, :-1, :], foot_joints_in_contact[:, 1:, :])
    
    
    foot_ground_disp = torch.abs(foot_joint_displacements_xz * foot_ground_contact)
    
    
    foot_ground_disp_mean = foot_ground_disp.mean() if foot_ground_disp.numel() > 0 else torch.tensor(0.0).to(device)
    
    
    foot_ground_disp_std = foot_ground_disp.std(dim=(1,2)) if foot_ground_disp.numel() > 0 else torch.tensor(0.0).to(device)
    
    
    foot_ground_disp_sum = foot_ground_disp.sum(dim=(1,2)) if foot_ground_disp.numel() > 0 else torch.tensor(0.0).to(device)
    
    
    skate_th = 0.0
    foot_ground_disp_perc = (foot_ground_disp > skate_th).float().mean(dim=(1,2))

    return (
        penetration_dist,
        float_dist,
        foot_ground_disp_mean, 
        foot_ground_disp_std,
        foot_ground_disp_sum,
        foot_ground_disp_perc
    )

def analyze_axis_lengths(verts=None, joints=None):
    
    if verts is not None:
        data = verts
    elif joints is not None:
        data = joints
    else:
        raise ValueError("Must provide at least one of 'verts' or 'joints'")
    
    
    x_range = data[..., 0].max() - data[..., 0].min()
    y_range = data[..., 1].max() - data[..., 1].min()
    z_range = data[..., 2].max() - data[..., 2].min()
    
    return x_range.item(), y_range.item(), z_range.item()