import torch
import math

def get_rodrigues_matrix(axis, angle):
    '''
    axis = [1,0,0] / [0, 1, 0] / [0, 0, 1]
    angle = [batch_size, 1]
    return -> [batch_size, 3, 3]
    '''
    axis = torch.tensor(axis, dtype=angle.dtype, device=angle.device)
    identity = torch.eye(3).unsqueeze(0).repeat(angle.shape[0], 1, 1).to(angle.device)      # identity -> [batch_size, 3, 3]
    s1 = torch.tensor(
        [
            [0, -axis[2], axis[1]],
            [axis[2], 0, -axis[0]],
            [-axis[1], axis[0], 0],
        ],
        dtype=angle.dtype,
        device=angle.device
    ).unsqueeze(0).repeat(angle.shape[0], 1, 1)                             # s1 -> [batch_size, 3, 3]
    axis = axis.view(3, 1)  # Ensure axis has shape [3, 1]
    s2 = torch.matmul(axis, axis.T).unsqueeze(0).repeat(angle.shape[0], 1, 1)      #s2 -> [batch_size, 3, 3]
    cos_angle = torch.cos(angle).unsqueeze(1).unsqueeze(1)
    sin_angle = torch.sin(angle).unsqueeze(1).unsqueeze(1)

    #print ("cos_angle:", cos_angle.shape)
    rodrigues_matrix = cos_angle * identity + sin_angle * s1 + (1 - cos_angle) * s2
    return rodrigues_matrix

def apply_transformation(vertices, position, rotation, rotation_order="XYZ", offset_first=False):
    '''
    vertices -> [batch_size * 3, 8, 3]
    position -> [batch_size * 3, 3]
    rotation -> [batch_size * 3, 3]
    '''

    position = position.unsqueeze(1).repeat(1, 8, 1)

    # process position first
    if offset_first:
        vertices = vertices + position

    # process rotation
    rot_mat = {}

    rot_mat["X"] = get_rodrigues_matrix([1, 0, 0], rotation[:,0])
    rot_mat["Y"] = get_rodrigues_matrix([0, 1, 0], rotation[:,1])
    rot_mat["Z"] = get_rodrigues_matrix([0, 0, 1], rotation[:,2])

    for s in rotation_order:
        vertices = torch.matmul(vertices, rot_mat[s].transpose(1, 2))

    # process position second
    if not offset_first:
        vertices = vertices + position

    return vertices

def transformation_matrix(position, rotation):
    '''
    Get RT
    position -> [batch_size * 3, 3]
    rotation -> [batch-size * 3, 3]
    '''
    Rx, Ry, Rz = rotation[:, 0], rotation[:, 1], rotation[:, 2]

    # calculate Rotation matrix
    ones = torch.ones (rotation.shape[0], device=rotation.device)
    zeros = torch.zeros(rotation.shape[0], device=rotation.device)
    Rx_matrix = torch.stack ([
        torch.stack([ones, zeros, zeros], dim=-1),
        torch.stack([zeros, torch.cos(Rx), -torch.sin(Rx)], dim=-1),
        torch.stack([zeros, torch.sin(Rx), torch.cos(Rx)], dim=-1)
    ], dim=-2)
    Ry_matrix = torch.stack([
        torch.stack([torch.cos(Ry), zeros, torch.sin(Ry)], dim=-1),
        torch.stack([zeros, ones, zeros], dim=-1),
        torch.stack([-torch.sin(Ry), zeros, torch.cos(Ry)], dim=-1)
    ], dim=-2)
    Rz_matrix = torch.stack([
        torch.stack([torch.cos(Rz), -torch.sin(Rz), zeros], dim=-1),
        torch.stack([torch.sin(Rz), torch.cos(Rz), zeros], dim=-1),
        torch.stack([zeros, zeros, ones], dim=-1)
    ], dim=-2)
    rotation_matrix = Rz_matrix @ Ry_matrix @ Rx_matrix
    
    # concatenate position
    RT = torch.eye(4, device=rotation.device).unsqueeze(0).repeat(rotation.shape[0], 1, 1)
    RT[:, :3, :3] = rotation_matrix
    RT[:, :3, -1] = position
    return RT

