
import os
import signal
import time
from collections import namedtuple
import itertools
import copy
import pprint

from ordered_set import OrderedSet
import scipy
import numpy as np

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold
import sympy as sp
from sympy import derive_by_array as ts_grad  #gradient
from sympy import tensorcontraction as ts_contr  #summation
from sympy import tensorproduct as ts_prod 
from sympy import transpose as ts_trans
import pysindy as ps
import argparse

from finitediff import get_diff, FiniteDiffVand

from utils import switch

#---- Load Data ---
def get_raw_data(problem='2d_incomp_viscose_newton_ns', datasource="COMSOL", verbose=False):
    ''' 
    Read data from .npz or .mat files, and optionally (verbose=True) check data correctness
    Input: problem str, datasource str, verbose bool
    Output: 
        u: np.ndarray. shape=(nx, nt, u_dim=1) for 2d data, shape=(nx, ny, nt, u_dim) for 3d data.
        grids: tuple of np.ndarray. return (x,t) for 2d data, return (x,y,t) for 3d data.
            where x.shape=(nx,), y.shape=(ny,), t.shape=(nt,)
    '''
    data_root = "./data"

    if datasource == "COMSOL":
        filename = os.path.join(data_root, problem.lower()+'.npz')
        data = np.load(filename)
        t = data['t']
        x = data['x']
        x = x[:,0] 
        y = data['y']
        y = y[0,:]
        ux = data['ux']
        uy = data['uy']
        p = data['p']
        rho = data['rho']
        
        if problem == "2d_incomp_viscose_newton_ns":
            dt = t[1] - t[0]
            dx = x[1] - x[0]
            dy = y[1] - y[0]
            Re = 100
            mu = 1 / Re
            u = np.stack([ux, uy, p, rho], axis=-1)  # shape=(100, 100, 181, 4)

        elif problem.startswith("2d_comp_viscose_new"):
            Re = 100
            u = np.stack([ux, uy, p, rho], axis=-1)  # shape=(100, 100, 181, 4)

        elif problem.startswith("2d_heat_comp"):
            T = data['T']
            u = np.stack([ux, uy, p, rho, T], axis=-1) # shape=(40, 100, 801, 5)

        else:
            raise ValueError(f"Dataset Not Found: {problem},{datasource}.")

    else:
            raise ValueError(f"Dataset Not Found: {problem}, {datasource}.")
            
    n_clip = 5
    if len(u.shape)==4:  # 3 dimensional data (2 space dim, 1 temporal dim)
        if n_clip > 0:
            u = u[n_clip: -n_clip, n_clip: -n_clip,...]
            x = x[n_clip: -n_clip]
            y = y[n_clip: -n_clip]
        return u, (x, y, t)

# --- list utils ----
def list_cat(ll):
    '''
    joint a iterable of lists into a single list.
    '''
    return sum(ll, [])

def list_remove(ll, fun):
    ''' 
    Remove elements of list, if fun(ele)==True.
    '''
    return [ele for ele in ll if not fun(ele)]
# --- ----


# ---- numpy utilities ----
def np_grad(arr_list, grids, is_time_grad=False):
    """ 
    Spacial or temporal gradient of np arrays.
    Input:
        arr_list: A single or a list of len(grids)-dim np.ndarray, the matrices to take gradient, each arr.shape=(nx, ny, (nz), nt)
        grids: list of 1-dim np.ndarray, spacial and temporal grids. grids=[x,y,(z),t], x.shape = (nx,).
        is_time_grad: Boolean. True->return temporal gradient only, False->return spacial gradient only
    Output:
        ret: list of len(grids)-dim np.ndarray, resulting gradients, length=len(arr_list)*(len(grids)-1).

    e.g. Input: arr_list =[u, v], grids=(x,y,t), is_time_grad=False
        Output: [Derivative(u(x, y, t), x), Derivative(v(x, y, t), x), Derivative(u(x, y, t), y), Derivative(v(x, y, t), y)]
    e.g. Input: arr_list =[u, v], grids=(x,y,t), is_time_grad=True
        Output: [Derivative(u(x, y, t), t), Derivative(v(x, y, t), t)]
    """
    if not isinstance(arr_list, list): #for single array.
        arr_list = [arr_list]
    ret = []
    
    for axis_idx, grid in enumerate(grids):
        if is_time_grad ^ (axis_idx == len(grids)-1): #skip time or spacial gradients.
            continue
        dx = grid[1] - grid[0]
        
        for arr in arr_list:
            ret.append(FiniteDiffVand(arr, dx=dx, d=1, axis=axis_idx))
    return ret

def np_grad_all(arr_list, grids):
    """ 
    np_grad() wrapper, get [grad_, grad_grad_, dt_] in one call.
    """
    dt_ = np_grad(arr_list, grids, is_time_grad=True)
    grad_ = np_grad(arr_list, grids)
    grad_grad_  = np_grad(grad_, grids)
    return [grad_, grad_grad_, dt_]

def pooling(mat, ksize, method='mean', pad=False):
    '''
    Non-overlapping pooling on 2D or 3D data.

    <mat>: ndarray, input array to pool.
    <ksize>: tuple of 2, kernel size in (ky, kx).
    <method>: str, 'max for max-pooling, 
                   'mean' for mean-pooling.
    <pad>: bool, pad <mat> or not. If no pad, output has size
           n//f, n being <mat> size, f being kernel size.
           if pad, output has size ceil(n/f).

    Return <result>: pooled matrix.
    '''
    if not hasattr(mat, "shape"):
        return np.zeros((1,))
    
    m, n = mat.shape[:2]
    ky, kx = ksize

    _ceil = lambda x, y: int(np.ceil(x / float(y)))

    if pad:
        ny = _ceil(m,ky)
        nx = _ceil(n,kx)
        size = (ny * ky, nx * kx) + mat.shape[2:]
        mat_pad = np.full(size, np.nan)
        mat_pad[: m, : n,...] = mat
    else:
        ny = m // ky
        nx = n // kx
        mat_pad = mat[: ny * ky, :nx * kx, ...]

    new_shape = (ny, ky, nx, kx) + mat.shape[2:]

    if method == 'max':
        result = np.nanmax(mat_pad.reshape(new_shape), axis=(1,3))
    else:
        result = np.nanmean(mat_pad.reshape(new_shape), axis=(1,3))

    return result


def np_ms(a):
    """numpy array mean square"""
    b = a ** 2
    if isinstance(b, np.ndarray):
        b = b.mean()
    return b
# ----  ----


# ---- sympy utilities ----
class TimeoutError(Exception):
    pass

def handler(signum, frame):
    raise TimeoutError("Simplification process took too long")

