import torch
import math

def get_rodrigues_matrix(axis, angle):
    '''
    axis = [1,0,0] / [0, 1, 0] / [0, 0, 1]
    angle = [batch_size, 1]
    return -> [batch_size, 3, 3]
    '''
    axis = torch.tensor(axis, dtype=angle.dtype, device=angle.device)
    identity = torch.eye(3).unsqueeze(0).repeat(angle.shape[0], 1, 1).to(angle.device)      # identity -> [batch_size, 3, 3]
    s1 = torch.tensor(
        [
            [0, -axis[2], axis[1]],
            [axis[2], 0, -axis[0]],
            [-axis[1], axis[0], 0],
        ],
        dtype=angle.dtype,
        device=angle.device
    ).unsqueeze(0).repeat(angle.shape[0], 1, 1)                             # s1 -> [batch_size, 3, 3]
    axis = axis.view(3, 1)  # Ensure axis has shape [3, 1]
    s2 = torch.matmul(axis, axis.T).unsqueeze(0).repeat(angle.shape[0], 1, 1)      #s2 -> [batch_size, 3, 3]
    cos_angle = torch.cos(angle).unsqueeze(1).unsqueeze(1)
    sin_angle = torch.sin(angle).unsqueeze(1).unsqueeze(1)

    #print ("cos_angle:", cos_angle.shape)
    rodrigues_matrix = cos_angle * identity + sin_angle * s1 + (1 - cos_angle) * s2
    return rodrigues_matrix

def apply_transformation(vertices, position, rotation, rotation_order="XYZ", offset_first=False):
    '''
    vertices -> [batch_size * 3, 8, 3]
    position -> [batch_size * 3, 3]
    rotation -> [batch_size * 3, 3]
    '''

    position = position.unsqueeze(1).repeat(1, 8, 1)

    # process position first
    if offset_first:
        vertices = vertices + position

    # process rotation
    rot_mat = {}

    rot_mat["X"] = get_rodrigues_matrix([1, 0, 0], rotation[:,0])
    rot_mat["Y"] = get_rodrigues_matrix([0, 1, 0], rotation[:,1])
    rot_mat["Z"] = get_rodrigues_matrix([0, 0, 1], rotation[:,2])

    for s in rotation_order:
        vertices = torch.matmul(vertices, rot_mat[s].transpose(1, 2))

    # process position second
    if not offset_first:
        vertices = vertices + position

    return vertices

def transformation_matrix(position, rotation):
    '''
    Get RT
    position -> [batch_size * 3, 3]
    rotation -> [batch-size * 3, 3]
    '''
    Rx, Ry, Rz = rotation[:, 0], rotation[:, 1], rotation[:, 2]

    # calculate Rotation matrix
    ones = torch.ones (rotation.shape[0], device=rotation.device)
    zeros = torch.zeros(rotation.shape[0], device=rotation.device)
    Rx_matrix = torch.stack ([
        torch.stack([ones, zeros, zeros], dim=-1),
        torch.stack([zeros, torch.cos(Rx), -torch.sin(Rx)], dim=-1),
        torch.stack([zeros, torch.sin(Rx), torch.cos(Rx)], dim=-1)
    ], dim=-2)
    Ry_matrix = torch.stack([
        torch.stack([torch.cos(Ry), zeros, torch.sin(Ry)], dim=-1),
        torch.stack([zeros, ones, zeros], dim=-1),
        torch.stack([-torch.sin(Ry), zeros, torch.cos(Ry)], dim=-1)
    ], dim=-2)
    Rz_matrix = torch.stack([
        torch.stack([torch.cos(Rz), -torch.sin(Rz), zeros], dim=-1),
        torch.stack([torch.sin(Rz), torch.cos(Rz), zeros], dim=-1),
        torch.stack([zeros, zeros, ones], dim=-1)
    ], dim=-2)
    rotation_matrix = Rz_matrix @ Ry_matrix @ Rx_matrix
    
    # concatenate position
    RT = torch.eye(4, device=rotation.device).unsqueeze(0).repeat(rotation.shape[0], 1, 1)
    RT[:, :3, :3] = rotation_matrix
    RT[:, :3, -1] = position
    return RT

