import warp as wp
from warp_utils import *
import numpy as np
import math


@wp.kernel
def set_leaf_and_branch(
    model: MPMModelStruct,
    state: MPMStateStruct,
    leaf_density: float,
    branch_density: float,
):
    p = wp.tid()
    material = model.material[p]
    if material == 8:
        state.particle_density[p] = leaf_density
    elif material == 9:
        state.particle_density[p] = branch_density

# # 判断粒子是枝叶粒子还是枝条粒子
# @wp.kernel
# def compute_particle_type(
#         state: MPMStateStruct,
#         model: MPMModelStruct,
# ):
#     p = wp.tid()
#     # 1.5上面是水粒子， 0.3下面是盆栽底部
#     if state.particle_x[p] <= 1.5 and state.particle_x[p] >= 0.3 and state.particle_selection[p] == 0:
#         if is_green(state.particle_v[p]):
#             model.material[p] = 10
#             state.particle_selection[p] = 1
#         else:
#             model.material[p] = 11
#             state.particle_selection[p] = 1


@wp.kernel
def update_particle_selection(state: MPMStateStruct):
    p = wp.tid()
    if state.particle_selection[p] != 0:
        state.particle_selection[p] = state.particle_selection[p] - 1

# compute deviatoric 计算偏应力
@wp.func
def deviatoric(input: wp.vec3):
    # 默认为三维的仿真 dim = 3
    sum = (input[0] + input[1] + input[2]) / 3.0
    output = wp.vec3(0.0, 0.0, 0.0)
    for i in range(3):
        output[i] = input[i] - sum
    return output

# compute stress from F
@wp.func
def kirchoff_stress_FCR(
    F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, mu: float, lam: float
):
    # compute kirchoff stress for FCR model (remember tau = P F^T)
    R = U * wp.transpose(V)
    id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
    return 2.0 * mu * (F - R) * wp.transpose(F) + id * lam * J * (J - 1.0)


# neohookean variant
@wp.func
def kirchoff_stress_neoHookean_borden(p: int, F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, mu: float, lam: float):
    id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
    k = 0.666666 * mu + lam
    dim = 3.0
    B = F * wp.transpose(F)
    B_trace = B[0, 0] + B[1, 1] + B[2, 2]
    devB = B - (1.0 / dim * B_trace) * id
    tau_dev = mu * pow(J, -2.0 / dim) * devB
    prime = k / 2.0 * (J - 1.0 / J)
    tau_vol = J * prime * id
    tau = tau_dev + tau_vol

    return tau


@wp.func
def kirchoff_stress_neoHookean(
    F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, sig: wp.vec3, mu: float, lam: float
):
    # compute kirchoff stress for FCR model (remember tau = P F^T)
    b = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])
    b_hat = b - wp.vec3(
        (b[0] + b[1] + b[2]) / 3.0,
        (b[0] + b[1] + b[2]) / 3.0,
        (b[0] + b[1] + b[2]) / 3.0,
    )
    tau = mu * J ** (-2.0 / 3.0) * b_hat + lam / 2.0 * (J * J - 1.0) * wp.vec3(
        1.0, 1.0, 1.0
    )
    return (
        U
        * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])
        * wp.transpose(V)
        * wp.transpose(F)
    )

@wp.func
def kirchoff_stress_water(
    J: float, model: MPMModelStruct, p: int
):
    # bulk = model.kappa[p]
    bulk = model.E[p]/(3.0*(1.0-2.0*model.nu[p]))
    gamma = 1.1 # gamma is set to be a liitle greater than 1 for weakly compressible fluids
    pressure = -bulk * (wp.pow(J, -gamma) - 1.)
    id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
    cauchy_stress = id * pressure
    return J * cauchy_stress

@wp.func
def kirchoff_stress_StVK_water(J: float, model: MPMModelStruct, p: int):
    E = model.E[p]
    id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
    return E * (J - 1.0) * id

@wp.func
def kirchoff_stress_StVK(
    F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float
):
    sig = wp.vec3(
        wp.max(sig[0], 0.01), wp.max(sig[1], 0.01), wp.max(sig[2], 0.01)
    )  # add this to prevent NaN in extrem cases
    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))
    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])
    ONE = wp.vec3(1.0, 1.0, 1.0)
    tau = 2.0 * mu * epsilon + lam * log_sig_sum * ONE
    return (
        U
        * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])
        * wp.transpose(V)
        * wp.transpose(F)
    )


@wp.func
def kirchoff_stress_drucker_prager(
    F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float
):
    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])
    center00 = 2.0 * mu * wp.log(sig[0]) * (1.0 / sig[0]) + lam * log_sig_sum * (
        1.0 / sig[0]
    )
    center11 = 2.0 * mu * wp.log(sig[1]) * (1.0 / sig[1]) + lam * log_sig_sum * (
        1.0 / sig[1]
    )
    center22 = 2.0 * mu * wp.log(sig[2]) * (1.0 / sig[2]) + lam * log_sig_sum * (
        1.0 / sig[2]
    )
    center = wp.mat33(center00, 0.0, 0.0, 0.0, center11, 0.0, 0.0, 0.0, center22)
    return U * center * wp.transpose(V) * wp.transpose(F)