def sp_simplify_with_timeout(expr, timeout=3):
    """ 
    Simplification for both Int and sp.expr, with time limitation 
    and also catch NotImplementedError (err msg: Improve MV Derivative support in collect)
    """
    if hasattr(expr, "simplify"):   
        signal.signal(signal.SIGALRM, handler)
        signal.alarm(timeout)
        try:
            simplified = expr.simplify()
            signal.alarm(0)
            return simplified
        except (NotImplementedError, TimeoutError) as error:
            signal.alarm(0)
            return expr
        
        
    return expr
# ---- ----

# ---- scipy utilities ----
class EarlyStopException(Exception):
    """Custom exception to signal early stopping"""
    pass

def optimize_with_timeout(mse_func, init_params, constr_dict_list, dataset_name, prev_sol_best=None, verbose=False):    
    global iteration_count, params, start_time
    iteration_count = 0
    params = init_params
    start_time = 0

    # ---- hyper params ----
    
    time_limit = 1500 if "_new_" in dataset_name else 500 # 5 minutes for compNS, 15min for new newton
    bound_f_coef = 500
    #check_time_nit = 3 if dataset_name=="2d_comp_viscose_newton_ns" else 9 # 3 for compNS, 9 for new newton/heat
    check_time_nit = 9

    def callback(xk,):
        
        global iteration_count, params, start_time
        iteration_count += 1
        params = xk
        time_elapsed = time.time() - start_time
        
        cur_f, cur_nit =  mse_func(xk), iteration_count
        if time_elapsed > time_limit: 
            #print(f"Early Stop: {time_elapsed=}, {cur_nit=}.")
            raise EarlyStopException
        if prev_sol_best and cur_nit >= check_time_nit:
            prev_f, prev_nit = prev_sol_best["fun"], prev_sol_best["nit"]
            bound_f = bound_f_coef * (prev_nit/cur_nit)**2 * prev_f
            if cur_f > bound_f:
                #print(f"Early Stop: {cur_nit=}, {cur_f=}, {bound_f=}.")
                raise EarlyStopException
    
    try:
        start_time = time.time()
        if len(init_params) > 0:
            sol = scipy.optimize.minimize(mse_func, init_params, constraints=constr_dict_list, method="SLSQP",options={"disp":verbose}, callback=callback)
        else:
            sol = {"fun":mse_func(init_params), "x":init_params, "nit":1}
        sol["status"] = "Success"
    except EarlyStopException:
        sol = {"fun":mse_func(params), "x":params, "nit":iteration_count, "status":"EarlyStop"}
    sol['time'] = time.time() - start_time
    return sol
# ---- ----

# ---- data structures ---
# DBNode denotes the the content of decision branch.
# DBNode.name = the name of branch, DBNode.children = the names of children DBNodes.
DBNode = namedtuple(typename="DBNode", field_names=["name","children"], defaults=["",list()])
# ---- ----