class Cylinder():
    def __init__(self, dtype, device, height, top_radius, bottom_radius = None, top_radius_z = None, bottom_radius_z = None, is_half = False, is_quarter = False, position = [0, 0, 0], rotation = [0, 0, 0], rotation_order = "XYZ"):
        """
        :param height: height of the cylinder in the Y-axis direction
        :param top_radius: radius of the top surface of the cylinder in the X-axis direction
        :param bottom_radius: radius of the bottom surface of the cylinder in the X-axis direction
        :param top_radius_z: radius of the top surface of the cylinder in the Z-axis direction
        :param bottom_radius_z: radius of the bottom surface of the cylinder in the Z-axis direction
        :param is_half: whether the cylinder is half
        :param is_quarter: whether the cylinder is quarter
        :param position: position (x, y, z) of the cylinder
        :param rotation: rotation of the cylinder, represented via Euler angles (x, y, z)
        :param rotation_order: rotation order of the three rotation axes of the cylinder

        """

        # Filling Missing Values
        if bottom_radius == None:
            bottom_radius = top_radius
        if top_radius_z == None:
            top_radius_z = top_radius
        if bottom_radius_z == None:
            bottom_radius_z = bottom_radius

        super().__init__(position, rotation, rotation_order)

        # Record Parameters
        self.height = height
        self.top_radius = top_radius
        self.bottom_radius = bottom_radius
        self.top_radius_z = top_radius_z
        self.bottom_radius_z = bottom_radius_z
        self.is_half = is_half
        self.is_quarter = is_quarter
        self.position = position
        self.rotation = rotation
        self.rotation_order = rotation_order
            
        # Manually Defined Default Template Instance 
        vertices = []
        vertices.append([0, 1 / 2, 0])
        vertices.append([0, -1 / 2, 0])
        num_of_segment = 256
        for i in range(num_of_segment+1):
            rotation_tmp = math.pi * 2 / num_of_segment * i
            if is_half:
                rotation_tmp = rotation_tmp / 2
            elif is_quarter:
                rotation_tmp = rotation_tmp / 4
            vertices.append([math.cos(rotation_tmp), 1 / 2, math.sin(rotation_tmp)])
            vertices.append([math.cos(rotation_tmp), -1 / 2, math.sin(rotation_tmp)])
        self.vertices = torch.tensor(vertices, dtype=self.dtype, device=self.device)

        faces = []
        faces.append([1, 0, 3])
        faces.append([0, 2, 3])
        for i in range(num_of_segment):
            faces.append([2*i+2, 2*i+4, 2*i+3])
            faces.append([2*i+5, 2*i+3, 2*i+4])
            faces.append([2*i+2, 0, 2*i+4])
            faces.append([1, 2*i+3, 2*i+5])
        faces.append([0, 1, 2*(num_of_segment+1)+1])
        faces.append([2*(num_of_segment+1)+1, 0, 2*(num_of_segment+1)])
        self.faces = torch.tensor(faces, dtype=dtype, device=device)

        # Differentiable Deformation
        scale = torch.tensor([
            [top_radius, height, top_radius_z], 
            [bottom_radius, height, bottom_radius_z]
        ], dtype=dtype, device=device)
        scale = scale.repeat(self.vertices.shape[0] // 2, 1)
        self.vertices = self.vertices * scale

        # Global Transformation
        self.vertices = apply_transformation(self.vertices, position, rotation, rotation_order)

class Round_Switch():
    def __init__(self, number_of_switch, size, offset, offset_Z, switch_rotation, position=[0, 0, 0], rotation=[0, 0, 0]):

        # Process rotation param
        rotation = [x / 180 * math.pi for x in rotation]
        switch_rotation = [x / 180 * math.pi for x in switch_rotation]
        
        self.rotation = rotation 
        self.position = position

        # Record Parameters
        self.number_of_switch = 1
        self.size = size
        self.offset = offset
        self.offset_Z = offset_Z
        self.switch_rotation = switch_rotation

        # Instantiate component geometries
        vertices_list = []
        faces_list = []
        total_num_vertices = 0

        for i in range(number_of_switch[0]):
            base_mesh_position = [
                locals()['offset_%d'%(i+1)][0],
                locals()['offset_%d'%(i+1)][1],
                offset_Z[0]
            ]
            base_mesh_rotation = [math.pi / 2 + switch_rotation[0], 0, 0]
            self.base_mesh = Cylinder(size[1], size[0], size[0],
                                      position=base_mesh_position,
                                      rotation=base_mesh_rotation)
            vertices_list.append(self.base_mesh.vertices)
            faces_list.append(self.base_mesh.faces + total_num_vertices)
            total_num_vertices += len(self.base_mesh.vertices)

        self.vertices = torch.cat(vertices_list, dim=0)
        self.faces = torch.cat(faces_list, dim=0)

        # Global Transformation
        self.vertices = apply_transformation(self.vertices, position, rotation, offset_first=True)

        self.semantic = 'Switch'

def Switch1 (dtype, device, size, offset, rotation, position) :
    rotation = torch.tensor(math.pi,dtype=dtype,device=device) * rotation 
    
    
    vertices_list = []
    faces_list = []

    base_mesh_rotation = torch.stack([
        torch.tensor(math.pi / 2, dtype=dtype, device=device) + rotation[0],
        torch.tensor(0.0, dtype=dtype, device=device),
        torch.tensor(0.0, dtype=dtype, device=device)
    ])
    base_mesh = Cylinder(size[1], size[0], size[0], position=position, rotation=base_mesh_rotation)
    vertices_list.append(base_mesh.vertices)
    faces_list.append(base_mesh.faces)

    vertices = torch.cat(vertices_list, dim=0)
    faces_list = torch.cat(faces_list, dim=0)

    vertices = apply_transformation(vertices, position, rotation,offset_first=True)
    

    gripper_pos = apply_transformation(
        torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device),
        position,
        rotation,
        offset_first=True
    )[0] 

    gripper_rot = rotation + torch.tensor([math.pi, 0.0, 0.0], dtype=dtype, device=device)

    result = torch.cat((gripper_pos, gripper_rot), dim=0)
    return result

def Switch (para):
    size = para[:,:1]
    position = para[:,1:4]
    rotation = para[:,4:7]
    dtype = rotation.dtype
    device = rotation.device

    B = rotation.shape[0]
    rotation = rotation * math.pi 

    size = torch.cat([
        torch.zeros((B, 1), dtype=size.dtype, device=size.device),  
        size / 2,                                                       
        torch.zeros((B, 1), dtype=size.dtype, device=size.device)   
    ], dim=1)
    gripper_pos = position + size

    gripper_rot = rotation + torch.tensor([math.pi, 0.0, 0.0], dtype=dtype, device=device)

    result = torch.cat((gripper_pos, gripper_rot), dim=1)
    return result