# NACC 塑性模型 (只适用于三维的版本)
@wp.func
def NACC(
        F_trial: wp.mat33, model: MPMModelStruct, p: int
):
    # 初始化参数
    E = model.E[p]
    nu = model.nu[p]
    mu = model.mu[p]
    kappa = E / (3.0 * (1.0 - 2.0 * nu))
    logJp = model.logJp[p]
    kesai = model.kesai[p]
    beta = model.beta[p]

    dim = wp.float(3)
    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    sig_old = wp.vec3(0.0)
    wp.svd3(F_trial, U, sig_old, V)

    # return sig_old

    sig = wp.vec3(
        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)
    )  # add this to prevent NaN in extrem cases

    # 计算 P0
    alpha = logJp
    p0 = kappa * (0.00001 + wp.sinh(kesai * max(-alpha, 0.0)))

    a = ((1.0 + 2.0 * beta) * (6.0 - dim)) / 2.0
    b = beta * p0
    c = p0

    # compute M in the NACC formula
    # friction_angle = wp.float(45)   # 这里默认摩擦角45和cd mpm里设置的一样
    # sin_phi = wp.sin(friction_angle / 180.0 * 3.141592653)
    # mohr_columb_friction = wp.sqrt((2.0/3.0) * 2.0 * sin_phi / (3.0 - sin_phi))
    # M = mohr_columb_friction * dim / wp.sqrt(2.0 / (6.0 - dim))

    # 计算p_trial
    Je_tr = sig[0] * sig[1] * sig[2]
    # if Je_tr > model.maxJ:
    #     model.maxJ = Je_tr
        # wp.print(model.maxJ)
    # if wp.abs(Je_tr - 1.0) < 0.01:
    #     Je_tr = 1.0
    Uprime = kappa / 2.0 * (Je_tr - 1.0 / Je_tr)
    p_trial = -Uprime * Je_tr
    # wp.print(p_trial)

    # b_hat_trial = wp.vec3(0.0, 0.0, 0.0)
    # for i in range(wp.int(dim)):
    #     b_hat_trial[i] = sig[i] * sig[i]
    # s_hat_trial = mu * wp.pow(Je_tr, -2.0/dim) * deviatoric(b_hat_trial)

    # prime = kappa / 2.0 * (Je_tr - 1.0 / Je_tr)
    # p_trial = -prime * Je_tr

    # y_s_half_coeff = (6.0 - dim) / 2.0 * (1.0 + 2.0 * beta)
    # y_p_half = M * M * (p_trial + beta * p0) * (p_trial - p0)
    # y = y_s_half_coeff * (wp.pow(s_hat_trial[0], 2.0) + wp.pow(s_hat_trial[1], 2.0) + wp.pow(s_hat_trial[2], 2.0)) + y_p_half
    #
    # p_min = beta * p0
    # if p_trial > p0:
    #     Je_new = wp.sqrt(-2.0 * p0 / kappa + 1.0)
    #     sigma = wp.vec3(1.0, 1.0, 1.0) * wp.pow(Je_new, 1.0/dim)
    #     model.logJp[p] = model.logJp[p] + wp.log(Je_tr / Je_new)
    #     return sigma
    #
    # elif p_trial < -p_min:
    #     Je_new = wp.sqrt(2.0 * p_min / kappa + 1.0)
    #     sigma = wp.vec3(1.0, 1.0, 1.0) * wp.pow(Je_new, 1.0/dim)
    #     model.logJp[p] = model.logJp[p] + wp.log(Je_tr / Je_new)
    #     return sigma
    #
    # if y < 0.0001:
    #     return sig
    #
    # s_hat_trial_norm = wp.sqrt(wp.pow(s_hat_trial[0], 2.0) + wp.pow(s_hat_trial[1], 2.0) + wp.pow(s_hat_trial[2], 2.0))
    # b_hat_new = wp.pow(Je_tr, 2.0/dim) / mu * wp.sqrt(-y_p_half / y_s_half_coeff) * s_hat_trial / s_hat_trial_norm
    #
    # sigma = wp.vec3(0.0, 0.0, 0.0)
    # for i in range(wp.int(dim)):
    #     sigma[i] = wp.sqrt(b_hat_new[i])
    #
    # if p0 > 0.0001 and p_trial < p0 - 0.0001 and p_trial > 0.0001 - p_min:
    #     p_center = (p0 - p_min) * 0.5
    #     q_trial = wp.sqrt((6.0 - dim) / 2.0) * s_hat_trial_norm
    #     direction = wp.array([0.0, 0.0])
    #     direction[0] = p_center - p_trial
    #     direction[1] = 0.0 - q_trial
    #     direction_norm = wp.sqrt(wp.pow(direction[0], 2.0) + wp.pow(direction[1], 2.0))
    #     direction = direction / direction_norm
    #
    #     C = M * M * (p_center + beta * p0) * (p_center - p0)
    #     B = M * M * direction[0] * (2.0 * p_center - p0 + beta * p0)
    #     A = M * M * direction[0] * direction[0] + (1.0 + 2.0 * beta) * direction[1] * direction[1]
    #
    #     l1 = (-B + wp.sqrt(B * B - 4.0 * A * C)) / (2.0 * A)
    #     l2 = (-B - wp.sqrt(B * B - 4.0 * A * C)) / (2.0 * A)
    #
    #     p1 = p_center + l1 * direction[0]
    #     p2 = p_center + l2 * direction[0]
    #
    #     p_fake = p1 if (p_trial - p_center) * (p1 - p_center) > 0.0 else p2
    #     Je_new_fake = wp.sqrt(wp.abd(-2.0 * p_fake / kappa + 1.0))
    #     if Je_new_fake > 0.0001:
    #         model.logJp[p] = model.logJp[p] + wp.log(Je_tr / Je_newfake)
    # return sigma


    # compare with physbam
    # return sig_old
    # case 1:
    if p_trial > p0:
        # wp.print("p0: ", p0, " p_trial: ", p_trial, "/n")
        # wp.print("enter plastic")
        Je_new = wp.sqrt(-2.0 * c / kappa + 1.0)
        sig_new = wp.pow(Je_new, 1.0 / dim) * wp.vec3(1.0, 1.0, 1.0)
        # 硬化跟踪参数α更新
        model.logJp[p] = model.logJp[p] + wp.log(Je_tr / Je_new)
        return sig_new
    # case 2:
    if p_trial < -b:
        # wp.print("enter plastic")
        Je_new = wp.sqrt(2.0 * b / kappa + 1.0)
        sig_new = wp.pow(Je_new, 1.0 / dim) * wp.vec3(1.0, 1.0, 1.0)
        # 硬化跟踪参数α更新
        model.logJp[p] = model.logJp[p] + wp.log(Je_tr / Je_new)
        return sig_new

    b_hat_trial = wp.vec3(0.0, 0.0, 0.0)
    for i in range(wp.int(dim)):
        b_hat_trial[i] = sig[i] * sig[i]
    sig2 = b_hat_trial
    s_hat_trial = mu * wp.pow(Je_tr, -2.0 / dim) * deviatoric(b_hat_trial)

    # compute M in the NACC formula
    friction_angle = wp.float(45)   # 这里默认摩擦角45和cd mpm里设置的一样
    sin_phi = wp.sin(friction_angle / 180.0 * 3.141592653)
    mohr_columb_friction = wp.sqrt((2.0/3.0) * 2.0 * sin_phi / (3.0 - sin_phi))
    M = mohr_columb_friction * dim / wp.sqrt(2.0 / (6.0 - dim))
    M2 = M * M

    # compute yield surface 计算屈服面
    s_hat_trial_norm = wp.sqrt(wp.pow(s_hat_trial[0], 2.0) + wp.pow(s_hat_trial[1], 2.0) + wp.pow(s_hat_trial[2], 2.0))
    y = a * wp.pow(s_hat_trial_norm, 2.0) + M2 * (p_trial + b) * (p_trial - p0)

    if y < 0.0001:
        sig_new = sig_old
        return sig_new
    # case 3
    pc = (1.0 - beta) * p0 / 2.0
    if y > 0.0001 and p_trial < p0 - 0.0001 and p_trial > -beta * p0 + 0.0001:
        # wp.print(y)
        aa = M2 * wp.pow((p_trial - pc), 2.0) / (a * wp.pow(s_hat_trial_norm, 2.0))
        dd = 1.0 + aa
        ff = aa * beta * p0 - aa * p0 - 2.0 * pc
        gg = wp.pow(pc, 2.0) - aa * beta * wp.pow(p0, 2.0)
        zz = wp.sqrt(wp.abs(wp.pow(ff, 2.0) - 4.0 * dd * gg))
        p1 = (-ff + zz) / (2.0 * dd)
        p2 = (-ff - zz) / (2.0 * dd)
        # p_fake = p1 if (p_trial - pc) * (p1 - pc) > 0.0 else p2
        if (p_trial - pc) * (p1 - pc) > 0.0:
            p_fake = p1
        else:
            p_fake = p2

        Je_new_fake = wp.sqrt(wp.abs(-2.0 * p_fake / kappa + 1.0))
        # 硬化跟踪参数α更新
        if (Je_new_fake > 0.0001):
            model.logJp[p] = model.logJp[p] + wp.log(Je_tr / Je_new_fake)

        # 计算b_n+1
        k = wp.sqrt(-M2 * (p_trial + b) * (p_trial - c) / a)
        be_new = k / mu * wp.pow(Je_tr, 2.0/dim) * s_hat_trial / s_hat_trial_norm + 1.0 / dim * (sig2[0] + sig2[1] + sig2[2]) * wp.vec3(1.0, 1.0, 1.0)
        sig_new = wp.vec3(1.0, 1.0, 1.0)
        for i in range(wp.int(dim)):
            sig_new[i] = wp.sqrt(be_new[i])

        return sig_new
    return sig_old
    

@wp.func
def NACC_return_mapping(F_trial: wp.mat33, model: MPMModelStruct, p: int):
    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    sig_old = wp.vec3(0.0)

    wp.svd3(F_trial, U, sig_old, V)
    sig_new = NACC(F_trial, model, p)
    # if wp.exp(model.logJp[p]) >= 1:
    #     wp.print(wp.exp(model.logJp[p]))
    #     sig_elastic = wp.mat33(
    #     sig_old[0],
    #     0.0,
    #     0.0,
    #     0.0,
    #     sig_old[1],
    #     0.0,
    #     0.0,
    #     0.0,
    #     sig_old[2],
    # )
    #     return U * sig_elastic * wp.transpose(V)
    # else:
    sig_elastic = wp.mat33(
        sig_new[0],
        0.0,
        0.0,
        0.0,
        sig_new[1],
        0.0,
        0.0,
        0.0,
        sig_new[2],
    )
    return U * sig_elastic * wp.transpose(V)


