import trimesh
import numpy as np
import torch

def read_vertex_laplacian(mesh):
    vertex_neighbors_list = mesh.vertex_neighbors
    
    length_list = []
    max_len = 0 
    for vertex_neighbors in vertex_neighbors_list:
        length_list.append(len(vertex_neighbors))
        if len(vertex_neighbors) > max_len:
            max_len = len(vertex_neighbors)
            
    new_vertex_neighbors_list = []
    for vertex_neighbors in vertex_neighbors_list:
        new_vertex_neighbors = np.concatenate([np.array(vertex_neighbors, dtype='int32'), -np.ones(max_len-len(vertex_neighbors))], axis=0)
        new_vertex_neighbors_list.append(new_vertex_neighbors)
        
    new_vertex_neighbors_np = np.array(new_vertex_neighbors_list, dtype='int32')
    length_np = np.array(length_list, dtype='int32')
    #print(new_vertex_neighbors_np.shape, length_np.shape)
    lape_idx = np.concatenate([new_vertex_neighbors_np, length_np[:, None]], axis=1)
    
    return lape_idx

def give_laplacian_coordinates(pred, lape_idx):
    r''' Returns the laplacian coordinates for the predictions and given block.

        The helper matrices are used to detect neighbouring vertices and
        the number of neighbours which are relevant for the weight matrix.
        The maximal number of neighbours is 8, and if a vertex has less,
        the index -1 is used which points to the added zero vertex.

    Arguments:
        pred (tensor): vertex predictions
    '''
    batch_size = pred.shape[0]
    num_vert = pred.shape[1]
    device = pred.device
    
    # Add "zero vertex" for vertices with less than 8 neighbours
    vertex = torch.cat(
        [pred, torch.zeros(batch_size, 1, 3).to(device)], 1)
    assert(vertex.shape == (batch_size, num_vert+1, 3))
    
    # Get 8 neighbours for each vertex; if a vertex has less, the
    # remaining indices are -1
    # indices = torch.from_numpy(lape_idx[:, :8]).to(device)
    # assert(indices.shape == (num_vert, 8))
    indices = lape_idx[:, :-1].to(device)
    weights = lape_idx[:, -1].float().to(device)
    weights = torch.reciprocal(weights)
    weights = weights.view(-1, 1).expand(-1, 3)
    vertex_select = vertex[:, indices.long(), :]
    #assert(vertex_select.shape == (batch_size, num_vert, 8, 3))
    laplace = vertex_select.sum(dim=2)  # Add neighbours
    laplace = torch.mul(laplace, weights)  # Multiply by weights
    laplace = torch.sub(pred, laplace)  # Subtract from prediction
    assert(laplace.shape == (batch_size, num_vert, 3))
    return laplace

################
# As Rigid As Possible for laplacian

def batch_euler2Rotation(angles):
    '''
    Transfer euler angles to rotation matrix, both in torch.tensor
    :param angles: euler angles in torch.tensor of shape[3, ]
    :return: rotation matrix in torch.tensor of shape[3, 3]
    '''
    batch_size, num_vert, _ = angles.size()
    rot = torch.zeros(size=(batch_size, num_vert, 3, 3), dtype=torch.float, device=angles.device) #.to(angles.device)
    sinX, cosX = torch.sin(angles[:, :, 0]), torch.cos(angles[:, :, 0])
    sinY, cosY = torch.sin(angles[:, :, 1]), torch.cos(angles[:, :, 1])
    sinZ, cosZ = torch.sin(angles[:, :, 2]), torch.cos(angles[:, :, 2])
    rot[:, :, 0, 0] = cosY * cosZ
    rot[:, :, 0, 1] = cosZ * sinX * sinY - cosX * sinZ
    rot[:, :, 0, 2] = sinX * sinZ + cosX * cosZ * sinY
    rot[:, :, 1, 0] = cosY * sinZ
    rot[:, :, 1, 1] = cosX * cosZ + sinX * sinY * sinZ
    rot[:, :, 1, 2] = cosX * sinY * sinZ - cosZ * sinX
    rot[:, :, 2, 0] = -sinY
    rot[:, :, 2, 1] = cosY * sinX
    rot[:, :, 2, 2] = cosX * cosY
    return rot

def lapla_term(original_pos, deformed_pos, euler, vertex_adjacency):
    lape_idx = vertex_adjacency[0, :, :].long()
    original_lap = give_laplacian_coordinates(original_pos, lape_idx)
    deformed_lap = give_laplacian_coordinates(deformed_pos, lape_idx)
    
    rot_matrix = batch_euler2Rotation(euler)
    rotated_original_lap = torch.matmul(rot_matrix, original_lap.unsqueeze(-1)).squeeze(-1)
    det = deformed_lap - rotated_original_lap
    det = torch.sum(det ** 2, dim=-1)
    #return torch.mean(det)
    return torch.sum(det)