import torch
import os
import numpy as np
from torch import nn
import opt_einsum as oe
from math import factorial, ceil
from torch.autograd import gradcheck

import time

# from pynvml import *
# nvmlInit()



if __name__ == "__main__":
    v_matrix = [[0, 0, 0], [0, 0, 1], [0, 1 ,0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1 ,0], [1, 1 ,1]]
    f_matrix = [[0,2,4], [4,2,6],[0,1,3], [0,3,2], [0,5,1], [0,4,5], [1,7,3], [1,5,7], [4,7,5], [4,6,7], [7,6,2],[7,2,3]] #*10000
    v_matrix, f_matrix = torch.tensor(v_matrix,dtype=torch.float64, requires_grad=True), torch.tensor(f_matrix,dtype=torch.long)
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    # device = 'cuda:0'
    v_matrix = v_matrix.to(device)
    f_matrix = f_matrix.to(device)



# This decides the number of faces handled in each loop, decrease if running out of memory
face_limit_loop = 1000

max_m = 15
def compute_f_ijk(max_m, device):
    # Global constants, needed for some calculations
    f_ijk_s_p = np.zeros((max_m,max_m,max_m), dtype=np.float64)
    f_ijk_p_s_2 = np.zeros((max_m,max_m,max_m), dtype=np.float64)
    # f_ijk_p_s_3 = np.zeros((max_m,max_m,max_m), dtype=np.float64)
    
    
    for i in range(max_m):
        for j in range(max_m):
            for k in range(max_m):
                f_ijk_s_p[i,j,k] = (factorial(i+j+k)/(factorial(i)*factorial(j)*factorial(k)))
                # f_ijk_p_s_3[i,j,k] = ((factorial(i)*factorial(j)*factorial(k))/factorial(i+j+k+3))
                f_ijk_p_s_2[i,j,k] = 1/f_ijk_s_p[i,j,k]/(i+j+k+1)/(i+j+k+2)
                # f_ijk_p_s_3[i,j,k] = f_ijk_p_s_2[i,j,k]/(i+j+k+3)
    f_ijk_s_p = torch.tensor(f_ijk_s_p).double().to(device)
    f_ijk_p_s_2 = torch.tensor(f_ijk_p_s_2).double().to(device)
    # f_ijk_p_s_3 = torch.tensor(f_ijk_p_s_3).double().to(device)
    # return f_ijk_s_p, f_ijk_p_s_2, f_ijk_p_s_3
    return f_ijk_s_p, f_ijk_p_s_2
    



# indexer_i = torch.arange(max_m)
# indexer_i = indexer_i.to(device)

C_construct_oe = oe.contract_expression('eijk,eijk,eijk,ijk->eijk',(face_limit_loop,max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),optimize='optimal')
conv_oe = oe.contract_expression('eijk, ail, bjm, ckn, elmn -> eabc',(face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),optimize='optimal')

norm_P_derivative_oe = oe.contract_expression('PFap,FP,F->Fap',(3,face_limit_loop,3,3),(face_limit_loop,3),[face_limit_loop],optimize='optimal')


