import torch
import torch.nn as nn 
import math 
from lift3d.models.concept.MetaWorld.utils import rotate
from lift3d.models.concept.MetaWorld.utils import Para_Estimator, Knowledge_Encoder, PN_Knowledge_Encoder, PT_Knowledge_Encoder, Fusion_Knowledge_Encoder
from lift3d.models.concept.MetaWorld.Dense_Concept import Dense_Cylinder, Dense_Cuboid

def Cylinder (height, radius, position, rotation, theta, Sem, order):
    """
    :param height: [B, 1], height of cylinder in Z-axis
    :param radius: [B, 1], radius of cylinder in X-Y-Plane
    :param position, rotation: [B,3], [B,3]
    :param theta: int
    """
    B = height.shape[0]
    n = int(360 / theta)
    theta_rad = theta * math.pi / 180.0
    
    if order == 1 : 
        V = [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), -0.5] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), 0.5] for i in range(n)
        ]
    elif order == 2 :
        V = [
            [math.cos(i * theta_rad), -0.5, math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), 0.5, math.sin(i * theta_rad)] for i in range(n)
        ]
    else :
        V = [
            [-0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ]
    
    vertices = torch.tensor(V, dtype=height.dtype, device=height.device).unsqueeze(dim=0).repeat(B,1,1)  # shape [B, 2n, 3]
    deformation = torch.cat([radius, radius, height], dim=1).unsqueeze(dim=1)
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    
    semantic = torch.ones([B,2*n,1], device=height.device, dtype=height.dtype) * Sem
    Affordance = torch.cat([vertices, semantic], dim=-1)
    
    return Affordance

def Cuboid(size, position, rotation, Sema):
    """
    :size[B,3]: [0]= x, [1]= y, [2]= z
    :position[B,3], rotation[B,3]:
    """
    B = size.shape[0]
    
    vertices = torch.tensor([
        [-1 / 2, 1 / 2, 1 / 2],
        [1 / 2, 1 / 2, 1 / 2],
        [-1 / 2, 1 / 2, -1 / 2],
        [1 / 2, 1 / 2, -1 / 2],
        [-1 / 2, -1 / 2, 1 / 2],
        [1 / 2, -1 / 2, 1 / 2],
        [-1 / 2, -1 / 2, -1 / 2],
        [1 / 2, -1 / 2, -1 / 2]
    ], device=position.device, dtype=position.dtype)
    
    deformation = size.unsqueeze(dim=1).repeat(1, 8, 1)
    
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance

def Concept_PegUnplugSide (paras):
    """
    :paras[:, 0:8]: peg parameters
    :paras[:, 8:11]: box size
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Peg
    radius = paras[:, 0:1]
    height = paras[:, 1:2]
    position_peg = paras[:, 2:5]
    rotation = paras[:, 5:8]
    vertices_peg = Cylinder(height, radius, position_peg, rotation, 30, 1, 2)
    
    direction = rotate(torch.cat([zeros, ones, zeros], dim=-1).unsqueeze(dim=1), rotation).view([B,3])
    Grasp_pose = torch.cat([position_peg, direction], dim=-1)
    
    # Box
    size = paras[:, 8:11]
    position_box = position_peg + torch.cat([zeros, 0.5 * height - 0.5 * size[:, 1:2], zeros], dim=-1)
    vertices_box = Cuboid(size, position_box, rotation, 2)
    
    Affordance = torch.cat([vertices_peg, vertices_box], dim=1)
    return Affordance, Grasp_pose

def process_paras(raw):
    """
    Normalize raw output for PegUnplugSide
    Input: [B, 11]
    Output: [B, 11]
    """
    B = raw.shape[0]

    # 0: radius ∈ [0.005, 0.05]
    radius = torch.sigmoid(raw[:, 0:1]) * 0.045 + 0.005

    # 1: height ∈ [0.05, 0.2]
    height = torch.sigmoid(raw[:, 1:2]) * 0.15 + 0.05

    # 2-5: position ∈ [-3, 3]^3
    position = torch.tanh(raw[:, 2:5]) * 3.0

    # 5-8: rotation ∈ [-π, π]^3
    rotation = torch.tanh(raw[:, 5:8]) * math.pi

    # 8-11: box size ∈ [0.05, 0.3]^3
    box_size = torch.sigmoid(raw[:, 8:11]) * 0.25 + 0.05

    return torch.cat([
        radius, height,
        position, rotation,
        box_size
    ], dim=-1)
    
class Concept_Module_PegUnplugSide(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Concept_Module_PegUnplugSide, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        knowledge, grasp_pose = Concept_PegUnplugSide(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class PN_Concept_Module_PegUnplugSide(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(PN_Concept_Module_PegUnplugSide, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PN_Knowledge_Encoder(Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        knowledge, grasp_pose = Concept_PegUnplugSide(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state

class Fusion_Concept_Module_PegUnplugSide(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Fusion_Concept_Module_PegUnplugSide, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Fusion_Knowledge_Encoder(Knowledge_Semantic_dim, 24+8, Knowledge_xyz_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        knowledge, grasp_pose = Concept_PegUnplugSide(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class PT_Concept_Module_PegUnplugSide(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(PT_Concept_Module_PegUnplugSide, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PT_Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        knowledge, grasp_pose = Concept_PegUnplugSide(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
#######################################################################

# Dense point cloud in concept space 

#######################################################################

def Dense_Concept_PegUnplugSide (paras):
    """
    :paras[:, 0:8]: peg parameters
    :paras[:, 8:11]: box size
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Peg
    radius = paras[:, 0:1]
    height = paras[:, 1:2]
    position_peg = paras[:, 2:5]
    rotation = paras[:, 5:8]
    vertices_peg = Dense_Cylinder(height, radius, position_peg, rotation, 30, Sem=1, order=2, N=40)
    
    direction = rotate(torch.cat([zeros, ones, zeros], dim=-1).unsqueeze(dim=1), rotation).view([B,3])
    Grasp_pose = torch.cat([position_peg, direction], dim=-1)
    
    # Box
    size = paras[:, 8:11]
    position_box = position_peg + torch.cat([zeros, 0.5 * height - 0.5 * size[:, 1:2], zeros], dim=-1)
    vertices_box = Dense_Cuboid(size, position_box, rotation, Sem=2, N=40)
    
    Affordance = torch.cat([vertices_peg, vertices_box], dim=1)
    return Affordance, Grasp_pose

class Dense_Concept_Module_PegUnplugSide(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Dense_Concept_Module_PegUnplugSide, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_PegUnplugSide(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class PN_Dense_Concept_Module_PegUnplugSide(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(PN_Dense_Concept_Module_PegUnplugSide, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PN_Knowledge_Encoder(Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_PegUnplugSide(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state

class Fusion_Dense_Concept_Module_PegUnplugSide(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Fusion_Dense_Concept_Module_PegUnplugSide, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Fusion_Knowledge_Encoder(Knowledge_Semantic_dim, 40+40, Knowledge_xyz_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # aras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_PegUnplugSide(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class PT_Dense_Concept_Module_PegUnplugSide(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(PT_Dense_Concept_Module_PegUnplugSide, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PT_Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_PegUnplugSide(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state