import torch


def calc_C3D4_element_features(elements, new_elements_pos, old_elements_pos, _safe_zero, device):
    
    elements_pos = new_elements_pos
    elements_features = [1.0 / 6 * torch.abs(torch.sum(torch.mul(elements_pos[:, 0] - elements_pos[:, 3], 
                                    torch.cross(elements_pos[:, 1] - elements_pos[:, 3], elements_pos[:, 2] - elements_pos[:, 3], dim = 1)), dim = 1))
                        ]  #Element Volume

    elements_features[-1] = elements_features[-1].reshape(-1, 1)
    elements_features.append(torch.zeros((elements.shape[0], 1)).to(device))
    for i, j, k in [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)]:
        #print(elements_pos[:, i].shape)
        edge_x = torch.sqrt(torch.sum((elements_pos[:, i] - elements_pos[:, j])**2, dim = 1))
        edge_y = torch.sqrt(torch.sum((elements_pos[:, i] - elements_pos[:, k])**2, dim = 1))
        edge_z = torch.sqrt(torch.sum((elements_pos[:, j] - elements_pos[:, k])**2, dim = 1))
        S = (edge_x + edge_y + edge_z) / 2
        elements_features[-1][:, 0] += torch.sqrt(torch.maximum(_safe_zero, S * (S - edge_x) * (S - edge_y) * (S - edge_z)))  #Surface Area     

    
    elements_pos = old_elements_pos
    elements_features.append(1.0 / 6 * torch.abs(torch.sum(torch.mul(elements_pos[:, 0] - elements_pos[:, 3], 
                                    torch.cross(elements_pos[:, 1] - elements_pos[:, 3], elements_pos[:, 2] - elements_pos[:, 3], dim = 1)), dim = 1))
                        )#Element Volume

    elements_features[-1] = elements_features[-1].reshape(-1, 1)
    elements_features.append(torch.zeros((elements.shape[0], 1)).to(device))
    for i, j, k in [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)]:
        #print(elements_pos[:, i].shape)
        edge_x = torch.sqrt(torch.sum((elements_pos[:, i] - elements_pos[:, j])**2, dim = 1))
        edge_y = torch.sqrt(torch.sum((elements_pos[:, i] - elements_pos[:, k])**2, dim = 1))
        edge_z = torch.sqrt(torch.sum((elements_pos[:, j] - elements_pos[:, k])**2, dim = 1))
        S = (edge_x + edge_y + edge_z) / 2
        elements_features[-1][:, 0] += torch.sqrt(torch.maximum(_safe_zero, S * (S - edge_x) * (S - edge_y) * (S - edge_z)))  #Surface Area     

    elements_features = torch.concat(elements_features, dim = 1)    
    return elements_features
    
def calc_C3D4_face_features(new_faces_pos, old_faces_pos, _safe_zero):
    faces_pos = new_faces_pos
    faces_edges = torch.concat([torch.sqrt(torch.sum((faces_pos[:,i,:] - faces_pos[:,j,:])**2, dim = -1)).reshape(-1, 1) for i in range(3) for j in range(i + 1, 3)], dim = 1)
    faces_S = torch.sum(faces_edges, dim = 1) / 2
    faces_features = [torch.sqrt(torch.maximum(_safe_zero, faces_S * (faces_S - faces_edges[:, 0])*(faces_S - faces_edges[:, 1])*(faces_S - faces_edges[:, 2])).reshape(-1, 1))]

    faces_features.append(faces_edges)
    
    faces_pos = old_faces_pos
    faces_edges = torch.concat([torch.sqrt(torch.sum((faces_pos[:,i,:] - faces_pos[:,j,:])**2, dim = -1)).reshape(-1, 1) for i in range(3) for j in range(i + 1, 3)], dim = 1)
    faces_S = torch.sum(faces_edges, dim = 1) / 2
    faces_features.append(torch.sqrt(torch.maximum(_safe_zero, faces_S * (faces_S - faces_edges[:, 0])*(faces_S - faces_edges[:, 1])*(faces_S - faces_edges[:, 2])).reshape(-1, 1)))

    faces_features.append(faces_edges)
    faces_features = torch.concat(faces_features, dim = 1)
    
    return faces_features

def calc_tri_area(a, b, c, _safe_zero):
    faces_pos = torch.concat([a.reshape(-1, 1, 3), b.reshape(-1, 1, 3), c.reshape(-1, 1, 3)], dim = 1)
    faces_edges = torch.concat([torch.sqrt(torch.sum((faces_pos[:,i,:] - faces_pos[:,j,:])**2, dim = -1)).reshape(-1, 1) for i in range(3) for j in range(i + 1, 3)], dim = 1)
    faces_S = torch.sum(faces_edges, dim = 1) / 2
    return torch.sqrt(torch.maximum(_safe_zero, faces_S * (faces_S - faces_edges[:, 0])*(faces_S - faces_edges[:, 1])*(faces_S - faces_edges[:, 2])).reshape(-1, 1))


