import jax
import jax.numpy as jnp


from jax import nn

from jax import vmap
from jax import grad


#jax.config.update("jax_enable_x64", True)

def hat_function(x,x_nodes,h=2):
   return jnp.maximum(0,1-jnp.abs((x-x_nodes)/h))


def hat_deriv(x,x_nodes,h=2):
   return grad(hat_function)(x,x_nodes,h)



def hat_vec(x,x_nodes,h=2):
   return vmap(vmap(hat_function,(None,0,None)),(0,None,None))(x,x_nodes,h)


def hat_deriv_vec(x,x_nodes,h=2):
   return vmap(vmap(hat_deriv,(None,0,None)),(0,None,None))(x,x_nodes,h)


def whitney_0form_vec_1D(W,x,x_nodes,h=2):
   return jnp.matmul(W,hat_vec(x,x_nodes,h).T)#nn.softmax(W,axis=0)@hat_vec(x,x_0,h)


def whitney_0form_vec_1D_grad(W,x,x_nodes,h=2):
   return jnp.matmul(W,hat_deriv_vec(x,x_nodes,h).T)


def whitney_1form_vec_1D(W,x,x_0,h=2):
   
   ex_prod=whitney_0form_vec_1D(W,x,x_0,h)[:,None]*whitney_0form_vec_1D_grad(W,x,x_0,h)[None]
   whitney_f=ex_prod-jnp.swapaxes(ex_prod,0,1)#jnp.transpose(ex_prod,(1,0)+tuple(range(2,ex_prod.ndim)))
   return whitney_f[jnp.triu_indices(W.shape[0],1)]




def construct_oriented_1D(W):#invariant of length/h 
   
   W=jnp.pad(W,((0,0),(1,1)))
   oriented_areas=(W[:,None,:-1][:]*W[None,:,1:])
   oriented_areas-= jnp.transpose(oriented_areas,(1,0,2))
   oriented_areas=oriented_areas.sum(axis=-1)
   oriented_areas=oriented_areas[jnp.triu_indices(W.shape[0],1)]
   
   return  oriented_areas**2
   



def vec_helper_M0_1D(A,B,h):
   
   def helper_M0(a,b):
      
      i=jnp.arange(0,a.shape[0]-1)

      def helper_i(i):
               
         tensor_mass=helper_tensor_mass(h)

         mass=jnp.einsum("i,j,ij->",
                        jax.lax.dynamic_slice(a,(i,),(2,)),
                        jax.lax.dynamic_slice(b,(i,),(2,)),
                        tensor_mass)
         return mass

      return jnp.sum(vmap(helper_i)(i))

   return vmap(vmap(helper_M0,in_axes=(None,0)),
                              in_axes=(0,None))(A,B)
                        



def construct_M1_1D(W,h):
   partials=vec_helper_M1_1D(W,W,W,W,h)
   M_1_redundant=(partials-jnp.transpose(partials,(1,0,2,3))-
                  jnp.transpose(partials,(0,1,3,2))+jnp.transpose(partials,(1,0,3,2)))
   M_1=M_1_redundant[jnp.triu_indices(W.shape[0],1)]
   M_1=vmap(lambda x: x[jnp.triu_indices(x.shape[0],1)])(M_1)

   return M_1



def vec_helper_M1_1D(A,B,C,D,h):
   
   def helper_M1(a,b,c,d):
      i=jnp.arange(0,a.shape[0]-1)

      def helper_i(i):
               
         
         
         tensor_C_x=helper_tensor_C(h)

         C=jnp.einsum("i,j,k,l,ijkl->",
                        jax.lax.dynamic_slice(a,(i,),(2,)),
                        jax.lax.dynamic_slice(b,(i,),(2,)),
                        jax.lax.dynamic_slice(c,(i,),(2,)),
                        jax.lax.dynamic_slice(d,(i,),(2,)),
                        tensor_C_x)
         return C

      return jnp.sum(vmap(helper_i)(i))

   return vmap(vmap(vmap(vmap(helper_M1,in_axes=(None,None,None,0)),
                        in_axes=(None,None,0,None)),
                        in_axes=(None,0,None,None)),
                        in_axes=(0,None,None,None))(A,B,C,D)



def construct_delta0(Npou=5):
   
   num_1forms=Npou*(Npou-1)//2
   delta0 = jnp.zeros((num_1forms,Npou))
   cnt=0
   for i in range(Npou):
      for j in range(i + 1,Npou):
            delta0=delta0.at[cnt,i].set(-1.0)
            delta0=delta0.at[cnt,j].set(1.0)
            cnt+=1
   return delta0







