import torch
import math
from StructPolicy.StructGen.utils import rotate

def Cuboid(size, position, rotation, Sema):
    """
    :size[B,3]: [0]=length, [1]=width, [2]=height
    :position[B,3], rotation[B,3]:
    """
    B = size.shape[0]
    
    vertices = torch.tensor([
        [-1 / 2, 1 / 2, 1 / 2],
        [1 / 2, 1 / 2, 1 / 2],
        [-1 / 2, 1 / 2, -1 / 2],
        [1 / 2, 1 / 2, -1 / 2],
        [-1 / 2, -1 / 2, 1 / 2],
        [1 / 2, -1 / 2, 1 / 2],
        [-1 / 2, -1 / 2, -1 / 2],
        [1 / 2, -1 / 2, -1 / 2]
    ], device=position.device, dtype=position.dtype)
    
    deformation = size.unsqueeze(dim=1).repeat(1, 8, 1)
    
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance

def Cylinder (height, radius, position, rotation, theta=30, Sem=0, order=1):
    """
    :param height: [B, 1], height of cylinder in Z-axis
    :param radius: [B, 1], radius of cylinder in X-Y-Plane
    :param position, rotation: [B,3], [B,3]
    :param theta: int
    """
    B = height.shape[0]
    n = int(360 / theta)
    theta_rad = theta * math.pi / 180.0
    
    if order == 1 : 
        V = [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), -0.5] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), 0.5] for i in range(n)
        ]
    elif order == 2 :
        V = [
            [math.cos(i * theta_rad), -0.5, math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), 0.5, math.sin(i * theta_rad)] for i in range(n)
        ]
    else :
        V = [
            [-0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ]
    
    vertices = torch.tensor(V, dtype=height.dtype, device=height.device).unsqueeze(dim=0).repeat(B,1,1)  # shape [B, 2n, 3]
    deformation = torch.cat([radius, radius, height], dim=1).unsqueeze(dim=1)
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    
    semantic = torch.ones([B,2*n,1], device=height.device, dtype=height.dtype) * Sem
    Affordance = torch.cat([vertices, semantic], dim=-1)
    
    return Affordance

def Rectangular_Ring(inner, outer, height, position, rotation, Sema):
    """
    :param inner[B,2]: [0]=inner_length, [1]=inner_width
    :param outer[B,2]: [0]=outer_length, [1]outer_width
    :height[B,1]: height of Rectangular_Ring
    :position[B,3], rotation[B,3]: position and rotation
    """
    B = inner.shape[0]

    # [B, 16, 3] vertex template
    vertices = torch.tensor([
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5]
    ], device=inner.device, dtype=inner.dtype).unsqueeze(0).repeat(B, 1, 1)  # [B, 16, 3]

    # [B, 1, 3] → [B, 8, 3] for outer & inner box
    outer_box = torch.cat([outer[:, 0:1], outer[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    inner_box = torch.cat([inner[:, 0:1], inner[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    deformation = torch.cat([outer_box, inner_box], dim=1)  # [B, 16, 3]

    # Apply scale, rotation, and translation
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(1)  # [B, 16, 3]
    Semantic = torch.ones([B, 16, 1], device=inner.device, dtype=inner.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)

    return Affordance

def Sphere(Radius, position, rotation, theta=30, Sem=0):
    """
    :param Radius: [B, 3], radius of the sphere
    :param position, rotation: [B,3], [B,3]
    :param theta: int, angular resolution
    :param Sem: int, semantic label
    """
    B = Radius.shape[0]
    n = int(360 / theta)
    theta_rad = theta * math.pi / 180.0
    
    # Generate vertices for a sphere using spherical coordinates
    V = []
    for i in range(n):  # azimuth angle (longitude)
        phi = i * theta_rad
        for j in range(n//2):  # polar angle (latitude), only half circle needed
            theta_polar = j * theta_rad
            
            x = math.sin(theta_polar) * math.cos(phi)
            y = math.sin(theta_polar) * math.sin(phi)
            z = math.cos(theta_polar)
            
            V.append([x, y, z])
    
    vertices = torch.tensor(V, dtype=Radius.dtype, device=Radius.device).unsqueeze(dim=0).repeat(B,1,1)  # shape [B, n*(n//2), 3]
    
    # Scale by radius and apply rotation/translation
    deformation = Radius.unsqueeze(dim=1)
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    
    semantic = torch.ones([B, n*(n//2), 1], device=Radius.device, dtype=Radius.dtype) * Sem
    Affordance = torch.cat([vertices, semantic], dim=-1)
    
    return Affordance
    
def Trianguler_Prism(size, position, rotation, Sem=0):
    """
    :param size: [B, 3], size of the Trianguler Prism
    :param position, rotation: [B,3], [B,3]
    :param Sem: int, semantic label
    """
    B = size.shape[0]

    vertices = torch.tensor([
        [0.5, 0, 0.5], [-0.5, 0, 0.5], [0, 1, 0.5],
        [0.5, 0, -0.5], [-0.5, 0, -0.5], [0, 1, -0.5],
    ])
    
    deformation = size.unsqueenze(dim=1).repeat(1, 6, 1)

    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * Sem
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance    

def Cone(radius, height, position, rotation, theta=30, Sem=0):
    """
    :param radius: [B, 1]
    :param height: [B, 1]
    :param position, rotation: [B,3], [B,3]
    :param Sem: int, semantic label
    """
    B = radius.shape[0]
    
    n = int(360 / theta)
    theta_rad = theta * math.pi / 180.0
    
    V = [
        [math.cos(i * theta_rad), math.sin(i * theta_rad), 0] for i in range(n)
    ]
    V.append([0,0,1])
    
    vertices = torch.tensor(V, dtype=height.dtype, device=height.device).unsqueeze(dim=0).repeat(B,1,1)  # shape [B, n + 1, 3]
    
    deformation = torch.cat([radius, radius, height], dim=1).unsqueeze(dim=1)
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    
    semantic = torch.ones([B,2*n,1], device=height.device, dtype=height.dtype) * Sem
    Affordance = torch.cat([vertices, semantic], dim=-1)
    
    return Affordance

def Ring(inner_radius, outer_radius, position, rotation, theta=30, Sem=0):
    """
    :param inner_radius: [B, 1], inner radius of the ring in X-Y plane
    :param outer_radius: [B, 1], outer radius of the ring in X-Y plane
    :param position, rotation: [B,3], [B,3]
    :param theta: int, angular resolution
    :param Sem: int, semantic label
    """
    B = inner_radius.shape[0]
    n = int(360 / theta)
    theta_rad = theta * math.pi / 180.0
    
    # Generate vertices for inner and outer circles
    V = []
    for i in range(n):
        # Outer circle vertices (top)
        x_out = math.cos(i * theta_rad)
        y_out = math.sin(i * theta_rad)
        V.append([x_out, y_out, 0.0])
        
        # Inner circle vertices (top)
        x_in = math.cos(i * theta_rad)
        y_in = math.sin(i * theta_rad)
        V.append([x_in, y_in, 0.0])
        
        # Connect to next point (for quad formation)
        next_i = (i + 1) % n
        x_out_next = math.cos(next_i * theta_rad)
        y_out_next = math.sin(next_i * theta_rad)
        V.append([x_out_next, y_out_next, 0.0])
        
        x_in_next = math.cos(next_i * theta_rad)
        y_in_next = math.sin(next_i * theta_rad)
        V.append([x_in_next, y_in_next, 0.0])
    
    vertices = torch.tensor(V, dtype=inner_radius.dtype, device=inner_radius.device).unsqueeze(dim=0).repeat(B,1,1)  # shape [B, 4n, 3]
    
    # Scale vertices (outer vertices by outer_radius, inner by inner_radius)
    # We alternate between outer and inner vertices in the vertex list
    scale_factors = torch.cat([outer_radius, inner_radius, outer_radius, inner_radius], dim=1).unsqueeze(1)
    scale_factors = scale_factors.repeat(1, n, 1)  # [B, n, 4]
    scale_factors = scale_factors.reshape(B, 4*n, 1)  # [B, 4n, 1]
    
    # Apply scaling to x and y coordinates (z remains unchanged)
    vertices_scaled = vertices.clone()
    vertices_scaled[..., :2] = vertices[..., :2] * scale_factors
    
    # Apply rotation and translation
    vertices_final = rotate(vertices_scaled, rotation) + position.unsqueeze(dim=1)
    
    semantic = torch.ones([B, 4*n, 1], device=inner_radius.device, dtype=inner_radius.dtype) * Sem
    Affordance = torch.cat([vertices_final, semantic], dim=-1)
    
    return Affordance
    