import torch
import torch.nn as nn 
import math 
from mdt.StructPolicy.utils import rotate
# from utils import rotate

def Cuboid(size, position, rotation, label):
    """
    :size[B,3]: [0]=length, [1]=width, [2]=height
    :position[B,3], rotation[B,3]:
    """
    assert size.shape == position.shape, f"Shape mismatch: size.shape={size.shape}, position.shape={position.shape}"
    B = size.shape[0]
    
    vertices = torch.tensor([
        [-1 / 2, 1 / 2, 1 / 2],
        [1 / 2, 1 / 2, 1 / 2],
        [-1 / 2, 1 / 2, -1 / 2],
        [1 / 2, 1 / 2, -1 / 2],
        [-1 / 2, -1 / 2, 1 / 2],
        [1 / 2, -1 / 2, 1 / 2],
        [-1 / 2, -1 / 2, -1 / 2],
        [1 / 2, -1 / 2, -1 / 2]
    ], device=position.device, dtype=position.dtype)
    
    deformation = size.unsqueeze(dim=1).expand(-1, 8, -1)
        
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * label
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance

def Trapezoid_Y(size, position, rotation, label):
    """
    :size[B,4]: [0]=x, [1]=y_upper, [2]=y_bottom, [3]=z
    :position[B,3], rotation[B,3]:
    """
    B = size.shape[0]
    
    vertices = torch.tensor([
        [-1 / 2,    1, 1 / 2],
        [1 / 2,     1, 1 / 2],
        [-1 / 2,    1, -1 / 2],
        [1 / 2,     1, -1 / 2],
        [-1 / 2,    0, 1 / 2],
        [1 / 2,     0, 1 / 2],
        [-1 / 2,    0, -1 / 2],
        [1 / 2,     0, -1 / 2]
    ], device=position.device, dtype=position.dtype)
    
    deformation_upper = torch.cat([
        size[:, 0:2],
        size[:, 3:4]
    ], dim=1).unsqueeze(dim=1).repeat(1,2,1)
    deformation_bottom = torch.cat([
        size[:, 0:1],
        size[:, 2:4]
    ], dim=1).unsqueeze(dim=1).repeat(1,2,1)
    deformation = torch.cat([
        deformation_upper,
        deformation_bottom,
        deformation_upper,
        deformation_bottom
    ], dim=1)
        
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * label
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance

