import torch
import torch.nn as nn 
import math 
from StructPolicy.StructGen.utils import rotate
from StructPolicy.StructTransformer.Encoders import Para_Estimator, Structure_Encoder, PT_Structure_Encoder, PN_Structure_Encoder, Fusion_Structure_Encoder
from StructPolicy.StructGen.Uniform_Structure import Uniform_Rectangular_Ring

def Rectangular_Ring(inner, outer, height, position, rotation, Sema):
    """
    :param inner[B,2]: [0]=inner_length, [1]=inner_width
    :param outer[B,2]: [0]=outer_length, [1]outer_width
    :height[B,1]: height of Rectangular_Ring
    :position[B,3], rotation[B,3]: position and rotation
    """
    B = inner.shape[0]

    # [B, 16, 3] vertex template
    vertices = torch.tensor([
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5],
        [-0.5,  0.5,  0.5], [0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5]
    ], device=inner.device, dtype=inner.dtype).unsqueeze(0).repeat(B, 1, 1)  # [B, 16, 3]

    # [B, 1, 3] → [B, 8, 3] for outer & inner box
    outer_box = torch.cat([outer[:, 0:1], outer[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    inner_box = torch.cat([inner[:, 0:1], inner[:, 1:2], height], dim=-1).unsqueeze(1).repeat(1, 8, 1)
    deformation = torch.cat([outer_box, inner_box], dim=1)  # [B, 16, 3]

    # Apply scale, rotation, and translation
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(1)  # [B, 16, 3]
    Semantic = torch.ones([B, 16, 1], device=inner.device, dtype=inner.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)

    return Affordance

def Cuboid(size, position, rotation, Sema):
    """
    :size[B,3]: [0]= x, [1]= y, [2]= z
    :position[B,3], rotation[B,3]:
    """
    B = size.shape[0]
    
    vertices = torch.tensor([
        [-1 / 2, 1 / 2, 1 / 2],
        [1 / 2, 1 / 2, 1 / 2],
        [-1 / 2, 1 / 2, -1 / 2],
        [1 / 2, 1 / 2, -1 / 2],
        [-1 / 2, -1 / 2, 1 / 2],
        [1 / 2, -1 / 2, 1 / 2],
        [-1 / 2, -1 / 2, -1 / 2],
        [1 / 2, -1 / 2, -1 / 2]
    ], device=position.device, dtype=position.dtype)
    
    deformation = size.unsqueeze(dim=1).repeat(1, 8, 1)
    
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance

def Structure_BinPicking(paras):
    """
    :paras[:, 0:12]: parameters of bin1
    :paras[:, 12:24]: parameters of bin2
    :paras[:, 24:33]: parameters of block
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Bin 1
    position_Bin1_Ring = paras[:, 0:3]
    rotation_Bin1 = paras[:, 3:6]
    outersize_Bin1 = paras[:, 6:8]
    innersize_Bin1 = outersize_Bin1 - paras[:, 8:10] * 2
    height_Bin1 = paras[:, 10:11]
    vertices_Bin1_Ring = Rectangular_Ring(innersize_Bin1, outersize_Bin1, height_Bin1, position_Bin1_Ring, rotation_Bin1, 1)
    
    size_Bin1_Cuboid = torch.cat([outersize_Bin1, paras[:, 11:12]], dim=-1)
    position_Bin1_Cuboid = position_Bin1_Ring - torch.cat([zeros, zeros, -0.5 * height_Bin1 - 0.5 * size_Bin1_Cuboid[:, -1:]], dim=-1)
    vertices_Bin1_Cuboid = Cuboid(size_Bin1_Cuboid, position_Bin1_Cuboid, rotation_Bin1, 2)
    
    # Bin2
    position_Bin2_Ring = paras[:, 12:15]
    rotation_Bin2 = paras[:, 15:18]
    outersize_Bin2 = paras[:, 18:20]
    innersize_Bin2 = outersize_Bin2 - 2 * paras[:, 20:22]
    height_Bin2 = paras[:, 22:23]
    vertices_Bin2_Ring = Rectangular_Ring(innersize_Bin2, outersize_Bin2, height_Bin2, position_Bin2_Ring, rotation_Bin2, 3)
    
    size_Bin2_Cuboid = torch.cat([zeros, zeros, paras[:, 23:24]], dim=-1)
    position_Bin2_Cuboid = position_Bin2_Ring - torch.cat([zeros, zeros, -0.5 * height_Bin2 - 0.5 * size_Bin2_Cuboid[:, -1:]], dim=-1)
    vertices_Bin2_Cuboid = Cuboid(size_Bin2_Cuboid, position_Bin2_Cuboid, rotation_Bin2, 4)
    
    # block
    size_Block = paras[:, 24:27]
    position_Block = paras[:, 27:30]
    rotation_Block = paras[:, 30:33]
    vertices_Block = Cuboid(size_Block, position_Block, rotation_Block, 5)
    
    direction = torch.cat([zeros, zeros, -1 * ones], dim=-1)
    Grasp_pose = torch.cat([position_Block, direction], dim=-1)
    
    Affordance = torch.cat([vertices_Bin1_Ring, vertices_Bin1_Cuboid, vertices_Bin2_Ring, vertices_Bin2_Cuboid, vertices_Block], dim=1)
    
    return Affordance, Grasp_pose

def process_paras(raw):
    """
    Normalize raw output from Para_Estimator into structured and bounded physical parameters.
    Input: raw tensor of shape [B, 33], unrestricted range
    Output: tensor of shape [B, 33], scaled to valid physical ranges
    """
    B = raw.shape[0]

    # ------- Bin 1 -------
    # 0-3: Position ∈ [-3, 3]
    pos_bin1 = torch.tanh(raw[:, 0:3]) * 3

    # 3-6: Rotation ∈ [-π, π]
    rot_bin1 = torch.tanh(raw[:, 3:6]) * math.pi

    # 6-8: Outer size (x,z) ∈ [0.05, 0.3]
    outer_bin1 = torch.sigmoid(raw[:, 6:8]) * 0.25 + 0.05

    # 8-10: Thickness ∈ [0.005, 0.05]
    thick_bin1 = torch.sigmoid(raw[:, 8:10]) * 0.045 + 0.005

    # 10-11: Height ∈ [0.05, 0.25]
    height_bin1 = torch.sigmoid(raw[:, 10:11]) * 0.2 + 0.05

    # 11: Bottom thickness (y) ∈ [0.02, 0.2]
    bottom_bin1 = torch.sigmoid(raw[:, 11:12]) * 0.18 + 0.02

    # ------- Bin 2 -------
    # 12-15: Position ∈ [-3, 3]
    pos_bin2 = torch.tanh(raw[:, 12:15]) * 3

    # 15-18: Rotation ∈ [-π, π]
    rot_bin2 = torch.tanh(raw[:, 15:18]) * math.pi

    # 18-20: Outer size ∈ [0.05, 0.3]
    outer_bin2 = torch.sigmoid(raw[:, 18:20]) * 0.25 + 0.05

    # 20-22: Thickness ∈ [0.005, 0.05]
    thick_bin2 = torch.sigmoid(raw[:, 20:22]) * 0.045 + 0.005

    # 22-23: Height ∈ [0.05, 0.25]
    height_bin2 = torch.sigmoid(raw[:, 22:23]) * 0.2 + 0.05

    # 23-24: Bottom thickness ∈ [0.02, 0.2]
    bottom_bin2 = torch.sigmoid(raw[:, 23:24]) * 0.18 + 0.02

    # ------- Block -------
    # 24-27: Size ∈ [0.02, 0.1]
    size_block = torch.sigmoid(raw[:, 24:27]) * 0.08 + 0.02

    # 27-30: Position ∈ [-3, 3]
    pos_block = torch.tanh(raw[:, 27:30]) * 3

    # 30-33: Rotation ∈ [-π, π]
    rot_block = torch.tanh(raw[:, 30:33]) * math.pi

    return torch.cat([
        pos_bin1, rot_bin1, outer_bin1, thick_bin1, height_bin1, bottom_bin1,
        pos_bin2, rot_bin2, outer_bin2, thick_bin2, height_bin2, bottom_bin2,
        size_block, pos_block, rot_block
    ], dim=-1)

class Structure_Module_BinPicking(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Structure_Module_BinPicking, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_BinPicking(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PN_Structure_Module_BinPicking(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PN_Structure_Module_BinPicking, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PN_Structure_Encoder(Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_BinPicking(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class Fusion_Structure_Module_BinPicking(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Fusion_Structure_Module_BinPicking, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Fusion_Structure_Encoder(Structure_Semantic_dim, (8+16)*2+8, Structure_xyz_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_BinPicking(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PT_Structure_Module_BinPicking(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PT_Structure_Module_BinPicking, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PT_Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_BinPicking(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
#######################################################################

# Uniform Sampling Representation

#######################################################################    

def Uniform_Cuboid(size, position, rotation, Sem, N=40):
    """
    Generate a cuboid mesh with differentiable point sampling
    
    Args:
        size: [B,3] tensor - Dimensions of the cuboid (x, y, z)
        position: [B,3] tensor - Center position of the cuboid
        rotation: [B,3] tensor - Rotation angles (Euler angles)
        Sema: scalar - Semantic value
        N: int - Number of points to sample (default: 1000)
        
    Returns:
        Affordance: [B, 8, 4] tensor - Original vertices with semantics
        SampledPoints: [B, N, 4] tensor - Sampled points with semantics
    """
    B = size.shape[0]
    device = size.device
    dtype = size.dtype
    
    # Define canonical vertices (unit cube centered at origin)
    vertices = torch.tensor([
        [-0.5,  0.5,  0.5], [ 0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [ 0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [ 0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [ 0.5, -0.5, -0.5]
    ], device=device, dtype=dtype)
    
    # Scale vertices to desired size
    vertices = vertices.unsqueeze(0).repeat(B, 1, 1) * size.unsqueeze(1)
    
    # Apply rotation and translation
    vertices = rotate(vertices, rotation) + position.unsqueeze(1)
    
    # Define faces for sampling (12 triangles)
    faces = torch.tensor([
        [0,1,2], [1,3,2],  # Top face
        [4,5,6], [5,7,6],  # Bottom face
        [0,4,1], [1,4,5],  # Front face
        [2,3,6], [3,7,6],  # Back face
        [0,2,4], [2,6,4],  # Left face
        [1,5,3], [3,5,7],   # Right face
    ], device=device)
    
    # Differentiable point sampling
    # 1. Compute face areas
    face_verts = vertices[:, faces]  # [B, 12, 3, 3]
    v0, v1, v2 = face_verts.unbind(2)
    areas = 0.5 * torch.norm(torch.cross(v1-v0, v2-v0), dim=2)  # [B, 12]
    
    # 2. Sample faces according to area distribution
    probs = torch.ones_like(areas) / areas.shape[1]
    face_idx = torch.multinomial(probs, N, replacement=True)  # [B, N]
    face_idx_expanded = face_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 3, 3)
    
    # 3. Sample points within each triangle using barycentric coordinates
    u = torch.rand(B, N, 1, device=device)
    v = torch.rand(B, N, 1, device=device)
    mask = (u + v) > 1
    u = torch.where(mask, 1-u, u)
    v = torch.where(mask, 1-v, v)
    w = 1 - u - v
    
    # Gather corresponding face vertices and interpolate
    selected_faces = torch.gather(
        face_verts, 
        1, 
        face_idx_expanded
    ).squeeze(2)  # [B, N, 3, 3]
    
    sampled_points = (selected_faces[..., 0, :] * u + 
                     selected_faces[..., 1, :] * v + 
                     selected_faces[..., 2, :] * w)
    
    # Add semantic information to sampled points
    sampled_semantic = torch.ones(B, N, 1, device=device) * Sem
    SampledPoints = torch.cat([sampled_points, sampled_semantic], dim=-1)
    
    return SampledPoints

def Uniform_Structure_BinPicking(paras):
    """
    :paras[:, 0:12]: parameters of bin1
    :paras[:, 12:24]: parameters of bin2
    :paras[:, 24:33]: parameters of block
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Bin 1
    position_Bin1_Ring = paras[:, 0:3]
    rotation_Bin1 = paras[:, 3:6]
    outersize_Bin1 = paras[:, 6:8]
    innersize_Bin1 = outersize_Bin1 - paras[:, 8:10] * 2
    height_Bin1 = paras[:, 10:11]
    vertices_Bin1_Ring = Uniform_Rectangular_Ring(innersize_Bin1, outersize_Bin1, height_Bin1, position_Bin1_Ring, rotation_Bin1, Sem=1, order=2, N=40)
    
    size_Bin1_Cuboid = torch.cat([outersize_Bin1, paras[:, 11:12]], dim=-1)
    position_Bin1_Cuboid = position_Bin1_Ring - torch.cat([zeros, zeros, -0.5 * height_Bin1 - 0.5 * size_Bin1_Cuboid[:, -1:]], dim=-1)
    vertices_Bin1_Cuboid = Uniform_Cuboid(size_Bin1_Cuboid, position_Bin1_Cuboid, rotation_Bin1, Sem=2, N=20)
    
    # Bin2
    position_Bin2_Ring = paras[:, 12:15]
    rotation_Bin2 = paras[:, 15:18]
    outersize_Bin2 = paras[:, 18:20]
    innersize_Bin2 = outersize_Bin2 - 2 * paras[:, 20:22]
    height_Bin2 = paras[:, 22:23]
    vertices_Bin2_Ring = Uniform_Rectangular_Ring(innersize_Bin2, outersize_Bin2, height_Bin2, position_Bin2_Ring, rotation_Bin2, Sem=3, order=2, N=40)
    
    size_Bin2_Cuboid = torch.cat([zeros, zeros, paras[:, 23:24]], dim=-1)
    position_Bin2_Cuboid = position_Bin2_Ring - torch.cat([zeros, zeros, -0.5 * height_Bin2 - 0.5 * size_Bin2_Cuboid[:, -1:]], dim=-1)
    vertices_Bin2_Cuboid = Uniform_Cuboid(size_Bin2_Cuboid, position_Bin2_Cuboid, rotation_Bin2, Sem=4, N=20)
    
    # block
    size_Block = paras[:, 24:27]
    position_Block = paras[:, 27:30]
    rotation_Block = paras[:, 30:33]
    vertices_Block = Uniform_Cuboid(size_Block, position_Block, rotation_Block, Sem=5, N=20)
    
    direction = torch.cat([zeros, zeros, -1 * ones], dim=-1)
    Grasp_pose = torch.cat([position_Block, direction], dim=-1)
    
    Affordance = torch.cat([vertices_Bin1_Ring, vertices_Bin1_Cuboid, vertices_Bin2_Ring, vertices_Bin2_Cuboid, vertices_Block], dim=1)
    
    return Affordance, Grasp_pose

class Uniform_Structure_Module_BinPicking(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Uniform_Structure_Module_BinPicking, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_BinPicking(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PN_Uniform_Structure_Module_BinPicking(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PN_Uniform_Structure_Module_BinPicking, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PN_Structure_Encoder(Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_BinPicking(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class Fusion_Uniform_Structure_Module_BinPicking(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Fusion_Uniform_Structure_Module_BinPicking, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Fusion_Structure_Encoder(Structure_Semantic_dim, (20+40)*2+20, Structure_xyz_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_BinPicking(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PT_Uniform_Structure_Module_BinPicking(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PT_Uniform_Structure_Module_BinPicking, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PT_Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_BinPicking(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state