import abc
from typing import List

import clip
import numpy as np
import torch
import torch.nn as nn

from lift3d.helpers.graphics import PointCloud
from lift3d.models.mlp.batchnorm_mlp import BatchNormMLP
from lift3d.models.mlp.mlp import MLP

# concept
from lift3d.models.concept.MetaWorld.Assembly import \
    Concept_Module_Assembly, PN_Concept_Module_Assembly, Fusion_Concept_Module_Assembly, PT_Concept_Module_Assembly, \
        Dense_Concept_Module_Assembly, Fusion_Dense_Concept_Module_Assembly, PT_Dense_Concept_Module_Assembly, PN_Dense_Concept_Module_Assembly
from lift3d.models.concept.MetaWorld.Shelf_Place import  \
    Concept_Module_ShelfPlace, PT_Concept_Module_ShelfPlace, Fusion_Concept_Module_ShelfPlace, PN_Concept_Module_ShelfPlace, \
        Dense_Concept_Module_ShelfPlace, Fusion_Dense_Concept_Module_ShelfPlace, PT_Dense_Concept_Module_ShelfPlace, PN_Dense_Concept_Module_ShelfPlace
from lift3d.models.concept.MetaWorld.Hand_Insert import \
    Concept_Module_HandInsert, PT_Concept_Module_HandInsert, Fusion_Concept_Module_HandInsert, PN_Concept_Module_HandInsert, \
        Dense_Concept_Module_HandInsert, Fusion_Dense_Concept_Module_HandInsert, PT_Dense_Concept_Module_HandInsert, PN_Dense_Concept_Module_HandInsert
from lift3d.models.concept.MetaWorld.Hammer import \
    Concept_Module_Hammer, PT_Concept_Module_Hammer, Fusion_Concept_Module_Hammer, PN_Concept_Module_Hammer, \
        Dense_Concept_Module_Hammer, PT_Dense_Concept_Module_Hammer, Fusion_Dense_Concept_Module_Hammer, PN_Dense_Concept_Module_Hammer
from lift3d.models.concept.MetaWorld.Sweep_Into import \
    Concept_Module_SweepInto, PT_Concept_Module_SweepInto, Fusion_Concept_Module_SweepInto, PN_Concept_Module_SweepInto, \
        Dense_Concept_Module_SweepInto, PT_Dense_Concept_Module_SweepInto, Fusion_Dense_Concept_Module_SweepInto, PN_Dense_Concept_Module_SweepInto
from lift3d.models.concept.MetaWorld.Bin_Picking import \
    Concept_Module_BinPicking, PT_Concept_Module_BinPicking, Fusion_Concept_Module_BinPicking, PN_Concept_Module_BinPicking, \
        Dense_Concept_Module_BinPicking, PT_Dense_Concept_Module_BinPicking, Fusion_Dense_Concept_Module_BinPicking, PN_Dense_Concept_Module_BinPicking
from lift3d.models.concept.MetaWorld.Push_Wall import \
    Concept_Module_PushWall, PT_Concept_Module_PushWall, Fusion_Concept_Module_PushWall, PN_Concept_Module_PushWall, \
        Dense_Concept_Module_PushWall, PT_Dense_Concept_Module_PushWall, Fusion_Dense_Concept_Module_PushWall, PN_Dense_Concept_Module_PushWall
from lift3d.models.concept.MetaWorld.Box_Close import \
    Concept_Module_BoxClose, PT_Concept_Module_BoxClose, Fusion_Concept_Module_BoxClose, PN_Concept_Module_BoxClose, \
        Dense_Concept_Module_BoxClose, PT_Dense_Concept_Module_BoxClose, Fusion_Dense_Concept_Module_BoxClose, PN_Dense_Concept_Module_BoxClose
