import torch
import torch.nn as nn 
import math 
from lift3d.models.concept.MetaWorld.utils import rotate
from lift3d.models.concept.MetaWorld.utils import Para_Estimator, Knowledge_Encoder, PT_Knowledge_Encoder, Fusion_Knowledge_Encoder, PN_Knowledge_Encoder

def Rectangular_Ring(inner, outer, height, position, rotation, Sema):
    """
    :param inner[B,2]: [0]=inner_length, [1]=inner_width
    :param outer[B,2]: [0]=outer_length, [1]outer_width
    :height[B,1]: height of Rectangular_Ring
    :position[B,3], rotation[B,3]: position and rotation
    """
    B = inner.shape[0]

    # [B, 16, 3] vertex template
    vertices = torch.tensor([
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5]
    ], device=inner.device, dtype=inner.dtype).unsqueeze(0).repeat(B, 1, 1)  # [B, 16, 3]

    # [B, 1, 3] → [B, 8, 3] for outer & inner box
    outer_box = torch.cat([outer[:, 0:1], height, outer[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    inner_box = torch.cat([inner[:, 0:1], height, inner[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    deformation = torch.cat([outer_box, inner_box], dim=1)  # [B, 16, 3]

    # Apply scale, rotation, and translation
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(1)  # [B, 16, 3]
    Semantic = torch.ones([B, 16, 1], device=inner.device, dtype=inner.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)

    return Affordance

def Cuboid(size, position, rotation, Sema):
    """
    :size[B,3]: [0]= x, [1]= y, [2]= z
    :position[B,3], rotation[B,3]:
    """
    B = size.shape[0]
    
    vertices = torch.tensor([
        [-1 / 2, 1 / 2, 1 / 2],
        [1 / 2, 1 / 2, 1 / 2],
        [-1 / 2, 1 / 2, -1 / 2],
        [1 / 2, 1 / 2, -1 / 2],
        [-1 / 2, -1 / 2, 1 / 2],
        [1 / 2, -1 / 2, 1 / 2],
        [-1 / 2, -1 / 2, -1 / 2],
        [1 / 2, -1 / 2, -1 / 2]
    ], device=position.device, dtype=position.dtype)
    
    deformation = size.unsqueeze(dim=1).repeat(1, 8, 1)
    
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance

def Concept_ShelfPlace(paras):
    """
    :paras[:, 0:3]: position
    :paras[:, 3:6]: rotation
    :paras[:, 6:9]: BackBoard(x, y, z)
    :paras[:, 9:12]: Rectangular_Ring(thickness_x, thickness_z, y)
    :paras[:, 12:14]: ClapBoard(z, z_offset)
    
    :paras[:, 14:17, 17:20, 20:23]: size, position, rotation of block
    """
    
    B = paras.shape[0]
    position= paras[:, 0:3]
    rotation = paras[:, 3:6]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Backboard
    Backboard_size = paras[:, 6:9]
    Backboard_position = position
    vertices_backboard = Cuboid(Backboard_size, Backboard_position, rotation, 1)
    
    # Rectangular_Ring
    outer_size = torch.cat([Backboard_size[:, 0:1], Backboard_size[:, 2:3]], dim=-1)
    inner_size = outer_size - 2 * paras[:, 9:11]
    Rect_height = paras[:, 11:12]
    Rect_position = position + torch.cat([zeros, -0.5 * Rect_height - 0.5 * Backboard_size[:, 1:2], zeros], dim=-1)
    vertices_Rect = Rectangular_Ring(inner_size, outer_size, Rect_height, Rect_position, rotation, 2)
    
    # Clapboard
    Clapboard_size = torch.cat([inner_size[:, 0:1], Rect_height, paras[:, 13:14]], dim=-1)
    Clapboard_position = Rect_position + torch.cat([zeros, zeros, paras[:, 13:14]], dim=-1)
    vertices_Clapboard = Cuboid(Clapboard_size, Clapboard_position, rotation, 3)
    
    # Block
    vertices_Block = Cuboid(paras[:, 14:17], paras[:, 17:20], paras[:, 20:23], 4)
    direction = torch.tensor([0, 0, -1], device=paras.device, dtype=paras.dtype).expand(B, 3)  
    grasp_pose = torch.cat([paras[:, 18:21], direction], dim=-1) # shape (B, 6): [x, y, z, dx, dy, dz]
    
    Affordance = torch.cat([vertices_backboard, vertices_Rect, vertices_Clapboard, vertices_Block], dim=1)  # [B, 40, 4]
    return Affordance, grasp_pose

def process_paras(raw):
    """
    Normalize raw output from Para_Estimator into structured and bounded physical parameters.
    Input: raw tensor of shape [B, 24], unrestricted range
    Output: tensor of shape [B, 24], scaled to valid physical ranges
    """
    B = raw.shape[0]

    # 0-3: Position (x,y,z) ∈ [-3.0, 3.0]^3
    position = torch.tanh(raw[:, 0:3]) * 3

    # 3-6: Rotation (rx, ry, rz) ∈ [-π, π]^3
    rotation = torch.tanh(raw[:, 3:6]) * math.pi

    # 6-9: Backboard Size (x,y,z) ∈ [0 , 1]^3
    backboard = torch.sigmoid(raw[:, 6:9]) * 1

    # 9-12: Ring thickness (tx, tz) ∈ [0.005, 0.05]; height_y ∈ [0.05, 0.2]
    ring_thickness = torch.sigmoid(raw[:, 9:11]) * 0.045 + 0.005
    ring_height = torch.sigmoid(raw[:, 11:12]) * 0.15 + 0.05

    # 12-14: Clapboard [z] ∈ [0.02, 0.15]; offset ∈ [-0.25, 0.25]
    clap_yz = torch.sigmoid(raw[:, 12:13]) * 0.13 + 0.02
    clap_offset = torch.tanh(raw[:, 13:14]) * 0.25

    # 14-17: Block size ∈ [0.02, 0.1]
    block_size = torch.sigmoid(raw[:, 14:17]) * 0.08 + 0.02

    # 17-20: Block pos ∈ [-3, 3]
    block_pos = torch.tanh(raw[:, 17:20]) * 3

    # 20-23: Block rot ∈ [-π, π]
    block_rot = torch.tanh(raw[:, 20:23]) * math.pi

    return torch.cat([
        position, rotation,
        backboard,
        ring_thickness, ring_height,
        clap_yz, clap_offset,
        block_size, block_pos, block_rot
    ], dim=-1)

class Concept_Module_ShelfPlace(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Concept_Module_ShelfPlace, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Concept_ShelfPlace(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class PT_Concept_Module_ShelfPlace(nn.Module):
    def __init__ (self, in_dim, paralist):
        super(PT_Concept_Module_ShelfPlace, self).__init__()
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PT_Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Concept_ShelfPlace(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class Fusion_Concept_Module_ShelfPlace(nn.Module):
    def __init__ (self, in_dim, paralist):
        super(Fusion_Concept_Module_ShelfPlace, self).__init__()
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Fusion_Knowledge_Encoder(Knowledge_Semantic_dim, 40, Knowledge_xyz_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Concept_ShelfPlace(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state

class PN_Concept_Module_ShelfPlace(nn.Module):
    def __init__ (self, in_dim, paralist):
        super(PN_Concept_Module_ShelfPlace, self).__init__()
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PN_Knowledge_Encoder(Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Concept_ShelfPlace(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state 
    
#######################################################################

# Dense point cloud in concept space 

#######################################################################

def Dense_Rectangular_Ring(inner, outer, height, position, rotation, Sema, N=40):
    """
    Differentiable rectangular ring with point sampling
    :param inner[B,2]: [inner_length, inner_width]
    :param outer[B,2]: [outer_length, outer_width]
    :height[B,1]: height of Rectangular_Ring
    :position[B,3], rotation[B,3]: position and rotation
    :Sema: scalar semantic value
    :N: number of points to sample
    Returns: 
        - Affordance: [B, 16, 4] original vertices with semantics
        - SampledPoints: [B, N, 4] sampled points with semantics (differentiable)
    """
    B = inner.shape[0]
    device = inner.device
    dtype = inner.dtype
    
    # [B, 16, 3] vertex template
    vertices = torch.tensor([
        # Outer box vertices (0-7)
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
        # Inner box vertices (8-15)
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5]
    ], device=device, dtype=dtype).unsqueeze(0).repeat(B, 1, 1)
    
    # Scale vertices
    outer_box = torch.cat([outer[:, 0:1], height, outer[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    inner_box = torch.cat([inner[:, 0:1], height, inner[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    deformation = torch.cat([outer_box, inner_box], dim=1)
    
    # Apply transformations
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(1)
    
    # Define faces for sampling (fixed template)
    faces = torch.tensor([
        # Outer faces (8 triangles)
        [0,1,4], [5,1,4],
        [1,3,5], [7,3,5],
        [2,3,6], [7,3,6],
        [0,2,4], [6,2,4],
        # Inner faces (8 triangles)
        [8,9,12], [13,9,12],
        [9,11,13], [15,11,13],
        [10,11,14], [15,11,14],
        [8,10,12], [14,10,12],
        # Connecting faces (16 triangles)
        [0,1,8], [9,1,8],
        [1,3,9], [11,3,9],
        [2,3,10], [11,3,10],
        [0,2,8], [10,2,8],
        [4,5,12], [13,5,12],
        [5,7,13], [15,7,13],
        [6,7,14], [15,7,14],
        [4,6,12], [14,6,12],
    ], device=device)
    
    # Differentiable point sampling
    # 1. Compute face areas
    face_verts = vertices[:, faces]  # [B, 32, 3, 3]
    v0, v1, v2 = face_verts.unbind(2)
    areas = 0.5 * torch.norm(torch.cross(v1-v0, v2-v0), dim=2)  # [B, 32]
    
    # 2. Sample faces according to area distribution
    probs = areas / (areas.sum(dim=1, keepdim=True) + 1e-10)
    face_idx = torch.multinomial(probs, N, replacement=True)  # [B, N]
    face_idx_expanded = face_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 3, 3)
    
    # 3. Sample points within each triangle
    # Get barycentric coordinates
    u = torch.rand(B, N, 1, device=device)
    v = torch.rand(B, N, 1, device=device)
    mask = (u + v) > 1
    u = torch.where(mask, 1-u, u)
    v = torch.where(mask, 1-v, v)
    w = 1 - u - v
    
    # Gather corresponding face vertices
    selected_faces = torch.gather(
        face_verts, 
        1, 
        face_idx_expanded
    ).squeeze(2)  # [B, N, 3, 3]
    
    # Interpolate points
    sampled_points = (selected_faces[..., 0, :] * u + 
                     selected_faces[..., 1, :] * v + 
                     selected_faces[..., 2, :] * w)
    
    # Add semantics to sampled points
    sampled_semantic = torch.ones(B, N, 1, device=device) * Sema
    SampledPoints = torch.cat([sampled_points, sampled_semantic], dim=-1)
    
    return SampledPoints

def Dense_Cuboid(size, position, rotation, Sema, N=40):
    """
    Generate a cuboid mesh with differentiable point sampling
    
    Args:
        size: [B,3] tensor - Dimensions of the cuboid (x, y, z)
        position: [B,3] tensor - Center position of the cuboid
        rotation: [B,3] tensor - Rotation angles (Euler angles)
        Sema: scalar - Semantic value
        N: int - Number of points to sample (default: 1000)
        
    Returns:
        Affordance: [B, 8, 4] tensor - Original vertices with semantics
        SampledPoints: [B, N, 4] tensor - Sampled points with semantics
    """
    B = size.shape[0]
    device = size.device
    dtype = size.dtype
    
    # Define canonical vertices (unit cube centered at origin)
    vertices = torch.tensor([
        [-0.5,  0.5,  0.5], [ 0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [ 0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [ 0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [ 0.5, -0.5, -0.5]
    ], device=device, dtype=dtype)
    
    # Scale vertices to desired size
    vertices = vertices.unsqueeze(0).repeat(B, 1, 1) * size.unsqueeze(1)
    
    # Apply rotation and translation
    vertices = rotate(vertices, rotation) + position.unsqueeze(1)
    
    # Define faces for sampling (12 triangles)
    faces = torch.tensor([
        [0,1,2], [1,3,2],  # Top face
        [4,5,6], [5,7,6],  # Bottom face
        [0,4,1], [1,4,5],  # Front face
        [2,3,6], [3,7,6],  # Back face
        [0,2,4], [2,6,4],  # Left face
        [1,5,3], [3,5,7],   # Right face
    ], device=device)
    
    # Differentiable point sampling
    # 1. Compute face areas
    face_verts = vertices[:, faces]  # [B, 12, 3, 3]
    v0, v1, v2 = face_verts.unbind(2)
    areas = 0.5 * torch.norm(torch.cross(v1-v0, v2-v0), dim=2)  # [B, 12]
    
    # 2. Sample faces according to area distribution
    probs = areas / (areas.sum(dim=1, keepdim=True) + 1e-10)
    face_idx = torch.multinomial(probs, N, replacement=True)  # [B, N]
    face_idx_expanded = face_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 3, 3)
    
    # 3. Sample points within each triangle using barycentric coordinates
    u = torch.rand(B, N, 1, device=device)
    v = torch.rand(B, N, 1, device=device)
    mask = (u + v) > 1
    u = torch.where(mask, 1-u, u)
    v = torch.where(mask, 1-v, v)
    w = 1 - u - v
    
    # Gather corresponding face vertices and interpolate
    selected_faces = torch.gather(
        face_verts, 
        1, 
        face_idx_expanded
    ).squeeze(2)  # [B, N, 3, 3]
    
    sampled_points = (selected_faces[..., 0, :] * u + 
                     selected_faces[..., 1, :] * v + 
                     selected_faces[..., 2, :] * w)
    
    # Add semantic information to sampled points
    sampled_semantic = torch.ones(B, N, 1, device=device) * Sema
    SampledPoints = torch.cat([sampled_points, sampled_semantic], dim=-1)
    
    return SampledPoints

def Dense_Concept_ShelfPlace(paras):
    """
    :paras[:, 0:3]: position
    :paras[:, 3:6]: rotation
    :paras[:, 6:9]: BackBoard(x, y, z)
    :paras[:, 9:12]: Rectangular_Ring(thickness_x, thickness_z, y)
    :paras[:, 12:14]: ClapBoard(z, z_offset)
    
    :paras[:, 14:17, 17:20, 20:23]: size, position, rotation of block
    """
    B = paras.shape[0]
    position= paras[:, 0:3]
    rotation = paras[:, 3:6]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    
    # Backboard
    Backboard_size = paras[:, 6:9]
    Backboard_position = position
    vertices_backboard = Dense_Cuboid(Backboard_size, Backboard_position, rotation, 1)
    
    # Rectangular_Ring
    outer_size = torch.cat([Backboard_size[:, 0:1], Backboard_size[:, 2:3]], dim=-1)
    inner_size = outer_size - 2 * paras[:, 9:11]
    Rect_height = paras[:, 11:12]
    Rect_position = position + torch.cat([zeros, -0.5 * Rect_height - 0.5 * Backboard_size[:, 1:2], zeros], dim=-1)
    vertices_Rect = Dense_Rectangular_Ring(inner_size, outer_size, Rect_height, Rect_position, rotation, 2)
    
    # Clapboard
    Clapboard_size = torch.cat([inner_size[:, 0:1], Rect_height, paras[:, 13:14]], dim=-1)
    Clapboard_position = Rect_position + torch.cat([zeros, zeros, paras[:, 13:14]], dim=-1)
    vertices_Clapboard = Dense_Cuboid(Clapboard_size, Clapboard_position, rotation, 3)
    
    # Block
    vertices_Block = Dense_Cuboid(paras[:, 14:17], paras[:, 17:20], paras[:, 20:23], 4)
    direction = torch.tensor([0, 0, -1], device=paras.device, dtype=paras.dtype).expand(B, 3)  
    grasp_pose = torch.cat([paras[:, 18:21], direction], dim=-1) # shape (B, 6): [x, y, z, dx, dy, dz]
    
    Affordance = torch.cat([vertices_backboard, vertices_Rect, vertices_Clapboard, vertices_Block], dim=1)
    return Affordance, grasp_pose

class Dense_Concept_Module_ShelfPlace(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Dense_Concept_Module_ShelfPlace, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_ShelfPlace(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state

class PT_Dense_Concept_Module_ShelfPlace(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(PT_Dense_Concept_Module_ShelfPlace, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PT_Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_ShelfPlace(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state

class Fusion_Dense_Concept_Module_ShelfPlace(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Fusion_Dense_Concept_Module_ShelfPlace, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Fusion_Knowledge_Encoder(Knowledge_Semantic_dim, 160, Knowledge_xyz_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_ShelfPlace(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class PN_Dense_Concept_Module_ShelfPlace(nn.Module):
    def __init__ (self, in_dim, paralist):
        super(PN_Dense_Concept_Module_ShelfPlace, self).__init__()
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PN_Knowledge_Encoder(Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_ShelfPlace(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state 