import torch
from math import factorial, ceil
import time
import numpy as np
import opt_einsum as oe

max_m = 15
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

v_matrix = [[0, 0, 0], [0, 0, 1], [0, 1 ,0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1 ,0], [1, 1 ,1]]
f_matrix = [[0,2,4], [4,2,6],[0,1,3], [0,3,2], [0,5,1], [0,4,5], [1,7,3], [1,5,7], [4,7,5], [4,6,7], [7,6,2],[7,2,3]]
v_matrix, f_matrix = torch.tensor(v_matrix,dtype=torch.float64), torch.tensor(f_matrix,dtype=torch.long)
v_matrix = v_matrix.to(device)
f_matrix = f_matrix.to(device)




# Global constants, needed for some calculations
f_ijk_s_p = np.zeros((max_m,max_m,max_m), dtype=np.float64)
f_ijk_p_s_2 = np.zeros((max_m,max_m,max_m), dtype=np.float64)


for i in range(max_m):
    for j in range(max_m):
        for k in range(max_m):
            f_ijk_s_p[i,j,k] = (factorial(i+j+k)/(factorial(i)*factorial(j)*factorial(k)))
            f_ijk_p_s_2[i,j,k] = 1/f_ijk_s_p[i,j,k]/(i+j+k+1)/(i+j+k+2)
f_ijk_s_p = torch.tensor(f_ijk_s_p).double().to(device)
f_ijk_p_s_2 = torch.tensor(f_ijk_p_s_2).double().to(device)