def boundary_indices_fn(
        Npou,
        actual_Npou):
        boundary_indices=jnp.concat([jnp.zeros(Npou,dtype=jnp.bool_),
                                jnp.ones(actual_Npou-Npou,jnp.bool_)])

        return boundary_indices





def lagrange_functions_2D_vec(points,x_nodes,y_nodes,h1,h2):
   x=points[:,0]
   y=points[:,1]
   return hat_vec(x,x_nodes,h1)[...,None]*hat_vec(y,y_nodes,h2)[:,None]



def grad_lagrange_functions_2D_vec(points,x_nodes,y_nodes,h1,h2):
   
   def lagrange_functions_2D(point,x_node,y_node):
      return hat_function(point[0],x_node,h1)*hat_function(point[1],y_node,h2)
   
   grad_fn=grad(lagrange_functions_2D)

   return vmap(vmap(vmap(grad_fn,
                        in_axes=(None,None,0)),
                        in_axes=(None,0,None)),
                        in_axes=(0,None,None))(points,x_nodes,y_nodes)

  

def whitney_0form_vec_2D(W,points,x_nodes,y_nodes,h1,h2):
   lagrange=lagrange_functions_2D_vec(points,x_nodes,y_nodes,h1,h2)
   return (W[:,None]*lagrange[None]).sum(axis=(-2,-1))



def whitney_0form_vec_2D_grad(W,points,x_nodes,y_nodes,h1,h2):
   lagrange_grad=grad_lagrange_functions_2D_vec(points,x_nodes,y_nodes,h1,h2)
   return (W[:,None,:,:,None]*lagrange_grad[None]).sum(axis=(-3,-2))



def whitney_1form_vec_2D(W,points,x_nodes,y_nodes,h1,h2):
   whitney=whitney_0form_vec_2D(W,points,x_nodes,y_nodes,h1,h2)
   whitney_grad=whitney_0form_vec_2D_grad(W,points,x_nodes,y_nodes,h1,h2)
   exterior=whitney[:,None,:,None]*whitney_grad[None]
   whitney_f=exterior-jnp.swapaxes(exterior,0,1)#jnp.transpose(exterior,(1,0)+tuple(range(2,exterior.ndim)))
   return whitney_f[jnp.triu_indices(W.shape[0],1)]


def whitney_1form_vec_2D_midpoints(W,h1,h2):
    
        W_mid=(W[:,:-1, :-1] + W[:,1:, :-1] + W[:,:-1, 1:] + W[:,1:, 1:])/4.

        W_grad_x=(W[:,1:] - W[:,:-1])/h1
        W_grad_y=(W[:,:,1:] - W[:,:,:-1])/h2

        W_grad_x=(W_grad_x[:,:,1:]+W_grad_x[:,:,:-1])/2.
        W_grad_y=(W_grad_y[:,1:]+W_grad_y[:,:-1])/2.

        W_grad=jnp.stack([W_grad_x,W_grad_y],axis=-1)

        flux_eval=W_mid[:,None,:,:,None]*W_grad[None]
        flux_eval-=jnp.swapaxes(flux_eval,0,1)
        flux_eval=flux_eval[jnp.triu_indices(W.shape[0],1)]

        return flux_eval


        
    


def construct_M0_2D(A,B,h1,h2):
    
   def helper_mass(a,b):
               
      i=jnp.arange(0,a.shape[0]-1)
      j=jnp.arange(0,a.shape[1]-1)

      def helper_ij(i,j):
               
            
            tensor_mass_x=helper_tensor_mass(h1)
            tensor_mass_y=helper_tensor_mass(h2)
            
            mass=jnp.einsum("ia,jb,ij,ab->",
                           jax.lax.dynamic_slice(a,(i,j),(2,2)),
                           jax.lax.dynamic_slice(b,(i,j),(2,2)),
                           tensor_mass_x,
                           tensor_mass_y         
                           )
            
            return mass


      return jnp.sum(vmap(vmap(helper_ij,in_axes=(None,0)),
                                       in_axes=(0,None))(i,j))

   return vmap(vmap(helper_mass,in_axes=(None,0)),
                                in_axes=(0,None))(A,B)