class Fl:
    '''2 Dimensional Fluid Mechanics Model in sympy.
    '''
    # ---- Symbols claim&definition (static variables) ----
    u, v = sp.symbols('u v', cls=sp.Function)
    x, y, t = sp.symbols('x, y, t')
    space_axis = [x, y]
    X = sp.Array(space_axis)  
    X_dim = len(X)
    u = u(*space_axis, t) # define velocity                                 
    v = v(*space_axis, t)
    V = sp.Array([u, v]) # define velocify tensor field 
    # define pressure, temperature and density field 
    p, T, rho = sp.symbols('p T rho', cls=sp.Function)
    p = p(*space_axis, t)
    T = T(*space_axis, t)
    rho = rho(*space_axis, t)

    # kinetic viscosity, volume viscosity 
    mu, lambda_ = sp.symbols('mu lambda')
    # stress, shear stress、strain rate、shear rate
    Sigma, tau, S, gamma = sp.symbols('Sigma tau S gamma')
    # body force
    fx, fy, F = sp.symbols('fx, fy, F')

    # variables introduced by turbulence model, k and epsilon for the k-epsilon turbulence model and for those with only one variable, turb2 = 0
    turb1, turb2 = sp.symbols('tb_1, tb_2', cls=sp.Function)
    turb1 = turb1(*space_axis, t)
    turb2 = turb2(*space_axis, t)
    
    # heat flux
    qx, qy, Q = sp.symbols('qx, qy, Q')
    Q = sp.Array([qx, qy])

    # constants
    I = sp.Array(sp.eye(X_dim)) # unit tensor, shape=(3,3)
    tol = 1e-3 # tolerence of small number
    T_ref = 300 #reference temperature.
    
    # ---- ---- ---- ---- ----

    # ---- Tensor Utilities (static functions)----
    @staticmethod
    def dot(tensor1, tensor2):
        ''' 
        Dot production between two sympy tensors.
        Contraction of the last dim of input tensors.
        Input:
            tensor1:  rank(n) tensor.
            tensor2:  rank(m) tensor
        Output:
            rank(n+m-2) tensor.
        '''
        n = len(tensor1.shape)
        return ts_contr(ts_prod(tensor1, tensor2), (n-1, n))
    
    @staticmethod
    def ddot(tensor1, tensor2):
        ''' 
        double Dot production between two sympy tensors.
        Contraction of the last two dims of input tensros.
        E.g. if both tensors are rank(2), A:B = a_{ij}b_{ij}
        Input:
            tensor1:  rank(n) tensor, n>=2
            tensor2:  rank(m) tensor, m>=2
        Output:
            rank(n+m-4) tensor.
        '''
        n = len(tensor1.shape)
        tmp = ts_prod(tensor1, tensor2)
        tmp = ts_contr(tmp, (n-2, n))
        tmp = ts_contr(tmp, (n-2, n-1))
        return tmp
    
    @staticmethod
    def div(f, x=X):
        '''
        Divergence := \grad_{x} \cdot (f).
        Input:
            f: sympy.Array, any field, rank(n) tensor
            x: sympy.Array, Euler coord
        Output: Divergence, rank(n-1) tensor
        '''
        return ts_contr(ts_grad(f, x), (0, 1))
    
    @staticmethod
    def conserve(f, vel=V, x=X, t=t):
        '''
        conservation form := df/dt + div(f*vel, x).
        Input:
            f: sympy.Array, any field, rank(n) tensor
            vel: sympy.Array, velocity field, rank(1) tensor
            x: sympy.Array, Euler coord
            t: sympy.Symbol, time
        Output:
            rank(n) tensor
        '''
        return f.diff(t) + Fl.div(ts_prod(f, vel), x)
    
    @staticmethod
    def DDt(f, vel=V, x=X, t=t):
        '''
        Total gradient Df/Dt := df/dt + vel \dot grad(f).
        Input:
            f: sympy.Array, any field, rank(n) tensor
            vel: sympy.Array, velocity field, rank(1) tensor
            x: sympy.Array, Euler coord
            t: sympy.Symbol, time
        Output:
            total gradient, rank(n) tensor
        '''
        return f.diff(t) + Fl.dot(vel, ts_grad(f, x))

    @staticmethod
    def ts_1d_list(tensor):
        """
        Flatten a sympy tensor to list.
        """
        if not hasattr(tensor, 'tolist'): # rank(0) tensor
            return [tensor]
        else:
            lst = tensor.tolist()
            while isinstance(lst[0], list): #recursively de-nest the list
                lst = sum(lst, [])
            return lst
    
    @staticmethod
    def ts_flatten(tensor): 
        """
        Flatten a sp.tensor to a 1d sp.tensor.
        """
        return sp.Array(Fl.ts_1d_list(tensor))
    
    @staticmethod
    def ts_grad_all(tensor): 
        """
        warpper of ts_grad(), get [grad_, grad_grad_, dt_] in one call.
        """
        dt_ = ts_grad(tensor, Fl.t)
        grad_ = ts_grad(tensor, Fl.X)
        grad_grad_ = ts_grad(Fl.ts_flatten(grad_),Fl.X)
        return [grad_, grad_grad_, dt_]

    grad = lambda a: ts_grad(a, Fl.X)
    norm = lambda tensor: sp.sqrt(Fl.ddot(tensor, tensor))
    # ---- ---- ---- ---- ----

    # ---- Sympy utilities ----
    @staticmethod
    def sp_maximum(a, b):
        """
        maximum of 2 symbols.
        Ref: https://stackoverflow.com/questions/60723841/how-do-i-use-sympy-lambdify-with-max-function-to-substitute-numpy-maximum-instea
        """
        return sp.Piecewise((b, a < b), (a, True))

    # ---- Symbols definition, continued (static variables) ----
    S = sp.Rational(1,2) * (ts_grad(V, X) + ts_trans(ts_grad(V, X))) #应变率张量Strain rate

    def __init__(self,):
        pass
    
    # ---- decision functions ----
    # Decision functions get_A() receive a deci_dict, 
    # output the expr of A, and optionally add params into params_constr.
    # params_constr is a dict, key = param, value=a list of constraints dict, 
    # each dict is {"type":"ineq", "fun":func}, where func()>=0
    @staticmethod
    def get_tau(deci_dict, params_constr, stridge_terms=None):
        '''shear stress tau.'''
        # ancestor variables
        mu_app = Fl.get_mu_app(deci_dict, params_constr, stridge_terms)

        tau = 2 * mu_app * Fl.S

        if deci_dict["is_compressible"]:
            lambda_ = sp.Rational(-2, Fl.X_dim) * mu_app
            tau += lambda_ * Fl.I * Fl.div(Fl.V)
        
        return sp_simplify_with_timeout(tau)

        
    @staticmethod
    def get_mu_app(deci_dict, params_constr, stridge_terms=None):
        '''appearent kinetic viscosity mu_app.'''
        # definition
        gamma = Fl.sp_maximum(sp.sqrt(2* Fl.ddot(Fl.S, Fl.S)), 0.01) # a lower bound of shear rate to avoid division 0 error

        # decision
        # mu_vs_sheer
        if deci_dict["is_newtonian"]:
            if deci_dict["type_newtonian"]==0: # inviscid flow
                mu_app = Fl.mu.subs({Fl.mu:0})
            elif deci_dict["type_newtonian"]==1: # newtonian fluid
                mu_app = Fl.mu
                params_constr[mu_app] = [{"type":"ineq", "fun":mu_app-Fl.tol}] #mu_app- 1e-5 >= 0
        else: # non_newtonian
            type_non_newtonian = switch(deci_dict["type_non_newtonian"])
            if type_non_newtonian(0): # power law
                mu_k, n = sp.symbols("mu_k, n")  # n=flow behaviour index
                mu_app = mu_k * (gamma)**(n-1)
                params_constr[mu_k] = [{"type":"ineq", "fun":mu_k-Fl.tol}] # mu_app- 1e-5>=0
                
                if deci_dict["is_dilatant"]:
                    params_constr[n] = [{"type":"ineq", "fun":n-1-Fl.tol}] # n>1
                else:
                    params_constr[n] = [{"type":"ineq", "fun":1-n-Fl.tol}] # n<1

            elif type_non_newtonian(1): # carreau model 
                # Infinite shear rate viscosity
                # "slope": = zero shear rate viscosity - mu_inf
                # relaxation time
                # power law exponent (n<1)
                mu_inf, mu_k, lambda_rt, n = sp.symbols("mu_inf, mu_k, lambda_rt, n")
                mu_app = mu_inf + mu_k*(1+(lambda_rt*gamma)**2)**((n-1)/2)
                params_constr[mu_inf] = [{"type":"ineq", "fun":mu_inf-Fl.tol}] # mu_inf>0
                params_constr[mu_k] = [{"type":"ineq", "fun":mu_k+mu_inf-Fl.tol}] # mu_k+mu_inf>0
                params_constr[lambda_rt] = [{"type":"ineq", "fun":lambda_rt-Fl.tol}] # lambda_rt>0
                params_constr[n] = [{"type":"ineq", "fun":1-Fl.tol-n}]  # 'n<1' = '1-tol-n>=0'
            
            elif type_non_newtonian(2):  # new_non_newtonian(fourier/poly type)
                mu_inf  = sp.symbols("mu_inf")
                params_constr[mu_inf] = []
                mu_app = mu_inf
                
                if deci_dict["poly_order"] > 0: # poly
                    for i in range(1, deci_dict["poly_order"]+1):
                        coef = sp.symbols("poly_coef_{}".format(i))
                        params_constr[coef] = []
                        mu_app += coef * gamma**i
                    
                if deci_dict["Fourier_order"] > 0: # Fourier(mu_app)
                    '''Fourier basis f(t)= [1+ sin(t) + cos(t) + sin(2t) + cos(2t) +...]'''
                    for i in range(1, deci_dict["Fourier_order"]+1):
                        sin_coef = sp.symbols("Fourier(mu_app)_sin_coef_{}".format(i))
                        cos_coef = sp.symbols("Fourier(mu_app)_cos_coef_{}".format(i))
                        params_constr[sin_coef] = []
                        params_constr[cos_coef] = []
                        mu_app += sin_coef*sp.sin(i*gamma) + cos_coef*sp.cos(i*gamma)

            elif type_non_newtonian(3):  # new_non_newtonian(STRidge type)
                '''mu_app = f(x), x= gamma  
                    f= elementary funcions and their product.
                    (more funcs in https://en.wikipedia.org/wiki/Elementary_function)
                    coefs needs STRidge selection.'''
                # construct library.
                x = sp.symbols('x')
                ele_func_list_0 = [sp.log(x), 1/x]
                ele_func_list_1 = [sp.sin(x), sp.cos(x), sp.log(x)]
                ele_func_list_2 = [1/x, ]
                
                ele_func_1_2_prod = [f[0]*f[1] for f in itertools.product(ele_func_list_1, ele_func_list_2)]
                lib = ele_func_list_0 + ele_func_1_2_prod

                # construct mu_app.
                # --add bias term.
                mu_inf  = sp.symbols("mu_inf")
                params_constr[mu_inf] =[]
                mu_app = mu_inf
                # --add lib.
                for f in lib:
                    name = str(f)+"_STR_coef"
                    if "Deleted_STR_coef" in deci_dict and name in deci_dict["Deleted_STR_coef"]:
                        continue
                    if stridge_terms:  #record norm, for normalization in optimization.
                        stridge_terms[name] = f.subs(x, gamma)

                    coef = sp.symbols(name)
                    params_constr[coef] = []
                    mu_app += coef * f

                mu_app = mu_app.subs(x, gamma)

        # mu_vs_temperature         
        # ref https://en.wikipedia.org/wiki/Temperature_dependence_of_viscosity
        if not deci_dict["is_isothermal"]:
            type_mu_temperature = switch(deci_dict["type_mu_temperature"])

            if type_mu_temperature(0): # Constant
                pass
            
            elif type_mu_temperature(1): # Gas: Power-Law
                s = sp.symbols("s")
                mu_app *= (Fl.T/Fl.T_ref)**s 
                params_constr[s] = [{"type":"ineq", "fun":s-0.5-Fl.tol}] #s>0.5
                
            elif type_mu_temperature(2): # Gas: Sutherland
                S_suther = sp.symbols("S_suther")
                mu_app *= (Fl.T/Fl.T_ref)**sp.Rational(3,2) * (Fl.T_ref + S_suther)/(Fl.T + S_suther)
                params_constr[S_suther] = [{"type":"ineq", "fun":S_suther-Fl.tol}] # S>0

            elif type_mu_temperature(3): # Liquid: exponential model(Andrade equation)   # mu = Aexp(BT)
                B = sp.symbols("B")
                mu_app *= sp.exp(B* (1/Fl.T - 1/Fl.T_ref))
                params_constr[B] = [{"type":"ineq", "fun":B-Fl.tol}] # B>0
            
            elif type_mu_temperature(4): # Liquid: polynomial    # mu = mu_app + Poly(T)
                poly_order = 3 
                for i in range(1, poly_order+1):
                    coef = sp.symbols("poly(mu_T)_coef_{}".format(i))
                    params_constr[coef] = [{"type":"ineq", "fun":sp.Abs(coef) - 0.005**i}] # abs(coef)> 0.005**i
                    mu_app += coef * Fl.T**i
                
        return sp_simplify_with_timeout(mu_app)
    
    @staticmethod
    def get_mass_res(deci_dict, params_constr):
        '''
        mass equation residual.(of density)
        Conservation of mass. conserve(rho)=0, or div(V)=0 if incompressible.
        '''
        if deci_dict["is_compressible"]:
            return Fl.conserve(Fl.rho)
        else:
            return Fl.rho* Fl.div(Fl.V)
    
    @staticmethod
    def get_incomp_res(deci_dict, params_constr):
        '''
        incompressible flow residual.
        D rho/D t = 0
        '''
        if deci_dict["is_compressible"]:
            return 0
        else:
            return Fl.DDt(Fl.rho)
    
    @staticmethod
    def get_mmt_res(deci_dict, params_constr, stridge_terms=None):
        '''momentum equation residual(of velocity).
        Conservation of momentum: conserve(rho*V)= div(Stress) + F
        '''
        # ancestor variables
        tau = Fl.get_tau(deci_dict, params_constr, stridge_terms)
        F = Fl.get_F(deci_dict, params_constr)

        # decision
        # LHS
        if (not deci_dict["is_turbulent"]) and deci_dict["type_non_turbulent"]==1: #creeping flow
            lhs = Fl.rho * (Fl.V.diff(Fl.t))
        else:
            lhs = Fl.rho * Fl.conserve(Fl.V) #conserve(rho*V) can be simplify to rho*conserve(V), if conserve(rho)=0, which is always true.

        # RHS
        Sigma = - Fl.p*Fl.I + tau
        if deci_dict["is_turbulent"]:  #cosnider trubulence viscosity for all turbulent
            mu_app = Fl.get_mu_app(deci_dict, params_constr)
            mu_T = Fl.get_mu_T(deci_dict, params_constr)
            Sigma += (tau / mu_app) * mu_T

            type_turbulent = switch(deci_dict["type_turbulent"]) 
            if type_turbulent(0) or type_turbulent(1): # k-epsilon / Realizable k-epsilon
                Sigma += - 2/3 * Fl.rho * Fl.turb1 * Fl.I
        
        rhs = Fl.div(Sigma) + F

        return sp_simplify_with_timeout(lhs - rhs)

    @staticmethod
    def get_F(deci_dict, params_constr):
        """
        body force F, vector
        """
        F = sp.Array([Fl.fx, Fl.fy])

        # decision
        type_body_force = switch(deci_dict["type_body_force"]) 
            
        if type_body_force(0): # Zero-force
            fx, fy = 0, 0
        elif type_body_force(1): # constant gravity force
            gy = sp.symbols("gy") #gravitational acceleration
            fx = 0
            fy = Fl.rho * gy
            params_constr[gy] =  [{"type":"ineq", "fun":abs(gy) - 1, "init":-10}] # abs value must be large enough.

        F = F.subs({Fl.fx:fx, Fl.fy:fy})
        return sp_simplify_with_timeout(F)
    
    @staticmethod
    def get_int_eng_res(deci_dict, params_constr):
        """
        Internal energy residual (of Temperature)
        """
        
        lhs = 0
        if deci_dict["is_isothermal"]: 
            # LHS=sqrt((dT/dt)**2 + (dT/dx)**2), RHS=0, ( T is constant at all time and sapce index).
            dT_dx = ts_grad(Fl.T, Fl.X)
            dT_dt = Fl.T.diff(Fl.t)
            res = sp.sqrt(Fl.dot(dT_dx, dT_dx) + dT_dt**2)
            return res/Fl.T_ref  # devide T_ref to balance residuals
        
        else: # nonisothermal, LHS: total gradient of T
            
            lhs = Fl.rho * Fl.DDt(Fl.T)

        # RHS:
        rhs = 0
        
        if deci_dict["is_thermal_conductive"]:
            # Fourier's law of conduction
            k = sp.symbols("k") # thermal conductivity
            rhs += Fl.div(k* Fl.grad(Fl.T)) 
            params_constr[k] = [{"type":"ineq", "fun":k-Fl.tol}] # k>0
        
        if deci_dict["is_pressure_work"]:
            beta = sp.symbols("beta") # bulk expansion coefficient
            Qp = beta * Fl.T * Fl.DDt(Fl.p) # Qp=beta*T*(dp/dt + v \dot grad(p))
            rhs += Qp
            params_constr[beta] = [{"type":"ineq", "fun":beta-Fl.tol}] # beta>0

        if deci_dict["is_viscosity_diffusion"]:
            tau = Fl.get_tau(deci_dict, params_constr)
            Qvd = Fl.ddot(tau, Fl.grad(Fl.V))  # Qvd= tau:grad(v)
            
            if deci_dict["is_turbulent"]:
                type_turbulent = switch(deci_dict["type_turbulent"]) 
                if type_turbulent(0) or type_turbulent(1): 
                    # k-epsilon / Realizable k-epsilon, diffusion of turbulent
                    Qvd += Fl.rho * Fl.turb2  # Qvd= tau:grad(v) + rho*eps
            rhs += Qvd
        
        if rhs != 0:
            Cp = sp.symbols("Cp") # heat capacity at constant pressure
            params_constr[Cp] = [{"type":"ineq", "fun":Cp-Fl.tol, "init":4200}] # Cp>0
            rhs = rhs / Cp
        
        return sp_simplify_with_timeout(lhs-rhs)/Fl.T_ref  # devide T_ref to balance residuals
    
    @staticmethod
    def get_turb1_res(deci_dict, params_constr):
        """ 
        Residual of equation of turb1
        for k-epsilon model: conserve(k) = ....
        """
        if not deci_dict["is_turbulent"]: 
            return 0
        
        type_turbulent = switch(deci_dict["type_turbulent"])  

        if type_turbulent(0) or type_turbulent(1): 
            # standard k-epsilon / Realizable k-epsilon, turb1 = k, turbulent kinetic energy
            # LHS: conserve(rho*k)
            lhs = Fl.rho * Fl.conserve(Fl.turb1)

            # RHS: 
            #variables
            mu_app = Fl.get_mu_app(deci_dict, params_constr)
            mu_T = Fl.get_mu_T(deci_dict, params_constr)
            P_k = Fl.get_P_k(deci_dict, params_constr)
            sigma_k = sp.symbols("sigma_k")
            params_constr[sigma_k] = [{"type":"ineq", "fun":sigma_k-Fl.tol}] #sigma_k>0

            rhs = Fl.div((mu_app + mu_T/sigma_k) * Fl.grad(Fl.turb1)) + P_k - Fl.rho*Fl.turb2

            return sp_simplify_with_timeout(lhs - rhs)
        
        
    @staticmethod
    def get_turb2_res(deci_dict, params_constr):
        """ 
        Residual of equation of turb2
        for k-epsilon model: conserve(eps) = ...
        """
        if not deci_dict["is_turbulent"]: 
            return 0
        
        type_turbulent = switch(deci_dict["type_turbulent"]) 

        if type_turbulent(0): 
            # standard k-epsilon model, turb2 = eps, turbulent kinetic energy dissipation rate.
            # LHS: conserve(rho*eps)
            lhs = Fl.rho * Fl.conserve(Fl.turb2)

            # RHS: 
            # variables
            mu_app = Fl.get_mu_app(deci_dict, params_constr)
            mu_T = Fl.get_mu_T(deci_dict, params_constr)
            P_k = Fl.get_P_k(deci_dict, params_constr)
            sigma_eps, C_eps1, C_eps2 = sp.symbols("sigma_epsilon, C_epsilon1, C_epsilon2")
            params_constr[sigma_eps] = [{"type":"ineq", "fun":sigma_eps-Fl.tol}] # sigma_eps>0
            params_constr[C_eps1] = [{"type":"ineq", "fun":C_eps1-Fl.tol}] # C_eps1>0
            params_constr[C_eps2] = [{"type":"ineq", "fun":C_eps2-Fl.tol}] # C_eps2>0

            rhs = Fl.div((mu_app + mu_T/sigma_eps) * Fl.grad(Fl.turb2)) \
                 + C_eps1 * Fl.turb2 / Fl.turb1 * P_k \
                 - C_eps2 * Fl.rho * Fl.turb2**2 / Fl.turb1

        elif type_turbulent(1): #Realizable k-epsilon
            # LHS: conserve(rho*eps)
            lhs = Fl.rho * Fl.conserve(Fl.turb2)

            # RHS coef1: C_1
            # C_1 = max(0.43, ita/(5+ita))
            # ita = s*k/eps
            # s = \sqrt(2* S:S)
            s = sp.sqrt(2 * Fl.ddot(Fl.S, Fl.S))
            ita =  s * Fl.turb1 / Fl.turb2
            C_1 = Fl.sp_maximum(0.43, ita/(5+ita))

            # RHS variables
            mu_app = Fl.get_mu_app(deci_dict, params_constr)
            mu_T = Fl.get_mu_T(deci_dict, params_constr)
            sigma_eps, C_eps2 = sp.symbols("sigma_epsilon, C_epsilon2")
            params_constr[sigma_eps] = [{"type":"ineq", "fun":sigma_eps-Fl.tol}] #sigma_eps>0
            params_constr[C_eps2] = [{"type":"ineq", "fun":C_eps2-Fl.tol}] #C_eps2>0
            nu = mu_app /Fl.rho

            # RHS
            rhs = Fl.div((mu_app + mu_T/sigma_eps) * Fl.grad(Fl.turb2)) \
                 + C_1 * Fl.rho * s * Fl.turb2 \
                 - C_eps2 * Fl.rho * Fl.turb2**2 / (Fl.turb1 + sp.sqrt(nu*Fl.turb2))

        return sp_simplify_with_timeout(lhs - rhs)
            
    @staticmethod
    def get_mu_T(deci_dict, params_constr):
        """eddy viscosity
        """
        if not deci_dict["is_turbulent"]: 
            return 0
        
        type_turbulent = switch(deci_dict["type_turbulent"]) 
        if type_turbulent(0): 
            # standard k-epsilon 
            # mu_t = rhp * C_mu * k**2 / eps,  C_mu is constant param.
            C_mu = sp.symbols("C_mu")
            mu_T = Fl.rho * C_mu * Fl.turb1**2 / Fl.turb2
            params_constr[C_mu] = [{"type":"ineq", "fun":C_mu-Fl.tol}] #C_mu>0
        elif type_turbulent(1): 
            # Realizable k-epsilon, Ref:https://doc.comsol.com/6.1/doc/com.comsol.help.cfd/cfd_ug_fluidflow_single.06.092.html
            #  mu_t = rhp * C_mu * k**2 / eps
            #  C_mu = 1/(A_0 + A_s*U_ast*k/epsilon)
            #  A_s = \sqrt(6) cos (1/3 * arccos(\sqrt(6)W))
            #  W = 2\sqrt(2)* S:(S \cdot S)/|S|**3         #Why 2\sqrt(2)?
            #  U_ast = \sqrt(S:S + Omega:Omega)
            #  Omega = 1/2 (grad(u)-grad(u)^T)
            Omega = sp.Rational(1,2) * (Fl.grad(Fl.V) - ts_trans(Fl.grad(Fl.V)))
            U_ast = sp.sqrt(Fl.ddot(Fl.S, Fl.S) + Fl.ddot(Omega, Omega))
            W = 2* sp.sqrt(2) * Fl.ddot(Fl.S, Fl.dot(Fl.S, Fl.S)) / Fl.norm(Fl.S)**3  #Why 2\sqrt(2)?
            A_s = sp.sqrt(6) * sp.cos(sp.Rational(1,3) * sp.acos(sp.sqrt(6)*W))
            A_0  = sp.symbols("A_0") 
            C_mu = 1/(A_0 + A_s*U_ast*Fl.turb1/Fl.turb2)
            mu_T = Fl.rho * C_mu * Fl.turb1**2 / Fl.turb2
            params_constr[A_0] = [{"type":"ineq", "fun":A_0-Fl.tol}] #A_0>0
            
        else:
            return 0
        return sp_simplify_with_timeout(mu_T)

    @staticmethod
    def get_P_k(deci_dict, params_constr):
        """Turbulent kinetic energy generation term
        """
        if not deci_dict["is_turbulent"]: 
            return 0
        
        type_turbulent = switch(deci_dict["type_turbulent"]) 
        if type_turbulent(0) or type_turbulent(1): # standard k-epsilon / Realizable k-epsilon
            # P_k = mu_T[grad(u):2S - 2/3*div(u)**2] - 2/3*rho*k*div(t)
            mu_T = Fl.get_mu_T(deci_dict, params_constr)
            P_k = mu_T * (Fl.ddot(Fl.grad(Fl.V), 2*Fl.S) - sp.Rational(2,3)*(Fl.div(Fl.V))**2)\
                    - sp.Rational(2,3) * Fl.rho * Fl.turb1 * Fl.div(Fl.V)
        else:
            return 0
        return sp_simplify_with_timeout(P_k)
        
    @staticmethod
    def gen_np_func(params_constr, sp_res_func_list, verbose=False):
        '''Convert sympy residual to numpy functions.
        '''
        # collect args
        args = list()
        # velocity
        args.extend([Fl.V]+ Fl.ts_grad_all(Fl.V))  # [V, grad_V, grad_grad_V, dt_V]

        # p, rho
        grad_p, _, dt_p = Fl.ts_grad_all(Fl.p)
        grad_rho, _, dt_rho = Fl.ts_grad_all(Fl.rho)
        args.extend([grad_p, dt_p, Fl.rho, grad_rho, dt_rho])

        # Temperature
        args.extend([Fl.T]+ Fl.ts_grad_all(Fl.T))

        # turbulent variables
        Turb = sp.Array([Fl.turb1, Fl.turb2])
        args.extend([Turb]+Fl.ts_grad_all(Turb))  # [Turb, grad_Turb, grad_grad_Turb, dt_Turb]


        # NOTE: d^2f/dydx equals d^2f/dxdy defaultly. Redundant
        args = sum([Fl.ts_1d_list(arg) for arg in args], [])

        # collect params
        params = list(params_constr.keys())  

        # conpile residuals to numpy func, via sympy.lambdify
        to_np_func = lambda sp_func: sp.lambdify([args,params], sp_func)
        res_func_list = []
        res_idx_list = [] # list of list(int), mapping from idx of sp_res_func_list to idxs of res_func_list
        idx = 0
        for sp_res in sp_res_func_list:
            if hasattr(sp_res, "shape"): # array variable
                res_func_list += [to_np_func(i) for i in sp_res]
                res_idx_list.append(list(range(idx, idx+len(sp_res))))
                idx += len(sp_res)
            else: # scalar variable
                res_func_list.append(to_np_func(sp_res))
                res_idx_list.append([idx,])
                idx += 1

        constr_dict_list = list()
        for pa, constr_dicts in params_constr.items():
            for cd in constr_dicts:
                c = {"type":cd["type"], "fun":sp.lambdify([params], cd["fun"]), "name":str(pa)}
                constr_dict_list.append(c)
                
        return res_func_list, constr_dict_list, res_idx_list

    @staticmethod
    def load_np_data(dataname_tuple, verbose=False):
        """
        load npz to args list.
        dataname_tuple = (problem, datasource)
        """
        U, grids = get_raw_data(*dataname_tuple)

        u, v, p, rho = U[...,0], U[...,1], U[...,2], U[...,3]
        if U.shape[-1] >= 5:
            T =  U[...,4]
        else:
            T = np.ones_like(p) * Fl.T_ref #constant T
        
        gt_params = [ ]

        # collet all args as a list of data tensor (nx, ny, nt).
        # NOTE:args here should be SYNC with args in Fl.gen_np_func() 
        args = []
        
        # velocity
        V = [u,v]
        args.extend([V,] + np_grad_all(V, grids))  # [V, grad_V, grad_grad_V, dt_V]  #len = 2+4+8+2=16

        # p, rho
        grad_p, _, dt_p = np_grad_all(p, grids)
        grad_rho, _, dt_rho = np_grad_all(rho, grids)  
        args.extend([grad_p, dt_p, [rho], grad_rho, dt_rho]) # len = 2+1+1+2+1 = 7

        # temperature
        args.extend([[T] ]+np_grad_all(T, grids))  # [T, grad_T, grad_grad_T, dt_T ]  # len=1+2+4+1=8

        #turbulent variables
        turb1, turb2 = np.ones_like(u), np.ones_like(u) # dummy 1s for non-terbulence. Zeros leads to DivZeroErr
        Turb = [turb1, turb2]
        args.extend([Turb,]+np_grad_all(Turb, grids))  # [Turb, grad_Turb, grad_grad_Turb, dt_Turb]  #len = 2+4+8+2 = 16
        
        
        args = sum(args, []) #concat all lists


        def pre_process(arr):
            # clip boundary
            n_clip = 5
            arr = arr[n_clip: -n_clip, n_clip: -n_clip, n_clip: -n_clip]
            
            return arr

        args = list(map(pre_process, args))
        return args, gt_params

    
    @staticmethod
    def test(deci_dict, dataname_tuple=('2d_comp_viscose_newton_ns',"COMSOL"), datafold_tuple=(0,1), verbose=False, 
             prev_sol_best={"fun":1e-3, "nit":5}, init_params=None, STR_iter_max=4, do_optimize=True):
        """ run test of a deci_dict, optimize params, output mse.
        Input:
            deci_dict: dict(deci, val), decision tree dict.
            dataname_tuple: str tuple, (problem , datasource).
            datafold_tuple: int tuple, (k-th fold, num of total folds).
            verbose: bool.
            prev_sol_best: dict. must have keys["fun", "nit"], used for early stop.
            init_params: np.array. initial guess of parameters.
            STR_iter_max: int, recursive depth limit of STRidge, only used in STRidge solver. 
            do_optimize: whether optimize for parameters, used for debugging.
        Output:
            sol: dict, all info of solution. 
        """
        #---load data---
        args, _ = Fl.load_np_data(dataname_tuple, verbose)

        #split args into train and valid, if k>1. 
        if datafold_tuple[1] > 1:
            k_th_fold, tot_folds = datafold_tuple
            kf = KFold(n_splits=tot_folds, shuffle=True, random_state=0)
            nt = args[0].shape[-2]
            train_idx, valid_idx = list(kf.split(np.arange(nt)))[k_th_fold]
            train_args = [arg[...,train_idx,:] for arg in args]
            valid_args = [arg[...,valid_idx,:] for arg in args]
        else:
            train_args, valid_args = args, args

        #--- define residuals---
        params_constr = dict()
        stridge_terms = dict()
        mass_res = Fl.get_mass_res(deci_dict, params_constr)
        mmt_res = Fl.get_mmt_res(deci_dict, params_constr, stridge_terms)
        int_eng_res = Fl.get_int_eng_res(deci_dict, params_constr)
        incomp_res = Fl.get_incomp_res(deci_dict, params_constr)
        turb1_res = Fl.get_turb1_res(deci_dict, params_constr)
        turb2_res = Fl.get_turb2_res(deci_dict, params_constr)

        sp_res_func_list = [mass_res, mmt_res, int_eng_res, incomp_res, turb1_res, turb2_res]
        res_name_list =    "mass_res, mmt_res, int_eng_res, incomp_res, turb1_res, turb2_res".split(', ')
        tot_count_ops = sum( [r.count_ops() for r in sp_res_func_list if hasattr(r, "count_ops")] )
        #print(f"{tot_count_ops=}") #\appx [100,300]

        #convert sp func to np func.
        res_func_list, constr_dict_list, res_idx_list = Fl.gen_np_func(params_constr, sp_res_func_list, verbose=verbose)     

        #---loss funcs---
        #hyper params
        pool_size = 5
        reg_scale = 1 # scaling coefficient of regularization. 1 for compNS,New,Heat_v2, 10 for Heat.
        if dataname_tuple[0] == '2d_heat_comp':# No need scale for heat_v2
            reg_scale = 10
        reg_coefs = reg_scale * np.array([1e-5, 1e-5, 1e-7])  # reg_coefs of [len(deci_dict), len(params), tot_count_ops]

        mse_list = lambda args, params: [(pooling(res(args, params), (pool_size, pool_size))**2).mean()  for res in res_func_list] 
        # pooling is applied to the first two axises of res, i.e. x and y axis
        
        mse_func = lambda args, params: sum(mse_list(args,params))
        reg_list = lambda params: [len(deci_dict), len(params), tot_count_ops]
        reg_name =               "[len(deci_dict), len(params), tot_count_ops]"
        reg_func = lambda params: reg_coefs.dot(np.array(reg_list(params)))
        train_loss_func = lambda params: mse_func(train_args, params) + reg_func(params)
        valid_loss_func = lambda params: mse_func(valid_args, params) + reg_func(params)

        # --- optimization---
        #initial guess 
        if len(params_constr)> 0:
            if init_params is None:
                init_params = np.random.rand(len(params_constr))
                for i, constr_list in enumerate(params_constr.values()):
                    for c in constr_list:
                        if "init" in c:
                            init_params[i] += c["init"]
        else:
            init_params = []
        
        is_STR_coefs = np.array([str(p).endswith("_STR_coef") for p in params_constr.keys()]) #boolean array
        if not do_optimize:
            sol = {"x":init_params, "fun":train_loss_func(init_params), "nit":1, "time":1, 'status':"Success" }
        elif is_STR_coefs.sum()==0 or STR_iter_max<=0: #ordinary solver
            sol = optimize_with_timeout(train_loss_func, init_params, constr_dict_list, dataname_tuple[0], prev_sol_best, verbose)
        else: # STRidge solver
            params_name = list(map(str, params_constr.keys()))
            #--- hyper-params---
            l2_reg_coef = 1e-2
            tol_w = 0.005 
            
            #--- norms ---
            #norm = np.ones_like(init_params)
            #for i, name in enumerate(params_name):
            #    if name in stridge_terms_norm:
            #        norm[i] = stridge_terms_norm[name]

            def pre_process(params):
                params = params.copy()
                params[is_STR_coefs] = 0.05*np.tanh(params[is_STR_coefs]) #bound coef in [-0.05, +0.05]
                #params[is_STR_coefs] /= norm[is_STR_coefs]
                return params
            
            STR_loss_func = lambda params: mse_func(train_args, pre_process(params)) + l2_reg_coef*(params[is_STR_coefs]**2).sum()

            if verbose:
                print(f"Start stridge_loop with {STR_iter_max=}", flush=1)
            
            sol = optimize_with_timeout(STR_loss_func, init_params, constr_dict_list, dataname_tuple[0], prev_sol_best, verbose)
            is_small_p = (pre_process(np.abs(sol['x'])) < tol_w) * is_STR_coefs
            # ---prints
            
            if verbose:
                print(f"{tol_w=}, {is_small_p=}")
            
                pprint.pprint(sol)
                pprint.pprint(dict(zip(params_name, pre_process(sol['x']))))
            
            # ---threshold small p
            next_init_params = sol['x'].copy()[~is_small_p]
            next_deci_dict = copy.deepcopy(deci_dict) # do not modify input deci_dict. 
            for i in range(len(is_small_p)):
                if is_small_p[i]:
                    if not "Deleted_STR_coef" in next_deci_dict:
                        next_deci_dict["Deleted_STR_coef"] = []
                    next_deci_dict["Deleted_STR_coef"].append(params_name[i])
            
            if verbose:
                print("next_deci_dict=")
                pprint.pprint( next_deci_dict)
                print("next_init_params=", next_init_params)
            # --- recursive call
            if STR_iter_max>0  and is_small_p.sum()>0 :
                return Fl.test(next_deci_dict, dataname_tuple, datafold_tuple, verbose, prev_sol_best, 
                               next_init_params, STR_iter_max-1)
            else: # convert back to normal loss
                sol['x'] = pre_process(sol['x'])
                sol['fun'] = train_loss_func(sol['x'])
                if verbose:
                    print("train_loss=",sol['fun'])
                    print("STR_coefs_name=",np.array(params_name)[is_STR_coefs].tolist())
                    print("STR_coefs(after pre_process)=",sol['x'][is_STR_coefs].tolist())    
        
        # ---return infos---
        ret_sol = dict()
        ret_sol['train_loss'] = sol['fun']
        ret_sol['valid_loss'] = valid_loss_func(sol['x'])
        ret_sol['deci_dict'] = deci_dict
        params_name = list(map(str, params_constr.keys())) #must return str (instead of sympy symbol), for pprint.
        ret_sol['params'] = dict(zip(params_name, sol['x']))
        ret_sol['time'] = sol['time']
        ret_sol['nit'] = sol['nit']
        ret_sol['status'] = sol['status']

        if ret_sol['status'] == "Success": # record detailed infos of train loss.
            losses = dict()
            mse_arr = np.array(mse_list(train_args, sol['x']))
            res_list = [mse_arr[idx].tolist() for idx in res_idx_list]
            losses.update(zip(res_name_list, res_list))  #dict{res: [res_values]}
            losses["tot_res"] = sum(list_cat(res_list))
            losses["reg_list"] = reg_list(sol['x'])
            losses["reg_name"] = reg_name
            losses["reg_coefs"] = reg_coefs
            ret_sol["losses"] = losses

        
        return ret_sol

    # DBNode denotes the the content of decision branch.
    # DBNode.name = the name of branch, DBNode.children = the names of children DBNodes.
    deci_info_dict = {
        "is_fluid": [DBNode('fluid')], #a place holder choice
        # --- turbulent flow ---
        "is_turbulent": [DBNode("non_turbulent", ["type_non_turbulent"]), 
                        DBNode("turbulent", ["type_turbulent"])],
        "type_non_turbulent":[DBNode("laminar"), DBNode("creeping")],
        "type_turbulent":[DBNode("k-eps"), DBNode("realizable-k-eps")],

        # --- constitutive equation---
        "is_newtonian": [DBNode("non_newtonian", ["type_non_newtonian"]), 
                        DBNode("newtonian", ["type_newtonian"])],
        "type_non_newtonian": [DBNode("powerlaw(mu_app)",["is_dilatant"]), 
                               DBNode("carreau"),
                               DBNode("new_non_newtonian_1", ["poly_order", "Fourier_order"]),
                               DBNode("new_non_newtonian_2"),],
        "poly_order":[DBNode("poly_order_0"), DBNode("poly_order_1"), DBNode("poly_order_2")],
        "Fourier_order":[DBNode("Fourier_order_0"), DBNode("Fourier_order_1"),DBNode("Fourier_order_2")],
        "is_dilatant": [DBNode("pseudoplastic"), DBNode("dilatant")],
        "type_newtonian": [DBNode("inviscid"), DBNode("newtonian")],
        

        # --- non_isothermal flow ---
        "is_isothermal": [
            DBNode("nonisothermal", ["type_mu_temperature","is_thermal_conductive","is_pressure_work","is_viscosity_diffusion"]),
            DBNode("isothermal")], 
        "type_mu_temperature":[DBNode("mu_temperature_independent"), DBNode("powerlaw(mu_T)"), DBNode("Sutherland"), DBNode("Andrade")],
        "is_thermal_conductive":[DBNode("not_thermal_conductive"), DBNode("Fourier")],
        "is_pressure_work":[DBNode("no_pressure_work"), DBNode("pressure_work")],
        "is_viscosity_diffusion":[DBNode("no_viscosity_diffusion"), DBNode("viscosity_diffusion")],
        #"type_heat_source":[radiation...]


        # --- other independent decisions ---
        "is_compressible":[DBNode("incompressible"), DBNode("compressible")],
        "type_body_force":[DBNode("no_body_force"), DBNode("gravity"), ],
        }
    
    # binary contradictions. dict(Name:NameSet) means Name contradicts with every name in NameSet.
    node_contra_dict = {
        "inviscid": {"viscosity_diffusion", "turbulent", "creeping"},
        "isothermal":{"powerlaw(mu-T)","Sutherland","Andrade"},
        "NA":{"non_newtonian", "newtonian"},  #a place holder
        "poly_order_0":{"Fourier_order_0"}
    }
    
    deci_range_dict = {name: list(range(len(decis))) for name, decis in deci_info_dict.items()}
    all_deci_set = OrderedSet(deci_info_dict.keys())
    dep_deci_set = OrderedSet(list_cat([n.children for n in list_cat(deci_info_dict.values())]))
    indep_deci_set = all_deci_set - dep_deci_set # independent nodes = all nodes execpt dependent nodes.
    parent_set = OrderedSet()
    #child_parent_dict = None 
    
    @staticmethod
    def init_nodes(dataset_name):
        if not '_new_' in dataset_name: #if not _new_ dataset, remove eqn_discovery module for speed.
            if 'fluid' in Fl.node_contra_dict:
                Fl.node_contra_dict['fluid'].add("new_non_newtonian_2")
            else: 
                Fl.node_contra_dict['fluid']= {"new_non_newtonian_2"}
        if not '_turb_' in dataset_name: #if not _turb_ dataset, remove turbulence module for speed.
            if 'fluid' in Fl.node_contra_dict:
                Fl.node_contra_dict['fluid'].add("turbulent")
            else: 
                Fl.node_contra_dict['fluid']= {"turbulent"}

    @staticmethod
    def is_phys_valid_deci(deci_dict):
        # all physics (DBNode) names
        dbnode_set = {Fl.deci_info_dict[deci][val].name for deci,val in deci_dict.items()}
        # binary contradictions
        for name, name_set in Fl.node_contra_dict.items():
            if name in dbnode_set and dbnode_set.intersection(name_set):
                return False
        return True
    
    @staticmethod
    def is_valid_deci(deci_dict):
        '''
        Check whether a deci_dict is valid.
        A deci_dict is valid, iff it satisfies both logic rules and physics rules.
        '''
        # ----Rules Type 1: Logic----
        # Rule 1.1: All deci_names and deci_values are in the range.
        for deci in deci_dict:
            if not (deci in Fl.deci_range_dict and deci_dict[deci] in Fl.deci_range_dict[deci]):
                return False
            
        # Rule1.2: all indep_deci are defined.
        for deci in Fl.indep_deci_set:
            if not deci in deci_dict:
                return False
        
        # Rule1.3: if a (deci,val) pair should have children, check that children are all defined.
        for deci, val in deci_dict.items():
            for ch_deci in Fl.deci_info_dict[deci][val].children:
                if not ch_deci in deci_dict:
                    return False

        # Rule 1.4: if a node has parent, check that the (parent,val) pair is defined. 
        if not hasattr(Fl, 'child_parent_dict'):
            Fl.child_parent_dict = dict()
            for parent,range_ in Fl.deci_range_dict.items():
                for val in range_:
                    for child in Fl.deci_info_dict[parent][val].children:
                        Fl.child_parent_dict[child] = (parent,val) 
        for deci in deci_dict:
            if deci in Fl.child_parent_dict:
                parent, val = Fl.child_parent_dict[deci]
                if not (parent in deci_dict and deci_dict[parent]==val):
                    return False

        # ----Rules Type 2: Physics----
        if not Fl.is_phys_valid_deci(deci_dict):
            return False
            
        return True