def Block (percepture):        
    parameters = torch.reshape(percepture,(percepture.shape[0]*3 ,9)) 
    #parameters = torch.reshape(percepture,(-1 ,6)) 
    # parameters[0] : [length, height, width, position[3], rotation[3]]

    # 1. Part pose
    # RT -> [batch_size, 3, 4, 4]
    Rotation = torch.abs(parameters[:,-3:])
    Position = parameters[:,-6:-3]
    RT = transformation_matrix(Position, Rotation).reshape(percepture.shape[0], 3, 4, 4)

    # 2. Affordance

    vertices = torch.tensor([
            [-1 / 2, 1 / 2, 1 / 2],
            [1 / 2, 1 / 2, 1 / 2],
            [-1 / 2, 1 / 2, -1 / 2],
            [1 / 2, 1 / 2, -1 / 2],
            [-1 / 2, -1 / 2, 1 / 2],
            [1 / 2, -1 / 2, 1 / 2],
            [-1 / 2, -1 / 2, -1 / 2],
            [1 / 2, -1 / 2, -1 / 2]
    ], dtype=parameters.dtype, device=parameters.device).unsqueeze(0).repeat(parameters.shape[0], 1, 1)

    # faces -> [batch_size * 3, 12, 3]
    faces = torch.tensor([
            [0, 1, 2], [1, 3, 2],
            [4, 6, 5], [5, 6, 7],
            [0, 4, 5], [0, 5, 1],
            [2, 7, 6], [2, 3, 7],
            [0, 6, 4], [0, 2, 6],
            [1, 5, 7], [1, 7, 3]
    ], dtype=parameters.dtype, device=parameters.device).unsqueeze(0).repeat(parameters.shape[0], 1, 1)

    
    length = parameters[:, 0].unsqueeze(1)  # [batch_size * 3, 1]
    height = parameters[:, 1].unsqueeze(1)  # [batch_size * 3, 1]
    width = parameters[:, 2].unsqueeze(1)   # [batch_size * 3, 1]

    # resize
    vertices[:, :, 0] *= length
    vertices[:, :, 1] *= height
    vertices[:, :, 2] *= width
    


    vertices = apply_transformation(vertices=vertices, position=Position, rotation=Rotation)        # vertices -> [batch_size * 3, 8, 3]

    # 3. result
    # RT -> [batch_size, 3, 4, 4]   vertices -> [batch_size, 3, 8, 3]   faces -> [batch_size, 3, 12, 3]
    RT = torch.reshape(RT, (percepture.shape[0], -1))
    vertices = torch.reshape(vertices, (percepture.shape[0], -1))
    faces = torch.reshape(faces, (percepture.shape[0], -1))
    result = torch.cat((RT, vertices, faces), dim=-1)
    
    # 4. Gripper pose（global）
    gripper_offset_local = torch.cat([
        torch.zeros_like(height),             # [B*3, 1], x offset
        height / 2,                           # [B*3, 1], y offset
        torch.zeros_like(height)             # [B*3, 1], z offset
    ], dim=1)  # shape: [B*3, 3]
    
    # apply rotation to offset
    R_block = transformation_matrix(torch.zeros_like(Position), Rotation)[:, :3, :3]  # [B*3, 3, 3]
    gripper_offset_global = torch.bmm(R_block, gripper_offset_local.unsqueeze(2)).squeeze(2)  # [B*3, 3]
    
    gripper_pos = Position + gripper_offset_global  # [B*3, 3]
    gripper_rot = Rotation + torch.tensor([math.pi, 0.0, 0.0], dtype=parameters.dtype, device=parameters.device)  # [B*3, 3]
    
    gripper_pose = torch.cat((gripper_pos, gripper_rot), dim=1)  # [B*3, 6]
    gripper_pose = torch.reshape(gripper_pose, (percepture.shape[0], -1))  # [B, 18]
    
    return result   # result -> [batch_size, 3*(RT[16] + vertices[24] + faces[36])] = [batch_size, 228]