def vec_helper_oriented_2D(A,B,h1,h2):
    
   def helper(a,b):
               
      i=jnp.arange(0,a.shape[0]-1)
      j=jnp.arange(0,a.shape[1]-1)

      def helper_ij(i,j):
               
            
            tensor_mass_x=helper_tensor_mass(h1)
            tensor_mass_y=helper_tensor_mass(h2)

            tensor_oriented_x=helper_tensor_oriented()
            tensor_oriented_y=helper_tensor_oriented()
            
            oriented_x=jnp.einsum("ia,jb,ij,ab->",
                           jax.lax.dynamic_slice(a,(i,j),(2,2)),
                           jax.lax.dynamic_slice(b,(i,j),(2,2)),
                           tensor_oriented_x,
                           tensor_mass_y         
                           )
            
            oriented_y=jnp.einsum("ia,jb,ij,ab->",
                           jax.lax.dynamic_slice(a,(i,j),(2,2)),
                           jax.lax.dynamic_slice(b,(i,j),(2,2)),
                           tensor_mass_x,
                           tensor_oriented_y   
                           )    
            
            return jnp.array([oriented_x,oriented_y])


      return jnp.sum(vmap(vmap(helper_ij,in_axes=(None,0)),
                                       in_axes=(0,None))(i,j),axis=(0,1))

   oriented_areas=vmap(vmap(helper,in_axes=(None,0)),
                                in_axes=(0,None))(A,B)
   
   #oriented_areas=oriented_areas[jnp.triu_indices(W.shape[0],1)]
   return jnp.sqrt((oriented_areas**2).sum(axis=-1))



def construct_oriented_2D(W,h1,h2):
   
   oriented_areas=vec_helper_oriented_2D(W,W,h1,h2)
   oriented_areas=oriented_areas[jnp.triu_indices(W.shape[0],1)]
   
   return oriented_areas 


def construct_M1_2D(W,h1,h2):
   
   partials=vec_helper_M1_2D(W,W,W,W,h1,h2)
   M_1_redundant=(partials-jnp.transpose(partials,(1,0,2,3))-
                jnp.transpose(partials,(0,1,3,2))+jnp.transpose(partials,(1,0,3,2)))
   M_1=M_1_redundant[jnp.triu_indices(W.shape[0],1)]
   M_1=vmap(lambda x: x[jnp.triu_indices(x.shape[0],1)])(M_1)
   return M_1



def vec_helper_M1_2D(A,B,C,D,h1,h2):
    
   def helper_M1(a,b,c,d):
               
      i=jnp.arange(0,a.shape[0]-1)
      j=jnp.arange(0,a.shape[1]-1)

      def helper_ij(i,j):
               
           
            tensor_A_x=helper_tensor_A(h1)
            tensor_A_y=helper_tensor_A(h2)
            tensor_C_x=helper_tensor_C(h1)
            tensor_C_y=helper_tensor_C(h2)

            CA=jnp.einsum("ia,jb,kc,ld,ijkl,abcd->",
                           jax.lax.dynamic_slice(a,(i,j),(2,2)),
                           jax.lax.dynamic_slice(b,(i,j),(2,2)),
                           jax.lax.dynamic_slice(c,(i,j),(2,2)),
                           jax.lax.dynamic_slice(d,(i,j),(2,2)),
                           tensor_C_x,
                           tensor_A_y         
                           )
            
            AC=jnp.einsum("ia,jb,kc,ld,ijkl,abcd->",
                           jax.lax.dynamic_slice(a,(i,j),(2,2)),
                           jax.lax.dynamic_slice(b,(i,j),(2,2)),
                           jax.lax.dynamic_slice(c,(i,j),(2,2)),
                           jax.lax.dynamic_slice(d,(i,j),(2,2)),
                           tensor_A_x,
                           tensor_C_y        
                           )
            
            return jnp.array([CA,AC])


      return jnp.sum(vmap(vmap(helper_ij,in_axes=(None,0)),
                                       in_axes=(0,None))(i,j))

   return vmap(vmap(vmap(vmap(helper_M1,in_axes=(None,None,None,0)),
                                       in_axes=(None,None,0,None)),
                                       in_axes=(None,0,None,None)),
                                       in_axes=(0,None,None,None))(A,B,C,D)




def helper_tensor_oriented():
      tensor=tensor=jnp.zeros(2*(2,))
      
      tensor=tensor.at[0,1].set(1)
      tensor=tensor.at[1,0].set(-1)
      
      return tensor



def helper_tensor_mass(h):
      tensor=jnp.zeros(2*(2,))

      """
      tensor=jnp.where(jnp.logical_or(i==0,i==len),
                       tensor.at[0,0].set(1/3*h),
                       tensor.at[0,0].set(2/3*h))
      """
      
      tensor=tensor.at[0,0].set(1/3*h)
      
      tensor=tensor.at[0,1].set(1/6*h)
      tensor=tensor.at[1,0].set(1/6*h)
