"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import jax.numpy as jnp 
from jax import jacfwd, grad, jit, vmap
from utils import div,sin_thres,tan_thres
import numpy as np
import jax
# this is the standard PDE loss for variable density Euler, formulated by assuming the
# velocity - u: R^{n+1} -> R^n 
# density - ρ: R^{n+1} -> R
# pressure - p:R^{n+1} -> R
# following convetion u = (t,x) [time always first]
# are all given separately. So good to use with regular PINN or div-free, but not with NCL
class PDE(object):
    # init takes boundary normal callable as argument
    def __init__(self):
        self.normal = None
        self.initial = None
    
    #set the initial condition for the pde
    def setInitial(self,init):
        self.initial = init
    
    def setNormal(self,nm):
        self.normal = nm
        
    def unpackPINN(self,pinn):
        #curry the u,rho,p functions at the parameters
        u = lambda x: pinn(x)[1:-1]
        rho = lambda x: pinn(x)[0]
        p = lambda x: pinn(x)[-1]
        return u,rho,p


    # Takes u,rho,p callables and evaluates the momentum equation 'rho*u_t + rho*[Du]u + grad(p)' at x
    # Du is spatial jacobian (excluding time)
    def mom(self,pinn,x):
        u,rho,p = self.unpackPINN(pinn)
        Du = jacfwd(u)(x)
        u_t = Du[:,0]
        Du_u = Du[:,1:]
        nabla_p = grad(p)(x)[1:] #chop off time derivative
        rho_x = rho(x)

        return rho_x*u_t + rho_x*Du_u + nabla_p
    
    # Evaluates the continuity eq
    def cont(self,pinn,x):
        u,rho,p = self.unpackPINN(pinn)
        nabla_rho = grad(rho)(x) #full gradient (inc time)
        rho_t = nabla_rho[0]
        div_rhou = div(lambda x: rho(x)*u(x))(x)
        return rho_t + div_rhou
    
    # Evaluates the incompressibility condition
    def incp(self,pinn,x):
        u,rho,p = self.unpackPINN(pinn)
        return div(u)(x)
    
    # evaluates the free-slip boundary condition
    def bd(self,pinn,x):
        u,rho,p = self.unpackPINN(pinn)
        return jnp.dot(u(x),self.normal(x))

    # evaluates the initial condition [rho - rho_0, u - u_0]
    def init(self,pinn,x):
        u,rho,p = self.unpackPINN(pinn)
        return jnp.array([rho(x),*u(x)]) - self.initial(x)

def syn(x):
    x = x.at[2].set(jnp.abs(x[2]))
    return x

def deal_sign(v,x):
    sign = jnp.sign(x[2])
    
    v = v.at[-1,0].set(v[-1,0]*sign)
    v = v.at[-1,1].set(v[-1,1]*sign)
    # v = v.at[-1,2].set(v[-1,2]*sign)
    return v