def calc_C3D8_face_features(new_faces_pos, old_faces_pos, _safe_zero):
    faces_features = []
    
    faces_pos = new_faces_pos
    faces_features.append(calc_tri_area(faces_pos[:, 0], faces_pos[:, 1], faces_pos[:, 2], _safe_zero) + calc_tri_area(faces_pos[:, 0], faces_pos[:, 2], faces_pos[:, 3], _safe_zero))
    faces_edges = torch.concat([torch.sqrt(torch.sum((faces_pos[:,i,:] - faces_pos[:,j,:])**2, dim = -1)).reshape(-1, 1) 
                                for i, j in [(0, 1), (1, 2), (2, 3), (3, 0)]], dim = 1)
    faces_features.append(torch.sum(faces_edges, dim = 1).reshape(-1, 1))
    
    faces_pos = old_faces_pos
    faces_features.append(calc_tri_area(faces_pos[:, 0], faces_pos[:, 1], faces_pos[:, 2], _safe_zero) + calc_tri_area(faces_pos[:, 0], faces_pos[:, 2], faces_pos[:, 3], _safe_zero))
    faces_edges = torch.concat([torch.sqrt(torch.sum((faces_pos[:,i,:] - faces_pos[:,j,:])**2, dim = -1)).reshape(-1, 1) 
                                for i, j in [(0, 1), (1, 2), (2, 3), (3, 0)]], dim = 1)
    faces_features.append(torch.sum(faces_edges, dim = 1).reshape(-1, 1))
    
    faces_features = torch.concat(faces_features, dim = 1)
    return faces_features

def calc_C3D4_area(a, b, c, d):
    elements_pos = torch.concat([a.reshape(-1, 1, 3), b.reshape(-1, 1, 3), c.reshape(-1, 1, 3), d.reshape(-1, 1, 3)], dim = 1)
    return 1.0 / 6 * torch.abs(torch.sum(torch.mul(elements_pos[:, 0] - elements_pos[:, 3], 
                                    torch.cross(elements_pos[:, 1] - elements_pos[:, 3], elements_pos[:, 2] - elements_pos[:, 3], dim = 1)), dim = 1))
                        
def calc_C3D8_element_features(elements, new_elements_pos, old_elements_pos, _safe_zero, device):
    elements_features = []
    
    elements_pos = new_elements_pos
    elements_features.append(torch.zeros((elements.shape[0], 1)).to(device))
    
    for i, j, k, l in [ (0, 1, 3, 4),
                        (1, 2, 3, 6),
                        (1, 4, 5, 6), 
                        (3, 4, 6, 7), 
                        (1, 3, 4, 6)
                        ]:
        elements_features[-1][:, 0] += calc_C3D4_area(elements_pos[:, i], elements_pos[:, j], elements_pos[:, k], elements_pos[:, l])
    
    
                          #Element Volume
#    elements_features[-1] = torch.pow(elements_features[-1], 1.0 / 3.0)
    elements_features[-1] = elements_features[-1].reshape(-1, 1)
    elements_features.append(torch.zeros((elements.shape[0], 1)).to(device))
    
    for i, j, k in [(0, 1, 2), (0, 2, 3), 
                    (4, 6, 5), (4, 6, 7),
                    (0, 1, 5), (0, 5, 4),
                    (1, 2, 6), (1, 6, 5),
                    (2, 3, 7), (2, 7, 6),
                    (0, 3, 7), (0, 7, 4)]:
        g = calc_tri_area(elements_pos[:, i], elements_pos[:, j], elements_pos[:, k], _safe_zero).reshape(-1)  #Surface Area   
        elements_features[-1][:, 0] += g    
    
    
#    elements_features[-1] = torch.pow(elements_features[-1], 1.0 / 2.0)
    
    elements_pos = old_elements_pos
    elements_features.append(torch.zeros((elements.shape[0], 1)).to(device))
    
    for i, j, k, l in [ (0, 1, 3, 4),
                        (1, 2, 3, 6),
                        (1, 4, 5, 6), 
                        (3, 4, 6, 7), 
                        (1, 3, 4, 6)
                        ]:
        elements_features[-1][:, 0] += calc_C3D4_area(elements_pos[:, i], elements_pos[:, j], elements_pos[:, k], elements_pos[:, l])
                          #Element Volume
    
    elements_features[-1] = elements_features[-1].reshape(-1, 1)
    elements_features.append(torch.zeros((elements.shape[0], 1)).to(device))
    
    for i, j, k in [(0, 1, 2), (0, 2, 3), 
                    (4, 6, 5), (4, 6, 7),
                    (0, 1, 5), (0, 5, 4),
                    (1, 2, 6), (1, 6, 5),
                    (2, 3, 7), (2, 7, 6),
                    (0, 3, 7), (0, 7, 4)]:
        g = calc_tri_area(elements_pos[:, i], elements_pos[:, j], elements_pos[:, k], _safe_zero).reshape(-1)  #Surface Area   
        elements_features[-1][:, 0] += g        
   
    
    elements_features = torch.concat(elements_features, dim = 1)
    return elements_features