@wp.func
def von_mises_return_mapping(F_trial: wp.mat33, model: MPMModelStruct, p: int):
    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    sig_old = wp.vec3(0.0)
    wp.svd3(F_trial, U, sig_old, V)

    sig = wp.vec3(
        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)
    )  # add this to prevent NaN in extrem cases
    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))
    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0

    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (
        epsilon[0] + epsilon[1] + epsilon[2]
    ) * wp.vec3(1.0, 1.0, 1.0)
    sum_tau = tau[0] + tau[1] + tau[2]
    cond = wp.vec3(
        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0
    )
    if wp.length(cond) > model.yield_stress[p]:
        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)
        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6
        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])
        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat
        sig_elastic = wp.mat33(
            wp.exp(epsilon[0]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon[1]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon[2]),
        )
        F_elastic = U * sig_elastic * wp.transpose(V)
        if model.hardening == 1:
            model.yield_stress[p] = (
                model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma
            )
        return F_elastic
    else:
        return F_trial


@wp.func
def von_mises_return_mapping_with_damage(
    F_trial: wp.mat33, model: MPMModelStruct, p: int
):
    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    sig_old = wp.vec3(0.0)
    wp.svd3(F_trial, U, sig_old, V)

    sig = wp.vec3(
        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)
    )  # add this to prevent NaN in extrem cases
    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))
    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0

    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (
        epsilon[0] + epsilon[1] + epsilon[2]
    ) * wp.vec3(1.0, 1.0, 1.0)
    sum_tau = tau[0] + tau[1] + tau[2]
    cond = wp.vec3(
        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0
    )
    # print(wp.length(cond))
    if wp.length(cond) > model.yield_stress[p]:
        # wp.print(wp.length(cond))
        if model.yield_stress[p] <= 0:
            # print("damaged !!!!!!!!!!!!")
            return F_trial
        # wp.print("into plastic /n")
        # print("not damaged !!!")
        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)
        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6
        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])
        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat
        model.yield_stress[p] = model.yield_stress[p] - model.softening * wp.length(
            (delta_gamma / epsilon_hat_norm) * epsilon_hat
        )
        if model.yield_stress[p] <= 0:
            model.mu[p] = 0.0
            model.lam[p] = 0.0
        sig_elastic = wp.mat33(
            wp.exp(epsilon[0]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon[1]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon[2]),
        )
        F_elastic = U * sig_elastic * wp.transpose(V)
        if model.hardening == 1:
            # print("not damaged !!!")
            model.yield_stress[p] = (
                model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma
            )
        return F_elastic
    else:
        # print("not qufu !!!")
        return F_trial


# for toothpaste
@wp.func
def viscoplasticity_return_mapping_with_StVK(
    F_trial: wp.mat33, model: MPMModelStruct, p: int, dt: float
):
    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    sig_old = wp.vec3(0.0)
    wp.svd3(F_trial, U, sig_old, V)

    sig = wp.vec3(
        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)
    )  # add this to prevent NaN in extrem cases
    b_trial = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])
    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))
    trace_epsilon = epsilon[0] + epsilon[1] + epsilon[2]
    epsilon_hat = epsilon - wp.vec3(
        trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0
    )
    s_trial = 2.0 * model.mu[p] * epsilon_hat
    s_trial_norm = wp.length(s_trial)
    y = s_trial_norm - wp.sqrt(2.0 / 3.0) * model.yield_stress[p]
    if y > 0:
        mu_hat = model.mu[p] * (b_trial[0] + b_trial[1] + b_trial[2]) / 3.0
        s_new_norm = s_trial_norm - y / (
            1.0 + model.plastic_viscosity / (2.0 * mu_hat * dt)
        )
        s_new = (s_new_norm / s_trial_norm) * s_trial
        epsilon_new = 1.0 / (2.0 * model.mu[p]) * s_new + wp.vec3(
            trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0
        )
        sig_elastic = wp.mat33(
            wp.exp(epsilon_new[0]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon_new[1]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon_new[2]),
        )
        F_elastic = U * sig_elastic * wp.transpose(V)
        return F_elastic
    else:
        return F_trial


@wp.func
def sand_return_mapping(
    F_trial: wp.mat33, state: MPMStateStruct, model: MPMModelStruct, p: int
):
    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    sig = wp.vec3(0.0)
    wp.svd3(F_trial, U, sig, V)

    epsilon = wp.vec3(
        wp.log(wp.max(wp.abs(sig[0]), 1e-14)),
        wp.log(wp.max(wp.abs(sig[1]), 1e-14)),
        wp.log(wp.max(wp.abs(sig[2]), 1e-14)),
    )
    sigma_out = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
    tr = epsilon[0] + epsilon[1] + epsilon[2]  # + state.particle_Jp[p]
    epsilon_hat = epsilon - wp.vec3(tr / 3.0, tr / 3.0, tr / 3.0)
    epsilon_hat_norm = wp.length(epsilon_hat)
    delta_gamma = (
        epsilon_hat_norm
        + (3.0 * model.lam[p] + 2.0 * model.mu[p])
        / (2.0 * model.mu[p])
        * tr
        * model.alpha
    )

    if delta_gamma <= 0:
        F_elastic = F_trial

    if delta_gamma > 0 and tr > 0:
        F_elastic = U * wp.transpose(V)

    if delta_gamma > 0 and tr <= 0:
        H = epsilon - epsilon_hat * (delta_gamma / epsilon_hat_norm)
        s_new = wp.vec3(wp.exp(H[0]), wp.exp(H[1]), wp.exp(H[2]))

        F_elastic = U * wp.diag(s_new) * wp.transpose(V)
    return F_elastic


@wp.kernel
def compute_mu_lam_from_E_nu(state: MPMStateStruct, model: MPMModelStruct):
    p = wp.tid()
    model.mu[p] = model.E[p] / (2.0 * (1.0 + model.nu[p]))
    model.lam[p] = (
        model.E[p] * model.nu[p] / ((1.0 + model.nu[p]) * (1.0 - 2.0 * model.nu[p]))
    )



@wp.kernel
def zero_grid(couple: CouplingSystem, state: MPMStateStruct, model: MPMModelStruct):
    grid_x, grid_y, grid_z = wp.tid()
    state.grid_m[grid_x, grid_y, grid_z] = 0.0
    state.grid_v_in[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)
    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)

    couple.grid_fluid_m[grid_x, grid_y, grid_z] = 0.0
    couple.grid_fluid_n[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)
    couple.grid_fluid_v[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)
    couple.grid_solid_m[grid_x, grid_y, grid_z] = 0.0
    couple.grid_solid_n[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)
    couple.grid_solid_v[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)

    model.grid_fluid_f[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)
    model.grid_solid_f[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)
    # state.grid_v_old[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)


@wp.func
def compute_dweight(
    model: MPMModelStruct, w: wp.mat33, dw: wp.mat33, i: int, j: int, k: int
):
    dweight = wp.vec3(
        dw[0, i] * w[1, j] * w[2, k],
        w[0, i] * dw[1, j] * w[2, k],
        w[0, i] * w[1, j] * dw[2, k],
    )
    return dweight * model.inv_dx


@wp.func
def update_cov(state: MPMStateStruct, p: int, grad_v: wp.mat33, dt: float):
    cov_n = wp.mat33(0.0)
    cov_n[0, 0] = state.particle_cov[p * 6]
    cov_n[0, 1] = state.particle_cov[p * 6 + 1]
    cov_n[0, 2] = state.particle_cov[p * 6 + 2]
    cov_n[1, 0] = state.particle_cov[p * 6 + 1]
    cov_n[1, 1] = state.particle_cov[p * 6 + 3]
    cov_n[1, 2] = state.particle_cov[p * 6 + 4]
    cov_n[2, 0] = state.particle_cov[p * 6 + 2]
    cov_n[2, 1] = state.particle_cov[p * 6 + 4]
    cov_n[2, 2] = state.particle_cov[p * 6 + 5]

    cov_np1 = cov_n + dt * (grad_v * cov_n + cov_n * wp.transpose(grad_v))

    state.particle_cov[p * 6] = cov_np1[0, 0]
    state.particle_cov[p * 6 + 1] = cov_np1[0, 1]
    state.particle_cov[p * 6 + 2] = cov_np1[0, 2]
    state.particle_cov[p * 6 + 3] = cov_np1[1, 1]
    state.particle_cov[p * 6 + 4] = cov_np1[1, 2]
    state.particle_cov[p * 6 + 5] = cov_np1[2, 2]