# this is the modified PDE loss for variable density Euler, formulated by assuming the
# vedens - v: R^{n+1} -> R^{n+1}  
# pressure - p:R^{n+1} -> R
# following convetion v = v(t,x) [time always first]
# it evaluates pdes scalable by the density [see appendix A of paper]
class PDEDivForm(object):
    # init takes boundary normal callable as argument
    def __init__(self,time_step,spatial_m=None):
        self.normal = None
        self.initial = None
        self.spatial_m = spatial_m
        self.time_step = time_step

    def setInitial(self,init):
        self.initial = init
    
    def setNormal(self,nm):
        self.normal = nm

    # Takes v,p callables 'rho*u_t + rho*[Du]u + grad(p)' at x
    # Du is spatial jacobian (excluding time)

    # geometry term
    # def mom(self,v,x):

    #     v_x = v(x)[:-1]
    #     rho = v_x[0]
    #     #jax.debug.print("rho {}",rho)
    #     rho_u = v_x[1:]

    #     Dv = jacfwd(v)(x)
    #     # fix with sphere gradient
    #     #print(Dv.shape)
        
    #     if self.spatial_m is not None:
    #         Dv = Dv.at[:,2].set(Dv[:,2]*sin_thres(x[1]))
    #     nabla_rho = Dv[0]
    #     Drhou = Dv[1:-1]
    #     # rho3u_t (1,2)
    #     rho3u_t = Drhou[:,0]*rho**2 - nabla_rho[0]*rho*rho_u
    #     #print(rho3u_t.shape)

    #     rho3Duu = rho*Drhou[:,1:]@rho_u - jnp.outer(nabla_rho[1:],rho_u).T@rho_u

    #     nabla_p = Dv[-1,1:] #spatial gradient
        
    #     geo = jnp.zeros(2)
        
    #     geo.at[0].set(-rho*(rho_u[1])*(rho_u[1])*tan_thres(x[1]))
    #     geo.at[1].set(rho*(rho_u[1])*(rho_u[0])*tan_thres(x[1]))
    #     return rho3u_t + rho3Duu + rho**2*nabla_p + geo
    
    def advect_loss(self,f_0,f_1,f_init,x):
        b_0 = f_0(x)
        b_1 = f_1(x)
        v_0 = b_0[:-1]
        w_0 = b_0[-1]
        v_1 = b_1[:-1]
        w_1 = b_1[-1]

        C_0 = jacfwd(f_0)
        #C_0 = lambda x: deal_sign(C_0_(syn(x)),x)
        C_1 = jacfwd(f_1)
        #C_1 = lambda x: deal_sign(C_1_(syn(x)),x)

        d_0 = C_0(x)[-1,:]
        d_1 = C_1(x)[-1,:]

        # b_inj = f_init(x)
        # v_inj = b_inj[:-1]
        # w_inj = b_inj[-1]
        # C_inj = jacfwd(f_init)

        #C_inj = lambda x: deal_sign(C_inj_(syn(x)),x)

        # jax.debug.print("{} {} {} {}",C_inj(x)[-1,:],C_inj_2(x)[-1,:],C_inj_3(x)[-1,:],x)

        # d_inj = C_inj(x)[-1,:]    
        
        # n = x/(1e-6+jnp.linalg.norm(x))

        # d_0 = d_0 + d_inj * 50 * self.time_step /4
        # v_0 = v_0 + v_inj * 50 * self.time_step /4
        # w_0 = w_0 + w_inj * 50 * self.time_step /4
        
        # d_0 = d_0 - jnp.sum(d_0*n)*n
        # d_1 = d_1 - jnp.sum(d_1*n)*n
        prod_0 = jnp.sum(d_0*v_0)
        prod_1 = jnp.sum(d_1*v_1)
        loss = (w_0+0.5*self.time_step*prod_0-w_1+0.5*self.time_step*prod_1)
        #jax.debug.print("advect loss {} {} {} {} {}",w_0,self.time_step*prod_0,self.time_step*prod_1,w_1, loss)
        
        return loss

    # def rho_pos(self,v,x):
    #     v_x = v(x)[:-1]
    #     rho = v_x[0]
    #     #return jnp.maximum(-rho*jnp.abs(rho),0)
    #     a = jnp.where(rho>0,0,-jnp.log(-rho))
    #     return a

    # # Evaluates the (scaled) incompressibility condition
    # def incp(self,v,x):
    #     nabla_rho = grad(lambda y: v(y)[0])(x)
    #     if self.spatial_m is not None:
    #         nabla_rho = nabla_rho.at[2].set(nabla_rho[2]*sin_thres(x[1]))
    #     return jnp.dot(nabla_rho,v(x)[:-1])
    
    # evaluates the free-slip boundary condition
    def bdry(self,v,x):
        return jnp.dot(v(x)[1:-1],self.normal(x))
    
    def bdry_cycle_r(self,v,x):
        y = x.copy()
        
        y = y.at[1].set(-jnp.pi)
        a = v(y)-v(x)
        #print(a.shape)
        # jax.debug.print("x:{} {}",y[2],x[2])
        #return jnp.sqrt(a[1]**2+a[2]**2)
        return a
    
    def bdry_cycle_l(self,v,x):
        y = x.copy()
        #print(y.shape)
        y = y.at[1].set(jnp.pi)
        a = v(y)-v(x)
        #return jnp.sqrt(a[1]**2+a[2]**2)
        return a
    
    def bdry_scalar_singular_n(self,v,x):
        y = x.copy()
        y = y.at[1].set(0)
        #jax.debug.print("y, {} {} ",v(y)[0],v(y)[-1])
        return (v(x)-v(y))

    def bdry_scalar_singular_s(self,v,x):
        y = x.copy()
        y = y.at[1].set(0)
        # jax.debug.print("y, {} {} ",v(y)[0],v(y)[-1])
        return (v(x)-v(y))

    # def bdry_vel_singular_n(self,u,x):
    #     u_ = u(x)[np.array([1,0])]
    #     u_ = u_.at[1].set(u_[1]* -1)
    #     Du = jacfwd(u)(x)[:,1:]
    #     Du = Du[np.array([[1,0],[0,1]])][:,0]
    #     a = Du-u_
    #     return a

    def bdry_vel_singular_n(self,v,x):
        v_x = v(x)[:-1]
        Dv = jacfwd(v)(x)
        a = jnp.zeros(2)
        a = a.at[0].set(Dv[1,1]+v_x[0])
        a = a.at[1].set(Dv[0,1]-v_x[1])
        return a

    # def bdry_vel_singular_s(self,u,x):
    #     u_ = u(x)[np.array([1,0])]
    #     Du = jacfwd(u)(x)[:,1:]
    #     Du = Du[np.array([[1,0],[0,1]])][:,0]
    #     a = Du-u_
    #     return a

    def bdry_vel_singular_s(self,v,x):
        v_x = v(x)[:-1]
        Dv = jacfwd(v)(x)
        a = jnp.zeros(2)
        a = a.at[0].set(Dv[1,1]-v_x[0])
        a = a.at[1].set(Dv[0,1]-v_x[1])
        return a

    def kl_divergence(self, mean, logvar):
        return - 0.5 * jnp.sum(1 + logvar - mean**2 - jnp.exp(logvar))

    # # evaluates the initial condition [rho - rho_0, rho*u - rho_0*u_0]
    def init_w(self,v,x):
        in_ev = self.initial(x)
        #jax.debug.print("{} {} ",v(x),in_ev)
        b = v(x)
        # print(b.shape)
        w = b[3]
        mean = b[4:7]
        logvar = b[7:10]
        loss_ = self.kl_divergence(mean,logvar)
        # print(w.shape,in_ev.shape)
        #jax.debug.print("{} {}",w-in_ev,loss_*0.001)
        return jnp.stack([(w - in_ev),loss_ * 0.01],axis=0)