import torch
import torch.nn as nn 
import math 
from lift3d.models.concept.MetaWorld.utils import rotate
from lift3d.models.concept.MetaWorld.utils import Para_Estimator, Knowledge_Encoder, PN_Knowledge_Encoder, PT_Knowledge_Encoder, Fusion_Knowledge_Encoder
from lift3d.models.concept.MetaWorld.Dense_Concept import Dense_Cuboid

def Cuboid(size, position, rotation, Sema):
    """
    :size[B,3]: [0]=length, [1]=width, [2]=height
    :position[B,3], rotation[B,3]:
    """
    B = size.shape[0]
    
    vertices = torch.tensor([
        [-1 / 2, 1 / 2, 1 / 2],
        [1 / 2, 1 / 2, 1 / 2],
        [-1 / 2, 1 / 2, -1 / 2],
        [1 / 2, 1 / 2, -1 / 2],
        [-1 / 2, -1 / 2, 1 / 2],
        [1 / 2, -1 / 2, 1 / 2],
        [-1 / 2, -1 / 2, -1 / 2],
        [1 / 2, -1 / 2, -1 / 2]
    ], device=position.device, dtype=position.dtype)
    
    deformation = size.unsqueeze(dim=1).repeat(1, 8, 1)
    
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance

def Concept_SweepInto(paras):
    """
    paras[:, 0:9]   parameters of Cuboid
    paras[:, 9:18]  parameters of hole
    """
    B = paras.shape[0]
    vertices_block = Cuboid(paras[:, 0:3], paras[:, 3:6], paras[:, 6:9], 1)
    vertices_hole = Cuboid(paras[:, 9:12], paras[:, 12:15], paras[:, 15:18], 2)
    
    Direction = torch.cat([torch.zeros([B,2], device=paras.device, dtype=paras.dtype), -torch.ones([B,1], device=paras.device, dtype=paras.dtype)], dim=-1)
    Grasp_pose = torch.cat([paras[:, 3:6], Direction], dim=-1)
    
    Affordance = torch.cat([vertices_block, vertices_hole], dim=1)
    
    return Affordance, Grasp_pose

def process_paras(raw):
    """
    Normalize raw output from Para_Estimator into structured and bounded physical parameters.
    Input: raw tensor of shape [B, 18], unrestricted range
    Output: tensor of shape [B, 18], scaled to valid physical ranges
    """
    B = raw.shape[0]

    # Block
    block_size = torch.sigmoid(raw[:, 0:3]) * 0.08 + 0.02        # [0.02, 0.10]
    block_pos = torch.tanh(raw[:, 3:6]) * 3.0                    # [-3, 3]
    block_rot = torch.tanh(raw[:, 6:9]) * math.pi               # [-π, π]

    # Hole
    hole_size = torch.sigmoid(raw[:, 9:12]) * 0.08 + 0.02        # [0.02, 0.10]
    hole_pos = torch.tanh(raw[:, 12:15]) * 3.0                   # [-3, 3]
    hole_rot = torch.tanh(raw[:, 15:18]) * math.pi              # [-π, π]

    return torch.cat([
        block_size, block_pos, block_rot,
        hole_size, hole_pos, hole_rot
    ], dim=-1)

class Concept_Module_SweepInto(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Concept_Module_SweepInto, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        knowledge, grasp_pose = Concept_SweepInto(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class PN_Concept_Module_SweepInto(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(PN_Concept_Module_SweepInto, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PN_Knowledge_Encoder(Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        knowledge, grasp_pose = Concept_SweepInto(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class Fusion_Concept_Module_SweepInto(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Fusion_Concept_Module_SweepInto, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Fusion_Knowledge_Encoder(Knowledge_Semantic_dim, 8+8, Knowledge_xyz_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        knowledge, grasp_pose = Concept_SweepInto(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class PT_Concept_Module_SweepInto(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(PT_Concept_Module_SweepInto, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PT_Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        knowledge, grasp_pose = Concept_SweepInto(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
#######################################################################

# Dense point cloud in concept space 

#######################################################################

def Dense_Concept_SweepInto(paras):
    """
    paras[:, 0:9]   parameters of Cuboid
    paras[:, 9:18]  parameters of hole
    """
    B = paras.shape[0]
    vertices_block = Dense_Cuboid(paras[:, 0:3], paras[:, 3:6], paras[:, 6:9], Sem=1, N=40)
    vertices_hole = Dense_Cuboid(paras[:, 9:12], paras[:, 12:15], paras[:, 15:18], Sem=2, N=40)
    
    Direction = torch.cat([torch.zeros([B,2], device=paras.device, dtype=paras.dtype), -torch.ones([B,1], device=paras.device, dtype=paras.dtype)], dim=-1)
    Grasp_pose = torch.cat([paras[:, 3:6], Direction], dim=-1)
    
    Affordance = torch.cat([vertices_block, vertices_hole], dim=1)
    
    return Affordance, Grasp_pose

class Dense_Concept_Module_SweepInto(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Dense_Concept_Module_SweepInto, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        knowledge, grasp_pose = Dense_Concept_SweepInto(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class PN_Dense_Concept_Module_SweepInto(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(PN_Dense_Concept_Module_SweepInto, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PN_Knowledge_Encoder(Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        knowledge, grasp_pose = Dense_Concept_SweepInto(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class Fusion_Dense_Concept_Module_SweepInto(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Fusion_Dense_Concept_Module_SweepInto, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Fusion_Knowledge_Encoder(Knowledge_Semantic_dim, 40+40, Knowledge_xyz_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        knowledge, grasp_pose = Dense_Concept_SweepInto(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class PT_Dense_Concept_Module_SweepInto(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(PT_Dense_Concept_Module_SweepInto, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PT_Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        knowledge, grasp_pose = Dense_Concept_SweepInto(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    