@wp.kernel
def p2g_apic_with_stress(state: MPMStateStruct, model: MPMModelStruct, dt: float):
    # input given to p2g:   particle_stress
    #                       particle_x
    #                       particle_v
    #                       particle_C
    p = wp.tid()
    # wp.print("1")
    if state.particle_selection[p] == 0:
        stress = state.particle_stress[p]
        grid_pos = state.particle_x[p] * model.inv_dx
        base_pos_x = wp.int(grid_pos[0] - 0.5)
        base_pos_y = wp.int(grid_pos[1] - 0.5)
        base_pos_z = wp.int(grid_pos[2] - 0.5)
        fx = grid_pos - wp.vec3(
            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)
        )
        wa = wp.vec3(1.5) - fx
        wb = fx - wp.vec3(1.0)
        wc = fx - wp.vec3(0.5)
        w = wp.mat33(
            wp.cw_mul(wa, wa) * 0.5,
            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),
            wp.cw_mul(wc, wc) * 0.5,
        )
        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))
        id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
        sigma_stress = state.particle_stress[p]
        stress = -dt * 4.0 * state.particle_vol[p] * sigma_stress / wp.pow(model.dx, 2.0)
        affine = stress * id + state.particle_mass[p] * state.particle_C[p]
        # wp.print("2")
        for i in range(0, 3):
            for j in range(0, 3):
                for k in range(0, 3):
                    dpos = (
                        wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx
                    ) * model.dx
                    ix = base_pos_x + i
                    iy = base_pos_y + j
                    iz = base_pos_z + k
                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation
                    dweight = compute_dweight(model, w, dw, i, j, k)
                    C = state.particle_C[p]
                    # if model.rpic = 0, standard apic
                    C = (1.0 - model.rpic_damping) * C + model.rpic_damping / 2.0 * (
                        C - wp.transpose(C)
                    )
                    if model.rpic_damping < -0.001:
                        # standard pic
                        C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)

                    elastic_force = -state.particle_vol[p] * stress * dweight
                    # elastic_force = -state.particle_vol[p] * stress * 4.0 / wp.pow(model.inv_dx,2.0) * weight * dpos

                    # p2g 网格vi_n 不是vi_(n+1)
                    v_old = weight * state.particle_mass[p] * (state.particle_v[p] + C * dpos)

                    # v_in_add = (
                    #     weight
                    #     * state.particle_mass[p]
                    #     * (state.particle_v[p] + C * dpos)
                    #     + dt * elastic_force
                    # )
                    # wp.atomic_add(state.grid_v_old, ix, iy, iz, v_old)  # add m*v not just v
                    # add m*v not just v; need to multiply 1/m in grid update
                    v_in_add = weight * (state.particle_mass[p] * state.particle_v[p] + affine @ dpos)
                    wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)
                    wp.atomic_add(
                        state.grid_m, ix, iy, iz, weight * state.particle_mass[p]
                    )
        # wp.print("3")


@wp.kernel
def p2g_apic_with_stress_water(couple: CouplingSystem, state: MPMStateStruct, model: MPMModelStruct, dt: float):
    # input given to p2g:   particle_stress
    #                       particle_x
    #                       particle_v
    #                       particle_C
    p = wp.tid()
    # wp.print("1")
    if state.particle_selection[p] == 0:
        sigma_stress = state.particle_stress[p]
        grid_pos = state.particle_x[p] * model.inv_dx
        base_pos_x = wp.int(grid_pos[0] - 0.5)
        base_pos_y = wp.int(grid_pos[1] - 0.5)
        base_pos_z = wp.int(grid_pos[2] - 0.5)
        fx = grid_pos - wp.vec3(
            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)
        )
        wa = wp.vec3(1.5) - fx
        wb = fx - wp.vec3(1.0)
        wc = fx - wp.vec3(0.5)
        w = wp.mat33(
            wp.cw_mul(wa, wa) * 0.5,
            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),
            wp.cw_mul(wc, wc) * 0.5,
        )
        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))

        id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
        stress = -dt * 4.0 * state.particle_vol[p] * sigma_stress / wp.pow(model.dx, 2.0)
        affine = stress * id + state.particle_mass[p] * state.particle_C[p]
        for i in range(0, 3):
            for j in range(0, 3):
                for k in range(0, 3):
                    dpos = (
                        wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx
                    ) * model.dx
                    ix = base_pos_x + i
                    iy = base_pos_y + j
                    iz = base_pos_z + k
                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation
                    dweight = compute_dweight(model, w, dw, i, j, k)
                    C = state.particle_C[p]
                    # if model.rpic = 0, standard apic
                    C = (1.0 - model.rpic_damping) * C + model.rpic_damping / 2.0 * (
                        C - wp.transpose(C)
                    )
                    if model.rpic_damping < -0.001:
                        # standard pic
                        C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)


                    # elastic_force = -state.particle_vol[p] * stress * 4.0 / wp.pow(model.inv_dx,2.0) * weight * dpos

                    # p2g 网格vi_n 不是vi_(n+1)
                    v_old = weight * state.particle_mass[p] * state.particle_v[p]
                    gridm = weight * state.particle_mass[p]

                    if model.material[p] == 7:  # water  now is solid
                        # not water
                        # elastic_force = -state.particle_vol[p] * sigma_stress * dweight
                        # v_in_add = (
                        #         weight
                        #         * state.particle_mass[p]
                        #         * (state.particle_v[p] + C * dpos)
                        #         + dt * elastic_force
                        # )
                        
                        # water
                        v_in_add = weight * (state.particle_mass[p] * state.particle_v[p] + affine @ dpos)

                        wp.atomic_add(
                            couple.grid_fluid_m, ix, iy, iz, gridm
                        )
                        #
                        wp.atomic_add(
                            couple.grid_fluid_v, ix, iy, iz, v_in_add
                        )
                        #
                        wp.atomic_add(
                            couple.grid_fluid_n, ix, iy, iz, state.particle_mass[p] * dweight
                        )

                        wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)

                        # state.grid_v_in[ix, iy, iz] = state.grid_v_in[ix, iy, iz] + v_in_add
                    else:  # solid
                        # v_in_add = weight * (state.particle_mass[p] * state.particle_v[p] + affine @ dpos)
                        elastic_force = -state.particle_vol[p] * sigma_stress * dweight
                        v_in_add = (
                                weight
                                * state.particle_mass[p]
                                * (state.particle_v[p] + C * dpos)
                                + dt * elastic_force
                        )
                        # couple.grid_solid_m[ix, iy, iz] = couple.grid_solid_m[ix, iy, iz] + gridm
                        # couple.grid_solid_p[ix, iy, iz] = couple.grid_solid_p[ix, iy, iz] + v_in_add
                        # couple.grid_solid_n[ix, iy, iz] = couple.grid_solid_n[ix, iy, iz] + state.particle_mass[
                        #     p] * dweight

                        wp.atomic_add(
                            couple.grid_solid_m, ix, iy, iz, gridm
                        )
                        wp.atomic_add(
                            couple.grid_solid_v, ix, iy, iz, v_in_add
                        )
                        wp.atomic_add(
                            couple.grid_solid_n, ix, iy, iz, state.particle_mass[p] * dweight
                        )
                        wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)

                        # state.grid_v_in[ix, iy, iz] = state.grid_v_in[ix, iy, iz] + v_in_add
                    # wp.atomic_add(state.grid_v_old, ix, iy, iz, v_old)  # add m*v not just v
                    # add m*v not just v; need to multiply 1/m in grid update

                    wp.atomic_add(
                        state.grid_m, ix, iy, iz, gridm
                    )
        # wp.print("3")