#
      tensor=tensor.at[1,1].set(1/3*h)

      #tensor=jnp.where(i==len,jnp.zeros(2*(2,)),tensor)
   
      
      return tensor



def helper_tensor_C(h):
      tensor=jnp.zeros(4*(2,))
      """
      tensor=jnp.where(jnp.logical_or(i==0,i==len),
                        tensor.at[0,0,0,0].set(1/3 * 1/h),
                        tensor.at[0,0,0,0].set(2/3 * 1/h))
      """
      tensor= tensor.at[0, 0, 0, 0].set(1/3 * 1/h)

      tensor= tensor.at[1, 0, 0, 0].set(1/6 * 1/h)
      tensor = tensor.at[0, 1, 0, 0].set(-1/3 * 1/h)
      tensor = tensor.at[0, 0, 1, 0].set(1/6 * 1/h)
      tensor = tensor.at[0, 0, 0, 1].set(-1/3 * 1/h)
      
      tensor = tensor.at[1, 1, 0, 0].set(-1/6 * 1/h)
      tensor = tensor.at[1, 0, 1, 0].set(1/3 * 1/h)
      tensor = tensor.at[1, 0, 0, 1].set(-1/6 * 1/h)
      tensor = tensor.at[0, 1, 1, 0].set(-1/6 * 1/h)
      tensor = tensor.at[0, 1, 0, 1].set(1/3 * 1/h)
      tensor = tensor.at[0, 0, 1, 1].set(-1/6 * 1/h)
      
      tensor = tensor.at[1, 1, 1, 0].set(-1/3 * 1/h)
      tensor = tensor.at[1, 0, 1, 1].set(-1/3 * 1/h)
      tensor = tensor.at[1, 1, 0, 1].set(1/6 * 1/h)
      tensor = tensor.at[0, 1, 1, 1].set(1/6 * 1/h)

      tensor= tensor.at[1, 1, 1, 1].set(1/3 * 1/h)

      #tensor=jnp.where(i==len,jnp.zeros(4*(2,)),tensor)
      
      return tensor
        

def helper_tensor_A(h):
      
      tensor=jnp.zeros(4*(2,))

      """
      tensor=jnp.where(jnp.logical_or(i==0,i==len),
                        tensor.at[0,0,0,0].set(1/5 * h),
                        tensor.at[0,0,0,0].set(2/5 * h))
      """
      

      tensor = tensor.at[0, 0, 0, 0].set(1/5 * h)

      tensor = tensor.at[1, 0, 0, 0].set(1/20 * h)
      tensor = tensor.at[0, 1, 0, 0].set(1/20 * h)
      tensor = tensor.at[0, 0, 1, 0].set(1/20 * h)
      tensor = tensor.at[0, 0, 0, 1].set(1/20 * h)

      tensor = tensor.at[1, 1, 0, 0].set(1/30 * h)
      tensor = tensor.at[1, 0, 1, 0].set(1/30 * h)
      tensor = tensor.at[1, 0, 0, 1].set(1/30 * h)
      tensor = tensor.at[0, 1, 1, 0].set(1/30 * h)
      tensor = tensor.at[0, 1, 0, 1].set(1/30 * h)
      tensor = tensor.at[0, 0, 1, 1].set(1/30 * h)

      tensor = tensor.at[1, 1, 1, 0].set(1/20 * h)
      tensor = tensor.at[1, 0, 1, 1].set(1/20 * h)
      tensor = tensor.at[1, 1, 0, 1].set(1/20 * h)
      tensor = tensor.at[0, 1, 1, 1].set(1/20 * h)

      tensor = tensor.at[1, 1, 1, 1].set(1/5 * h)

      #tensor=jnp.where(i==len,jnp.zeros(4*(2,)),tensor)

      return tensor



