import torch
import torch.nn as nn 
import math 
from StructPolicy.StructGen.utils import rotate
from StructPolicy.StructTransformer.Encoders import Para_Estimator, Structure_Encoder, PN_Structure_Encoder, PT_Structure_Encoder, Fusion_Structure_Encoder
from StructPolicy.StructGen.Uniform_Structure import Uniform_Cuboid, Uniform_Cylinder

def Cylinder (height, radius, position, rotation, theta=30, Sem=0, order=1):
    """
    :param height: [B, 1], height of cylinder in Z-axis
    :param radius: [B, 1], radius of cylinder in X-Y-Plane
    :param position, rotation: [B,3], [B,3]
    :param theta: int
    """
    B = height.shape[0]
    n = int(360 / theta)
    theta_rad = theta * math.pi / 180.0
    
    if order == 1 : 
        V = [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), -0.5] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), math.sin(i * theta_rad), 0.5] for i in range(n)
        ]
    elif order == 2 :
        V = [
            [math.cos(i * theta_rad), -0.5, math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [math.cos(i * theta_rad), 0.5, math.sin(i * theta_rad)] for i in range(n)
        ]
    else :
        V = [
            [-0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ] + [
            [0.5, math.cos(i * theta_rad), math.sin(i * theta_rad)] for i in range(n)
        ]
    
    vertices = torch.tensor(V, dtype=height.dtype, device=height.device).unsqueeze(dim=0).repeat(B,1,1)  # shape [B, 2n, 3]
    deformation = torch.cat([radius, radius, height], dim=1).unsqueeze(dim=1)
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    
    semantic = torch.ones([B,2*n,1], device=height.device, dtype=height.dtype) * Sem
    Affordance = torch.cat([vertices, semantic], dim=-1)
    
    return Affordance

def Cuboid(size, position, rotation, Sema):
    """
    :size[B,3]: [0]=length, [1]=width, [2]=height
    :position[B,3], rotation[B,3]:
    """
    B = size.shape[0]
    
    vertices = torch.tensor([
        [-1 / 2, 1 / 2, 1 / 2],
        [1 / 2, 1 / 2, 1 / 2],
        [-1 / 2, 1 / 2, -1 / 2],
        [1 / 2, 1 / 2, -1 / 2],
        [-1 / 2, -1 / 2, 1 / 2],
        [1 / 2, -1 / 2, 1 / 2],
        [-1 / 2, -1 / 2, -1 / 2],
        [1 / 2, -1 / 2, -1 / 2]
    ], device=position.device, dtype=position.dtype)
    
    deformation = size.unsqueeze(dim=1).repeat(1, 8, 1)
    
    vertices = rotate(vertices * deformation, rotation) + position.unsqueeze(dim=1)
    Semantic = torch.ones([B, 8, 1], device=position.device, dtype=position.dtype) * Sema
    Affordance = torch.cat((vertices, Semantic), dim=-1)
    
    return Affordance

def Structure_Assembly(paras, theta=30):
    """
    Input: 
        paras
        theta
    Output:
        [Affordance, Semantic]
        Grasp pose
    """
    B = paras.shape[0]
    para_peg = paras[:, 0:8]
    para_nut = paras[:, 8:20]
    ones = torch.ones_like(paras[:, 0:1]).unsqueeze(dim=1)
    zeros = torch.zeros_like(paras[:, 0:1])
    
    ##################
    # Structure for peg
    ##################
    
    position_peg = para_peg[:, 0:3]
    rotation_peg = para_peg[:, 3:6]
    radius_peg = para_peg[:,6:7]
    high_peg = para_peg[:, 7:8]
    
    Affordance_peg = Cylinder(high_peg, radius_peg, position_peg, rotation_peg, 30, Sem=1)
    
    ##################
    # Structure for nut
    ##################
    
    position_nut = para_nut[:, 0:3].unsqueeze(dim=1)
    rotation_nut = para_nut[:, 3:6]
    handle_nut = para_nut[:, 6:9]
    high_rod = para_nut[:, 9:10]
    radius_ring = para_nut[:, 10:12]
    
    V_ring = []
    max = math.floor(360/theta)
    for i in range(max):
        angle = math.radians(i * theta)
        V_ring.append(torch.tensor([math.cos(angle), math.sin(angle), 0], dtype=paras.dtype, device=paras.device))
    for i in range(max):
        angle = math.radians(i * theta)
        V_ring.append(torch.tensor([math.cos(angle), math.sin(angle), 0], dtype=paras.dtype, device=paras.device))
        
    vertices_ring = torch.stack(V_ring, dim=0).unsqueeze(dim=0).repeat(B,1,1)
    deformation_ring = torch.cat((radius_ring, zeros), dim=-1).unsqueeze(dim=1)
    vertices_ring = vertices_ring * deformation_ring
    vertices_ring = rotate(vertices_ring, rotation_nut) + position_nut
    
    Semantic_ring = (ones*2).repeat(1, vertices_ring.shape[1], 1)
    Affordance_ring = torch.cat((vertices_ring, Semantic_ring), dim=-1)
    
    V_rod = [
        torch.tensor([0, 0, 0], dtype=paras.dtype, device=paras.device),
        torch.tensor([0, -1, 0], dtype=paras.dtype, device=paras.device),
    ] 
    
    vertices_rod = torch.stack(V_rod, dim=0).unsqueeze(dim=0).repeat(B,1,1)
    deformation_rod = torch.cat([zeros, high_rod, zeros], dim=-1).unsqueeze(dim=1)
    shift_rod = torch.cat([zeros, -1 * radius_ring[:, 1:2], zeros], dim=-1).unsqueeze(dim=1)
    vertices_rod = vertices_rod * deformation_rod + shift_rod
    vertices_rod = rotate(vertices_rod, rotation_nut) + position_nut
    
    Semantic_rod = (ones*3).repeat(1, vertices_rod.shape[1], 1)
    Affordance_rod = torch.cat((vertices_rod, Semantic_rod), dim=-1)
    
    position_handle = position_nut.squeeze(dim=1) - torch.cat((zeros, - radius_ring[:, 1:2] - high_rod - 0.5 * handle_nut[:, 1:2], zeros), dim=-1)
    Affordance_handle = Cuboid(handle_nut, position_handle, rotation_nut, 4)
    vertices_handle = Affordance_handle[:, :, 0:3]
    
    #############
    # Affordance
    #############
    
    Affordance = torch.cat([
        Affordance_peg,
        Affordance_ring,
        Affordance_rod,
        Affordance_handle
    ], dim=1)  # shape: (B, N_all, 4)
    
    #############
    # Grasp pose
    #############
    
    center = vertices_handle.mean(dim=1)  # shape (B, 3)
    direction = torch.tensor([0, 0, -1], device=paras.device).expand(B, 3)  
    grasp_pose = torch.cat([center, direction], dim=-1)  # shape (B, 6): [x, y, z, dx, dy, dz]
    
    return Affordance, grasp_pose

class Structure_Module_Assembly(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Structure_Module_Assembly, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        Structure, grasp_pose = Structure_Assembly(paras, 30)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

class PN_Structure_Module_Assembly(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PN_Structure_Module_Assembly, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PN_Structure_Encoder(Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        Structure, grasp_pose = Structure_Assembly(paras, 30)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class Fusion_Structure_Module_Assembly(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Fusion_Structure_Module_Assembly, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Fusion_Structure_Encoder(Structure_Semantic_dim, 24+24+2+8, Structure_xyz_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        Structure, grasp_pose = Structure_Assembly(paras, 30)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PT_Structure_Module_Assembly(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PT_Structure_Module_Assembly, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PT_Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        Structure, grasp_pose = Structure_Assembly(paras, 30)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state

#######################################################################

# Uniform Sample representation

#######################################################################

def Uniform_Structure_Assembly(paras, theta=12):
    """
    Input: 
        paras
        theta
    Output:
        [Affordance, Semantic]
        Grasp pose
    """
    B = paras.shape[0]
    para_peg = paras[:, 0:8]
    para_nut = paras[:, 8:20]
    ones = torch.ones_like(paras[:, 0:1]).unsqueeze(dim=1)
    zeros = torch.zeros_like(paras[:, 0:1])
    
    ##################
    # Structure for peg
    ##################
    
    position_peg = para_peg[:, 0:3]
    rotation_peg = para_peg[:, 3:6]
    radius_peg = para_peg[:,6:7]
    high_peg = para_peg[:, 7:8]
    
    Affordance_peg = Uniform_Cylinder(high_peg, radius_peg, position_peg, rotation_peg, 30, Sem=1, order=1, N=40)
    
    ##################
    # Structure for nut
    ##################
    
    position_nut = para_nut[:, 0:3].unsqueeze(dim=1)
    rotation_nut = para_nut[:, 3:6]
    handle_nut = para_nut[:, 6:9]
    high_rod = para_nut[:, 9:10]
    radius_ring = para_nut[:, 10:12]
    
    V_ring = []
    max = math.floor(360/theta)
    for i in range(max):
        angle = math.radians(i * theta)
        V_ring.append(torch.tensor([math.cos(angle), math.sin(angle), 0], dtype=paras.dtype, device=paras.device))
    for i in range(max):
        angle = math.radians(i * theta)
        V_ring.append(torch.tensor([math.cos(angle), math.sin(angle), 0], dtype=paras.dtype, device=paras.device))
        
    vertices_ring = torch.stack(V_ring, dim=0).unsqueeze(dim=0).repeat(B,1,1)
    deformation_ring = torch.cat((radius_ring, zeros), dim=-1).unsqueeze(dim=1)
    vertices_ring = vertices_ring * deformation_ring
    vertices_ring = rotate(vertices_ring, rotation_nut) + position_nut
    
    Semantic_ring = (ones*2).repeat(1, vertices_ring.shape[1], 1)
    Affordance_ring = torch.cat((vertices_ring, Semantic_ring), dim=-1)
    
    V_rod = [
        torch.tensor([0, 0, 0], dtype=paras.dtype, device=paras.device),
        torch.tensor([0, -1, 0], dtype=paras.dtype, device=paras.device),
    ] 
    
    vertices_rod = torch.stack(V_rod, dim=0).unsqueeze(dim=0).repeat(B,1,1)
    deformation_rod = torch.cat([zeros, high_rod, zeros], dim=-1).unsqueeze(dim=1)
    shift_rod = torch.cat([zeros, -1 * radius_ring[:, 1:2], zeros], dim=-1).unsqueeze(dim=1)
    vertices_rod = vertices_rod * deformation_rod + shift_rod
    vertices_rod = rotate(vertices_rod, rotation_nut) + position_nut
    
    Semantic_rod = (ones*3).repeat(1, vertices_rod.shape[1], 1)
    Affordance_rod = torch.cat((vertices_rod, Semantic_rod), dim=-1)
    
    position_handle = position_nut.squeeze(dim=1) - torch.cat((zeros, - radius_ring[:, 1:2] - high_rod - 0.5 * handle_nut[:, 1:2], zeros), dim=-1)
    Affordance_handle = Uniform_Cuboid(handle_nut, position_handle, rotation_nut, 4, 40)
    vertices_handle = Affordance_handle[:, :, 0:3]
    
    #############
    # Affordance
    #############
    
    Affordance = torch.cat([
        Affordance_peg,
        Affordance_ring,
        Affordance_rod,
        Affordance_handle
    ], dim=1)  # shape: (B, N_all, 4)
    
    #############
    # Grasp pose
    #############
    
    center = vertices_handle.mean(dim=1)  # shape (B, 3)
    direction = torch.tensor([0, 0, -1], device=paras.device).expand(B, 3)  
    grasp_pose = torch.cat([center, direction], dim=-1)  # shape (B, 6): [x, y, z, dx, dy, dz]
    
    return Affordance, grasp_pose

class Uniform_Structure_Module_Assembly(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Uniform_Structure_Module_Assembly, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_Assembly(paras, 12)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PT_Uniform_Structure_Module_Assembly(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PT_Uniform_Structure_Module_Assembly, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PT_Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_Assembly(paras, 12)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class PN_Uniform_Structure_Module_Assembly(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(PN_Uniform_Structure_Module_Assembly, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = PN_Structure_Encoder(Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        # paras = process_paras(paras)
        Structure, grasp_pose = Uniform_Structure_Assembly(paras, 12)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state
    
class Fusion_Uniform_Structure_Module_Assembly(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_Semantic_dim = paralist[2]
        Structure_xyz_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Fusion_Uniform_Structure_Module_Assembly, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Fusion_Structure_Encoder(Structure_Semantic_dim, 40+60+2+40, Structure_xyz_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        Structure, grasp_pose = Uniform_Structure_Assembly(paras, 12)
        Structure_emb = self.Structure_encoder(Structure)
        robot_state = torch.cat((grasp_pose, robot_state), dim=-1)
        return Structure_emb, robot_state