@wp.kernel
def p2g_apic_with_stress_2solid(couple: CouplingSystem, state: MPMStateStruct, model: MPMModelStruct, dt: float):
    # input given to p2g:   particle_stress
    #                       particle_x
    #                       particle_v
    #                       particle_C
    p = wp.tid()
    # wp.print("1")
    if state.particle_selection[p] == 0:
        sigma_stress = state.particle_stress[p]
        grid_pos = state.particle_x[p] * model.inv_dx
        base_pos_x = wp.int(grid_pos[0] - 0.5)
        base_pos_y = wp.int(grid_pos[1] - 0.5)
        base_pos_z = wp.int(grid_pos[2] - 0.5)
        fx = grid_pos - wp.vec3(
            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)
        )
        wa = wp.vec3(1.5) - fx
        wb = fx - wp.vec3(1.0)
        wc = fx - wp.vec3(0.5)
        w = wp.mat33(
            wp.cw_mul(wa, wa) * 0.5,
            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),
            wp.cw_mul(wc, wc) * 0.5,
        )
        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))

        id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
        stress = -dt * 4.0 * state.particle_vol[p] * sigma_stress / wp.pow(model.dx, 2.0)
        affine = stress * id + state.particle_mass[p] * state.particle_C[p]
        for i in range(0, 3):
            for j in range(0, 3):
                for k in range(0, 3):
                    dpos = (
                        wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx
                    ) * model.dx
                    ix = base_pos_x + i
                    iy = base_pos_y + j
                    iz = base_pos_z + k
                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation
                    dweight = compute_dweight(model, w, dw, i, j, k)
                    C = state.particle_C[p]
                    # if model.rpic = 0, standard apic
                    C = (1.0 - model.rpic_damping) * C + model.rpic_damping / 2.0 * (
                        C - wp.transpose(C)
                    )
                    if model.rpic_damping < -0.001:
                        # standard pic
                        C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)


                    # elastic_force = -state.particle_vol[p] * stress * 4.0 / wp.pow(model.inv_dx,2.0) * weight * dpos

                    # p2g 网格vi_n 不是vi_(n+1)
                    v_old = weight * state.particle_mass[p] * state.particle_v[p]
                    gridm = weight * state.particle_mass[p]

                    if model.material[p] == 0:  # jelly now is solid
                         # elastic_force = -state.particle_vol[p] * sigma_stress * dweight
                        # v_in_add = (
                        #         weight
                        #         * state.particle_mass[p]
                        #         * (state.particle_v[p] + C * dpos)
                        #         + dt * elastic_force
                        # )
                        v_in_add = weight * (state.particle_mass[p] * state.particle_v[p] + affine @ dpos)

                        wp.atomic_add(
                            couple.grid_fluid_m, ix, iy, iz, gridm
                        )
                        #
                        wp.atomic_add(
                            couple.grid_fluid_v, ix, iy, iz, v_in_add
                        )
                        #
                        wp.atomic_add(
                            couple.grid_fluid_n, ix, iy, iz, state.particle_mass[p] * dweight
                        )

                        wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)

                        # state.grid_v_in[ix, iy, iz] = state.grid_v_in[ix, iy, iz] + v_in_add
                    else:  # solid
                        v_in_add = weight * (state.particle_mass[p] * state.particle_v[p] + affine @ dpos)
                        # elastic_force = -state.particle_vol[p] * sigma_stress * dweight
                        # v_in_add = (
                        #         weight
                        #         * state.particle_mass[p]
                        #         * (state.particle_v[p] + C * dpos)
                        #         + dt * elastic_force
                        # )
                        # couple.grid_solid_m[ix, iy, iz] = couple.grid_solid_m[ix, iy, iz] + gridm
                        # couple.grid_solid_p[ix, iy, iz] = couple.grid_solid_p[ix, iy, iz] + v_in_add
                        # couple.grid_solid_n[ix, iy, iz] = couple.grid_solid_n[ix, iy, iz] + state.particle_mass[
                        #     p] * dweight

                        wp.atomic_add(
                            couple.grid_solid_m, ix, iy, iz, gridm
                        )
                        wp.atomic_add(
                            couple.grid_solid_v, ix, iy, iz, v_in_add
                        )
                        wp.atomic_add(
                            couple.grid_solid_n, ix, iy, iz, state.particle_mass[p] * dweight
                        )
                        wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)

                        # state.grid_v_in[ix, iy, iz] = state.grid_v_in[ix, iy, iz] + v_in_add
                    # wp.atomic_add(state.grid_v_old, ix, iy, iz, v_old)  # add m*v not just v
                    # add m*v not just v; need to multiply 1/m in grid update

                    wp.atomic_add(
                        state.grid_m, ix, iy, iz, gridm
                    )

@wp.kernel
def grid_normalization_and_gravity(
    state: MPMStateStruct, model: MPMModelStruct, dt: float
):
    grid_x, grid_y, grid_z = wp.tid()
    if state.grid_m[grid_x, grid_y, grid_z] > 1e-15:
        v_out = state.grid_v_in[grid_x, grid_y, grid_z] * (
            1.0 / state.grid_m[grid_x, grid_y, grid_z]
        )
        # v_old = state.grid_v_old[grid_x, grid_y, grid_z] * (
        #         1.0 / state.grid_m[grid_x, grid_y, grid_z]
        # )
        # add gravity
        v_out = v_out + dt * model.gravitational_accelaration
        state.grid_v_out[grid_x, grid_y, grid_z] = v_out


