import torch
import torch.nn as nn 
import math 
from StructPolicy.StructGen.utils import rotate
from StructPolicy.StructTransformer.Encoders import Para_Estimator, Structure_Encoder, PN_Structure_Encoder, Fusion_Structure_Encoder, PT_Structure_Encoder
from StructPolicy.StructGen.Uniform_Structure import Uniform_Cuboid, Uniform_Cylinder, Uniform_Rectangular_Ring, Uniform_Handle

def Rectangular_Ring(inner, outer, height, position, rotation, Sema):
    """
    :param inner[B,2]: [0]=inner_length, [1]=inner_width
    :param outer[B,2]: [0]=outer_length, [1]outer_width
    :height[B,1]: height of Rectangular_Ring
    :position[B,3], rotation[B,3]: position and rotation
    """
    B = inner.shape[0]

    # [B, 16, 3] vertex template
    vertices = torch.tensor([
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5]
    ], device=inner.device, dtype=inner.dtype).unsqueeze(0).repeat(B, 1, 1)  # [B, 16, 3]

    # [B, 1, 3] → [B, 8, 3] for outer & inner box
    outer_box = torch.cat([outer[:, 0:1], outer[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    inner_box = torch.cat([inner[:, 0:1], inner[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    deformation = torch.cat([outer_box, inner_box], dim=1)  # [B, 16, 3]

    # Apply scale, rotation, and translation
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(1)  # [B, 16, 3]
    Semantic = torch.ones([B, 16, 1], device=inner.device, dtype=inner.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)

    return Affordance

def Handle(inner, outer, height, position, rotation, Sema):
    """
    :param inner[B,2]: [0]=inner_length, [1]=inner_width
    :param outer[B,2]: [0]=outer_length, [1]outer_width
    :height[B,1]: height of Handle
    :position[B,3], rotation[B,3]: position and rotation
    """
    B = inner.shape[0]

    # [B, 16, 3] vertex template
    vertices = torch.tensor([
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5]
    ], device=inner.device, dtype=inner.dtype).unsqueeze(0).repeat(B, 1, 1)  # [B, 16, 3]

    # [B, 1, 3] → [B, 8, 3] for outer & inner box
    outer_box1 = torch.cat([outer[:, 0:1], height, outer[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 4, 1)
    outer_box2 = torch.cat([outer[:, 0:1], height, inner[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 4, 1)
    inner_box = torch.cat([inner[:, 0:1], height, inner[:, 1:2]], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    deformation = torch.cat([outer_box1, outer_box2, inner_box], dim=1)  # [B, 16, 3]

    # Apply scale, rotation, and translation
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(1)  # [B, 16, 3]
    Semantic = torch.ones([B, 16, 1], device=inner.device, dtype=inner.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)

    return Affordance

def Cuboid(size, position, rotation, Sema):
    """
    :size[B,3]: [0]= x, [1]= y, [2]= z
    :position[B,3], rotation[B,3]:
    """
    B = size.shape[0]
    
    vertices = torch.tensor([
        [-1 / 2, 1 / 2, 1 / 2],
        [1 / 2, 1 / 2, 1 / 2],
        [-1 / 2, 1 / 2, -1 / 2],
        [1 / 2, 1 / 2, -1 / 2],
        [-1 / 2, -1 / 2, 1 / 2],
        [1 / 2, -1 / 2, 1 / 2],
        [-1 / 2, -1 / 2, -1 / 2],
        [1 / 2, -1 / 2, -1 / 2]
    ], device=position.device, dtype=position.dtype)
    
    deformation = size.unsqueeze(dim=1).repeat(1, 8, 1)
    
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance

def Structure_BoxClose(paras):
    """
    :paras[:, 0:14]:    parameters of Box
    :paras[:, 14:28]:   parameters of Cover
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Box
    position_Ring = paras[:, 0:3]
    rotation_Box = paras[:, 3:6]
    outer_size = paras[:, 6:8]
    inner_size = outer_size - 2 * paras[:, 8:10]
    height_Ring = paras[:, 10:11]
    vertices_Ring = Rectangular_Ring(inner_size, outer_size, height_Ring, position_Ring, rotation_Box, 1)
    
    size_under = paras[:, 11:14]
    position_under = position_Ring - height_Ring * 0.5 - size_under[:, -1:] * 0.5
    vertices_Under = Cuboid(size_under, position_under, rotation_Box, 2)
    
    # Cover
    size_Cuboid = paras[:, 14:17]
    position_Cuboid = paras[:, 17:20]
    rotation_Cover = paras[:, 20:23]
    vertices_Cuboid = Cuboid(size_Cuboid, position_Cuboid, rotation_Cover, 3)
    
    outer_Handle = paras[:, 23:25]
    inner_Handle = outer_Handle - 2 * paras[:, 25:27]
    position_handle = position_Cuboid + torch.cat([zeros, zeros, 0.5 * size_Cuboid[:, -1:]], dim=-1) \
                                        + torch.cat([zeros, zeros, inner_Handle[:, -1:]], dim=-1)
    height_Handle = paras[:, 27:28]
    vertices_Handle = Handle(inner_Handle, outer_Handle, height_Handle, position_handle, rotation_Cover, 4)
    
    direction = rotate(torch.cat([zeros, zeros, -1 * ones], dim=-1).unsqueeze(dim=1), rotation_Cover).view([B,3])
    Grasp_pose = torch.cat([position_handle, direction], dim=-1)
    
    Affordance = torch.cat([vertices_Ring, vertices_Under, vertices_Cuboid, vertices_Handle], dim=1)
    
    return Affordance, Grasp_pose

def process_paras(raw):
    """
    Normalize raw output from Para_Estimator into structured and bounded physical parameters.
    Input: raw tensor of shape [B, 28], unrestricted range
    Output: tensor of shape [B, 28], scaled to valid physical ranges
    """
    B = raw.shape[0]

    # -------- Box --------
    # 0-3: Box position ∈ [-3, 3]
    box_pos = torch.tanh(raw[:, 0:3]) * 3

    # 3-6: Box rotation ∈ [-π, π]
    box_rot = torch.tanh(raw[:, 3:6]) * math.pi

    # 6-8: outer size (x,z) ∈ [0.05, 0.3]
    outer_size = torch.sigmoid(raw[:, 6:8]) * 0.25 + 0.05

    # 8-10: ring thickness ∈ [0.005, 0.05]
    ring_thickness = torch.sigmoid(raw[:, 8:10]) * 0.045 + 0.005

    # 10-11: ring height ∈ [0.05, 0.2]
    ring_height = torch.sigmoid(raw[:, 10:11]) * 0.15 + 0.05

    # 11-14: bottom box size ∈ [0.05, 0.3]
    box_bottom_size = torch.sigmoid(raw[:, 11:14]) * 0.25 + 0.05

    # -------- Cover --------
    # 14-17: cover cuboid size ∈ [0.05, 0.3]
    cover_size = torch.sigmoid(raw[:, 14:17]) * 0.25 + 0.05

    # 17-20: cover position ∈ [-3, 3]
    cover_pos = torch.tanh(raw[:, 17:20]) * 3

    # 20-23: cover rotation ∈ [-π, π]
    cover_rot = torch.tanh(raw[:, 20:23]) * math.pi

    # 23-25: handle outer size ∈ [0.01, 0.1]
    handle_outer = torch.sigmoid(raw[:, 23:25]) * 0.09 + 0.01

    # 25-27: handle thickness ∈ [0.005, 0.05]
    handle_thickness = torch.sigmoid(raw[:, 25:27]) * 0.045 + 0.005

    # 27-28: handle height ∈ [0.01, 0.1]
    handle_height = torch.sigmoid(raw[:, 27:28]) * 0.09 + 0.01

    # Concatenate all
    return torch.cat([
        box_pos, box_rot,
        outer_size, ring_thickness, ring_height,
        box_bottom_size,
        cover_size, cover_pos, cover_rot,
        handle_outer, handle_thickness, handle_height
    ], dim=-1)

class Structure_Module_BoxClose(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Structure_Module_BoxClose, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_BoxClose(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PN_Structure_Module_BoxClose(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PN_Structure_Module_BoxClose, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PN_Structure_Encoder(Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_BoxClose(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

class Fusion_Structure_Module_BoxClose(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Fusion_Structure_Module_BoxClose, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Fusion_Structure_Encoder(Structure_Semantic_dim, 16+8+8+16, Structure_xyz_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_BoxClose(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

class PT_Structure_Module_BoxClose(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PT_Structure_Module_BoxClose, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PT_Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_BoxClose(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

#######################################################################

# Uniform Sampling Representation

#######################################################################    

def Uniform_Structure_BoxClose(paras):
    """
    :paras[:, 0:14]:    parameters of Box
    :paras[:, 14:28]:   parameters of Cover
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Box
    position_Ring = paras[:, 0:3]
    rotation_Box = paras[:, 3:6]
    outer_size = paras[:, 6:8]
    inner_size = outer_size - 2 * paras[:, 8:10]
    height_Ring = paras[:, 10:11]
    vertices_Ring = Uniform_Rectangular_Ring(inner_size, outer_size, height_Ring, position_Ring, rotation_Box, Sem=1, order=2, N=40)
    
    size_under = paras[:, 11:14]
    position_under = position_Ring - height_Ring * 0.5 - size_under[:, -1:] * 0.5
    vertices_Under = Uniform_Cuboid(size_under, position_under, rotation_Box, Sem=2, N=40)
    
    # Cover
    size_Cuboid = paras[:, 14:17]
    position_Cuboid = paras[:, 17:20]
    rotation_Cover = paras[:, 20:23]
    vertices_Cuboid = Uniform_Cuboid(size_Cuboid, position_Cuboid, rotation_Cover, Sem=3, N=40)
    
    outer_Handle = paras[:, 23:25]
    inner_Handle = outer_Handle - 2 * paras[:, 25:27]
    position_handle = position_Cuboid + torch.cat([zeros, zeros, 0.5 * size_Cuboid[:, -1:]], dim=-1) \
                                        + torch.cat([zeros, zeros, inner_Handle[:, -1:]], dim=-1)
    height_Handle = paras[:, 27:28]
    vertices_Handle = Uniform_Handle(inner_Handle, outer_Handle, height_Handle, position_handle, rotation_Cover, Sem=4, order=1, N=40)
    
    direction = rotate(torch.cat([zeros, zeros, -1 * ones], dim=-1).unsqueeze(dim=1), rotation_Cover).view([B,3])
    Grasp_pose = torch.cat([position_handle, direction], dim=-1)
    
    Affordance = torch.cat([vertices_Ring, vertices_Under, vertices_Cuboid, vertices_Handle], dim=1)
    
    return Affordance, Grasp_pose

class Uniform_Structure_Module_BoxClose(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Uniform_Structure_Module_BoxClose, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_BoxClose(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PN_Uniform_Structure_Module_BoxClose(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PN_Uniform_Structure_Module_BoxClose, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PN_Structure_Encoder(Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_BoxClose(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class Fusion_Uniform_Structure_Module_BoxClose(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Fusion_Uniform_Structure_Module_BoxClose, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Fusion_Structure_Encoder(Structure_Semantic_dim, 40*4, Structure_xyz_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_BoxClose(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

class PT_Uniform_Structure_Module_BoxClose(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PT_Uniform_Structure_Module_BoxClose, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PT_Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_BoxClose(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state