def Handle(inner, outer, height, position, rotation, Semantic):
    """
    :param inner[B,2]: [0]=inner_x, [1]=inner_y
    :param outer[B,2]: [0]=outer_x, [1]outer_y
    :height[B,1]: height of Handle (z)
    :position[B,3], rotation[B,3]: position and rotation
    """
    B = inner.shape[0]

    # [B, 16, 3] vertex template
    vertices = torch.tensor([
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5]
    ], device=inner.device, dtype=inner.dtype).unsqueeze(0).repeat(B, 1, 1)  # [B, 16, 3]

    # [B, 1, 3] → [B, 8, 3] for outer & inner box
    outer_box1 = torch.cat([outer[:, 0:1], outer[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 4, 1)
    outer_box2 = torch.cat([outer[:, 0:1], inner[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 4, 1)
    inner_box = torch.cat([inner[:, 0:1], inner[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    deformation = torch.cat([outer_box1, outer_box2, inner_box], dim=1)  # [B, 16, 3]

    # Apply scale, rotation, and translation
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(1)  # [B, 16, 3]
    Semantic = torch.ones([B, 16, 1], device=inner.device, dtype=inner.dtype) * Semantic
    Affordance = torch.cat((vertices, Semantic), dim=-1)

    return Affordance

def Cylinder (height, radius, position, rotation, theta=30, Semantic=0, order="z"):
    """
    :param height: [B, 1], height of cylinder in Z-axis
    :param radius: [B, 1], radius of cylinder in X-Y-Plane
    :param position, rotation: [B,3], [B,3]
    :param theta: int
    """
    B = height.shape[0]
    n = int(360 / theta)
    theta_rad = theta * math.pi / 180.0
    
    if order == "z" : 
        V = [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), -0.5] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), 0.5] for i in range(n)
        ]
    elif order == "y" :
        V = [
            [math.cos(i * theta_rad), -0.5, math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), 0.5, math.sin(i * theta_rad)] for i in range(n)
        ]
    else :
        V = [
            [-0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ]
    
    vertices = torch.tensor(V, dtype=height.dtype, device=height.device).unsqueeze(dim=0).repeat(B,1,1)  # shape [B, 2n, 3]
    deformation = torch.cat([radius, radius, height], dim=1).unsqueeze(dim=1)
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    
    semantic = torch.ones([B,2*n,1], device=height.device, dtype=height.dtype) * Semantic
    Affordance = torch.cat([vertices, semantic], dim=-1)
    
    return Affordance

def StructureMap(paras):
    """
    :paras[0:3]: position of desk center
    :paras[3:6]: rotation of desk
    
    """
    paras = preprocess_parameters(paras)
    B = paras.shape[0]
    device = paras.device
    dtype = paras.dtype
    Segmentation_theta = 30
    
    Vertices = []
    TargetPose = []
    
    # Desk
    Position_Desk = paras[:, 0:3]
    Rotation_Desk = paras[:, 3:6]   # Euler
    
    ## Desk Body
    # paras 6-9
    Size_Desk_Body = paras[:, 6:9]
    Position_Desk_Body = Position_Desk
    Rotation_Desk_Body = torch.zeros([B, 3], device=device, dtype=dtype)
    Vertices_Desk_Body = Cuboid(Size_Desk_Body, Position_Desk_Body, Rotation_Desk_Body, 1)
    Vertices.append(Vertices_Desk_Body)
    
    ## Desk Back
    # paras 9-12
    Size_Desk_Back = torch.cat([
        Size_Desk_Body[:, 0:1],
        paras[:, 9:12],
    ], dim=1)   # x, y_upper, y_bottom, z
    Position_Desk_Back = Position_Desk_Body + torch.cat([
        torch.zeros([B, 1], device=device, dtype=dtype),
        0.5 * Size_Desk_Body[:, 1:2],
        0.5 * Size_Desk_Back[:, 2:3]
    ], dim=1)
    Rotation_Desk_Back = torch.zeros([B, 3], device=device, dtype=dtype)
    Vertices_Desk_Back = Trapezoid_Y(Size_Desk_Back, Position_Desk_Back, Rotation_Desk_Back, 2)
    Vertices.append(Vertices_Desk_Back)
    
    ## Desk Drawer
    # paras 12-25
    Rotation_Desk_Drawer_Body = torch.tensor([0,0,-0.5*math.pi], device=device, dtype=dtype).unsqueeze(dim=0).repeat(B, 1)
    Rotation_Desk_Drawer_Handle = torch.tensor([0.5*math.pi,0.5*math.pi,0], device=device, dtype=dtype).unsqueeze(dim=0).repeat(B, 1)
    Rotation_Desk_Drawer_FrontBoard = torch.zeros([B, 3], device=device, dtype=dtype)
    
    Inner_Desk_Drawer_Body = paras[:, 12:14]
    Outer_Desk_Drawer_Body = Inner_Desk_Drawer_Body + paras[:, 14:16]
    Height_Desk_Drawer_Body = Size_Desk_Body[:, 1:2]
    
    Size_Desk_Drawer_FrontBoard = paras[:, 16:19]
    
    Inner_Desk_Drawer_Handle = paras[:, 19:21]
    Outer_Desk_Drawer_Handle = Inner_Desk_Drawer_Handle + paras[:, 21:23]
    Height_Desk_Drawer_Handle = paras[:, 23:24]
    
    Shift_Desk_Drawer = paras[:, 24:25]
    Position_Desk_Drawer_Body = Position_Desk + torch.cat([
        paras[:, 13:14],
        torch.zeros([B, 1], device=device, dtype=dtype),
        torch.zeros([B, 1], device=device, dtype=dtype)
    ], dim=1)
    Position_Desk_Drawer_FrontBoard = Position_Desk_Drawer_Body + torch.cat([
        torch.zeros([B,1], device=device, dtype=dtype),
        - 0.5 * (Height_Desk_Drawer_Body + Size_Desk_Drawer_FrontBoard[:, 1:2]),
        torch.zeros([B,1], device=device, dtype=dtype)
    ], dim=1)
    Position_Desk_Drawer_Handle = Position_Desk_Drawer_FrontBoard + torch.cat([
        torch.zeros([B,1], device=device, dtype=dtype),
        - 0.5 * (Size_Desk_Drawer_FrontBoard[:, 1:2] + Inner_Desk_Drawer_Handle[:, 0:1]),
        torch.zeros([B,1], device=device, dtype=dtype)
    ], dim=1)
    
    Vertices_Desk_Drawer_Body = Handle(Inner_Desk_Drawer_Body, Outer_Desk_Drawer_Body, Height_Desk_Drawer_Body, Position_Desk_Drawer_Body, Rotation_Desk_Drawer_Body, 3)
    Vertices_Desk_Drawer_FrontBoard = Cuboid(Size_Desk_Drawer_FrontBoard, Position_Desk_Drawer_FrontBoard, Rotation_Desk_Drawer_FrontBoard, 4)
    Vertices_Desk_Drawer_Handle = Handle(Inner_Desk_Drawer_Handle, Outer_Desk_Drawer_Handle, Height_Desk_Drawer_Handle, Position_Desk_Drawer_Handle, Rotation_Desk_Drawer_Handle, 5)
    Vertices_Desk_Drawer = torch.cat([Vertices_Desk_Drawer_Body, Vertices_Desk_Drawer_FrontBoard, Vertices_Desk_Drawer_Handle], dim=1)
    Vertices.append(Vertices_Desk_Drawer)
    
    Direction_Desk_Drawer = torch.tensor([0, 1, 0], device=device, dtype=dtype).unsqueeze(dim=0).repeat(B,1)
    TargetPose_Desk_Drawer = torch.cat([Position_Desk_Drawer_Handle, Direction_Desk_Drawer], dim=1)    
    TargetPose.append(TargetPose_Desk_Drawer)
    
    ## Button
    # paras 25-28
    Height_Desk_Button = paras[:, 25:26]
    Radius_Desk_Button = paras[:, 26:27]
    
    Position_Desk_Button_x = paras[:, 27:28]
    Position_Desk_Button_y = Position_Desk_Body[:, 1:2] - Size_Desk_Back[:, 2:3]
    Position_Desk_Button_z = Position_Desk_Body[:, 2:3] + 0.5*(Size_Desk_Body[:, 2:3] + Height_Desk_Button)
    Position_Desk_Button = torch.cat([Position_Desk_Button_x, Position_Desk_Button_y, Position_Desk_Button_z], dim=1)
    
    Rotation_Desk_Button = torch.zeros([B, 3], device=device, dtype=dtype)
    Vertices_Desk_Button = Cylinder(Height_Desk_Button, Radius_Desk_Button, Position_Desk_Button, Rotation_Desk_Button, Segmentation_theta, 6, "z")
    Vertices.append(Vertices_Desk_Button)
    
    Direction_Desk_Button = torch.tensor([0,0,-1], device=device, dtype=dtype).unsqueeze(dim=0).repeat(B,1)
    TargetPose_Desk_Button = torch.cat([Position_Desk_Button, Direction_Desk_Button], dim=1)
    TargetPose.append(TargetPose_Desk_Button)
    
    ## Switch
    # paras 28-32
    Radius_Desk_Switch = paras[:, 30:31]
    Height_Desk_Switch = paras[:, 31:32]

    cos_theta = torch.abs(Size_Desk_Back[:, 2:3] - Size_Desk_Back[:, 1:2]) / torch.sqrt((Size_Desk_Back[:, 2:3] - Size_Desk_Back[:, 1:2])**2 + Size_Desk_Back[:, 3:4]**2)
    sin_theta = Size_Desk_Back[:, 3:4] / torch.sqrt((Size_Desk_Back[:, 2:3] - Size_Desk_Back[:, 1:2])**2 + Size_Desk_Back[:, 3:4]**2)
    Position_Desk_Switch = torch.cat([
        paras[:, 28:29],
        Position_Desk_Body[:, 1:2] + 0.5*Size_Desk_Body[:, 1:2] - Size_Desk_Back[:, 2:3] + paras[:, 29:30]*cos_theta - Height_Desk_Switch*sin_theta,
        Position_Desk_Body[:, 2:3] + 0.5*Size_Desk_Body[:, 2:3] + paras[:, 29:30]*sin_theta + Height_Desk_Switch*cos_theta
    ], dim=1)
    assert Position_Desk_Switch.shape[0] == B, f"Position_Desk_Switch shape:{Position_Desk_Switch.shape}"
    
    Rotation_Desk_Switch = torch.zeros([B,3], device=device, dtype=dtype)

    Vertices_Desk_Switch = Cylinder(Height_Desk_Switch, Radius_Desk_Switch, Position_Desk_Switch, Rotation_Desk_Switch, Segmentation_theta, 7, "x")
    Vertices.append(Vertices_Desk_Switch)
    
    Direction_Desk_Switch = torch.tensor([0, 1, 0], device=device, dtype=dtype).unsqueeze(dim=0).repeat(B, 1)
    TargetPose_Desk_Switch = torch.cat([Position_Desk_Switch, Direction_Desk_Switch], dim=1)
    TargetPose.append(TargetPose_Desk_Switch)
    
    ## Slide
    # paras 32-38
    Size_Desk_Slide_Board_y = torch.sqrt((Size_Desk_Back[:, 2:3] - Size_Desk_Back[:, 1:2])**2 + Size_Desk_Back[:, 3:4]**2)
    Size_Desk_Slide_Board = torch.cat([paras[:, 32:33], Size_Desk_Slide_Board_y, paras[:, 33:34]], dim=1)
    Outer_Desk_Slide_Handle = torch.cat([
        paras[:, 34:35],
        Size_Desk_Slide_Board_y
    ], dim=1)
    Inner_Desk_Slide_Handle = Outer_Desk_Slide_Handle - paras[:, 35:36].repeat(1, 2)
    Height_Desk_Slide_Handle = paras[:, 36:37]
    
    Position_Desk_Slide_Board = torch.cat([
        paras[:, 37:38],
        Position_Desk_Back[:, 1:2] - 0.5*(Size_Desk_Back[:, 1:2] + Size_Desk_Back[:, 2:3]),
        Position_Desk_Back[:, 2:3]
    ], dim=1)
    Position_Desk_Slide_Handle = Position_Desk_Slide_Board + torch.cat([
        torch.zeros([B,1], device=device, dtype=dtype),
        -0.5*Inner_Desk_Slide_Handle[:, 0:1]*sin_theta,
        0.5*Inner_Desk_Slide_Handle[:, 0:1]*cos_theta
    ], dim=1)
    
    Rotation_Desk_Slide_Board = torch.cat([torch.arccos(cos_theta), torch.zeros([B,2], device=device, dtype=dtype)], dim=1)
    Rotation_Desk_Slide_Handle = torch.tensor([0, -0.5*math.pi, 0], device=device, dtype=dtype).unsqueeze(dim=0).repeat(B,1)
    
    Vertices_Desk_Slide_Board = Cuboid(Size_Desk_Slide_Board, Position_Desk_Slide_Board, Rotation_Desk_Slide_Board, 8)
    Vertices_Desk_Slide_Handle = Handle(Inner_Desk_Slide_Handle, Outer_Desk_Slide_Handle, Height_Desk_Slide_Handle, Position_Desk_Slide_Handle, Rotation_Desk_Slide_Handle, 9)
    Vertices_Desk_Slide_Handle[:, :, 0:3] = rotate(Vertices_Desk_Slide_Handle[:, :, 0:3], torch.cat([torch.arccos(cos_theta), torch.zeros([B,2], device=device, dtype=dtype)], dim=1))
    Vertices_Desk_Slide = torch.cat([Vertices_Desk_Slide_Board, Vertices_Desk_Slide_Handle], dim=1)
    Vertices.append(Vertices_Desk_Slide)
    
    Direction_Desk_Slide = torch.tensor([0, 1, 0], device=device, dtype=dtype).unsqueeze(dim=0).repeat(B,1)
    TargetPose_Desk_Slide = torch.cat([Position_Desk_Slide_Handle, Direction_Desk_Slide], dim=1)
    TargetPose.append(TargetPose_Desk_Slide)
    
    ## Lights
    # paras 38-40
    Size_Desk_LightButton = torch.cat([
        paras[:, 38:39],
        Size_Desk_Back[:, 1:2],
        paras[:, 39:40]
    ], dim=1)
    
    Position_Desk_LightButton = torch.cat([
        Position_Desk_Button[:, 0:1],
        Position_Desk_Back[:, 1:2] - 0.5*Size_Desk_Back[:, 1:2],
        Position_Desk_Back[:, 2:3] + 0.5*(Size_Desk_Back[:, 3:4] + Size_Desk_LightButton[:, 2:3])
    ], dim=1)
    Position_Desk_LightBulb = torch.cat([
        Position_Desk_Switch[:, 0:1],
        Position_Desk_Back[:, 1:2] - 0.5*Size_Desk_Back[:, 1:2],
        Position_Desk_Back[:, 2:3]
    ], dim=1)
    
    Rotation_Desk_LightButton = torch.zeros([B,3], device=device, dtype=dtype)
    
    Vertices_Desk_LightButton = Cuboid(Size_Desk_LightButton, Position_Desk_LightButton, Rotation_Desk_LightButton, 10)
    Vertices_Desk_LightBuble = torch.cat([Position_Desk_LightBulb.reshape(B, 1, 3), 11*torch.ones([B, 1, 1], device=device, dtype=dtype)], dim=2)
    Vertices.append(Vertices_Desk_LightButton)
    Vertices.append(Vertices_Desk_LightBuble)

    ## Legs
    # paras 40-42
    Size_Desk_Leg = torch.cat([
        paras[:, 40:41].repeat(1, 2),
        paras[:, 41:42]
    ], dim=1)
    
    Position_Desk_Leg1 = Position_Desk_Body + torch.cat([
        0.5 * (Size_Desk_Body[:, 0:1] - Size_Desk_Leg[:, 0:1]),
        0.5 * (-Size_Desk_Body[:, 1:2] + Size_Desk_Leg[:, 1:2]),
        -0.5* (Size_Desk_Body[:, 2:3] + Size_Desk_Leg[:, 2:3]) 
    ], dim=1)
    Position_Desk_Leg2 = Position_Desk_Body + torch.cat([
        0.5 * (Size_Desk_Body[:, 0:1] - Size_Desk_Leg[:, 0:1]),
        0.5 * (Size_Desk_Body[:, 1:2] - Size_Desk_Leg[:, 1:2]),
        -0.5* (Size_Desk_Body[:, 2:3] + Size_Desk_Leg[:, 2:3]) 
    ], dim=1)
    Position_Desk_Leg3 = Position_Desk_Body + torch.cat([
        0.5 * (-Size_Desk_Body[:, 0:1] + Size_Desk_Leg[:, 0:1]),
        0.5 * (Size_Desk_Body[:, 1:2] - Size_Desk_Leg[:, 1:2]),
        -0.5* (Size_Desk_Body[:, 2:3] + Size_Desk_Leg[:, 2:3]) 
    ], dim=1)
    Position_Desk_Leg4 = Position_Desk_Body + torch.cat([
        0.5 * (-Size_Desk_Body[:, 0:1] + Size_Desk_Leg[:, 0:1]),
        0.5 * (-Size_Desk_Body[:, 1:2] + Size_Desk_Leg[:, 1:2]),
        -0.5* (Size_Desk_Body[:, 2:3] + Size_Desk_Leg[:, 2:3]) 
    ], dim=1)
    
    Rotation_Desk_Leg = torch.zeros([B, 3], device=device, dtype=dtype)
    
    Vertices_Desk_Leg1 = Cuboid(Size_Desk_Leg, Position_Desk_Leg1, Rotation_Desk_Leg, 12)
    Vertices_Desk_Leg2 = Cuboid(Size_Desk_Leg, Position_Desk_Leg2, Rotation_Desk_Leg, 13)
    Vertices_Desk_Leg3 = Cuboid(Size_Desk_Leg, Position_Desk_Leg3, Rotation_Desk_Leg, 14)
    Vertices_Desk_Leg4 = Cuboid(Size_Desk_Leg, Position_Desk_Leg4, Rotation_Desk_Leg, 15)
    Vertices_Desk_Legs = torch.cat([Vertices_Desk_Leg1, Vertices_Desk_Leg2, Vertices_Desk_Leg3, Vertices_Desk_Leg4], dim=1)
    Vertices.append(Vertices_Desk_Legs)
    
    # Blocks
    # paras 42-69
    Vertices_Block_Red = Cuboid(paras[:, 42:45], paras[:, 45:48], paras[:, 48:51], 16)
    Vertices_Block_Blue = Cuboid(paras[:, 51:54], paras[:, 54:57], paras[:, 57:60], 17)
    Vertices_Block_Pink = Cuboid(paras[:, 60:63], paras[:, 63:66], paras[:, 66:69], 18)
    Vertices.append(Vertices_Block_Red)
    Vertices.append(Vertices_Block_Blue)
    Vertices.append(Vertices_Block_Pink)
    
    Direction_Block_Red = rotate(paras[:, 42:45].unsqueeze(dim=1), paras[:, 48:51]).squeeze(dim=1)
    Direction_Block_Blue = rotate(paras[:, 51:54].unsqueeze(dim=1), paras[:, 57:60]).squeeze(dim=1)
    Direction_Block_Pink = rotate(paras[:, 60:63].unsqueeze(dim=1), paras[:, 66:69]).squeeze(dim=1)
    
    TargetPose_Block_Red = torch.cat([paras[:, 42:45], Direction_Block_Red], dim = 1)
    TargetPose_Block_Blue = torch.cat([paras[:, 51:54], Direction_Block_Blue], dim = 1)
    TargetPose_Block_Pink = torch.cat([paras[:, 60:63], Direction_Block_Pink], dim = 1)
    TargetPose.append(TargetPose_Block_Red)
    TargetPose.append(TargetPose_Block_Blue)
    TargetPose.append(TargetPose_Block_Pink)
    
    ###########
    # Result  #
    ###########
    
    Rotated_Vertices = []
    for vertices in Vertices:
        rotated_coordinates = rotate(vertices[:, :, 0:3], Rotation_Desk)
        rotated_vertices = torch.cat([rotated_coordinates, vertices[:, :, 3:]], dim=2)
        Rotated_Vertices.append(rotated_vertices)
        
    # Vertices = torch.cat(Vertices, dim=1)
    TargetPose = torch.cat(TargetPose, dim=1)
    
    return Rotated_Vertices, TargetPose

import torch
import math

def preprocess_parameters(paras):
    """
    将原始参数约束到合理的物理范围内
    :param paras: 输入张量 [B, 69]，值范围无限制
    :return: 处理后的参数 [B, 69]，所有值都在合理物理范围内
    """
    B = paras.shape[0]
    
    # 初始化处理后的参数
    processed_paras = torch.zeros_like(paras)
    
    # ===================== 位置参数 (0-3) =====================
    # 桌面中心位置 (单位:米) - 约束到 [-1,1] 范围内
    processed_paras[:, 0:3] = torch.tanh(paras[:, 0:3]) * 1.0
    
    # ===================== 旋转参数 (3-6) =====================
    # 欧拉角旋转 (单位:弧度) - 约束到 [-π,π] 范围内
    processed_paras[:, 3:6] = torch.tanh(paras[:, 3:6]) * math.pi
    
    # ===================== 尺寸参数 (6-45) =====================
    # 所有尺寸参数 (单位:米) - 使用sigmoid约束到合理范围
    # 基本尺寸约束为 (0.1, 2.0)米
    
    # 桌面主体 (6-9)
    processed_paras[:, 6:9] = torch.sigmoid(paras[:, 6:9]) * (2.0 - 0.1) + 0.1
    
    # 桌面背部 (9-12)
    processed_paras[:, 9:12] = torch.sigmoid(paras[:, 9:12]) * (2.0 - 0.1) + 0.1
    
    # 抽屉参数 (12-25)
    # 内尺寸 (12-14)
    processed_paras[:, 12:14] = torch.sigmoid(paras[:, 12:14]) * (0.5 - 0.05) + 0.05
    # 外尺寸增量 (14-16)
    processed_paras[:, 14:16] = torch.sigmoid(paras[:, 14:16]) * (0.1 - 0.01) + 0.01
    # 前板尺寸 (16-19)
    processed_paras[:, 16:19] = torch.sigmoid(paras[:, 16:19]) * (0.3 - 0.05) + 0.05
    # 把手内尺寸 (19-21)
    processed_paras[:, 19:21] = torch.sigmoid(paras[:, 19:21]) * (0.15 - 0.03) + 0.03
    # 把手外尺寸增量 (21-23)
    processed_paras[:, 21:23] = torch.sigmoid(paras[:, 21:23]) * (0.05 - 0.01) + 0.01
    # 把手高度 (23-24)
    processed_paras[:, 23:24] = torch.sigmoid(paras[:, 23:24]) * (0.1 - 0.02) + 0.02
    # 抽屉位移 (24-25)
    processed_paras[:, 24:25] = torch.sigmoid(paras[:, 24:25]) * 0.2
    
    # ===================== 按钮参数 (25-28) =====================
    # 按钮高度 (25-26)
    processed_paras[:, 25:26] = torch.sigmoid(paras[:, 25:26]) * (0.1 - 0.02) + 0.02
    # 按钮半径 (26-27)
    processed_paras[:, 26:27] = torch.sigmoid(paras[:, 26:27]) * (0.05 - 0.01) + 0.01
    # 按钮x位置 (27-28)
    processed_paras[:, 27:28] = torch.tanh(paras[:, 27:28]) * 0.5  # 在桌面宽度范围内
    
    # ===================== 开关参数 (28-32) =====================
    # 开关位置 (28-30)
    processed_paras[:, 28:30] = torch.tanh(paras[:, 28:30]) * 0.5
    # 开关半径 (30-31)
    processed_paras[:, 30:31] = torch.sigmoid(paras[:, 30:31]) * (0.03 - 0.005) + 0.005
    # 开关高度 (31-32)
    processed_paras[:, 31:32] = torch.sigmoid(paras[:, 31:32]) * (0.1 - 0.02) + 0.02
    
    # ===================== 滑动部件参数 (32-38) =====================
    # 滑板尺寸 (32-34)
    processed_paras[:, 32:34] = torch.sigmoid(paras[:, 32:34]) * (0.3 - 0.05) + 0.05
    # 把手外尺寸 (34-36)
    processed_paras[:, 34:36] = torch.sigmoid(paras[:, 34:36]) * (0.2 - 0.05) + 0.05
    # 把手高度 (36-37)
    processed_paras[:, 36:37] = torch.sigmoid(paras[:, 36:37]) * (0.1 - 0.02) + 0.02
    # 滑板位置 (37-38)
    processed_paras[:, 37:38] = torch.tanh(paras[:, 37:38]) * 0.5
    
    # ===================== 灯光参数 (38-40) =====================
    # 按钮尺寸 (38-40)
    processed_paras[:, 38:40] = torch.sigmoid(paras[:, 38:40]) * (0.1 - 0.02) + 0.02
    
    # ===================== 桌腿参数 (40-42) =====================
    # 桌腿尺寸 (40-42)
    processed_paras[:, 40:42] = torch.sigmoid(paras[:, 40:42]) * (0.1 - 0.02) + 0.02
    
    # ===================== 积木块参数 (42-69) =====================
    # 积木尺寸 (42-45, 51-54, 60-63)
    processed_paras[:, 42:45] = torch.sigmoid(paras[:, 42:45]) * (0.3 - 0.05) + 0.05
    processed_paras[:, 51:54] = torch.sigmoid(paras[:, 51:54]) * (0.3 - 0.05) + 0.05
    processed_paras[:, 60:63] = torch.sigmoid(paras[:, 60:63]) * (0.3 - 0.05) + 0.05
    
    # 积木位置 (45-48, 54-57, 63-66) - 约束到 [-1,1] 范围内
    processed_paras[:, 45:48] = torch.tanh(paras[:, 45:48]) * 1.0
    processed_paras[:, 54:57] = torch.tanh(paras[:, 54:57]) * 1.0
    processed_paras[:, 63:66] = torch.tanh(paras[:, 63:66]) * 1.0
    
    # 积木旋转 (48-51, 57-60, 66-69) - 约束到 [-π,π] 范围内
    processed_paras[:, 48:51] = torch.tanh(paras[:, 48:51]) * math.pi
    processed_paras[:, 57:60] = torch.tanh(paras[:, 57:60]) * math.pi
    processed_paras[:, 66:69] = torch.tanh(paras[:, 66:69]) * math.pi
    
    return processed_paras

if __name__ == "__main__":
    from utils import rotate
    paras = torch.ones([128, 69], device='cuda', dtype=torch.float32)
    StructureMap(paras)