# add gravity
@wp.kernel
def grid_normalization_and_gravity_water(
    use_collision_force: int, couple: CouplingSystem, state: MPMStateStruct, model: MPMModelStruct, dt: float
):
    grid_x, grid_y, grid_z = wp.tid()
    if state.grid_m[grid_x, grid_y, grid_z] > 1e-15:
        v_out = state.grid_v_in[grid_x, grid_y, grid_z] * (
            1.0 / state.grid_m[grid_x, grid_y, grid_z]
        )
        # v_old = state.grid_v_old[grid_x, grid_y, grid_z] * (
        #         1.0 / state.grid_m[grid_x, grid_y, grid_z]
        # )
        # add gravity
        v_out = v_out + dt * model.gravitational_accelaration
        state.grid_v_out[grid_x, grid_y, grid_z] = v_out

        if couple.grid_fluid_m[grid_x, grid_y, grid_z] > 1e-15 and couple.grid_solid_m[grid_x, grid_y, grid_z] > 1e-15:  # 该网格点附近同时存在流体固体粒子
            # update gravity force on fluid and solid grid
            v_fluid_out = couple.grid_fluid_v[grid_x, grid_y, grid_z] * (
                        1.0 / couple.grid_fluid_m[grid_x, grid_y, grid_z])
            v_solid_out = couple.grid_solid_v[grid_x, grid_y, grid_z] * (
                        1.0 / couple.grid_solid_m[grid_x, grid_y, grid_z])
            v_fluid_out = v_fluid_out + dt * model.gravitational_accelaration
            v_solid_out = v_solid_out + dt * model.gravitational_accelaration

            couple.grid_fluid_v[grid_x, grid_y, grid_z] = v_fluid_out
            couple.grid_solid_v[grid_x, grid_y, grid_z] = v_solid_out


            if use_collision_force:

                # add collision force
                fluid_n = (couple.grid_fluid_n[grid_x, grid_y, grid_z] /
                            wp.sqrt(wp.pow(couple.grid_fluid_n[grid_x, grid_y, grid_z][0], 2.0)
                                    + wp.pow(couple.grid_fluid_n[grid_x, grid_y, grid_z][1], 2.0)
                                    + wp.pow(couple.grid_fluid_n[grid_x, grid_y, grid_z][2], 2.0)))
                solid_n = (couple.grid_solid_n[grid_x, grid_y, grid_z] /
                             wp.sqrt(wp.pow(couple.grid_solid_n[grid_x, grid_y, grid_z][0], 2.0)
                                    + wp.pow(couple.grid_solid_n[grid_x, grid_y, grid_z][1], 2.0)
                                    + wp.pow(couple.grid_solid_n[grid_x, grid_y, grid_z][2], 2.0)))
                delta_n = fluid_n - solid_n

                # compute collision surface normal
                normal_fluid = delta_n / wp.sqrt(wp.pow(delta_n[0], 2.0) + wp.pow(delta_n[1], 2.0) + wp.pow(delta_n[2], 2.0))
                normal_solid = -normal_fluid
                fluid_p = (v_fluid_out * couple.grid_fluid_m[grid_x, grid_y, grid_z])
                solid_p = (v_solid_out * couple.grid_solid_m[grid_x, grid_y, grid_z])
                penetration = wp.dot((v_fluid_out - v_solid_out), normal_fluid)

                v0 = wp.vec3(0.0, 0.0, 0.0)
                if penetration > 0:  # fluid and solid collision
                    # couple.beta = 1.0
                    f_collision = ((solid_p * couple.grid_solid_m[grid_x, grid_y, grid_z] -
                                fluid_p * couple.grid_fluid_m[grid_x, grid_y, grid_z])
                                    / ((couple.grid_fluid_m[grid_x, grid_y, grid_z] + couple.grid_solid_m[grid_x, grid_y, grid_z])*dt))
                    f_ap_fluid = couple.beta * wp.abs(wp.dot(f_collision, normal_solid)) * normal_solid
                    f_ap_solid = -f_ap_fluid
                    model.grid_fluid_f[grid_x, grid_y, grid_z] = f_ap_fluid
                    model.grid_solid_f[grid_x, grid_y, grid_z] = f_ap_solid
        elif couple.grid_fluid_m[grid_x, grid_y, grid_z] > 1e-15:  # 网格点附近只有流体粒子，只更新流体网格速度
            v_fluid_out = couple.grid_fluid_v[grid_x, grid_y, grid_z] * (
                    1.0 / couple.grid_fluid_m[grid_x, grid_y, grid_z])
            v_fluid_out = v_fluid_out + dt * model.gravitational_accelaration
            couple.grid_fluid_v[grid_x, grid_y, grid_z] = v_fluid_out
        else:  # 网格点附近只有固体粒子，只更新固体网格速度
            v_solid_out = couple.grid_solid_v[grid_x, grid_y, grid_z] * (
                    1.0 / couple.grid_solid_m[grid_x, grid_y, grid_z])
            v_solid_out = v_solid_out + dt * model.gravitational_accelaration
            couple.grid_solid_v[grid_x, grid_y, grid_z] = v_solid_out

        # couple.grid_fluid_v[grid_x, grid_y, grid_z] = state.grid_v_out[grid_x, grid_y, grid_z]
        # couple.grid_solid_v[grid_x, grid_y, grid_z] = state.grid_v_out[grid_x, grid_y, grid_z]

@wp.kernel
def g2p(state: MPMStateStruct, model: MPMModelStruct, dt: float):
    use_flip_pic_ratio = False
    flip_pic_ratio = 0.98

    p = wp.tid()
    if state.particle_selection[p] == 0:
        grid_pos = state.particle_x[p] * model.inv_dx
        base_pos_x = wp.int(grid_pos[0] - 0.5)
        base_pos_y = wp.int(grid_pos[1] - 0.5)
        base_pos_z = wp.int(grid_pos[2] - 0.5)
        fx = grid_pos - wp.vec3(
            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)
        )
        wa = wp.vec3(1.5) - fx
        wb = fx - wp.vec3(1.0)
        wc = fx - wp.vec3(0.5)
        w = wp.mat33(
            wp.cw_mul(wa, wa) * 0.5,
            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),
            wp.cw_mul(wc, wc) * 0.5,
        )
        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))
        new_v = wp.vec3(0.0, 0.0, 0.0)
        zero_v = wp.vec3(0.0, 0.0, 0.0)
        old_v = wp.vec3(0.0, 0.0, 0.0)
        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        for i in range(0, 3):
            for j in range(0, 3):
                for k in range(0, 3):
                    ix = base_pos_x + i
                    iy = base_pos_y + j
                    iz = base_pos_z + k
                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx
                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation
                    grid_v = state.grid_v_out[ix, iy, iz]
                    # grid_v_old = state.grid_v_old[ix, iy, iz]
                    new_v = new_v + grid_v * weight
                    # old_v = old_v + grid_v_old * weight
                    new_C = new_C + wp.outer(grid_v, dpos) * (
                            weight * model.inv_dx * 4.0
                    )
                    dweight = compute_dweight(model, w, dw, i, j, k)
                    new_F = new_F + wp.outer(grid_v, dweight)
        # if use_flip_pic_ratio:
        #     state.particle_v[p] = state.particle_v[p] * flip_pic_ratio
        #     delta_v = new_v - flip_pic_ratio * old_v
        #     state.particle_v[p] = state.particle_v[p] + delta_v
        # else:
        state.particle_v[p] = new_v
        # if model.particle_static[p] < 1.0:
        #     state.particle_v[p] = zero_v
        state.particle_x[p] = state.particle_x[p] + dt * new_v
        state.particle_C[p] = new_C
        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
        F_tmp = (I33 + new_C * dt) * state.particle_F[p]
        state.particle_F_trial[p] = F_tmp
        model.particle_J[p] = (1.0 + dt * wp.trace(new_C)) * model.particle_J[p]
        # if model.material[p] == 7:
        #     F_tmp = (I33 + new_C * dt) * state.particle_F[p]
        #     state.particle_F_trial[p] = F_tmp
        # else:
        #     F_tmp = (I33 + new_F * dt) * state.particle_F[p]
        #     state.particle_F_trial[p] = F_tmp

        if model.update_cov_with_F:
            update_cov(state, p, new_F, dt)


@wp.kernel  # 用new_C更新F F不自己更新
def g2p_water(couple: CouplingSystem, use_collision_force: int, state: MPMStateStruct, model: MPMModelStruct, dt: float):
    use_flip_pic_ratio = False
    flip_pic_ratio = 0.98

    p = wp.tid()
    if state.particle_selection[p] == 0:
        grid_pos = state.particle_x[p] * model.inv_dx
        base_pos_x = wp.int(grid_pos[0] - 0.5)
        base_pos_y = wp.int(grid_pos[1] - 0.5)
        base_pos_z = wp.int(grid_pos[2] - 0.5)
        fx = grid_pos - wp.vec3(
            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)
        )
        wa = wp.vec3(1.5) - fx
        wb = fx - wp.vec3(1.0)
        wc = fx - wp.vec3(0.5)
        w = wp.mat33(
            wp.cw_mul(wa, wa) * 0.5,
            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),
            wp.cw_mul(wc, wc) * 0.5,
        )
        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))
        new_v = wp.vec3(0.0, 0.0, 0.0)
        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        for i in range(0, 3):
            for j in range(0, 3):
                for k in range(0, 3):
                    ix = base_pos_x + i
                    iy = base_pos_y + j
                    iz = base_pos_z + k
                    dpos = (wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx) * model.dx
                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation
                    # grid_v = state.grid_v_out[ix, iy, iz]
                    grid_v_new = wp.vec3(0.0, 0.0, 0.0)
                    # grid_v_old = state.grid_v_old[ix, iy, iz]
                    # new_v = new_v + grid_v * weight
                    if model.material[p] == 7:
                        grid_v = couple.grid_fluid_v[ix, iy, iz]
                        grid_v_new = grid_v
                        # new_v = new_v + grid_v * weight
                        if use_collision_force and wp.abs(wp.max(model.grid_fluid_f[ix, iy, iz])) > 1e-10:
                            f_ap = model.grid_fluid_f[ix, iy, iz]
                            grid_v_new = grid_v + f_ap / couple.grid_fluid_m[ix, iy, iz] * dt
                            # new_v = new_v + f_ap * weight / couple.grid_fluid_m[ix, iy, iz] * dt
                    else:
                        grid_v = couple.grid_solid_v[ix, iy, iz]
                        grid_v_new = grid_v
                        # new_v = new_v + grid_v * weight
                        if use_collision_force and wp.abs(wp.max(model.grid_solid_f[ix, iy, iz])) > 1e-10:
                            f_ap = model.grid_solid_f[ix, iy, iz]
                            grid_v_new = grid_v + f_ap / couple.grid_solid_m[ix, iy, iz] * dt
                            # new_v = new_v + f_ap * weight / couple.grid_solid_m[ix, iy, iz] * dt
                    new_v = new_v + grid_v_new * weight
                    # new_v = new_v + f_ap * weight / state.particle_mass[p] * dt
                    
                    new_C = new_C + wp.outer(grid_v, dpos) * (
                            weight * 4.0 / wp.pow(model.dx, 2.0)
                    )
                    dweight = compute_dweight(model, w, dw, i, j, k)
                    new_F = new_F + wp.outer(grid_v, dweight)
        state.particle_v[p] = new_v
        # if model.particle_static[p] < 1.0:
        #     state.particle_v[p] = zero_v
        state.particle_x[p] = state.particle_x[p] + dt * new_v
        state.particle_C[p] = new_C
        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
        model.particle_J[p] = (1.0 + dt * wp.trace(new_C)) * model.particle_J[p]
        if model.material[p] == 7:
            F_tmp = (I33 + new_C * dt) * state.particle_F[p] # water
            # F_tmp = (I33 + new_F * dt) * state.particle_F[p]   # not water
        else:
            # F_tmp = (I33 + new_C * dt) * state.particle_F[p]
            F_tmp = (I33 + new_F * dt) * state.particle_F[p]
        state.particle_F_trial[p] = F_tmp
        if model.update_cov_with_F:
            update_cov(state, p, new_F, dt)


