import torch
import torch.nn as nn 
import math 
from StructPolicy.StructGen.utils import rotate
from StructPolicy.StructTransformer.Encoders import Para_Estimator, Structure_Encoder, PT_Structure_Encoder, Fusion_Structure_Encoder, PN_Structure_Encoder

def Cylinder (height, radius, position, rotation, theta=30, Sem=0, order=1):
    """
    :param height: [B, 1], height of cylinder in Z-axis
    :param radius: [B, 1], radius of cylinder in X-Y-Plane
    :param position, rotation: [B,3], [B,3]
    :param theta: int
    """
    B = height.shape[0]
    n = int(360 / theta)
    theta_rad = theta * math.pi / 180.0
    
    if order == 1 : 
        V = [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), -0.5] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), 0.5] for i in range(n)
        ]
    elif order == 2 :
        V = [
            [math.cos(i * theta_rad), -0.5, math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), 0.5, math.sin(i * theta_rad)] for i in range(n)
        ]
    else :
        V = [
            [-0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ]
    
    vertices = torch.tensor(V, dtype=height.dtype, device=height.device).unsqueeze(dim=0).repeat(B,1,1)  # shape [B, 2n, 3]
    deformation = torch.cat([radius, radius, height], dim=1).unsqueeze(dim=1)
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    
    semantic = torch.ones([B,2*n,1], device=height.device, dtype=height.dtype) * Sem
    Affordance = torch.cat([vertices, semantic], dim=-1)
    
    return Affordance

def Cuboid(size, position, rotation, Sema):
    """
    :size[B,3]: [0]=length, [1]=width, [2]=height
    :position[B,3], rotation[B,3]:
    """
    B = size.shape[0]
    
    vertices = torch.tensor([
        [-1 / 2, 1 / 2, 1 / 2],
        [1 / 2, 1 / 2, 1 / 2],
        [-1 / 2, 1 / 2, -1 / 2],
        [1 / 2, 1 / 2, -1 / 2],
        [-1 / 2, -1 / 2, 1 / 2],
        [1 / 2, -1 / 2, 1 / 2],
        [-1 / 2, -1 / 2, -1 / 2],
        [1 / 2, -1 / 2, -1 / 2]
    ], device=position.device, dtype=position.dtype)
    
    deformation = size.unsqueeze(dim=1).repeat(1, 8, 1)
    
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance

def Structure_PushWall(paras):
    """
    :paras[:, 0:8]: parameters of cylinder
    :paras[:, 8:17]: parameters of wall
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    
    # cylinder
    position_cylinder = paras[:, 0:3]
    rotation_cylinder = paras[:, 3:6]
    height_cylinder = paras[:, 6:7]
    radius_cylinder = paras[:, 7:8]
    vertices_cylinder = Cylinder(height_cylinder, radius_cylinder, position_cylinder, rotation_cylinder, 45, 1, 1)
    
    direction = rotate(torch.cat([zeros, ones, zeros], dim=-1).unsqueeze(dim=1), rotation_cylinder).view([B,3])
    Grasp_pose = torch.cat([position_cylinder, direction], dim=-1)
    
    # wall
    position_wall = paras[:, 8:11]
    rotation_wall = paras[:, 11:14]
    size_wall = paras[:, 14:17]
    vertices_wall = Cuboid(size_wall, position_wall, rotation_wall, 2)
    
    Affordance = torch.cat([vertices_cylinder, vertices_wall], dim=1)
    
    return Affordance, Grasp_pose

def process_paras(raw):
    """
    Normalize raw output from Para_Estimator into structured and bounded physical parameters.
    Input: raw tensor of shape [B, 17], unrestricted range
    Output: tensor of shape [B, 17], scaled to valid physical ranges
    """
    B = raw.shape[0]

    # --- Cylinder ---
    # 0-3: position ∈ [-3, 3]
    pos_cylinder = torch.tanh(raw[:, 0:3]) * 3

    # 3-6: rotation ∈ [-π, π]
    rot_cylinder = torch.tanh(raw[:, 3:6]) * math.pi

    # 6-7: height ∈ [0.05, 0.2]
    height_cylinder = torch.sigmoid(raw[:, 6:7]) * 0.15 + 0.05

    # 7-8: radius ∈ [0.015, 0.1]
    radius_cylinder = torch.sigmoid(raw[:, 7:8]) * 0.085 + 0.015

    # --- Wall ---
    # 8-11: position ∈ [-3.5, 3.5]
    pos_wall = torch.tanh(raw[:, 8:11]) * 3.5

    # 11-14: rotation ∈ [-π, π]
    rot_wall = torch.tanh(raw[:, 11:14]) * math.pi

    # 14-17: size ∈ [0.05, 1.0] 
    size_wall = torch.sigmoid(raw[:, 14:17]) * 0.95 + 0.05

    return torch.cat([
        pos_cylinder, rot_cylinder, height_cylinder, radius_cylinder,
        pos_wall, rot_wall, size_wall
    ], dim=-1)

class Structure_Module_PushWall(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Structure_Module_PushWall, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_PushWall(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

class PT_Structure_Module_PushWall(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PT_Structure_Module_PushWall, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PT_Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_PushWall(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class Fusion_Structure_Module_PushWall(nn.Module):
    def __init__ (self, in_dim, paralist):
        super(Fusion_Structure_Module_PushWall, self).__init__()
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Fusion_Structure_Encoder(Structure_Semantic_dim, 24+8, Structure_xyz_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_PushWall(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

class PN_Structure_Module_PushWall(nn.Module):
    def __init__ (self, in_dim, paralist):
        super(PN_Structure_Module_PushWall, self).__init__()
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PN_Structure_Encoder(Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_PushWall(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state 
    
    
#######################################################################

# Uniform Sampling Representation

#######################################################################

def Uniform_Cuboid(size, position, rotation, Sema, N=80):
    """
    Generate a cuboid mesh with differentiable point sampling
    
    Args:
        size: [B,3] tensor - Dimensions of the cuboid (x, y, z)
        position: [B,3] tensor - Center position of the cuboid
        rotation: [B,3] tensor - Rotation angles (Euler angles)
        Sema: scalar - Semantic value
        N: int - Number of points to sample (default: 1000)
        
    Returns:
        Affordance: [B, 8, 4] tensor - Original vertices with semantics
        SampledPoints: [B, N, 4] tensor - Sampled points with semantics
    """
    B = size.shape[0]
    device = size.device
    dtype = size.dtype
    
    # Define canonical vertices (unit cube centered at origin)
    vertices = torch.tensor([
        [-0.5,  0.5,  0.5], [ 0.5,  0.5,  0.5],
        [-0.5,  0.5, -0.5], [ 0.5,  0.5, -0.5],
        [-0.5, -0.5,  0.5], [ 0.5, -0.5,  0.5],
        [-0.5, -0.5, -0.5], [ 0.5, -0.5, -0.5]
    ], device=device, dtype=dtype)
    
    # Scale vertices to desired size
    vertices = vertices.unsqueeze(0).repeat(B, 1, 1) * size.unsqueeze(1)
    
    # Apply rotation and translation
    vertices = rotate(vertices, rotation) + position.unsqueeze(1)
    
    # Define faces for sampling (12 triangles)
    faces = torch.tensor([
        [0,1,2], [1,3,2],  # Top face
        [4,5,6], [5,7,6],  # Bottom face
        [0,4,1], [1,4,5],  # Front face
        [2,3,6], [3,7,6],  # Back face
        [0,2,4], [2,6,4],  # Left face
        [1,5,3], [3,5,7],   # Right face
    ], device=device)
    
    # Differentiable point sampling
    # 1. Compute face areas
    face_verts = vertices[:, faces]  # [B, 12, 3, 3]
    v0, v1, v2 = face_verts.unbind(2)
    areas = 0.5 * torch.norm(torch.cross(v1-v0, v2-v0), dim=2)  # [B, 12]
    
    # 2. Sample faces according to area distribution
    probs = areas / (areas.sum(dim=1, keepdim=True) + 1e-10)
    face_idx = torch.multinomial(probs, N, replacement=True)  # [B, N]
    face_idx_expanded = face_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 3, 3)
    
    # 3. Sample points within each triangle using barycentric coordinates
    u = torch.rand(B, N, 1, device=device)
    v = torch.rand(B, N, 1, device=device)
    mask = (u + v) > 1
    u = torch.where(mask, 1-u, u)
    v = torch.where(mask, 1-v, v)
    w = 1 - u - v
    
    # Gather corresponding face vertices and interpolate
    selected_faces = torch.gather(
        face_verts, 
        1, 
        face_idx_expanded
    ).squeeze(2)  # [B, N, 3, 3]
    
    sampled_points = (selected_faces[..., 0, :] * u + 
                     selected_faces[..., 1, :] * v + 
                     selected_faces[..., 2, :] * w)
    
    # Add semantic information to sampled points
    sampled_semantic = torch.ones(B, N, 1, device=device) * Sema
    SampledPoints = torch.cat([sampled_points, sampled_semantic], dim=-1)
    
    return SampledPoints

def Uniform_Cylinder(height, radius, position, rotation, theta=30, Sem=0, order=1, N=80):
    """
    Generate uniformly sampled points on a complete cylindrical surface (side + top/bottom)
    
    Args:
        height: [B,1] - Height of cylinder along axis
        radius: [B,1] - Radius of cylinder
        position: [B,3] - Center position
        rotation: [B,3] - Rotation angles (Euler)
        theta: int - Angle step for discretization (degrees) - unused in sampling
        Sem: int - Semantic value
        order: int - Axis orientation (1=Z, 2=Y, 3=X)
        N: int - Number of points to sample
        
    Returns:
        SampledPoints: [B, N, 4] - Sampled surface points with semantics (xyz + semantic)
    """
    B = height.shape[0]
    device = height.device
    dtype = height.dtype
    
    # Determine fraction of points for each surface type
    n_side = 40
    n_top = 20
    n_bottom = 20
    
    # Initialize output tensor
    sampled_points = torch.zeros(B, N, 3, device=device, dtype=dtype)
    
    # 1. Sample side surface (cylindrical part)
    if n_side > 0:
        u = torch.rand((B, n_side), device=device)  # Height [0,1]
        v = torch.rand((B, n_side), device=device)  # Angle [0,1]
        
        angles = 2 * math.pi * v
        heights = (u - 0.5) * height
        
        if order == 1:  # Z-axis
            x = radius * torch.cos(angles)
            y = radius * torch.sin(angles)
            z = heights
        elif order == 2:  # Y-axis
            x = radius * torch.cos(angles)
            z = radius * torch.sin(angles)
            y = heights
        else:  # X-axis
            y = radius * torch.cos(angles)
            z = radius * torch.sin(angles)
            x = heights
        
        sampled_points[:, :n_side, :] = torch.cat([x.unsqueeze(dim=-1), y.unsqueeze(dim=-1), z.unsqueeze(dim=-1)], dim=-1)
    
    # 2. Sample top and bottom surfaces (circular disks)
    if n_top + n_bottom > 0:
        # Shared angle sampling for both disks
        v = torch.rand((B, n_top + n_bottom), device=device) * 2 * math.pi
        r = torch.sqrt(torch.rand((B, n_top + n_bottom), device=device)) * radius
        
        if order == 1:  # Z-axis
            x = r * torch.cos(v)
            y = r * torch.sin(v)
            z_top = height / 2
            z_bottom = -height / 2
        elif order == 2:  # Y-axis
            x = r * torch.cos(v)
            z = r * torch.sin(v)
            y_top = height / 2
            y_bottom = -height / 2
        else:  # X-axis
            y = r * torch.cos(v)
            z = r * torch.sin(v)
            x_top = height / 2
            x_bottom = -height / 2
        
        # Assign to top and bottom sections
        if n_top > 0:
            if order == 1:
                z_values = torch.ones((B, n_top), device=device) * z_top
                sampled_points[:, n_side:n_side+n_top, :] = torch.cat([x[:, :n_top].unsqueeze(dim=-1), y[:, :n_top].unsqueeze(dim=-1), z_values.unsqueeze(dim=-1)], dim=-1)
            elif order == 2:
                y_values = torch.ones((B, n_top), device=device) * y_top
                sampled_points[:, n_side:n_side+n_top, :] = torch.cat([x[:, :n_top].unsqueeze(dim=-1), y_values.unsqueeze(dim=-1), z[:, :n_top].unsqueeze(dim=-1)], dim=-1)
            else:
                x_values = torch.ones((B, n_top), device=device) * x_top
                sampled_points[:, n_side:n_side+n_top, :] = torch.cat([x_values.unsqueeze(dim=-1), y[:, :n_top].unsqueeze(dim=-1), z[:, :n_top].unsqueeze(dim=-1)], dim=-1)
        
        if n_bottom > 0:
            start_idx = n_side + n_top
            if order == 1:
                z_values = torch.ones((B, n_bottom), device=device) * z_bottom
                sampled_points[:, start_idx:, :] = torch.cat([x[:, n_top:].unsqueeze(dim=-1), y[:, n_top:].unsqueeze(dim=-1), z_values.unsqueeze(dim=-1)], dim=-1)
            elif order == 2:
                y_values = torch.ones((B, n_bottom), device=device) * y_bottom
                sampled_points[:, start_idx:, :] = torch.cat([x[:, n_top:].unsqueeze(dim=-1), y_values.unsqueeze(dim=-1), z[:, n_top:].unsqueeze(dim=-1)], dim=-1)
            else:
                x_values = torch.ones((B, n_bottom), device=device) * x_bottom
                sampled_points[:, start_idx:, :] = torch.cat([x_values.unsqueeze(dim=-1), y[:, n_top:].unsqueeze(dim=-1), z[:, n_top:].unsqueeze(dim=-1)], dim=-1)
    
    # Apply rotation and translation
    sampled_points = rotate(sampled_points, rotation) + position.unsqueeze(1)
    
    # Add semantic information
    sampled_semantic = torch.ones(B, N, 1, device=device, dtype=dtype) * Sem
    SampledPoints = torch.cat([sampled_points, sampled_semantic], dim=-1)
    
    return SampledPoints

def Uniform_Structure_PushWall(paras):
    """
    :paras[:, 0:8]: parameters of cylinder
    :paras[:, 8:17]: parameters of wall
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    
    # cylinder
    position_cylinder = paras[:, 0:3]
    rotation_cylinder = paras[:, 3:6]
    height_cylinder = paras[:, 6:7]
    radius_cylinder = paras[:, 7:8]
    vertices_cylinder = Uniform_Cylinder(height_cylinder, radius_cylinder, position_cylinder, rotation_cylinder, 30, 1, 1)
    
    direction = rotate(torch.cat([zeros, ones, zeros], dim=-1).unsqueeze(dim=1), rotation_cylinder).view([B,3])
    Grasp_pose = torch.cat([position_cylinder, direction], dim=-1)
    
    # wall
    position_wall = paras[:, 8:11]
    rotation_wall = paras[:, 11:14]
    size_wall = paras[:, 14:17]
    vertices_wall = Uniform_Cuboid(size_wall, position_wall, rotation_wall, 2)
    
    Affordance = torch.cat([vertices_cylinder, vertices_wall], dim=1)
    
    return Affordance, Grasp_pose

class Uniform_Structure_Module_PushWall(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Uniform_Structure_Module_PushWall, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_PushWall(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

class PT_Uniform_Structure_Module_PushWall(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PT_Uniform_Structure_Module_PushWall, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PT_Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_PushWall(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

class Fusion_Uniform_Structure_Module_PushWall(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Fusion_Uniform_Structure_Module_PushWall, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Fusion_Structure_Encoder(Structure_Semantic_dim, 160, Structure_xyz_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_PushWall(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PN_Uniform_Structure_Module_PushWall(nn.Module):
    def __init__ (self, in_dim, paralist):
        super(PN_Uniform_Structure_Module_PushWall, self).__init__()
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PN_Structure_Encoder(Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_PushWall(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state 