from lift3d.models.concept.MetaWorld.Button_Press import \
    Concept_Module_ButtonPress, PT_Concept_Module_ButtonPress, Fusion_Concept_Module_ButtonPress, PN_Concept_Module_ButtonPress, \
        Dense_Concept_Module_ButtonPress, PT_Dense_Concept_Module_ButtonPress, Fusion_Dense_Concept_Module_ButtonPress, PN_Dense_Concept_Module_ButtonPress
from lift3d.models.concept.MetaWorld.Reach import Concept_Module_Reach
from lift3d.models.concept.MetaWorld.Drawer_Open import \
    Concept_Module_DrawerOpen, PT_Concept_Module_DrawerOpen, Fusion_Concept_Module_DrawerOpen, PN_Concept_Module_DrawerOpen, \
        Dense_Concept_Module_DrawerOpen, PT_Dense_Concept_Module_DrawerOpen, Fusion_Dense_Concept_Module_DrawerOpen, PN_Dense_Concept_Module_DrawerOpen
from lift3d.models.concept.MetaWorld.Handle_Pull import \
    Concept_Module_HandlePull, PT_Concept_Module_HandlePull, Fusion_Concept_Module_HandlePull, PN_Concept_Module_HandlePull, \
        Dense_Concept_Module_HandlePull, PT_Dense_Concept_Module_HandlePull, Fusion_Dense_Concept_Module_HandlePull, PN_Dense_Concept_Module_HandlePull
from lift3d.models.concept.MetaWorld.Peg_Unplug_Side import \
    Concept_Module_PegUnplugSide, PT_Concept_Module_PegUnplugSide, Fusion_Concept_Module_PegUnplugSide, PN_Concept_Module_PegUnplugSide, \
        Dense_Concept_Module_PegUnplugSide, PT_Dense_Concept_Module_PegUnplugSide, Fusion_Dense_Concept_Module_PegUnplugSide, PN_Dense_Concept_Module_PegUnplugSide
from lift3d.models.concept.MetaWorld.Lever_Pull import \
    Concept_Module_LeverPull, PT_Concept_Module_LeverPull, Fusion_Concept_Module_LeverPull, PN_Concept_Module_LeverPull, \
        Dense_Concept_Module_LeverPull, PT_Dense_Concept_Module_LeverPull, Fusion_Dense_Concept_Module_LeverPull, PN_Dense_Concept_Module_LeverPull
from lift3d.models.concept.MetaWorld.Dial_Turn import \
    Concept_Module_DialTurn, PT_Concept_Module_DialTurn, Fusion_Concept_Module_DialTurn, PN_Concept_Module_DialTurn, \
        Dense_Concept_Module_DialTurn, PT_Dense_Concept_Module_DialTurn, Fusion_Dense_Concept_Module_DialTurn, PN_Dense_Concept_Module_DialTurn

