import torch
import torch.nn as nn 
import math 
from lift3d.models.concept.MetaWorld.utils import rotate
from lift3d.models.concept.MetaWorld.utils import Para_Estimator, Knowledge_Encoder, PT_Knowledge_Encoder, Fusion_Knowledge_Encoder, PN_Knowledge_Encoder

def Cylinder (height, radius, position, rotation, theta, Sem, order):
    """
    :param height: [B, 1], height of cylinder in Z-axis
    :param radius: [B, 1], radius of cylinder in X-Y-Plane
    :param position, rotation: [B,3], [B,3]
    :param theta: int
    """
    B = height.shape[0]
    n = int(360 / theta)
    theta_rad = theta * math.pi / 180.0
    
    if order == 1 : 
        V = [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), -0.5] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), 0.5] for i in range(n)
        ]
    elif order == 2 :
        V = [
            [math.cos(i * theta_rad), -0.5, math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), 0.5, math.sin(i * theta_rad)] for i in range(n)
        ]
    else :
        V = [
            [-0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ]
    
    vertices = torch.tensor(V, dtype=height.dtype, device=height.device).unsqueeze(dim=0).repeat(B,1,1)  # shape [B, 2n, 3]
    deformation = torch.cat([radius, radius, height], dim=1).unsqueeze(dim=1)
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    
    semantic = torch.ones([B,2*n,1], device=height.device, dtype=height.dtype) * Sem
    Affordance = torch.cat([vertices, semantic], dim=-1)
    
    return Affordance

def Concept_ButtonPress(paras):
    """
    :paras[:, 0:8]: parameters of button
    :paras[:, 8:10]: parameters of stop
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Button
    position_Button = paras[:, 0:3]
    rotation_Button = paras[:, 3:6]
    radius_Button = paras[:, 6:7]
    height_Button = paras[:, 7:8]
    vertices_Button = Cylinder(height_Button, radius_Button, position_Button, rotation_Button, 30, 1, 1)
    
    direction = rotate(torch.cat([zeros, zeros, -1 * ones], dim=-1).unsqueeze(dim=1), rotation_Button).view([B,3])
    Grasp_pose = torch.cat([position_Button, direction], dim=-1)
    
    # Stop
    height_Stop = paras[:, 8:9]
    radius_Stop = paras[:, 9:10]
    position_Stop = position_Button + torch.cat([zeros, zeros, -0.5 * height_Button - 0.5 * height_Stop], dim=-1)
    vertices_Stop = Cylinder(height_Stop, radius_Stop, position_Stop, rotation_Button, 30, 2, 1)
    
    Affordance = torch.cat([vertices_Button, vertices_Stop], dim=1)
    return Affordance, Grasp_pose

def process_paras(raw):
    """
    Normalize raw output from Para_Estimator into structured and bounded physical parameters.
    Input: raw tensor of shape [B, 10]
    Output: tensor of shape [B, 10], scaled to valid physical ranges
    """
    B = raw.shape[0]

    # 0–3: Button Position ∈ [-3, 3]
    position = torch.tanh(raw[:, 0:3]) * 3.0

    # 3–6: Rotation (Euler angles) ∈ [-π, π]
    rotation = torch.tanh(raw[:, 3:6]) * math.pi

    # 6: Button Radius ∈ [0.01, 0.05]
    radius_button = torch.sigmoid(raw[:, 6:7]) * 0.04 + 0.01

    # 7: Button Height ∈ [0.02, 0.1]
    height_button = torch.sigmoid(raw[:, 7:8]) * 0.08 + 0.02

    # 8: Stop Height ∈ [0.02, 0.2]
    height_stop = torch.sigmoid(raw[:, 8:9]) * 0.18 + 0.02

    # 9: Stop Radius ∈ [0.02, 0.08]
    radius_stop = torch.sigmoid(raw[:, 9:10]) * 0.06 + 0.02

    return torch.cat([
        position, rotation,
        radius_button, height_button,
        height_stop, radius_stop
    ], dim=-1)

class Concept_Module_ButtonPress(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Concept_Module_ButtonPress, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        knowledge, grasp_pose = Concept_ButtonPress(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state

class PT_Concept_Module_ButtonPress(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(PT_Concept_Module_ButtonPress, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PT_Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        knowledge, grasp_pose = Concept_ButtonPress(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
  
class Fusion_Concept_Module_ButtonPress(nn.Module):
    def __init__ (self, in_dim, paralist):
        super(Fusion_Concept_Module_ButtonPress, self).__init__()
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Fusion_Knowledge_Encoder(Knowledge_Semantic_dim, 24+24, Knowledge_xyz_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        knowledge, grasp_pose = Concept_ButtonPress(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state

class PN_Concept_Module_ButtonPress(nn.Module):
    def __init__ (self, in_dim, paralist):
        super(PN_Concept_Module_ButtonPress, self).__init__()
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PN_Knowledge_Encoder(Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        paras = process_paras(paras)
        knowledge, grasp_pose = Concept_ButtonPress(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state 
    
#######################################################################

# Dense point cloud in concept space 

#######################################################################

def Dense_Cylinder(height, radius, position, rotation, theta=30, Sem=0, order=1, N=80):
    """
    Generate uniformly sampled points on a complete cylindrical surface (side + top/bottom)
    
    Args:
        height: [B,1] - Height of cylinder along axis
        radius: [B,1] - Radius of cylinder
        position: [B,3] - Center position
        rotation: [B,3] - Rotation angles (Euler)
        theta: int - Angle step for discretization (degrees) - unused in sampling
        Sem: int - Semantic value
        order: int - Axis orientation (1=Z, 2=Y, 3=X)
        N: int - Number of points to sample
        
    Returns:
        SampledPoints: [B, N, 4] - Sampled surface points with semantics (xyz + semantic)
    """
    B = height.shape[0]
    device = height.device
    dtype = height.dtype
    
    # Determine fraction of points for each surface type
    n_side = 40
    n_top = 20
    n_bottom = 20
    
    # Initialize output tensor
    sampled_points = torch.zeros(B, N, 3, device=device, dtype=dtype)
    
    # 1. Sample side surface (cylindrical part)
    if n_side > 0:
        u = torch.rand((B, n_side), device=device)  # Height [0,1]
        v = torch.rand((B, n_side), device=device)  # Angle [0,1]
        
        angles = 2 * math.pi * v
        heights = (u - 0.5) * height
        
        if order == 1:  # Z-axis
            x = radius * torch.cos(angles)
            y = radius * torch.sin(angles)
            z = heights
        elif order == 2:  # Y-axis
            x = radius * torch.cos(angles)
            z = radius * torch.sin(angles)
            y = heights
        else:  # X-axis
            y = radius * torch.cos(angles)
            z = radius * torch.sin(angles)
            x = heights
        
        sampled_points[:, :n_side, :] = torch.cat([x.unsqueeze(dim=-1), y.unsqueeze(dim=-1), z.unsqueeze(dim=-1)], dim=-1)
    
    # 2. Sample top and bottom surfaces (circular disks)
    if n_top + n_bottom > 0:
        # Shared angle sampling for both disks
        v = torch.rand((B, n_top + n_bottom), device=device) * 2 * math.pi
        r = torch.sqrt(torch.rand((B, n_top + n_bottom), device=device)) * radius
        
        if order == 1:  # Z-axis
            x = r * torch.cos(v)
            y = r * torch.sin(v)
            z_top = height / 2
            z_bottom = -height / 2
        elif order == 2:  # Y-axis
            x = r * torch.cos(v)
            z = r * torch.sin(v)
            y_top = height / 2
            y_bottom = -height / 2
        else:  # X-axis
            y = r * torch.cos(v)
            z = r * torch.sin(v)
            x_top = height / 2
            x_bottom = -height / 2
        
        # Assign to top and bottom sections
        if n_top > 0:
            if order == 1:
                z_values = torch.ones((B, n_top), device=device) * z_top
                sampled_points[:, n_side:n_side+n_top, :] = torch.cat([x[:, :n_top].unsqueeze(dim=-1), y[:, :n_top].unsqueeze(dim=-1), z_values.unsqueeze(dim=-1)], dim=-1)
            elif order == 2:
                y_values = torch.ones((B, n_top), device=device) * y_top
                sampled_points[:, n_side:n_side+n_top, :] = torch.cat([x[:, :n_top].unsqueeze(dim=-1), y_values.unsqueeze(dim=-1), z[:, :n_top].unsqueeze(dim=-1)], dim=-1)
            else:
                x_values = torch.ones((B, n_top), device=device) * x_top
                sampled_points[:, n_side:n_side+n_top, :] = torch.cat([x_values.unsqueeze(dim=-1), y[:, :n_top].unsqueeze(dim=-1), z[:, :n_top].unsqueeze(dim=-1)], dim=-1)
        
        if n_bottom > 0:
            start_idx = n_side + n_top
            if order == 1:
                z_values = torch.ones((B, n_bottom), device=device) * z_bottom
                sampled_points[:, start_idx:, :] = torch.cat([x[:, n_top:].unsqueeze(dim=-1), y[:, n_top:].unsqueeze(dim=-1), z_values.unsqueeze(dim=-1)], dim=-1)
            elif order == 2:
                y_values = torch.ones((B, n_bottom), device=device) * y_bottom
                sampled_points[:, start_idx:, :] = torch.cat([x[:, n_top:].unsqueeze(dim=-1), y_values.unsqueeze(dim=-1), z[:, n_top:].unsqueeze(dim=-1)], dim=-1)
            else:
                x_values = torch.ones((B, n_bottom), device=device) * x_bottom
                sampled_points[:, start_idx:, :] = torch.cat([x_values.unsqueeze(dim=-1), y[:, n_top:].unsqueeze(dim=-1), z[:, n_top:].unsqueeze(dim=-1)], dim=-1)
    
    # Apply rotation and translation
    sampled_points = rotate(sampled_points, rotation) + position.unsqueeze(1)
    
    # Add semantic information
    sampled_semantic = torch.ones(B, N, 1, device=device, dtype=dtype) * Sem
    SampledPoints = torch.cat([sampled_points, sampled_semantic], dim=-1)
    
    return SampledPoints

def Dense_Concept_ButtonPress(paras):
    """
    :paras[:, 0:8]: parameters of button
    :paras[:, 8:10]: parameters of stop
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Button
    position_Button = paras[:, 0:3]
    rotation_Button = paras[:, 3:6]
    radius_Button = paras[:, 6:7]
    height_Button = paras[:, 7:8]
    vertices_Button = Dense_Cylinder(height_Button, radius_Button, position_Button, rotation_Button, 30, 1, 1)
    
    direction = rotate(torch.cat([zeros, zeros, -1 * ones], dim=-1).unsqueeze(dim=1), rotation_Button).view([B,3])
    Grasp_pose = torch.cat([position_Button, direction], dim=-1)
    
    # Stop
    height_Stop = paras[:, 8:9]
    radius_Stop = paras[:, 9:10]
    position_Stop = position_Button + torch.cat([zeros, zeros, -0.5 * height_Button - 0.5 * height_Stop], dim=-1)
    vertices_Stop = Dense_Cylinder(height_Stop, radius_Stop, position_Stop, rotation_Button, 30, 2, 1)
    
    Affordance = torch.cat([vertices_Button, vertices_Stop], dim=1)
    return Affordance, Grasp_pose

class Dense_Concept_Module_ButtonPress(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Dense_Concept_Module_ButtonPress, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_ButtonPress(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state

class PT_Dense_Concept_Module_ButtonPress(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(PT_Dense_Concept_Module_ButtonPress, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PT_Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_ButtonPress(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state

class Fusion_Dense_Concept_Module_ButtonPress(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Fusion_Dense_Concept_Module_ButtonPress, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Fusion_Knowledge_Encoder(Knowledge_Semantic_dim, 160, Knowledge_xyz_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_ButtonPress(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state
    
class PN_Dense_Concept_Module_ButtonPress(nn.Module):
    def __init__ (self, in_dim, paralist):
        super(PN_Dense_Concept_Module_ButtonPress, self).__init__()
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_Semantic_dim = paralist[2]
        Knowledge_xyz_dim = paralist[3]
        Concept_out_dim = paralist[4]
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = PN_Knowledge_Encoder(Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        knowledge, grasp_pose = Dense_Concept_ButtonPress(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return concept_emb, robot_state 