import torch
import math
from lift3d.models.concept.MetaWorld.utils import rotate

def Dense_Cuboid(size, position, rotation, Sem, N=40):
    """
    Generate a cuboid mesh with differentiable point sampling
    
    Args:
        size: [B,3] tensor - Dimensions of the cuboid (x, y, z)
        position: [B,3] tensor - Center position of the cuboid
        rotation: [B,3] tensor - Rotation angles (Euler angles)
        Sema: scalar - Semantic value
        N: int - Number of points to sample (default: 1000)
        
    Returns:
        Affordance: [B, 8, 4] tensor - Original vertices with semantics
        SampledPoints: [B, N, 4] tensor - Sampled points with semantics
    """
    B = size.shape[0]
    device = size.device
    dtype = size.dtype
    
    # Define canonical vertices (unit cube centered at origin)
    vertices = torch.tensor([
        [-0.5,  0.5,  0.5], [ 0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [ 0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [ 0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [ 0.5, -0.5, -0.5]
    ], device=device, dtype=dtype)
    
    # Scale vertices to desired size
    vertices = vertices.unsqueeze(0).repeat(B, 1, 1) * size.unsqueeze(1)
    
    # Apply rotation and translation
    vertices = rotate(vertices, rotation) + position.unsqueeze(1)
    
    # Define faces for sampling (12 triangles)
    faces = torch.tensor([
        [0,1,2], [1,3,2],  # Top face
        [4,5,6], [5,7,6],  # Bottom face
        [0,4,1], [1,4,5],  # Front face
        [2,3,6], [3,7,6],  # Back face
        [0,2,4], [2,6,4],  # Left face
        [1,5,3], [3,5,7],   # Right face
    ], device=device)
    
    # Differentiable point sampling
    # 1. Compute face areas
    face_verts = vertices[:, faces]  # [B, 12, 3, 3]
    v0, v1, v2 = face_verts.unbind(2)
    areas = 0.5 * torch.norm(torch.cross(v1-v0, v2-v0), dim=2)  # [B, 12]
    
    # 2. Sample faces according to area distribution
    probs = areas / (areas.sum(dim=1, keepdim=True) + 1e-8)
    face_idx = torch.multinomial(probs, N, replacement=True)  # [B, N]
    face_idx_expanded = face_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 3, 3)
    
    # 3. Sample points within each triangle using barycentric coordinates
    u = torch.rand(B, N, 1, device=device)
    v = torch.rand(B, N, 1, device=device)
    mask = (u + v) > 1
    u = torch.where(mask, 1-u, u)
    v = torch.where(mask, 1-v, v)
    w = 1 - u - v
    
    # Gather corresponding face vertices and interpolate
    selected_faces = torch.gather(
        face_verts, 
        1, 
        face_idx_expanded
    ).squeeze(2)  # [B, N, 3, 3]
    
    sampled_points = (selected_faces[..., 0, :] * u + 
                     selected_faces[..., 1, :] * v + 
                     selected_faces[..., 2, :] * w)
    
    # Add semantic information to sampled points
    sampled_semantic = torch.ones(B, N, 1, device=device) * Sem
    SampledPoints = torch.cat([sampled_points, sampled_semantic], dim=-1)
    
    return SampledPoints

def Dense_Cuboid_rot(size, position, rotation, rot, Sem, N=40):
    """
    Generate a cuboid mesh with differentiable point sampling
    
    Args:
        size: [B,3] tensor - Dimensions of the cuboid (x, y, z)
        position: [B,3] tensor - Center position of the cuboid
        rotation: [B,3] tensor - Rotation angles (Euler angles)
        Sema: scalar - Semantic value
        N: int - Number of points to sample (default: 1000)
        
    Returns:
        Affordance: [B, 8, 4] tensor - Original vertices with semantics
        SampledPoints: [B, N, 4] tensor - Sampled points with semantics
    """
    B = size.shape[0]
    device = size.device
    dtype = size.dtype
    
    # Define canonical vertices (unit cube centered at origin)
    vertices = torch.tensor([
        [-0.5,  0.5,  0.5], [ 0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [ 0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [ 0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [ 0.5, -0.5, -0.5]
    ], device=device, dtype=dtype)
    
    Rot = torch.cat([torch.cos(rot), torch.sin(rot), torch.zeros_like(rot)], dim=-1).unsqueeze(dim=1)
    vertices = vertices * Rot
    
    # Scale vertices to desired size
    vertices = vertices * size.unsqueeze(dim=1).repeat(1, 8, 1)
    
    # Apply rotation and translation
    vertices = rotate(vertices, rotation) + position.unsqueeze(1)
    
    # Define faces for sampling (12 triangles)
    faces = torch.tensor([
        [0,1,2], [1,3,2],  # Top face
        [4,5,6], [5,7,6],  # Bottom face
        [0,4,1], [1,4,5],  # Front face
        [2,3,6], [3,7,6],  # Back face
        [0,2,4], [2,6,4],  # Left face
        [1,5,3], [3,5,7],   # Right face
    ], device=device)
    
    # Differentiable point sampling
    # 1. Compute face areas
    face_verts = vertices[:, faces]  # [B, 12, 3, 3]
    v0, v1, v2 = face_verts.unbind(2)
    areas = 0.5 * torch.norm(torch.cross(v1-v0, v2-v0), dim=2)  # [B, 12]
    
    # 2. Sample faces according to area distribution
    probs = areas / (areas.sum(dim=1, keepdim=True) + 1e-8)
    face_idx = torch.multinomial(probs, N, replacement=True)  # [B, N]
    face_idx_expanded = face_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 3, 3)
    
    # 3. Sample points within each triangle using barycentric coordinates
    u = torch.rand(B, N, 1, device=device)
    v = torch.rand(B, N, 1, device=device)
    mask = (u + v) > 1
    u = torch.where(mask, 1-u, u)
    v = torch.where(mask, 1-v, v)
    w = 1 - u - v
    
    # Gather corresponding face vertices and interpolate
    selected_faces = torch.gather(
        face_verts, 
        1, 
        face_idx_expanded
    ).squeeze(2)  # [B, N, 3, 3]
    
    sampled_points = (selected_faces[..., 0, :] * u + 
                     selected_faces[..., 1, :] * v + 
                     selected_faces[..., 2, :] * w)
    
    # Add semantic information to sampled points
    sampled_semantic = torch.ones(B, N, 1, device=device) * Sem
    SampledPoints = torch.cat([sampled_points, sampled_semantic], dim=-1)
    
    return SampledPoints

def Dense_Cylinder(height, radius, position, rotation, theta=30, Sem=0, order=1, N=80):
    """
    Generate uniformly sampled points on a complete cylindrical surface (side + top/bottom)
    
    Args:
        height: [B,1] - Height of cylinder along axis
        radius: [B,1] - Radius of cylinder
        position: [B,3] - Center position
        rotation: [B,3] - Rotation angles (Euler)
        theta: int - Angle step for discretization (degrees) - unused in sampling
        Sem: int - Semantic value
        order: int - Axis orientation (1=Z, 2=Y, 3=X)
        N: int - Number of points to sample
        
    Returns:
        SampledPoints: [B, N, 4] - Sampled surface points with semantics (xyz + semantic)
    """
    B = height.shape[0]
    device = height.device
    dtype = height.dtype
    
    # Determine fraction of points for each surface type
    n_side = N // 2
    n_top = N // 4
    n_bottom = N - n_side - n_top
    
    # Initialize output tensor
    sampled_points = torch.zeros(B, N, 3, device=device, dtype=dtype)
    
    # 1. Sample side surface (cylindrical part)
    if n_side > 0:
        u = torch.rand((B, n_side), device=device)  # Height [0,1]
        v = torch.rand((B, n_side), device=device)  # Angle [0,1]
        
        angles = 2 * math.pi * v
        heights = (u - 0.5) * height
        
        if order == 1:  # Z-axis
            x = radius * torch.cos(angles)
            y = radius * torch.sin(angles)
            z = heights
        elif order == 2:  # Y-axis
            x = radius * torch.cos(angles)
            z = radius * torch.sin(angles)
            y = heights
        else:  # X-axis
            y = radius * torch.cos(angles)
            z = radius * torch.sin(angles)
            x = heights
        
        sampled_points[:, :n_side, :] = torch.cat([x.unsqueeze(dim=-1), y.unsqueeze(dim=-1), z.unsqueeze(dim=-1)], dim=-1)
    
    # 2. Sample top and bottom surfaces (circular disks)
    if n_top + n_bottom > 0:
        # Shared angle sampling for both disks
        v = torch.rand((B, n_top + n_bottom), device=device) * 2 * math.pi
        r = torch.sqrt(torch.rand((B, n_top + n_bottom), device=device)) * radius
        
        if order == 1:  # Z-axis
            x = r * torch.cos(v)
            y = r * torch.sin(v)
            z_top = height / 2
            z_bottom = -height / 2
        elif order == 2:  # Y-axis
            x = r * torch.cos(v)
            z = r * torch.sin(v)
            y_top = height / 2
            y_bottom = -height / 2
        else:  # X-axis
            y = r * torch.cos(v)
            z = r * torch.sin(v)
            x_top = height / 2
            x_bottom = -height / 2
        
        # Assign to top and bottom sections
        if n_top > 0:
            if order == 1:
                z_values = torch.ones((B, n_top), device=device) * z_top
                sampled_points[:, n_side:n_side+n_top, :] = torch.cat([x[:, :n_top].unsqueeze(dim=-1), y[:, :n_top].unsqueeze(dim=-1), z_values.unsqueeze(dim=-1)], dim=-1)
            elif order == 2:
                y_values = torch.ones((B, n_top), device=device) * y_top
                sampled_points[:, n_side:n_side+n_top, :] = torch.cat([x[:, :n_top].unsqueeze(dim=-1), y_values.unsqueeze(dim=-1), z[:, :n_top].unsqueeze(dim=-1)], dim=-1)
            else:
                x_values = torch.ones((B, n_top), device=device) * x_top
                sampled_points[:, n_side:n_side+n_top, :] = torch.cat([x_values.unsqueeze(dim=-1), y[:, :n_top].unsqueeze(dim=-1), z[:, :n_top].unsqueeze(dim=-1)], dim=-1)
        
        if n_bottom > 0:
            start_idx = n_side + n_top
            if order == 1:
                z_values = torch.ones((B, n_bottom), device=device) * z_bottom
                sampled_points[:, start_idx:, :] = torch.cat([x[:, n_top:].unsqueeze(dim=-1), y[:, n_top:].unsqueeze(dim=-1), z_values.unsqueeze(dim=-1)], dim=-1)
            elif order == 2:
                y_values = torch.ones((B, n_bottom), device=device) * y_bottom
                sampled_points[:, start_idx:, :] = torch.cat([x[:, n_top:].unsqueeze(dim=-1), y_values.unsqueeze(dim=-1), z[:, n_top:].unsqueeze(dim=-1)], dim=-1)
            else:
                x_values = torch.ones((B, n_bottom), device=device) * x_bottom
                sampled_points[:, start_idx:, :] = torch.cat([x_values.unsqueeze(dim=-1), y[:, n_top:].unsqueeze(dim=-1), z[:, n_top:].unsqueeze(dim=-1)], dim=-1)
    
    # Apply rotation and translation
    sampled_points = rotate(sampled_points, rotation) + position.unsqueeze(1)
    
    # Add semantic information
    sampled_semantic = torch.ones(B, N, 1, device=device, dtype=dtype) * Sem
    SampledPoints = torch.cat([sampled_points, sampled_semantic], dim=-1)
    
    return SampledPoints

def Dense_Rectangular_Ring(inner, outer, height, position, rotation, Sem, order=1, N=40):
    """
    Differentiable rectangular ring with point sampling
    :param inner[B,2]: [inner_length, inner_width]
    :param outer[B,2]: [outer_length, outer_width]
    :height[B,1]: height of Rectangular_Ring
    :position[B,3], rotation[B,3]: position and rotation
    :Sema: scalar semantic value
    :N: number of points to sample
    Returns: 
        - Affordance: [B, 16, 4] original vertices with semantics
        - SampledPoints: [B, N, 4] sampled points with semantics (differentiable)
    """
    B = inner.shape[0]
    device = inner.device
    dtype = inner.dtype
    
    if order == 1:
        # [B, 16, 3] vertex template
        vertices = torch.tensor([
            # Outer box vertices (0-7)
            [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
            [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
            [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
            [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
            # Inner box vertices (8-15)
            [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
            [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
            [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
            [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5]
        ], device=device, dtype=dtype).unsqueeze(0).repeat(B, 1, 1)
        
        # Scale vertices
        outer_box = torch.cat([outer[:, 0:1], height, outer[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 8, 1)
        inner_box = torch.cat([inner[:, 0:1], height, inner[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 8, 1)
        deformation = torch.cat([outer_box, inner_box], dim=1)
        
        # Apply transformations
        vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(1)
        
        # Define faces for sampling (fixed template)
        faces = torch.tensor([
            # Outer faces (8 triangles)
            [0,1,4], [5,1,4],
            [1,3,5], [7,3,5],
            [2,3,6], [7,3,6],
            [0,2,4], [6,2,4],
            # Inner faces (8 triangles)
            [8,9,12], [13,9,12],
            [9,11,13], [15,11,13],
            [10,11,14], [15,11,14],
            [8,10,12], [14,10,12],
            # Connecting faces (16 triangles)
            [0,1,8], [9,1,8],
            [1,3,9], [11,3,9],
            [2,3,10], [11,3,10],
            [0,2,8], [10,2,8],
            [4,5,12], [13,5,12],
            [5,7,13], [15,7,13],
            [6,7,14], [15,7,14],
            [4,6,12], [14,6,12],
        ], device=device)
    else:
        # [B, 16, 3] vertex template
        vertices = torch.tensor([
            # Outer box vertices (0-7)
            [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
            [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
            [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
            [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
            # Inner box vertices (8-15)
            [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
            [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
            [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
            [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
        ], device=device, dtype=dtype).unsqueeze(0).repeat(B, 1, 1)
        
        # Scale vertices
        outer_box = torch.cat([outer[:, 0:1], outer[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 8, 1)
        inner_box = torch.cat([inner[:, 0:1], inner[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 8, 1)
        deformation = torch.cat([outer_box, inner_box], dim=1)
        
        # Apply transformations
        vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(1)
        
        # Define faces for sampling (fixed template)
        faces = torch.tensor([
            # Outer faces (8 triangles)
            [0,1,4], [5,1,4],
            [1,3,5], [7,3,5],
            [2,3,6], [7,3,6],
            [0,2,4], [6,2,4],
            # Inner faces (8 triangles)
            [8,9,12], [13,9,12],
            [9,11,13], [15,11,13],
            [10,11,14], [15,11,14],
            [8,10,12], [14,10,12],
            # Connecting faces (16 triangles)
            [0,1,8], [9,1,8],
            [1,3,9], [11,3,9],
            [2,3,10], [11,3,10],
            [0,2,8], [10,2,8],
            [4,5,12], [13,5,12],
            [5,7,13], [15,7,13],
            [6,7,14], [15,7,14],
            [4,6,12], [14,6,12],
        ], device=device)
                
    # Differentiable point sampling
    # 1. Compute face areas
    face_verts = vertices[:, faces]  # [B, 32, 3, 3]
    v0, v1, v2 = face_verts.unbind(2)
    areas = 0.5 * torch.norm(torch.cross(v1-v0, v2-v0), dim=2)  # [B, 32]
    
    # 2. Sample faces according to area distribution
    probs = areas / (areas.sum(dim=1, keepdim=True) + 1e-8)
    face_idx = torch.multinomial(probs, N, replacement=True)  # [B, N]
    face_idx_expanded = face_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 3, 3)
    
    # 3. Sample points within each triangle
    # Get barycentric coordinates
    u = torch.rand(B, N, 1, device=device)
    v = torch.rand(B, N, 1, device=device)
    mask = (u + v) > 1
    u = torch.where(mask, 1-u, u)
    v = torch.where(mask, 1-v, v)
    w = 1 - u - v
    
    # Gather corresponding face vertices
    selected_faces = torch.gather(
        face_verts, 
        1, 
        face_idx_expanded
    ).squeeze(2)  # [B, N, 3, 3]
    
    # Interpolate points
    sampled_points = (selected_faces[..., 0, :] * u + 
                     selected_faces[..., 1, :] * v + 
                     selected_faces[..., 2, :] * w)
    
    # Add semantics to sampled points
    sampled_semantic = torch.ones(B, N, 1, device=device) * Sem
    SampledPoints = torch.cat([sampled_points, sampled_semantic], dim=-1)
    
    return SampledPoints

def Dense_Handle(inner, outer, height, position, rotation, Sem, order, N=40):
    """
    Differentiable rectangular ring with point sampling
    :param inner[B,2]: [inner_length, inner_width]
    :param outer[B,2]: [outer_length, outer_width]
    :height[B,1]: height of Rectangular_Ring
    :position[B,3], rotation[B,3]: position and rotation
    :Sema: scalar semantic value
    :N: number of points to sample
    Returns: 
        - Affordance: [B, 16, 4] original vertices with semantics
        - SampledPoints: [B, N, 4] sampled points with semantics (differentiable)
    """
    B = inner.shape[0]
    device = inner.device
    dtype = inner.dtype
    
    if order == 1 :
        # [B, 16, 3] vertex template
        vertices = torch.tensor([
            # Outer box vertices (0-7)
            [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
            [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
            [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
            [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
            # Inner box vertices (8-15)
            [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
            [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
            [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
            [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5]
        ], device=device, dtype=dtype).unsqueeze(0).repeat(B, 1, 1)
            
        # [B, 1, 3] → [B, 8, 3] for outer & inner box
        outer_box1 = torch.cat([outer[:, 0:1], height, outer[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 4, 1)
        outer_box2 = torch.cat([outer[:, 0:1], height, inner[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 4, 1)
        inner_box = torch.cat([inner[:, 0:1], height, inner[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 8, 1)
        deformation = torch.cat([outer_box1, outer_box2, inner_box], dim=1)  # [B, 16, 3]
            
        # Apply transformations
        vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(1)
            
        # Define faces for sampling (fixed template)
        faces = torch.tensor([
            # Outer faces (6 triangles)
            [0,1,4], [5,1,4],
            [1,3,5], [7,3,5],
            [0,2,4], [6,2,4],
            # Inner faces (6 triangles)
            [8,9,12], [13,9,12],
            [9,11,13], [15,11,13],
            [8,10,12], [14,10,12],
            # Connecting faces (12 triangles)
            [0,1,8], [9,1,8],
            [1,3,9], [11,3,9],
            [0,2,8], [10,2,8],
            [4,5,12], [13,5,12],
            [5,7,13], [15,7,13],
            [4,6,12], [14,6,12],
        ], device=device)
    else:
        # [B, 16, 3] vertex template
        vertices = torch.tensor([
            # Outer box vertices (0-7)
            [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
            [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
            [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
            [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
            # Inner box vertices (8-15)
            [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
            [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
            [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
            [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
        ], device=device, dtype=dtype).unsqueeze(0).repeat(B, 1, 1)
        
        # Scale vertices
        outer_box1 = torch.cat([outer[:, 0:1], outer[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 4, 1)
        outer_box2 = torch.cat([outer[:, 0:1], inner[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 4, 1)
        inner_box = torch.cat([inner[:, 0:1], inner[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 8, 1)
        deformation = torch.cat([outer_box1, outer_box2, inner_box], dim=1)  # [B, 16, 3]
        
        # Apply transformations
        vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(1)
        
        # Define faces for sampling (fixed template)
        faces = torch.tensor([
            # Outer faces (6 triangles)
            [0,1,4], [5,1,4],
            [1,3,5], [7,3,5],
            [0,2,4], [6,2,4],
            # Inner faces (6 triangles)
            [8,9,12], [13,9,12],
            [9,11,13], [15,11,13],
            [8,10,12], [14,10,12],
            # Connecting faces (12 triangles)
            [0,1,8], [9,1,8],
            [1,3,9], [11,3,9],
            [0,2,8], [10,2,8],
            [4,5,12], [13,5,12],
            [5,7,13], [15,7,13],
            [4,6,12], [14,6,12],
        ], device=device)
                
    # Differentiable point sampling
    # 1. Compute face areas
    face_verts = vertices[:, faces]  # [B, 24, 3, 3]
    v0, v1, v2 = face_verts.unbind(2)
    areas = 0.5 * torch.norm(torch.cross(v1-v0, v2-v0), dim=2)  # [B, 24]
    
    # 2. Sample faces according to area distribution
    probs = areas / (areas.sum(dim=1, keepdim=True) + 1e-8)
    face_idx = torch.multinomial(probs, N, replacement=True)  # [B, N]
    face_idx_expanded = face_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 3, 3)
    
    # 3. Sample points within each triangle
    # Get barycentric coordinates
    u = torch.rand(B, N, 1, device=device)
    v = torch.rand(B, N, 1, device=device)
    mask = (u + v) > 1
    u = torch.where(mask, 1-u, u)
    v = torch.where(mask, 1-v, v)
    w = 1 - u - v
    
    # Gather corresponding face vertices
    selected_faces = torch.gather(
        face_verts, 
        1, 
        face_idx_expanded
    ).squeeze(2)  # [B, N, 3, 3]
    
    # Interpolate points
    sampled_points = (selected_faces[..., 0, :] * u + 
                     selected_faces[..., 1, :] * v + 
                     selected_faces[..., 2, :] * w)
    
    # Add semantics to sampled points
    sampled_semantic = torch.ones(B, N, 1, device=device) * Sem
    SampledPoints = torch.cat([sampled_points, sampled_semantic], dim=-1)
    
    return SampledPoints