face_limit_loop = 1000
C_construct_oe = oe.contract_expression('eijk,eijk,eijk,ijk->eijk',(face_limit_loop,max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),optimize='optimal')
conv_oe = oe.contract_expression('eijk, ail, bjm, ckn, elmn -> eabc',(face_limit_loop,max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(max_m,max_m,max_m),(face_limit_loop,max_m,max_m,max_m),optimize='optimal')





def M_ijk_Koehl(f_matrix, v_matrix, m, device):
    total_num_faces = len(f_matrix)
    
    loop_batch_size = min(face_limit_loop, total_num_faces)

    m_s = [torch.zeros([1,m,m],dtype=torch.float64, requires_grad=True,device=device) for _i in range(m)]
    

    ABC_all = v_matrix[f_matrix].to(device)
    
    jl = torch.arange(m,device=device)
    kl = torch.arange(m,device=device)
    il_i = torch.arange(m,device=device)
    
    for batch_num in range(ceil(total_num_faces/loop_batch_size)):
        start = batch_num*loop_batch_size
        end = min([(batch_num+1)*loop_batch_size, total_num_faces])
        num_faces = end-start
        
        # Extract coordinates
        # ABC = v_matrix[f_matrix[start:end,:]].to(device)
        ABC = ABC_all[start:end,:].to(device)

        cross = torch.norm(torch.cross(ABC[:,0]-ABC[:,2],ABC[:,1]-ABC[:,2]),p=2,dim = 1)


        # Allocate Tensors
        M_S_tensor = torch.zeros([num_faces,m,m,m],dtype=torch.float64,device=device)
        D_tensor = torch.zeros([num_faces,m,m,m],dtype=torch.float64,device=device)
        S_tensor = torch.zeros([num_faces,m,m,m],dtype=torch.float64,device=device)
        D_tensor[:,0,0,0] = 1
        S_tensor[:,0,0,0] = 1
        

        # Calculate C,D,S,M Tensors, parallellized over faces
        for i in range(m):
            # il = torch.tensor(i).to(device)
            il = il_i[i]

            # C_tensor = torch.multiply(torch.multiply(torch.pow(ABC[:,2][:,0].view(-1,1,1,1),il.view(1,-1,1,1)),torch.pow(ABC[:,2][:,1].view(-1,1,1,1),jl.view(1,1,-1,1))),torch.pow(ABC[:,2][:,2].view(-1,1,1,1),kl.view(1,1,1,-1)))*f_ijk_s_p[:m,:m,:m][il.view(-1,1,1),jl.view(1,-1,1),kl.view(1,1,-1)]
            # C1_tensor = torch.multiply(torch.multiply(torch.pow(ABC[:,0][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,0][:,1].view(-1,1,1,1),i.view(1,1,-1,1))),torch.pow(ABC[:,0][:,2].view(-1,1,1,1),i.view(1,1,1,-1)))*f_ijk_s_p[:m,:m,:m]
            # C2_tensor = torch.multiply(torch.multiply(torch.pow(ABC[:,1][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,1][:,1].view(-1,1,1,1),i.view(1,1,-1,1))),torch.pow(ABC[:,1][:,2].view(-1,1,1,1),i.view(1,1,1,-1)))*f_ijk_s_p[:m,:m,:m]
            # C3_tensor = torch.multiply(torch.multiply(torch.pow(ABC[:,2][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,2][:,1].view(-1,1,1,1),i.view(1,1,-1,1))),torch.pow(ABC[:,2][:,2].view(-1,1,1,1),i.view(1,1,1,-1)))*f_ijk_s_p[:m,:m,:m]
            C_tensor = C_construct_oe(torch.pow(ABC[:,2][:,0].view(-1,1,1,1),il.view(1,-1,1,1)),torch.pow(ABC[:,2][:,1].view(-1,1,1,1),jl.view(1,1,-1,1)),torch.pow(ABC[:,2][:,2].view(-1,1,1,1),kl.view(1,1,1,-1)), f_ijk_s_p[il.view(-1,1,1),jl.view(1,-1,1),kl.view(1,1,-1)]).double()
        
            for j in range(m):
                for k in range(m):

                    D_tensor[:,i,j,k] = ABC[:,1][:,0]*D_tensor[:,i-1,j,k]+ABC[:,1][:,1]*D_tensor[:,i,j-1,k]+ABC[:,1][:,2]*D_tensor[:,i,j,k-1]+C_tensor[:,0,j,k]
                    S_tensor[:,i,j,k] = ABC[:,0][:,0]*S_tensor[:,i-1,j,k]+ABC[:,0][:,1]*S_tensor[:,i,j-1,k]+ABC[:,0][:,2]*S_tensor[:,i,j,k-1]+D_tensor[:,i,j,k]



                    M_S_tensor[:,i,j,k] = f_ijk_p_s_2[i,j,k]*(cross[:]*S_tensor[:,i,j,k])


            m_s[i] = m_s[i] + torch.sum(M_S_tensor[:,i,:,:],0)
            

    m_s = torch.stack(m_s).squeeze()

    
    return m_s



def M_ijk_Pozo(f_matrix, v_matrix, m, device):
    total_num_faces = len(f_matrix)
    
    loop_batch_size = min(face_limit_loop, total_num_faces)

    m_s = torch.zeros([m,m,m],dtype=torch.float64, requires_grad=True, device=device)
    
    ABC_all = v_matrix[f_matrix].to(device)
        
    i = torch.arange(m,device=device)
    
    # set up "contraction kernel"
    inds = torch.arange(m,device=device)
    kernel = ((inds.unsqueeze (0) + inds.unsqueeze (1)).unsqueeze (0) == inds.unsqueeze (-1).unsqueeze (-1)).double()
    
    for batch_num in range(ceil(total_num_faces/loop_batch_size)):
        start = batch_num*loop_batch_size
        end = min([(batch_num+1)*loop_batch_size, total_num_faces])
        num_faces = end-start
        
        # Extract coordinates
        # ABC = v_matrix[f_matrix[start:end,:]].to(device)
        ABC = ABC_all[start:end,:]

        cross = torch.norm(torch.cross(ABC[:,0]-ABC[:,2],ABC[:,1]-ABC[:,2]),p=2,dim = 1)
        
        
        # Calculate C tensors
        C1_tensor = C_construct_oe(torch.pow(ABC[:,0][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,0][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,0][:,2].view(-1,1,1,1),i.view(1,1,1,-1)), f_ijk_s_p[:m,:m,:m]).double()
        C2_tensor = C_construct_oe(torch.pow(ABC[:,1][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,1][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,1][:,2].view(-1,1,1,1),i.view(1,1,1,-1)), f_ijk_s_p[:m,:m,:m]).double()
        C3_tensor = C_construct_oe(torch.pow(ABC[:,2][:,0].view(-1,1,1,1),i.view(1,-1,1,1)),torch.pow(ABC[:,2][:,1].view(-1,1,1,1),i.view(1,1,-1,1)),torch.pow(ABC[:,2][:,2].view(-1,1,1,1),i.view(1,1,1,-1)), f_ijk_s_p[:m,:m,:m]).double()
            
        


        # D_tensor = torch.einsum ('eijk, ail, bjm, ckn, elmn -> eabc', C2_tensor, kernel, kernel, kernel, C3_tensor)
        
        # S_tensor = torch.einsum ('eijk, ail, bjm, ckn, elmn -> eabc', C1_tensor, kernel, kernel, kernel, D_tensor) 

        D_tensor = conv_oe( C2_tensor, kernel, kernel, kernel, C3_tensor).double()
    
        S_tensor = conv_oe( C1_tensor, kernel, kernel, kernel, D_tensor).double()
        
        m_s = m_s + (S_tensor * f_ijk_p_s_2[:m,:m,:m] * cross.view(-1,1,1,1)).sum(0)
        
    return m_s




# Now test the speed of the two implementations for various max_m values and F_values
if __name__ == "__main__":

    print('Surface moments match:',torch.allclose(M_ijk_Koehl(f_matrix,v_matrix,max_m,device).double(),M_ijk_Pozo(f_matrix,v_matrix,max_m,device)))
    

    # Starting tests for time
    m_values = [5,8,10,12,15]
    
    F_values = [10,100,1000,10000,100000]
    
    Pozo_times = np.zeros((len(m_values),len(F_values)))
    Koehl_times = np.zeros((len(m_values),len(F_values)))
    # # Clean up the GPU memory
    # torch.cuda.empty_cache()
    # print("Sleeping for 1 second before testing Pozo implementation...")
    # time.sleep(1)  # Ensure the GPU is warmed up before timing
    # # warm up the GPU by running a dummy operation
    # _ = M_ijk_Pozo(f_matrix.repeat(10,1), v_matrix, max_m, device)
    # First test the Pozo implementation
    for i,max_m in enumerate(m_values):
        for j,F in enumerate(F_values):
            mesh = f_matrix.repeat(F,1)
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
            start = time.time()
            moments = M_ijk_Pozo(mesh, v_matrix, max_m, device)
            end = time.time()
            Pozo_times[i,j] = end - start
            print("Pozo_times:")
            print(np.array2string(Pozo_times, separator=', '))
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
            start = time.time()
            moments = M_ijk_Koehl(mesh, v_matrix, max_m, device)
            end = time.time()
            Koehl_times[i,j] = end - start
            print("Koehl_times:")
            print(np.array2string(Koehl_times, separator=', '))
    # Print out the results
    print("Pozo_times = ",np.array2string(Pozo_times, separator=', '))
    print("Koehl_times = ", np.array2string(Koehl_times, separator=', '))