@wp.kernel  # 用new_C更新F F不自己更新
def g2p_2solid(couple: CouplingSystem, use_collision_force: int, state: MPMStateStruct, model: MPMModelStruct,
              dt: float):
    use_flip_pic_ratio = False
    flip_pic_ratio = 0.98

    p = wp.tid()
    if state.particle_selection[p] == 0:
        grid_pos = state.particle_x[p] * model.inv_dx
        base_pos_x = wp.int(grid_pos[0] - 0.5)
        base_pos_y = wp.int(grid_pos[1] - 0.5)
        base_pos_z = wp.int(grid_pos[2] - 0.5)
        fx = grid_pos - wp.vec3(
            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)
        )
        wa = wp.vec3(1.5) - fx
        wb = fx - wp.vec3(1.0)
        wc = fx - wp.vec3(0.5)
        w = wp.mat33(
            wp.cw_mul(wa, wa) * 0.5,
            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),
            wp.cw_mul(wc, wc) * 0.5,
        )
        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))
        new_v = wp.vec3(0.0, 0.0, 0.0)
        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        for i in range(0, 3):
            for j in range(0, 3):
                for k in range(0, 3):
                    ix = base_pos_x + i
                    iy = base_pos_y + j
                    iz = base_pos_z + k
                    dpos = (wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx) * model.dx
                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation
                    # grid_v = state.grid_v_out[ix, iy, iz]
                    grid_v_new = wp.vec3(0.0, 0.0, 0.0)
                    # grid_v_old = state.grid_v_old[ix, iy, iz]
                    # new_v = new_v + grid_v * weight
                    if model.material[p] == 0:
                        grid_v = couple.grid_fluid_v[ix, iy, iz]
                        grid_v_new = grid_v
                        # new_v = new_v + grid_v * weight
                        if use_collision_force and wp.abs(wp.max(model.grid_fluid_f[ix, iy, iz])) > 1e-10:
                            f_ap = model.grid_fluid_f[ix, iy, iz]
                            grid_v_new = grid_v + f_ap / couple.grid_fluid_m[ix, iy, iz] * dt
                            # new_v = new_v + f_ap * weight / couple.grid_fluid_m[ix, iy, iz] * dt
                    else:
                        grid_v = couple.grid_solid_v[ix, iy, iz]
                        grid_v_new = grid_v
                        # new_v = new_v + grid_v * weight
                        if use_collision_force and wp.abs(wp.max(model.grid_solid_f[ix, iy, iz])) > 1e-10:
                            f_ap = model.grid_solid_f[ix, iy, iz]
                            grid_v_new = grid_v + f_ap / couple.grid_solid_m[ix, iy, iz] * dt
                            # new_v = new_v + f_ap * weight / couple.grid_solid_m[ix, iy, iz] * dt
                    new_v = new_v + grid_v_new * weight
                    # new_v = new_v + f_ap * weight / state.particle_mass[p] * dt

                    new_C = new_C + wp.outer(grid_v, dpos) * (
                            weight * 4.0 / wp.pow(model.dx, 2.0)
                    )
                    dweight = compute_dweight(model, w, dw, i, j, k)
                    new_F = new_F + wp.outer(grid_v, dweight)
        state.particle_v[p] = new_v
        # if model.particle_static[p] < 1.0:
        #     state.particle_v[p] = zero_v
        state.particle_x[p] = state.particle_x[p] + dt * new_v
        state.particle_C[p] = new_C
        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
        model.particle_J[p] = (1.0 + dt * wp.trace(new_C)) * model.particle_J[p]
        if model.material[p] == 0:
            F_tmp = (I33 + new_C * dt) * state.particle_F[p]
            # F_tmp = (I33 + new_F * dt) * state.particle_F[p]
        else:
            F_tmp = (I33 + new_C * dt) * state.particle_F[p]
            # F_tmp = (I33 + new_F * dt) * state.particle_F[p]
        state.particle_F_trial[p] = F_tmp
        if model.update_cov_with_F:
            update_cov(state, p, new_F, dt)


# compute (Kirchhoff) stress = stress(returnMap(F_trial))
@wp.kernel
def compute_stress_from_F_trial(
    state: MPMStateStruct, model: MPMModelStruct, dt: float
):
    p = wp.tid()
    if state.particle_selection[p] == 0:
        # apply return mapping
        if model.material[p] == 1:  # metal
            # state.particle_F[p] = von_mises_return_mapping_with_damage(
            #     state.particle_F_trial[p], model, p
            # )
            state.particle_F[p] = von_mises_return_mapping(
                state.particle_F_trial[p], model, p
            )
        elif model.material[p] == 2:  # sand
            state.particle_F[p] = sand_return_mapping(
                state.particle_F_trial[p], state, model, p
            )
        elif model.material[p] == 3:  # visplas, with StVk+VM, no thickening
            state.particle_F[p] = viscoplasticity_return_mapping_with_StVK(
                state.particle_F_trial[p], model, p, dt
            )
        elif model.material[p] == 5:
            state.particle_F[p] = von_mises_return_mapping_with_damage(
                state.particle_F_trial[p], model, p
            )
        elif model.material[p] == 6:
            state.particle_F[p] = NACC_return_mapping(
                state.particle_F_trial[p], model, p
            )
        # elif model.material[p] == 7:
        #     state.particle_F[p] = NACC_return_mapping(
        #         state.particle_F_trial[p], model, p
        #     )
        else:  # elastic
            state.particle_F[p] = state.particle_F_trial[p]

        # also compute stress here
        J = wp.determinant(state.particle_F[p])
        U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        sig = wp.vec3(0.0)
        stress = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        wp.svd3(state.particle_F[p], U, sig, V)
        if model.material[p]== 0 or model.material[p] == 5:
            stress = kirchoff_stress_FCR(
                state.particle_F[p], U, V, J, model.mu[p], model.lam[p]
            )
        if model.material[p] == 1:
            stress = kirchoff_stress_FCR(
                state.particle_F[p], U, V, J, model.mu[p], model.lam[p]
            )
            # stress = kirchoff_stress_StVK(
            #     state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]
            # )
        if model.material[p] == 2:
            stress = kirchoff_stress_drucker_prager(
                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]
            )
        if model.material[p] == 3:
            # temporarily use stvk, subject to change
            stress = kirchoff_stress_StVK(
                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]
            )
        if model.material[p] == 6:
            stress = kirchoff_stress_neoHookean_borden(
                p, state.particle_F[p], U, V, J, model.mu[p], model.lam[p]
            )
            # stress = kirchoff_stress_water(
            #     J, model, p
            # )
            # stress = kirchoff_stress_neoHookean(
            #     state.particle_F[p], U, V, J, sig, model.mu[p], model.lam[p]
            # )
            # stress = kirchoff_stress_FCR(
            #     state.particle_F[p], U, V, J, model.mu[p], model.lam[p]
            # )
            # stress = kirchoff_stress_neoHookean_borden(
            #     p, state.particle_F[p], U, V, model.particle_J[p], model.mu[p], model.lam[p]
            # )

            # state.particle_tau[p] = stress
        if model.material[p] == 7:
            stress = kirchoff_stress_StVK_water(
                model.particle_J[p], model, p
            )
            stress = kirchoff_stress_neoHookean_borden(
                p, state.particle_F[p], U, V, J, model.mu[p], model.lam[p]
            )
            stress = kirchoff_stress_neoHookean_borden(
                p, state.particle_F[p], U, V, J, model.mu[p], model.lam[p]
            )
            stress = kirchoff_stress_water(
                J, model, p
            )
        stress = (stress + wp.transpose(stress)) / 2.0  # enfore symmetry
        state.particle_stress[p] = stress