class Actor(nn.Module, metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def forward(self, images, point_clouds, robot_states):
        pass


class VisionGuidedMLP(Actor):
    def __init__(
        self,
        image_encoder: nn.Module,
        image_dropout_rate: float,
        robot_state_dim: int,
        robot_state_dropout_rate: float,
        action_dim: int,
        policy_hidden_dims: List[int],
        policy_head_init_method: str,
    ):
        super(VisionGuidedMLP, self).__init__()
        self.image_encoder = image_encoder
        self.image_dropout = nn.Dropout(image_dropout_rate)
        self.robot_state_encoder = nn.Linear(robot_state_dim, image_encoder.feature_dim)
        self.robot_state_dropout = nn.Dropout(robot_state_dropout_rate)
        self.policy_head = MLP(
            input_dim=2 * image_encoder.feature_dim,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            init_method=policy_head_init_method,
        )

    def forward(self, images, point_clouds, robot_states, texts):
        image_emb = self.image_encoder(images)
        image_emb = self.image_dropout(image_emb)
        robot_state_emb = self.robot_state_encoder(robot_states)
        robot_state_emb = self.robot_state_dropout(robot_state_emb)
        emb = torch.cat([image_emb, robot_state_emb], dim=1)
        actions = self.policy_head(emb)
        return actions


class PointCloudGuidedMLP(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        point_cloud_dropout_rate: float,
        robot_state_dim: int,
        robot_state_dropout_rate: float,
        action_dim: int,
        policy_hidden_dims: List[int],
        policy_head_init_method: str,
    ):
        super(PointCloudGuidedMLP, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.point_cloud_dropout = nn.Dropout(point_cloud_dropout_rate)
        self.robot_state_encoder = nn.Linear(
            robot_state_dim, point_cloud_encoder.feature_dim
        )
        self.robot_state_dropout = nn.Dropout(robot_state_dropout_rate)
        self.policy_head = MLP(
            input_dim=2 * point_cloud_encoder.feature_dim,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            init_method=policy_head_init_method,
        )

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        point_cloud_emb = self.point_cloud_dropout(point_cloud_emb)
        robot_state_emb = self.robot_state_encoder(robot_states)
        robot_state_emb = self.robot_state_dropout(robot_state_emb)
        emb = torch.cat([point_cloud_emb, robot_state_emb], dim=1)
        actions = self.policy_head(emb)
        return actions


class VisionGuidedBatchNormMLP(Actor):
    def __init__(
        self,
        image_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
    ):
        super(VisionGuidedBatchNormMLP, self).__init__()
        self.image_encoder = image_encoder
        self.policy_head = BatchNormMLP(
            input_dim=image_encoder.feature_dim + robot_state_dim,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        image_emb = self.image_encoder(images)
        emb = torch.cat([image_emb, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions


class PointCloudGuidedBatchNormMLP(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
    ):
        super(PointCloudGuidedBatchNormMLP, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        emb = torch.cat([point_cloud_emb, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions

#################################################################################

# Main Experiments Code

#################################################################################

    
class Concept_Assembly(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_Assembly, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_Assembly(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_ShelfPlace(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_ShelfPlace, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_ShelfPlace(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   

class Concept_HandInsert(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_HandInsert, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_HandInsert(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_Hammer(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_Hammer, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_Hammer(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_SweepInto(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_SweepInto, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_SweepInto(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_BinPicking(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_BinPicking, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_BinPicking(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_PushWall(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_PushWall, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_PushWall(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_BoxClose(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_BoxClose, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_BoxClose(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_ButtonPress(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_ButtonPress, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_ButtonPress(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_Reach(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_Reach, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_Reach(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4],
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_DrawerOpen(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_DrawerOpen, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_DrawerOpen(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_HandlePull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_HandlePull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_HandlePull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_PegUnplugSide(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_PegUnplugSide, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_PegUnplugSide(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_LeverPull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_LeverPull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_LeverPull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
class Concept_DialTurn(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Concept_DialTurn, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Concept_Module_DialTurn(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
    
#############################################################################

# Ablation Study Code

#############################################################################

# Ex1: PointNet + Dense
class Dense_PN_Concept_Assembly(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_Assembly, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_Assembly(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_BinPicking(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_BinPicking, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_BinPicking(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_BoxClose(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_BoxClose, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_BoxClose(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_ButtonPress(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_ButtonPress, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_ButtonPress(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_DialTurn(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_DialTurn, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_DialTurn(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_DrawerOpen(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_DrawerOpen, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_DrawerOpen(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_Hammer(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_Hammer, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_Hammer(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_HandInsert(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_HandInsert, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_HandInsert(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_HandlePull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_HandlePull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_HandlePull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_LeverPull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_LeverPull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_LeverPull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_PegUnplugSide(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_PegUnplugSide, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_PegUnplugSide(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_PushWall(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_PushWall, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_PushWall(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_ShelfPlace(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_ShelfPlace, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_ShelfPlace(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_PN_Concept_SweepInto(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PN_Concept_SweepInto, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Dense_Concept_Module_SweepInto(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data
    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions

# Ex2: PointNet + Simple
class PN_Concept_Assembly(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_Assembly, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_Assembly(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_BinPicking(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_BinPicking, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_BinPicking(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_BoxClose(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_BoxClose, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_BoxClose(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_ButtonPress(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_ButtonPress, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_ButtonPress(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_DialTurn(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_DialTurn, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_DialTurn(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_DrawerOpen(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_DrawerOpen, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_DrawerOpen(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_Hammer(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_Hammer, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_Hammer(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_HandInsert(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_HandInsert, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_HandInsert(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_HandlePull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_HandlePull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_HandlePull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_LeverPull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_LeverPull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_LeverPull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_PegUnplugSide(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_PegUnplugSide, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_PegUnplugSide(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_PushWall(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_PushWall, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_PushWall(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_ShelfPlace(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_ShelfPlace, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_ShelfPlace(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class PN_Concept_SweepInto(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PN_Concept_SweepInto, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PN_Concept_Module_SweepInto(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions

# Ex3: DenseFusion + Dense
class Dense_Fusion_Concept_Assembly(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_Assembly, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_Assembly(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_BinPicking(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_BinPicking, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_BinPicking(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_BoxClose(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_BoxClose, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_BoxClose(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_ButtonPress(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_ButtonPress, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_ButtonPress(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_DialTurn(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_DialTurn, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_DialTurn(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_DrawerOpen(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_DrawerOpen, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_DrawerOpen(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_Hammer(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_Hammer, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_Hammer(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_HandInsert(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_HandInsert, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_HandInsert(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_HandlePull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_HandlePull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_HandlePull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_LeverPull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_LeverPull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_LeverPull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_PegUnplugSide(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_PegUnplugSide, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_PegUnplugSide(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_PushWall(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_PushWall, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_PushWall(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_ShelfPlace(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_ShelfPlace, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_ShelfPlace(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Dense_Fusion_Concept_SweepInto(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Fusion_Concept_SweepInto, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Dense_Concept_Module_SweepInto(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions

# Ex4: DenseFusion + Simple
class Fusion_Concept_Assembly(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_Assembly, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_Assembly(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_BinPicking(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_BinPicking, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_BinPicking(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_BoxClose(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_BoxClose, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_BoxClose(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_ButtonPress(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_ButtonPress, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_ButtonPress(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_DialTurn(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_DialTurn, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_DialTurn(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_DrawerOpen(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_DrawerOpen, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_DrawerOpen(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_Hammer(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_Hammer, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_Hammer(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_HandInsert(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_HandInsert, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_HandInsert(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_HandlePull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_HandlePull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_HandlePull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_LeverPull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_LeverPull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_LeverPull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_PegUnplugSide(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_PegUnplugSide, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_PegUnplugSide(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_PushWall(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_PushWall, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_PushWall(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_ShelfPlace(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_ShelfPlace, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_ShelfPlace(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions
class Fusion_Concept_SweepInto(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Fusion_Concept_SweepInto, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Fusion_Concept_Module_SweepInto(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions

# Ex5: PointTransformer + Dense
class Dense_PT_Concept_Assembly(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_Assembly, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_Assembly(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_BinPicking(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_BinPicking, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_BinPicking(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_BoxClose(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_BoxClose, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_BoxClose(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_ButtonPress(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_ButtonPress, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_ButtonPress(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_DialTurn(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_DialTurn, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_DialTurn(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_DrawerOpen(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_DrawerOpen, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_DrawerOpen(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_Hammer(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_Hammer, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_Hammer(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_HandInsert(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_HandInsert, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_HandInsert(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_HandlePull(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_HandlePull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_HandlePull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_LeverPull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_LeverPull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_LeverPull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_PegUnplugSide(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_PegUnplugSide, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_PegUnplugSide(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_PushWall(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_PushWall, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_PushWall(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_ShelfPlace(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_ShelfPlace, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_ShelfPlace(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class Dense_PT_Concept_SweepInto(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_PT_Concept_SweepInto, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Dense_Concept_Module_SweepInto(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   

# Ex6: PointTransformer + Simple
class PT_Concept_Assembly(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_Assembly, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_Assembly(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_BinPicking(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_BinPicking, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_BinPicking(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_BoxClose(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_BoxClose, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_BoxClose(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_ButtonPress(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_ButtonPress, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_ButtonPress(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_DialTurn(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_DialTurn, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_DialTurn(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_DrawerOpen(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_DrawerOpen, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_DrawerOpen(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_Hammer(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_Hammer, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_Hammer(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_HandInsert(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_HandInsert, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_HandInsert(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_HandlePull(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_HandlePull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_HandlePull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_LeverPull(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_LeverPull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_LeverPull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_PegUnplugSide(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_PegUnplugSide, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_PegUnplugSide(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_PushWall(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_PushWall, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_PushWall(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_ShelfPlace(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_ShelfPlace, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_ShelfPlace(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   
class PT_Concept_SweepInto(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(PT_Concept_SweepInto, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = PT_Concept_Module_SweepInto(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions   

# Ex7: Adapted PT + Dense
class Dense_Concept_Assembly(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_Assembly, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_Assembly(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions  
class Dense_Concept_BinPicking(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_BinPicking, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_BinPicking(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions  
class Dense_Concept_BoxClose(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_BoxClose, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_BoxClose(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions  
class Dense_Concept_ButtonPress(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_ButtonPress, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_ButtonPress(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions  
class Dense_Concept_DialTurn(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_DialTurn, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_DialTurn(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions  
class Dense_Concept_DrawerOpen(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_DrawerOpen, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_DrawerOpen(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions  
class Dense_Concept_Hammer(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_Hammer, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_Hammer(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions  
class Dense_Concept_HandInsert(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_HandInsert, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_HandInsert(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions 
class Dense_Concept_HandlePull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_HandlePull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_HandlePull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions 
class Dense_Concept_LeverPull(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_LeverPull, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_LeverPull(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions    
class Dense_Concept_PegUnplugSide(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_PegUnplugSide, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_PegUnplugSide(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions  
class Dense_Concept_PushWall(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_PushWall, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_PushWall(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions    
class Dense_Concept_ShelfPlace(Actor):

    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_ShelfPlace, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_ShelfPlace(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions  
class Dense_Concept_SweepInto(Actor):
    def __init__(
        self,
        point_cloud_encoder: nn.Module,
        robot_state_dim: int,
        action_dim: int,
        policy_hidden_dims: List[int],
        nonlinearity: str,
        dropout_rate: float,
        concept_para_list: List[int],
    ):
        super(Dense_Concept_SweepInto, self).__init__()
        self.point_cloud_encoder = point_cloud_encoder
        self.concept = Dense_Concept_Module_SweepInto(point_cloud_encoder.feature_dim, concept_para_list)
        self.policy_head = BatchNormMLP(
            input_dim=point_cloud_encoder.feature_dim + robot_state_dim + concept_para_list[4] + 6,
            hidden_dims=policy_hidden_dims,
            output_dim=action_dim,
            nonlinearity=nonlinearity,
            dropout_rate=dropout_rate,
        )
        for param in list(self.policy_head.parameters())[-2:]:
            param.data = 1e-2 * param.data

    def forward(self, images, point_clouds, robot_states, texts):
        # * Notice: normalize the input point cloud
        point_clouds = PointCloud.normalize(point_clouds)
        point_cloud_emb = self.point_cloud_encoder(point_clouds)
        concept_feature, robot_states = self.concept(point_cloud_emb, robot_states)
        emb = torch.cat([point_cloud_emb, concept_feature, robot_states], dim=1)
        actions = self.policy_head(emb)
        return actions  