## Derivatives
C_grad_conv_i_oe = oe.contract_expression('i,eijk,eijk,eijk,ijk->eijk',[max_m],(face_limit_loop,max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),optimize='optimal')
C_grad_conv_j_oe = oe.contract_expression('j,eijk,eijk,eijk,ijk->eijk',[max_m],(face_limit_loop,max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),optimize='optimal')
C_grad_conv_k_oe = oe.contract_expression('k,eijk,eijk,eijk,ijk->eijk',[max_m],(face_limit_loop,max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),optimize='optimal')
S_deriv_conv_1_oe = oe.contract_expression('peijk, ail, bjm, ckn, elmn -> peabc',(3,face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),optimize='optimal')
S_deriv_conv_2_oe = oe.contract_expression('eijk, ail, bjm, ckn, elmn -> eabc',(face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),optimize='optimal')
S_deriv_conv_3_oe = oe.contract_expression('eijk, ail, bjm, ckn, pelmn -> peabc',(face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(3,face_limit_loop,max_m,max_m,max_m),optimize='optimal')
S_deriv_conv_4_oe = oe.contract_expression('peijk, ail, bjm, ckn, elmn -> peabc',(3,face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),optimize='optimal')
S_deriv_conv_5_oe = oe.contract_expression('eijk, ail, bjm, ckn, pelmn -> peabc',(face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(3,face_limit_loop,max_m,max_m,max_m),optimize='optimal')
S_deriv_conv_6_oe = oe.contract_expression('eijk, ail, bjm, ckn, pelmn -> peabc',(face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(3,face_limit_loop,max_m,max_m,max_m),optimize='optimal')

S_deriv_conv_7_oe = oe.contract_expression('ijk,apeijk->apeijk',(max_m,max_m,max_m),(3,3,face_limit_loop,max_m,max_m,max_m),optimize='optimal')
M_ijk_F_derivative_1_oe = oe.contract_expression('F,apFijk->Fapijk',[face_limit_loop],(3,3,face_limit_loop,max_m,max_m,max_m),optimize='optimal')
M_ijk_F_derivative_2_oe = oe.contract_expression('Fap,Fijk->Fapijk',(face_limit_loop,3,3),(face_limit_loop,max_m,max_m,max_m),optimize='optimal')
# loss_gradients_phi_oe = oe.contract_expression('apFijk,ijk->Fap',(3,3,face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),optimize='optimal')
# collect_grads_oe = oe.contract_expression('vf,fpijk->vpijk',(len(v_matrix),face_limit_loop),(face_limit_loop,3,max_m,max_m,max_m),optimize='optimal')
grad_prod_oe = oe.contract_expression('abijk,ijk->ab',(3*face_limit_loop,3,max_m,max_m,max_m),(max_m,max_m,max_m),optimize='optimal')




class GeometricMoment(torch.autograd.Function):
    @staticmethod
    def forward(ctx, f_matrix, v_matrix, max_m, device):
        # f_ijk_s_p, f_ijk_p_s_2, f_ijk_p_s_3 = compute_f_ijk(max_m, device)
        f_ijk_s_p, f_ijk_p_s_2 = compute_f_ijk(max_m, device)
        
        
        ctx.max_m = max_m
        ctx.device = device
        
        M_ijk_surf = torch.zeros([max_m,max_m,max_m],dtype=torch.float64, requires_grad=True, device=device)
    
        total_num_faces = len(f_matrix)
        loop_batch_size = min(face_limit_loop, total_num_faces)

        ABC_all = v_matrix[f_matrix].to(device)
    
        i = torch.arange(max_m,device=device)
        
        # set up "contraction kernel"
        inds = torch.arange (max_m,device=device)
        kernel = ((inds.unsqueeze (0) + inds.unsqueeze (1)).unsqueeze (0) == inds.unsqueeze (-1).unsqueeze (-1)).double()
    
        # First Calculate the Geometric Moments. This is the FORWARD calculation!
        for batch_num in range(ceil(total_num_faces/loop_batch_size)):
            start = batch_num*loop_batch_size
            end = min([(batch_num+1)*loop_batch_size, total_num_faces])
            # num_faces = end-start
    
            # Extract coordinates
            # ABC = v_matrix[f_matrix[start:end,:]].to(device).double()
            ABC = ABC_all[start:end,:].to(device).double()
            # Create cross product norm
            P = torch.norm(torch.cross(ABC[:,0]-ABC[:,2],ABC[:,1]-ABC[:,2]),p=2,dim = 1).double()
    
            # Calculate C tensors
            C1_tensor = C_construct_oe(torch.pow(ABC[:,0][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,0][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,0][:,2].view(-1,1,1,1),i.view(1,1,1,-1)), f_ijk_s_p).double()
            
            C2_tensor = C_construct_oe(torch.pow(ABC[:,1][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,1][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,1][:,2].view(-1,1,1,1),i.view(1,1,1,-1)), f_ijk_s_p).double()
            
            C3_tensor = C_construct_oe(torch.pow(ABC[:,2][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,2][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,2][:,2].view(-1,1,1,1),i.view(1,1,1,-1)), f_ijk_s_p).double()
            
    
            
            assert torch.isfinite(C2_tensor).all(), 'The C2_tensor contains nans'
            
            D_tensor = conv_oe( C2_tensor, kernel, kernel, kernel, C3_tensor).double()
    
            assert torch.isfinite(D_tensor).all(), 'The D_tensor contains nans'
            S_tensor = conv_oe( C1_tensor, kernel, kernel, kernel, D_tensor).double()
    
            # Compute the contributions to the Surface moments
            M_ijk_surf = M_ijk_surf + (S_tensor * f_ijk_p_s_2 * P.view(-1,1,1,1)).sum(0)
            # del S_tensor, D_tensor, kernel, C2_tensor, C3_tensor, ABC, P

        # ctx.save_for_backward(v_matrix, f_matrix, f_ijk_s_p, f_ijk_p_s_2, f_ijk_p_s_3)
        ctx.save_for_backward(v_matrix, f_matrix, f_ijk_s_p, f_ijk_p_s_2)
        
        return M_ijk_surf

    
    @staticmethod
    def backward(ctx, grad_output):
        #print("Backward of moment_loss", grad_output.shape) # [max_m, max_m, max_m]
        # v_matrix, f_matrix, f_ijk_s_p, f_ijk_p_s_2, f_ijk_p_s_3 = ctx.saved_tensors
        v_matrix, f_matrix, f_ijk_s_p, f_ijk_p_s_2= ctx.saved_tensors
        
        device = ctx.device
        max_m = ctx.max_m

        collect_grads_oe = oe.contract_expression('vf,fp->vp',(len(v_matrix),face_limit_loop),(face_limit_loop,3),optimize='optimal')
        # collect_grads_oe = oe.contract_expression('vf,fpijk->vpijk',(len(v_matrix),face_limit_loop),(face_limit_loop,3,max_m,max_m,max_m),optimize='optimal')
        
        total_num_faces = len(f_matrix)
        loop_batch_size = min(face_limit_loop, total_num_faces)
        grads = torch.zeros(*v_matrix.shape,device=device)

        ABC_all = v_matrix[f_matrix].to(device)

        # For computing the derivatives of the cross product, we need the unit vectors
        e_1 = torch.Tensor([1,0,0]).double().to(device)
        e_2 = torch.Tensor([0,1,0]).double().to(device)
        e_3 = torch.Tensor([0,0,1]).double().to(device)

        i = torch.arange(max_m,device=device)

        # set up "contraction kernel"
        inds = torch.arange(max_m,device=device)
        kernel = ((inds.unsqueeze (0) + inds.unsqueeze (1)).unsqueeze (0) == inds.unsqueeze (-1).unsqueeze (-1)).double()

        # For computing the derivatives of the C tensors, we need the indices
        j = torch.ones(max_m,device=device)
        j[0] = 0
        j = i-j

        # Need to set up the collection kernel to collect the gradients from faces to vertices
        k  = torch.arange(v_matrix.shape[0],device=device)

        # Now Calculate the derivatives w.r.t. the vertices. This is the BACKWARD calculation!
        for batch_num in range(ceil(total_num_faces/loop_batch_size)):
            start = batch_num*loop_batch_size
            end = min([(batch_num+1)*loop_batch_size, total_num_faces])
            num_faces = end-start

            # Extract coordinates
            # ABC = v_matrix[f_matrix[start:end,:]].to(device)
            ABC = ABC_all[start:end,:]

            # Compute the triangle vectors
            p13 = ABC[:,0]-ABC[:,2]
            p23 = ABC[:,1]-ABC[:,2]

            # Compute the cross product
            P = torch.cross(p13,p23)
            
            # Compute norm of the cross product (2 x Area of triangle)
            norm_P = torch.norm(P,p=2,dim=1)

            # Compute derivatives of cross product P
            del_P_x_1 = torch.cross( e_1.broadcast_to(p23.shape), p23) # torch.cross( e_1.view(1,-1), p23)
            del_P_y_1 = torch.cross( e_2.broadcast_to(p23.shape), p23) # torch.cross( e_2.view(1,-1), p23)
            del_P_z_1 = torch.cross( e_3.broadcast_to(p23.shape), p23) # torch.cross( e_3.view(1,-1), p23)
            
            del_P_x_2 = torch.cross(-e_1.broadcast_to(p13.shape), p13) # torch.cross(-e_1.view(1,-1), p13)
            del_P_y_2 = torch.cross(-e_2.broadcast_to(p13.shape), p13) # torch.cross(-e_2.view(1,-1), p13)
            del_P_z_2 = torch.cross(-e_3.broadcast_to(p13.shape), p13) # torch.cross(-e_3.view(1,-1), p13)
            
            del_P_x_3 = -del_P_x_1 - del_P_x_2
            del_P_y_3 = -del_P_y_1 - del_P_y_2
            del_P_z_3 = -del_P_z_1 - del_P_z_2
            
            cross_grads = torch.stack([del_P_x_1, del_P_y_1, del_P_z_1, del_P_x_2, del_P_y_2, del_P_z_2, del_P_x_3, del_P_y_3, del_P_z_3])
            cross_grads = cross_grads.reshape([3,3,num_faces,3])
            cross_grads = cross_grads.permute(3,2,0,1)
            # Indexing is 
            # [P,F,a,phi]

            # Compute derivatives of norm of cross product
            norm_P_derivative = norm_P_derivative_oe(cross_grads,P,torch.reciprocal(norm_P))
            # Indexing: [F,a p]
            
            # Compute C tensors, and eventually S_tensor
            ########################## CONSIDER RENAMING AND MOVING THESE OUTSIDE FOR PERFORMANCE ###############################
            C1_tensor = C_construct_oe(torch.pow(ABC[:,0][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,0][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,0][:,2].view(-1,1,1,1),i.view(1,1,1,-1)),f_ijk_s_p)
            C2_tensor = C_construct_oe(torch.pow(ABC[:,1][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,1][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,1][:,2].view(-1,1,1,1),i.view(1,1,1,-1)),f_ijk_s_p)
            C3_tensor = C_construct_oe(torch.pow(ABC[:,2][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,2][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,2][:,2].view(-1,1,1,1),i.view(1,1,1,-1)),f_ijk_s_p)
            


            D_tensor = conv_oe( C2_tensor, kernel, kernel, kernel, C3_tensor)
            S_tensor = conv_oe( C1_tensor, kernel, kernel, kernel, D_tensor) * f_ijk_p_s_2
            
            # Compute derivatives of C_tensors, and eventually derivatives of S_tensors
            # Calculate gradients of C tensors
            C1_grads = torch.stack([
                C_grad_conv_i_oe(i,torch.pow(ABC[:,0][:,0].view(-1,1,1,1),j.view(1,-1,1,1)),torch.pow(ABC[:,0][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,0][:,2].view(-1,1,1,1),i.view(1,1,1,-1)),f_ijk_s_p),
                C_grad_conv_j_oe(i,torch.pow(ABC[:,0][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,0][:,1].view(-1,1,1,1),j.view(1,1,-1,1)),torch.pow(ABC[:,0][:,2].view(-1,1,1,1),i.view(1,1,1,-1)),f_ijk_s_p),
                C_grad_conv_k_oe(i,torch.pow(ABC[:,0][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,0][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,0][:,2].view(-1,1,1,1),j.view(1,1,1,-1)),f_ijk_s_p)
            ],dim=0)
            C2_grads = torch.stack([
                C_grad_conv_i_oe(i,torch.pow(ABC[:,1][:,0].view(-1,1,1,1),j.view(1,-1,1,1)),torch.pow(ABC[:,1][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,1][:,2].view(-1,1,1,1),i.view(1,1,1,-1)),f_ijk_s_p),
                C_grad_conv_j_oe(i,torch.pow(ABC[:,1][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,1][:,1].view(-1,1,1,1),j.view(1,1,-1,1)),torch.pow(ABC[:,1][:,2].view(-1,1,1,1),i.view(1,1,1,-1)),f_ijk_s_p),
                C_grad_conv_k_oe(i,torch.pow(ABC[:,1][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,1][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,1][:,2].view(-1,1,1,1),j.view(1,1,1,-1)),f_ijk_s_p)
            ],dim=0)
            C3_grads = torch.stack([
                C_grad_conv_i_oe(i,torch.pow(ABC[:,2][:,0].view(-1,1,1,1),j.view(1,-1,1,1)),torch.pow(ABC[:,2][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,2][:,2].view(-1,1,1,1),i.view(1,1,1,-1)),f_ijk_s_p),
                C_grad_conv_j_oe(i,torch.pow(ABC[:,2][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,2][:,1].view(-1,1,1,1),j.view(1,1,-1,1)),torch.pow(ABC[:,2][:,2].view(-1,1,1,1),i.view(1,1,1,-1)),f_ijk_s_p),
                C_grad_conv_k_oe(i,torch.pow(ABC[:,2][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,2][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,2][:,2].view(-1,1,1,1),j.view(1,1,1,-1)),f_ijk_s_p)
            ],dim=0)
            
            # # set up "contraction kernel"
            # inds = torch.arange (max_m).to(device)
            # kernel = ((inds.unsqueeze (0) + inds.unsqueeze (1)).unsqueeze (0) == inds.unsqueeze (-1).unsqueeze (-1)).double().to(device)
            
            # Compute the 9 derivative tensors for S_ijk
        
            S_derivatives = torch.stack([
                S_deriv_conv_1_oe( C1_grads , kernel, kernel, kernel, S_deriv_conv_2_oe( C2_tensor, kernel, kernel, kernel, C3_tensor)) ,
                S_deriv_conv_3_oe( C1_tensor, kernel, kernel, kernel, S_deriv_conv_4_oe( C2_grads , kernel, kernel, kernel, C3_tensor)) ,
                S_deriv_conv_5_oe( C1_tensor, kernel, kernel, kernel, S_deriv_conv_6_oe( C2_tensor, kernel, kernel, kernel, C3_grads ))
            ],dim=0)
            S_derivatives = S_deriv_conv_7_oe(f_ijk_p_s_2, S_derivatives)
            # Indexing: [a,phi,F,i,j,k]
            
            # Compute the derivatives of the M_ijk w.r.t. vertices of each face
            M_ijk_F_derivative = M_ijk_F_derivative_1_oe(norm_P,S_derivatives) + M_ijk_F_derivative_2_oe(norm_P_derivative,S_tensor)            
            
            # Create 'collection kernel' to send the derivatives to the correct vertices from faces
            collect_kernel = (k.unsqueeze(1).repeat(1,f_matrix[start:end,:].nelement()) == f_matrix[start:end,:].flatten(0,1)).double().to(device)
                        
            # Collect the gradients into the gradient vector
            grad_prod = grad_prod_oe(M_ijk_F_derivative.flatten(0,1), grad_output)
            grads = grads + collect_grads_oe(collect_kernel,  grad_prod )

            

        assert torch.isfinite(grads).all(), f'grads contains NaNs {torch.isnan(grads).any()} or Inf {torch.isinf().any()}'
        return None, grads.to(v_matrix.device), None, None
        #return None, torch.randn(v_matrix.shape[0], 3).to(v_matrix.device), None, None, None

    

class MomentLoss(nn.Module):
    def __init__(self,):
        super().__init__()


    def forward(self, y_true, y_pred):
        loss = 1/2 * torch.einsum('ijkl,ijkl->', y_pred - y_true, y_pred - y_true)
        return loss


def gpu_report(device_index):
    h = nvmlDeviceGetHandleByIndex(device_index)
    info = nvmlDeviceGetMemoryInfo(h)
    #free_memory, total_memory = torch.cuda.mem_get_info(device_index)
    total_memory = torch.cuda.get_device_properties(device_index).total_memory
    r = torch.cuda.memory_reserved(device_index)
    a = torch.cuda.memory_allocated(device_index)

    print(f'total    : {info.total} {total_memory}')
    print(f'free     : {info.free} {r - a}')
    print(f'used     : {info.used} {a} {r}')


if __name__ == "__main__":
    moment_func = GeometricMoment.apply
    # test = gradcheck(moment_func, (f_matrix, v_matrix, max_m, device), eps=1e-6, atol=1e-4)
    # print(f'Gradients match: {test}')

    ## Clean up GPU memory
    #torch.cuda.empty_cache()
    ## Wait
    #print("Waiting for GPU memory to clear...")
    #time.sleep(1)
    # warm up
    torch.norm(moment_func(f_matrix, v_matrix, max_m, device)).backward()

    # Starting tests for time
    m_values = [5,8,10,12,15]
    F_values = [10,100,1000,10000]
    f_times = np.zeros((len(m_values),len(F_values)))
    b_times = np.zeros((len(m_values),len(F_values)))
    for i,max_m in enumerate(m_values):
        for j,F in enumerate(F_values):
            mesh = f_matrix.repeat(F,1)
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
            forward_start = time.time()
            moments = moment_func(mesh, v_matrix, max_m, device)
            loss = torch.norm(moments)
            forward_end = time.time()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
            backward_start = time.time()
            loss.backward()
            backward_end = time.time()
            f_times[i,j] = forward_end - forward_start
            b_times[i,j] = backward_end - backward_start


            print('f_times = ',np.array2string(f_times, separator=', '))
            print('b_times = ',np.array2string(b_times, separator=', '))