@wp.kernel
def compute_cov_from_F(state: MPMStateStruct, model: MPMModelStruct):
    p = wp.tid()

    F = state.particle_F_trial[p]

    init_cov = wp.mat33(0.0)
    init_cov[0, 0] = state.particle_init_cov[p * 6]
    init_cov[0, 1] = state.particle_init_cov[p * 6 + 1]
    init_cov[0, 2] = state.particle_init_cov[p * 6 + 2]
    init_cov[1, 0] = state.particle_init_cov[p * 6 + 1]
    init_cov[1, 1] = state.particle_init_cov[p * 6 + 3]
    init_cov[1, 2] = state.particle_init_cov[p * 6 + 4]
    init_cov[2, 0] = state.particle_init_cov[p * 6 + 2]
    init_cov[2, 1] = state.particle_init_cov[p * 6 + 4]
    init_cov[2, 2] = state.particle_init_cov[p * 6 + 5]

    cov = F * init_cov * wp.transpose(F)

    state.particle_cov[p * 6] = cov[0, 0]
    state.particle_cov[p * 6 + 1] = cov[0, 1]
    state.particle_cov[p * 6 + 2] = cov[0, 2]
    state.particle_cov[p * 6 + 3] = cov[1, 1]
    state.particle_cov[p * 6 + 4] = cov[1, 2]
    state.particle_cov[p * 6 + 5] = cov[2, 2]


@wp.kernel
def compute_R_from_F(state: MPMStateStruct, model: MPMModelStruct):
    p = wp.tid()

    F = state.particle_F_trial[p]

    # polar svd decomposition
    U = wp.mat33(0.0)
    V = wp.mat33(0.0)
    sig = wp.vec3(0.0)
    wp.svd3(F, U, sig, V)

    if wp.determinant(U) < 0.0:
        U[0, 2] = -U[0, 2]
        U[1, 2] = -U[1, 2]
        U[2, 2] = -U[2, 2]

    if wp.determinant(V) < 0.0:
        V[0, 2] = -V[0, 2]
        V[1, 2] = -V[1, 2]
        V[2, 2] = -V[2, 2]

    # compute rotation matrix
    R = U * wp.transpose(V)
    state.particle_R[p] = wp.transpose(R)


@wp.kernel
def add_damping_via_grid(state: MPMStateStruct, scale: float):
    grid_x, grid_y, grid_z = wp.tid()
    state.grid_v_out[grid_x, grid_y, grid_z] = (
        state.grid_v_out[grid_x, grid_y, grid_z] * scale
    )

# @wp.kernel
# def reset_water_selection(
#     state: MPMStateStruct,
#     model: MPMModelStruct,
#     particle_selection: np.array,
# ):
#     p = wp.tid()
#     if model.material[p] == 7:
#         state.particle_selection[p] = 2


@wp.kernel
def apply_additional_params(
    state: MPMStateStruct,
    model: MPMModelStruct,
    params_modifier: MaterialParamsModifier,
):
    p = wp.tid()
    pos = state.particle_x[p]
    if (
        pos[0] > params_modifier.point[0] - params_modifier.size[0]
        and pos[0] < params_modifier.point[0] + params_modifier.size[0]
        and pos[1] > params_modifier.point[1] - params_modifier.size[1]
        and pos[1] < params_modifier.point[1] + params_modifier.size[1]
        and pos[2] > params_modifier.point[2] - params_modifier.size[2]
        and pos[2] < params_modifier.point[2] + params_modifier.size[2]
    ):
        model.material[p] = params_modifier.material
        model.E[p] = params_modifier.E
        model.nu[p] = params_modifier.nu
        model.logJp[p] = params_modifier.logJp
        model.beta[p] = params_modifier.beta
        model.kesai[p] = params_modifier.kesai
        state.particle_density[p] = params_modifier.density
        state.particle_v[p] = params_modifier.v
        # state.particle_vol[p] = params_modifier.volume


@wp.kernel
def apply_additional_params_with_vol(
    state: MPMStateStruct,
    model: MPMModelStruct,
    params_modifier: MaterialParamsModifier,
):
    p = wp.tid()
    pos = state.particle_x[p]
    if (
        pos[0] > params_modifier.point[0] - params_modifier.size[0]
        and pos[0] < params_modifier.point[0] + params_modifier.size[0]
        and pos[1] > params_modifier.point[1] - params_modifier.size[1]
        and pos[1] < params_modifier.point[1] + params_modifier.size[1]
        and pos[2] > params_modifier.point[2] - params_modifier.size[2]
        and pos[2] < params_modifier.point[2] + params_modifier.size[2]
    ):
        model.material[p] = params_modifier.material
        model.E[p] = params_modifier.E
        model.nu[p] = params_modifier.nu
        model.logJp[p] = params_modifier.logJp
        model.beta[p] = params_modifier.beta
        model.kesai[p] = params_modifier.kesai
        state.particle_density[p] = params_modifier.density
        state.particle_v[p] = params_modifier.v
        state.particle_vol[p] = params_modifier.volume

@wp.kernel
def selection_add_impulse_on_particles(
    state: MPMStateStruct, impulse_modifier: Impulse_modifier
):
    p = wp.tid()
    offset = state.particle_x[p] - impulse_modifier.point
    if (
        wp.abs(offset[0]) < impulse_modifier.size[0]
        and wp.abs(offset[1]) < impulse_modifier.size[1]
        and wp.abs(offset[2]) < impulse_modifier.size[2]
    ):
        impulse_modifier.mask[p] = 1
    else:
        impulse_modifier.mask[p] = 0


@wp.kernel
def selection_enforce_particle_velocity_translation(
    state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier
):
    p = wp.tid()
    offset = state.particle_x[p] - velocity_modifier.point
    if (
        wp.abs(offset[0]) < velocity_modifier.size[0]
        and wp.abs(offset[1]) < velocity_modifier.size[1]
        and wp.abs(offset[2]) < velocity_modifier.size[2]
    ):
        velocity_modifier.mask[p] = 1
    else:
        velocity_modifier.mask[p] = 0


@wp.kernel
def selection_enforce_particle_velocity_cylinder(
    state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier
):
    p = wp.tid()
    offset = state.particle_x[p] - velocity_modifier.point

    vertical_distance = wp.abs(wp.dot(offset, velocity_modifier.normal))

    horizontal_distance = wp.length(
        offset - wp.dot(offset, velocity_modifier.normal) * velocity_modifier.normal
    )
    if (
        vertical_distance < velocity_modifier.half_height_and_radius[0]
        and horizontal_distance < velocity_modifier.half_height_and_radius[1]
    ):
        velocity_modifier.mask[p] = 1
    else:
        velocity_modifier.mask[p] = 0
