import torch
import torch.nn as nn 
import math 
from StructPolicy.StructGen.utils import rotate
from StructPolicy.StructTransformer.Encoders import Para_Estimator, Structure_Encoder, PT_Structure_Encoder, PN_Structure_Encoder, Fusion_Structure_Encoder
from StructPolicy.StructGen.Uniform_Structure import Uniform_Cuboid, Uniform_Cylinder

def Cylinder (height, radius, position, rotation, theta=30, Sem=0, order=1):
    """
    :param height: [B, 1], height of cylinder in Z-axis
    :param radius: [B, 1], radius of cylinder in X-Y-Plane
    :param position, rotation: [B,3], [B,3]
    :param theta: int
    """
    B = height.shape[0]
    n = int(360 / theta)
    theta_rad = theta * math.pi / 180.0
    
    if order == 1 : 
        V = [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), -0.5] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), 0.5] for i in range(n)
        ]
    elif order == 2 :
        V = [
            [math.cos(i * theta_rad), -0.5, math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), 0.5, math.sin(i * theta_rad)] for i in range(n)
        ]
    else :
        V = [
            [-0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ]
    
    vertices = torch.tensor(V, dtype=height.dtype, device=height.device).unsqueeze(dim=0).repeat(B,1,1)  # shape [B, 2n, 3]
    deformation = torch.cat([radius, radius, height], dim=1).unsqueeze(dim=1)
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    
    semantic = torch.ones([B,2*n,1], device=height.device, dtype=height.dtype) * Sem
    Affordance = torch.cat([vertices, semantic], dim=-1)
    
    return Affordance

def Cuboid(size, position, rotation, Sema):
    """
    :size[B,3]: [0]= x, [1]= y, [2]= z
    :position[B,3], rotation[B,3]:
    """
    B = size.shape[0]
    
    vertices = torch.tensor([
        [-1 / 2, 1 / 2, 1 / 2],
        [1 / 2, 1 / 2, 1 / 2],
        [-1 / 2, 1 / 2, -1 / 2],
        [1 / 2, 1 / 2, -1 / 2],
        [-1 / 2, -1 / 2, 1 / 2],
        [1 / 2, -1 / 2, 1 / 2],
        [-1 / 2, -1 / 2, -1 / 2],
        [1 / 2, -1 / 2, -1 / 2]
    ], device=position.device, dtype=position.dtype)
    
    deformation = size.unsqueeze(dim=1).repeat(1, 8, 1)
    
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance
    
def Structure_Hammer(paras):
    """
    :paras[:, 0:10]: parameters of hammer
    :paras[:, 10:19]: parameters of Collision box
    """
    
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype = paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Structure for hammer
    position_handle = paras[:, 0:3]
    rotation_handle = paras[:, 3:6]
    height_handle = paras[:, 6:7]
    radius_handle = paras[:, 7:8]
    vertices_handle = Cylinder(height_handle, radius_handle, position_handle, rotation_handle, 30, 1, 1)
    
    height_hammer = paras[:, 8:9]
    radius_hammer = paras[:, 9:10]
    position_hammer = position_handle + 0.5 * height_handle + radius_hammer
    rotation_hammer = rotation_handle
    vertices_hammer = Cylinder(height_hammer, radius_hammer, position_hammer, rotation_hammer, 30, 2, 2)
    
    direction = torch.cat([zeros, zeros, ones * (-1)], dim=-1)
    Grasp_pose = torch.cat([position_handle, direction], dim=-1)
    
    # Structure for collision box
    box_size = paras[:, 10:13]
    box_position = paras[:, 13:16]
    box_rotation = paras[:, 16:19]
    vertices_box = Cuboid(box_size, box_position, box_rotation, 3)
    
    Affordance = torch.cat([vertices_handle, vertices_hammer, vertices_box], dim=1)
    
    return Affordance, Grasp_pose

def process_paras(raw):
    """
    Normalize raw output from Para_Estimator into structured and bounded physical parameters.
    Input: raw tensor of shape [B, 19], unrestricted range
    Output: tensor of shape [B, 19], scaled to valid physical ranges
    """
    B = raw.shape[0]

    # 0–3: hammer handle position ∈ [-3, 3]
    pos_handle = torch.tanh(raw[:, 0:3]) * 3.0

    # 3–6: hammer handle rotation ∈ [-π, π]
    rot_handle = torch.tanh(raw[:, 3:6]) * math.pi

    # 6–7: handle height ∈ [0.05, 0.3]
    h_handle = torch.sigmoid(raw[:, 6:7]) * 0.25 + 0.05

    # 7–8: handle radius ∈ [0.005, 0.05]
    r_handle = torch.sigmoid(raw[:, 7:8]) * 0.045 + 0.005

    # 8–9: hammerhead height ∈ [0.02, 0.15]
    h_hammer = torch.sigmoid(raw[:, 8:9]) * 0.13 + 0.02

    # 9–10: hammerhead radius ∈ [0.01, 0.08]
    r_hammer = torch.sigmoid(raw[:, 9:10]) * 0.07 + 0.01

    # 10–13: box size ∈ [0.02, 0.2]
    box_size = torch.sigmoid(raw[:, 10:13]) * 0.18 + 0.02

    # 13–16: box position ∈ [-3, 3]
    box_pos = torch.tanh(raw[:, 13:16]) * 3.0

    # 16–19: box rotation ∈ [-π, π]
    box_rot = torch.tanh(raw[:, 16:19]) * math.pi

    return torch.cat([
        pos_handle, rot_handle, h_handle, r_handle,
        h_hammer, r_hammer,
        box_size, box_pos, box_rot
    ], dim=-1)

class Structure_Module_Hammer(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Structure_Module_Hammer, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_Hammer(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PN_Structure_Module_Hammer(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PN_Structure_Module_Hammer, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PN_Structure_Encoder(Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_Hammer(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state    
    
class Fusion_Structure_Module_Hammer(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Fusion_Structure_Module_Hammer, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Fusion_Structure_Encoder(Structure_Semantic_dim, 24*2+8, Structure_xyz_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_Hammer(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

class PT_Structure_Module_Hammer(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PT_Structure_Module_Hammer, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PT_Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        Structure, grasp_pose = Structure_Hammer(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state  

#######################################################################

# Uniform Sampling Representation

#######################################################################

def Uniform_Structure_Hammer(paras):
    """
    :paras[:, 0:10]: parameters of hammer
    :paras[:, 10:19]: parameters of Collision box
    """
    
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype = paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Structure for hammer
    position_handle = paras[:, 0:3]
    rotation_handle = paras[:, 3:6]
    height_handle = paras[:, 6:7]
    radius_handle = paras[:, 7:8]
    vertices_handle = Uniform_Cylinder(height_handle, radius_handle, position_handle, rotation_handle, 30, Sem=1, order=1, N=60)
    
    height_hammer = paras[:, 8:9]
    radius_hammer = paras[:, 9:10]
    position_hammer = position_handle + 0.5 * height_handle + radius_hammer
    rotation_hammer = rotation_handle
    vertices_hammer = Uniform_Cylinder(height_hammer, radius_hammer, position_hammer, rotation_hammer, 30, Sem=2, order=2, N=60)
    
    direction = torch.cat([zeros, zeros, ones * (-1)], dim=-1)
    Grasp_pose = torch.cat([position_handle, direction], dim=-1)
    
    # Structure for collision box
    box_size = paras[:, 10:13]
    box_position = paras[:, 13:16]
    box_rotation = paras[:, 16:19]
    vertices_box = Uniform_Cuboid(box_size, box_position, box_rotation, Sem=3, N=40)
    
    Affordance = torch.cat([vertices_handle, vertices_hammer, vertices_box], dim=1)
    
    return Affordance, Grasp_pose

class Uniform_Structure_Module_Hammer(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Uniform_Structure_Module_Hammer, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_Hammer(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PN_Uniform_Structure_Module_Hammer(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PN_Uniform_Structure_Module_Hammer, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PN_Structure_Encoder(Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_Hammer(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state    
    
class Fusion_Uniform_Structure_Module_Hammer(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Fusion_Uniform_Structure_Module_Hammer, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Fusion_Structure_Encoder(Structure_Semantic_dim, 60*2+40, Structure_xyz_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_Hammer(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

class PT_Uniform_Structure_Module_Hammer(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PT_Uniform_Structure_Module_Hammer, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PT_Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_Hammer(paras)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state  