def helper_tensor_B():
    

   tensor=jnp.zeros(4*(2,))

   """
   tensor=jnp.where(jnp.logical_or(i==0,i==len),
                     tensor.at[0,0,0,0].set(1/5 * h),
                     tensor.at[0,0,0,0].set(2/5 * h))



   """
   tensor = tensor.at[1, 1, 1, 1].set(1/4) 
   tensor = tensor.at[1, 1, 1, 0].set(-1/4) 

   tensor = tensor.at[0, 1, 1, 1].set(1/12) 
   tensor = tensor.at[1, 0, 1, 1].set(1/12) 
   tensor = tensor.at[1, 1, 0, 1].set(1/12) 

   tensor = tensor.at[0, 1, 1, 0].set(-1/12) 
   tensor = tensor.at[1, 0, 1, 0].set(-1/12) 
   tensor = tensor.at[1, 1, 0, 0].set(-1/12) 

   tensor = tensor.at[0, 0, 0, 0].set(-1/4) 
   tensor = tensor.at[0, 0, 0, 1].set(1/4) 

   tensor = tensor.at[1, 0, 0, 0].set(-1/12) 
   tensor = tensor.at[0, 1, 0, 0].set(-1/12) 
   tensor = tensor.at[0, 0, 1, 0].set(-1/12) 
   
   tensor = tensor.at[1, 0, 0, 1].set(1/12) 
   tensor = tensor.at[0, 1, 0, 1].set(1/12) 
   tensor = tensor.at[0, 0, 1, 1].set(1/12) 

   return tensor




def vec_helper_M1_2D_with_K(A,B,C,D,K,h1,h2):
    
   def helper_M1(a,b,c,d):
               
      i=jnp.arange(0,a.shape[0]-1)
      j=jnp.arange(0,a.shape[1]-1)

      def helper_ij(i,j):
               
            tensor_A_x=helper_tensor_A(h1)
            tensor_A_y=helper_tensor_A(h2)
            tensor_C_x=helper_tensor_C(h1)
            tensor_C_y=helper_tensor_C(h2)
            tensor_B=helper_tensor_B()

            CA=jnp.einsum("ia,jb,kc,ld,ijkl,abcd->",
                           jax.lax.dynamic_slice(a,(i,j),(2,2)),
                           jax.lax.dynamic_slice(b,(i,j),(2,2)),
                           jax.lax.dynamic_slice(c,(i,j),(2,2)),
                           jax.lax.dynamic_slice(d,(i,j),(2,2)),
                           tensor_C_x,
                           tensor_A_y         
                           )
            
            AC=jnp.einsum("ia,jb,kc,ld,ijkl,abcd->",
                           jax.lax.dynamic_slice(a,(i,j),(2,2)),
                           jax.lax.dynamic_slice(b,(i,j),(2,2)),
                           jax.lax.dynamic_slice(c,(i,j),(2,2)),
                           jax.lax.dynamic_slice(d,(i,j),(2,2)),
                           tensor_A_x,
                           tensor_C_y        
                           )
            
            B1=jnp.einsum("ia,jb,kc,ld,ijkl,abcd->",
                           jax.lax.dynamic_slice(a,(i,j),(2,2)),
                           jax.lax.dynamic_slice(b,(i,j),(2,2)),
                           jax.lax.dynamic_slice(c,(i,j),(2,2)),
                           jax.lax.dynamic_slice(d,(i,j),(2,2)),
                           jnp.swapaxes(tensor_B,1,3),
                           tensor_B       
                           )
            
            B2=jnp.einsum("ia,jb,kc,ld,ijkl,abcd->",
                           jax.lax.dynamic_slice(a,(i,j),(2,2)),
                           jax.lax.dynamic_slice(b,(i,j),(2,2)),
                           jax.lax.dynamic_slice(c,(i,j),(2,2)),
                           jax.lax.dynamic_slice(d,(i,j),(2,2)),
                           tensor_B,
                           jnp.swapaxes(tensor_B,1,3)      
                           )
            
            return jnp.array([CA,AC,B1,B2])
      
      partials=jnp.sum(vmap(vmap(helper_ij,in_axes=(None,0)),
                                       in_axes=(0,None))(i,j),axis=(0,1))
      
      Mdif=K[0]*partials[0]+K[1]*partials[1]+K[2]*(partials[2]+partials[3])
      M1=partials[0]+partials[1]
      return jnp.array([Mdif,M1])
   
   
   return vmap(vmap(vmap(vmap(helper_M1,in_axes=(None,None,None,0)),
                                       in_axes=(None,None,0,None)),
                                       in_axes=(None,0,None,None)),
                                       in_axes=(0,None,None,None))(A,B,C,D)



def construct_M1_2D_with_K(W,K,h1,h2):
   
   partials=vec_helper_M1_2D_with_K(W,W,W,W,K,h1,h2)
   M_1_redundant=(partials-jnp.swapaxes(partials,0,1)-
               jnp.swapaxes(partials,2,3)+jnp.swapaxes(jnp.swapaxes(partials,0,1),2,3))
   M_1=M_1_redundant[jnp.triu_indices(W.shape[0],1)]
   M_1=vmap(lambda x: x[jnp.triu_indices(x.shape[0],1)])(M_1)
   
   return M_1










