import warp as wp
from warp_utils import *
import numpy as np
import math
import torch


# compute stress from F
@wp.func
def kirchoff_stress_FCR(
        F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, mu: float, lam: float
):
    # compute kirchoff stress for FCR model (remember tau = P F^T)
    R = U * wp.transpose(V)
    id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
    return 2.0 * mu * (F - R) * wp.transpose(F) + id * lam * J * (J - 1.0)


@wp.func
def kirchoff_stress_neoHookean(
        F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, sig: wp.vec3, mu: float, lam: float
):
    # compute kirchoff stress for FCR model (remember tau = P F^T)
    b = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])
    b_hat = b - wp.vec3(
        (b[0] + b[1] + b[2]) / 3.0,
        (b[0] + b[1] + b[2]) / 3.0,
        (b[0] + b[1] + b[2]) / 3.0,
    )
    tau = mu * J ** (-2.0 / 3.0) * b_hat + lam / 2.0 * (J * J - 1.0) * wp.vec3(
        1.0, 1.0, 1.0
    )
    return (
            U
            * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])
            * wp.transpose(V)
            * wp.transpose(F)
    )


@wp.func
def kirchoff_stress_StVK(
        F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float
):
    sig = wp.vec3(
        wp.max(sig[0], 0.01), wp.max(sig[1], 0.01), wp.max(sig[2], 0.01)
    )  # add this to prevent NaN in extrem cases
    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))
    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])
    ONE = wp.vec3(1.0, 1.0, 1.0)
    tau = 2.0 * mu * epsilon + lam * log_sig_sum * ONE
    return (
            U
            * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])
            * wp.transpose(V)
            * wp.transpose(F)
    )


@wp.func
def kirchoff_stress_drucker_prager(
        F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float
):
    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])
    center00 = 2.0 * mu * wp.log(sig[0]) * (1.0 / sig[0]) + lam * log_sig_sum * (
            1.0 / sig[0]
    )
    center11 = 2.0 * mu * wp.log(sig[1]) * (1.0 / sig[1]) + lam * log_sig_sum * (
            1.0 / sig[1]
    )
    center22 = 2.0 * mu * wp.log(sig[2]) * (1.0 / sig[2]) + lam * log_sig_sum * (
            1.0 / sig[2]
    )
    center = wp.mat33(center00, 0.0, 0.0, 0.0, center11, 0.0, 0.0, 0.0, center22)
    return U * center * wp.transpose(V) * wp.transpose(F)


@wp.func
def von_mises_return_mapping(F_trial: wp.mat33, model: MPMModelStruct, p: int):
    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    sig_old = wp.vec3(0.0)
    wp.svd3(F_trial, U, sig_old, V)

    sig = wp.vec3(
        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)
    )  # add this to prevent NaN in extrem cases
    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))
    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0

    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (
            epsilon[0] + epsilon[1] + epsilon[2]
    ) * wp.vec3(1.0, 1.0, 1.0)
    sum_tau = tau[0] + tau[1] + tau[2]
    cond = wp.vec3(
        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0
    )
    if wp.length(cond) > model.yield_stress[p]:
        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)
        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6
        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])
        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat
        sig_elastic = wp.mat33(
            wp.exp(epsilon[0]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon[1]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon[2]),
        )
        F_elastic = U * sig_elastic * wp.transpose(V)
        if model.hardening == 1:
            model.yield_stress[p] = (
                    model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma
            )
        return F_elastic
    else:
        return F_trial


@wp.func
def von_mises_return_mapping_with_damage(
        F_trial: wp.mat33, model: MPMModelStruct, p: int
):
    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    sig_old = wp.vec3(0.0)
    wp.svd3(F_trial, U, sig_old, V)

    sig = wp.vec3(
        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)
    )  # add this to prevent NaN in extrem cases
    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))
    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0

    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (
            epsilon[0] + epsilon[1] + epsilon[2]
    ) * wp.vec3(1.0, 1.0, 1.0)
    sum_tau = tau[0] + tau[1] + tau[2]
    cond = wp.vec3(
        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0
    )
    if wp.length(cond) > model.yield_stress[p]:
        if model.yield_stress[p] <= 0:
            return F_trial
        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)
        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6
        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])
        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat
        model.yield_stress[p] = model.yield_stress[p] - model.softening * wp.length(
            (delta_gamma / epsilon_hat_norm) * epsilon_hat
        )
        if model.yield_stress[p] <= 0:
            model.mu[p] = 0.0
            model.lam[p] = 0.0
        sig_elastic = wp.mat33(
            wp.exp(epsilon[0]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon[1]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon[2]),
        )
        F_elastic = U * sig_elastic * wp.transpose(V)
        if model.hardening == 1:
            model.yield_stress[p] = (
                    model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma
            )
        return F_elastic
    else:
        return F_trial


# for toothpaste
@wp.func
def viscoplasticity_return_mapping_with_StVK(
        F_trial: wp.mat33, model: MPMModelStruct, p: int, dt: float
):
    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    sig_old = wp.vec3(0.0)
    wp.svd3(F_trial, U, sig_old, V)

    sig = wp.vec3(
        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)
    )  # add this to prevent NaN in extrem cases
    b_trial = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])
    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))
    trace_epsilon = epsilon[0] + epsilon[1] + epsilon[2]
    epsilon_hat = epsilon - wp.vec3(
        trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0
    )
    s_trial = 2.0 * model.mu[p] * epsilon_hat
    s_trial_norm = wp.length(s_trial)
    y = s_trial_norm - wp.sqrt(2.0 / 3.0) * model.yield_stress[p]
    if y > 0:
        mu_hat = model.mu[p] * (b_trial[0] + b_trial[1] + b_trial[2]) / 3.0
        s_new_norm = s_trial_norm - y / (
                1.0 + model.plastic_viscosity / (2.0 * mu_hat * dt)
        )
        s_new = (s_new_norm / s_trial_norm) * s_trial
        epsilon_new = 1.0 / (2.0 * model.mu[p]) * s_new + wp.vec3(
            trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0
        )
        sig_elastic = wp.mat33(
            wp.exp(epsilon_new[0]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon_new[1]),
            0.0,
            0.0,
            0.0,
            wp.exp(epsilon_new[2]),
        )
        F_elastic = U * sig_elastic * wp.transpose(V)
        return F_elastic
    else:
        return F_trial


@wp.func
def sand_return_mapping(
        F_trial: wp.mat33, state: MPMStateStruct, model: MPMModelStruct, p: int
):
    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
    sig = wp.vec3(0.0)
    wp.svd3(F_trial, U, sig, V)

    epsilon = wp.vec3(
        wp.log(wp.max(wp.abs(sig[0]), 1e-14)),
        wp.log(wp.max(wp.abs(sig[1]), 1e-14)),
        wp.log(wp.max(wp.abs(sig[2]), 1e-14)),
    )
    sigma_out = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
    tr = epsilon[0] + epsilon[1] + epsilon[2]  # + state.particle_Jp[p]
    epsilon_hat = epsilon - wp.vec3(tr / 3.0, tr / 3.0, tr / 3.0)
    epsilon_hat_norm = wp.length(epsilon_hat)
    delta_gamma = (
            epsilon_hat_norm
            + (3.0 * model.lam[p] + 2.0 * model.mu[p])
            / (2.0 * model.mu[p])
            * tr
            * model.alpha
    )

    if delta_gamma <= 0:
        F_elastic = F_trial

    if delta_gamma > 0 and tr > 0:
        F_elastic = U * wp.transpose(V)

    if delta_gamma > 0 and tr <= 0:
        H = epsilon - epsilon_hat * (delta_gamma / epsilon_hat_norm)
        s_new = wp.vec3(wp.exp(H[0]), wp.exp(H[1]), wp.exp(H[2]))

        F_elastic = U * wp.diag(s_new) * wp.transpose(V)
    return F_elastic


@wp.kernel
def compute_mu_lam_from_E_nu(state: MPMStateStruct, model: MPMModelStruct):
    p = wp.tid()
    model.mu[p] = 1e7 * model.E[p] / (2.0 * (1.0 + model.nu[p]))
    model.lam[p] = (
            1e7 * model.E[p] * model.nu[p] / ((1.0 + model.nu[p]) * (1.0 - 2.0 * model.nu[p]))
    )


@wp.kernel
def zero_grid(state: MPMStateStruct, model: MPMModelStruct):
    grid_x, grid_y, grid_z = wp.tid()
    state.grid_m[grid_x, grid_y, grid_z] = 0.0
    state.grid_v_in[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)
    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)


@wp.func
def compute_dweight(
        model: MPMModelStruct, w: wp.mat33, dw: wp.mat33, i: int, j: int, k: int
):
    dweight = wp.vec3(
        dw[0, i] * w[1, j] * w[2, k],
        w[0, i] * dw[1, j] * w[2, k],
        w[0, i] * w[1, j] * dw[2, k],
    )
    return dweight * model.inv_dx


@wp.func
def update_cov(state: MPMStateStruct, p: int, grad_v: wp.mat33, dt: float):
    cov_n = wp.mat33(0.0)
    cov_n[0, 0] = state.particle_cov[p * 6]
    cov_n[0, 1] = state.particle_cov[p * 6 + 1]
    cov_n[0, 2] = state.particle_cov[p * 6 + 2]
    cov_n[1, 0] = state.particle_cov[p * 6 + 1]
    cov_n[1, 1] = state.particle_cov[p * 6 + 3]
    cov_n[1, 2] = state.particle_cov[p * 6 + 4]
    cov_n[2, 0] = state.particle_cov[p * 6 + 2]
    cov_n[2, 1] = state.particle_cov[p * 6 + 4]
    cov_n[2, 2] = state.particle_cov[p * 6 + 5]

    cov_np1 = cov_n + dt * (grad_v * cov_n + cov_n * wp.transpose(grad_v))

    state.particle_cov[p * 6] = cov_np1[0, 0]
    state.particle_cov[p * 6 + 1] = cov_np1[0, 1]
    state.particle_cov[p * 6 + 2] = cov_np1[0, 2]
    state.particle_cov[p * 6 + 3] = cov_np1[1, 1]
    state.particle_cov[p * 6 + 4] = cov_np1[1, 2]
    state.particle_cov[p * 6 + 5] = cov_np1[2, 2]


@wp.kernel
def p2g_apic_with_stress(state: MPMStateStruct, model: MPMModelStruct, dt: float):
    # input given to p2g:   particle_stress
    #                       particle_x
    #                       particle_v
    #                       particle_C
    p = wp.tid()
    if state.particle_selection[p] == 0:
        stress = state.particle_stress[p]
        grid_pos = state.particle_x[p] * model.inv_dx
        base_pos_x = wp.int(grid_pos[0] - 0.5)
        base_pos_y = wp.int(grid_pos[1] - 0.5)
        base_pos_z = wp.int(grid_pos[2] - 0.5)
        fx = grid_pos - wp.vec3(
            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)
        )
        wa = wp.vec3(1.5) - fx
        wb = fx - wp.vec3(1.0)
        wc = fx - wp.vec3(0.5)
        w = wp.mat33(
            wp.cw_mul(wa, wa) * 0.5,
            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),
            wp.cw_mul(wc, wc) * 0.5,
        )
        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))

        for i in range(0, 3):
            for j in range(0, 3):
                for k in range(0, 3):
                    dpos = (
                                   wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx
                           ) * model.dx
                    ix = base_pos_x + i
                    iy = base_pos_y + j
                    iz = base_pos_z + k
                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation
                    dweight = compute_dweight(model, w, dw, i, j, k)
                    C = state.particle_C[p]
                    # if model.rpic = 0, standard apic
                    C = (1.0 - model.rpic_damping) * C + model.rpic_damping / 2.0 * (
                            C - wp.transpose(C)
                    )
                    if model.rpic_damping < -0.001:
                        # standard pic
                        C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)

                    elastic_force = -state.particle_vol[p] * stress * dweight
                    v_in_add = (
                            weight
                            * state.particle_mass[p]
                            * (state.particle_v[p] + C * dpos)
                            + dt * elastic_force
                    )
                    wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)
                    wp.atomic_add(
                        state.grid_m, ix, iy, iz, weight * state.particle_mass[p]
                    )


# add gravity
@wp.kernel
def grid_normalization_and_gravity(
        state: MPMStateStruct, model: MPMModelStruct, dt: float
):
    grid_x, grid_y, grid_z = wp.tid()
    if state.grid_m[grid_x, grid_y, grid_z] > 1e-15:
        v_out = state.grid_v_in[grid_x, grid_y, grid_z] * (
                1.0 / state.grid_m[grid_x, grid_y, grid_z]
        )
        # add gravity
        v_out = v_out + dt * model.gravitational_accelaration
        state.grid_v_out[grid_x, grid_y, grid_z] = v_out


@wp.kernel
def g2p(state: MPMStateStruct, model: MPMModelStruct, dt: float):
    p = wp.tid()
    if state.particle_selection[p] == 0:
        grid_pos = state.particle_x[p] * model.inv_dx
        base_pos_x = wp.int(grid_pos[0] - 0.5)
        base_pos_y = wp.int(grid_pos[1] - 0.5)
        base_pos_z = wp.int(grid_pos[2] - 0.5)
        fx = grid_pos - wp.vec3(
            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)
        )
        wa = wp.vec3(1.5) - fx
        wb = fx - wp.vec3(1.0)
        wc = fx - wp.vec3(0.5)
        w = wp.mat33(
            wp.cw_mul(wa, wa) * 0.5,
            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),
            wp.cw_mul(wc, wc) * 0.5,
        )
        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))
        new_v = wp.vec3(0.0, 0.0, 0.0)
        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        for i in range(0, 3):
            for j in range(0, 3):
                for k in range(0, 3):
                    ix = base_pos_x + i
                    iy = base_pos_y + j
                    iz = base_pos_z + k
                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx
                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation
                    grid_v = state.grid_v_out[ix, iy, iz]
                    new_v = new_v + grid_v * weight
                    new_C = new_C + wp.outer(grid_v, dpos) * (
                            weight * model.inv_dx * 4.0
                    )
                    dweight = compute_dweight(model, w, dw, i, j, k)
                    new_F = new_F + wp.outer(grid_v, dweight)

        state.particle_v[p] = new_v
        state.particle_x[p] = state.particle_x[p] + dt * new_v
        state.particle_C[p] = new_C
        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
        F_tmp = (I33 + new_F * dt) * state.particle_F[p]
        state.particle_F_trial[p] = F_tmp

        if model.update_cov_with_F:
            update_cov(state, p, new_F, dt)


# compute (Kirchhoff) stress = stress(returnMap(F_trial))
@wp.kernel
def compute_stress_from_F_trial(
        state: MPMStateStruct, model: MPMModelStruct, dt: float
):
    p = wp.tid()
    if state.particle_selection[p] == 0:
        # apply return mapping
        if model.material == 1:  # metal
            state.particle_F[p] = von_mises_return_mapping(
                state.particle_F_trial[p], model, p
            )
        elif model.material == 2:  # sand
            state.particle_F[p] = sand_return_mapping(
                state.particle_F_trial[p], state, model, p
            )
        elif model.material == 3:  # visplas, with StVk+VM, no thickening
            state.particle_F[p] = viscoplasticity_return_mapping_with_StVK(
                state.particle_F_trial[p], model, p, dt
            )
        elif model.material == 5:
            state.particle_F[p] = von_mises_return_mapping_with_damage(
                state.particle_F_trial[p], model, p
            )
        else:  # elastic
            state.particle_F[p] = state.particle_F_trial[p]

        # also compute stress here
        J = wp.determinant(state.particle_F[p])
        U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        sig = wp.vec3(0.0)
        stress = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        wp.svd3(state.particle_F[p], U, sig, V)
        if model.material == 0 or model.material == 5:
            stress = kirchoff_stress_FCR(
                state.particle_F[p], U, V, J, model.mu[p], model.lam[p]
            )
        if model.material == 1:
            stress = kirchoff_stress_StVK(
                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]
            )
        if model.material == 2:
            stress = kirchoff_stress_drucker_prager(
                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]
            )
        if model.material == 3:
            # temporarily use stvk, subject to change
            stress = kirchoff_stress_StVK(
                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]
            )

        stress = (stress + wp.transpose(stress)) / 2.0  # enfore symmetry
        state.particle_stress[p] = stress


@wp.kernel
def compute_cov_from_F(state: MPMStateStruct, model: MPMModelStruct):
    p = wp.tid()

    F = state.particle_F_trial[p]
    state.particle_cov[p * 6] = F[0, 0] * (
                state.particle_init_cov[p * 6] * F[0, 0] + state.particle_init_cov[p * 6 + 1] * F[0, 1] +
                state.particle_init_cov[p * 6 + 2] * F[0, 2]) + \
                                F[0, 1] * (state.particle_init_cov[p * 6 + 1] * F[0, 0] + state.particle_init_cov[
        p * 6 + 3] * F[0, 1] + state.particle_init_cov[p * 6 + 4] * F[0, 2]) + \
                                F[0, 2] * (state.particle_init_cov[p * 6 + 2] * F[0, 0] + state.particle_init_cov[
        p * 6 + 4] * F[0, 1] + state.particle_init_cov[p * 6 + 5] * F[0, 2])
    state.particle_cov[p * 6 + 1] = F[1, 0] * (
                state.particle_init_cov[p * 6] * F[0, 0] + state.particle_init_cov[p * 6 + 1] * F[0, 1] +
                state.particle_init_cov[p * 6 + 2] * F[0, 2]) + \
                                    F[1, 1] * (state.particle_init_cov[p * 6 + 1] * F[0, 0] + state.particle_init_cov[
        p * 6 + 3] * F[0, 1] + state.particle_init_cov[p * 6 + 4] * F[0, 2]) + \
                                    F[1, 2] * (state.particle_init_cov[p * 6 + 2] * F[0, 0] + state.particle_init_cov[
        p * 6 + 4] * F[0, 1] + state.particle_init_cov[p * 6 + 5] * F[0, 2])
    state.particle_cov[p * 6 + 2] = F[2, 0] * (
                state.particle_init_cov[p * 6] * F[0, 0] + state.particle_init_cov[p * 6 + 1] * F[0, 1] +
                state.particle_init_cov[p * 6 + 2] * F[0, 2]) + \
                                    F[2, 1] * (state.particle_init_cov[p * 6 + 1] * F[0, 0] + state.particle_init_cov[
        p * 6 + 3] * F[0, 1] + state.particle_init_cov[p * 6 + 4] * F[0, 2]) + \
                                    F[2, 2] * (state.particle_init_cov[p * 6 + 2] * F[0, 0] + state.particle_init_cov[
        p * 6 + 4] * F[0, 1] + state.particle_init_cov[p * 6 + 5] * F[0, 2])
    state.particle_cov[p * 6 + 3] = F[1, 0] * (
                state.particle_init_cov[p * 6] * F[1, 0] + state.particle_init_cov[p * 6 + 1] * F[1, 1] +
                state.particle_init_cov[p * 6 + 2] * F[1, 2]) + \
                                    F[1, 1] * (state.particle_init_cov[p * 6 + 1] * F[1, 0] + state.particle_init_cov[
        p * 6 + 3] * F[1, 1] + state.particle_init_cov[p * 6 + 4] * F[1, 2]) + \
                                    F[1, 2] * (state.particle_init_cov[p * 6 + 2] * F[1, 0] + state.particle_init_cov[
        p * 6 + 4] * F[1, 1] + state.particle_init_cov[p * 6 + 5] * F[1, 2])
    state.particle_cov[p * 6 + 4] = F[2, 0] * (
                state.particle_init_cov[p * 6] * F[1, 0] + state.particle_init_cov[p * 6 + 1] * F[1, 1] +
                state.particle_init_cov[p * 6 + 2] * F[1, 2]) + \
                                    F[2, 1] * (state.particle_init_cov[p * 6 + 1] * F[1, 0] + state.particle_init_cov[
        p * 6 + 3] * F[1, 1] + state.particle_init_cov[p * 6 + 4] * F[1, 2]) + \
                                    F[2, 2] * (state.particle_init_cov[p * 6 + 2] * F[1, 0] + state.particle_init_cov[
        p * 6 + 4] * F[1, 1] + state.particle_init_cov[p * 6 + 5] * F[1, 2])
    state.particle_cov[p * 6 + 5] = F[2, 0] * (
                state.particle_init_cov[p * 6] * F[2, 0] + state.particle_init_cov[p * 6 + 1] * F[2, 1] +
                state.particle_init_cov[p * 6 + 2] * F[2, 2]) + \
                                    F[2, 1] * (state.particle_init_cov[p * 6 + 1] * F[2, 0] + state.particle_init_cov[
        p * 6 + 3] * F[2, 1] + state.particle_init_cov[p * 6 + 4] * F[2, 2]) + \
                                    F[2, 2] * (state.particle_init_cov[p * 6 + 2] * F[2, 0] + state.particle_init_cov[
        p * 6 + 4] * F[2, 1] + state.particle_init_cov[p * 6 + 5] * F[2, 2])


@wp.kernel
def compute_R_from_F(state: MPMStateStruct, model: MPMModelStruct):
    p = wp.tid()

    F = state.particle_F_trial[p]

    # polar svd decomposition
    U = wp.mat33(0.0)
    V = wp.mat33(0.0)
    sig = wp.vec3(0.0)
    wp.svd3(F, U, sig, V)

    if wp.determinant(U) < 0.0:
        U[0, 2] = -U[0, 2]
        U[1, 2] = -U[1, 2]
        U[2, 2] = -U[2, 2]

    if wp.determinant(V) < 0.0:
        V[0, 2] = -V[0, 2]
        V[1, 2] = -V[1, 2]
        V[2, 2] = -V[2, 2]

    # compute rotation matrix
    R = U * wp.transpose(V)
    state.particle_R[p] = wp.transpose(R)


@wp.kernel
def add_damping_via_grid(state: MPMStateStruct, scale: float):
    grid_x, grid_y, grid_z = wp.tid()
    state.grid_v_out[grid_x, grid_y, grid_z] = (
            state.grid_v_out[grid_x, grid_y, grid_z] * scale
    )


@wp.kernel
def apply_additional_params(
        state: MPMStateStruct,
        model: MPMModelStruct,
        params_modifier: MaterialParamsModifier,
):
    p = wp.tid()
    pos = state.particle_x[p]
    if (
            pos[0] > params_modifier.point[0] - params_modifier.size[0]
            and pos[0] < params_modifier.point[0] + params_modifier.size[0]
            and pos[1] > params_modifier.point[1] - params_modifier.size[1]
            and pos[1] < params_modifier.point[1] + params_modifier.size[1]
            and pos[2] > params_modifier.point[2] - params_modifier.size[2]
            and pos[2] < params_modifier.point[2] + params_modifier.size[2]
    ):
        model.E[p] = params_modifier.E
        model.nu[p] = params_modifier.nu
        state.particle_density[p] = params_modifier.density


@wp.kernel
def selection_add_impulse_on_particles(
        state: MPMStateStruct, impulse_modifier: Impulse_modifier
):
    p = wp.tid()
    offset = state.particle_x[p] - impulse_modifier.point
    if (
            wp.abs(offset[0]) < impulse_modifier.size[0]
            and wp.abs(offset[1]) < impulse_modifier.size[1]
            and wp.abs(offset[2]) < impulse_modifier.size[2]
    ):
        impulse_modifier.mask[p] = 1
    else:
        impulse_modifier.mask[p] = 0


@wp.kernel
def selection_enforce_particle_velocity_translation(
        state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier
):
    p = wp.tid()
    offset = state.particle_x[p] - velocity_modifier.point
    if (
            wp.abs(offset[0]) < velocity_modifier.size[0]
            and wp.abs(offset[1]) < velocity_modifier.size[1]
            and wp.abs(offset[2]) < velocity_modifier.size[2]
    ):
        velocity_modifier.mask[p] = 1
    else:
        velocity_modifier.mask[p] = 0


@wp.kernel
def selection_enforce_particle_velocity_cylinder(
        state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier
):
    p = wp.tid()
    offset = state.particle_x[p] - velocity_modifier.point

    vertical_distance = wp.abs(wp.dot(offset, velocity_modifier.normal))

    horizontal_distance = wp.length(
        offset - wp.dot(offset, velocity_modifier.normal) * velocity_modifier.normal
    )
    if (
            vertical_distance < velocity_modifier.half_height_and_radius[0]
            and horizontal_distance < velocity_modifier.half_height_and_radius[1]
    ):
        velocity_modifier.mask[p] = 1
    else:
        velocity_modifier.mask[p] = 0


@wp.kernel()
def sum_array(x: wp.array(dtype=float), y: wp.array(dtype=float), z: wp.array(dtype=float)):
    tid = wp.tid()
    wp.atomic_add(z, 0, x[tid] * y[tid])


@wp.kernel()
def sum_vec3(x: wp.array(dtype=wp.vec3), y: wp.array(dtype=wp.vec3), z: wp.array(dtype=float)):
    tid = wp.tid()
    wp.atomic_add(z, 0, x[tid][0] * y[tid][0] + x[tid][1] * y[tid][1] + x[tid][2] * y[tid][2])


@wp.kernel()
def sum_mat33(x: wp.array(dtype=wp.mat33), y: wp.array(dtype=wp.mat33), z: wp.array(dtype=float)):
    tid = wp.tid()
    wp.atomic_add(z, 0,
                  x[tid][0, 0] * y[tid][0, 0] + x[tid][0, 1] * y[tid][0, 1] + x[tid][0, 2] * y[tid][0, 2] + x[tid][
                      1, 0] * y[tid][1, 0] + x[tid][1, 1] * y[tid][1, 1] + x[tid][1, 2] * y[tid][1, 2] + x[tid][2, 0] *
                  y[tid][2, 0] + x[tid][2, 1] * y[tid][2, 1] + x[tid][2, 2] * y[tid][2, 2])


@wp.kernel()
def set_value(x: wp.array(dtype=float), y: wp.array(dtype=float)):
    tid = wp.tid()
    x[tid] = y[tid]

@wp.kernel()
def set_value_vec(x: wp.array(dtype=wp.vec3f), y: wp.array(dtype=wp.vec3f)):
    tid = wp.tid()
    x[tid] = y[tid]


@wp.kernel()
def set_value_E(E: wp.array(dtype=float), target: float, ee_mask: wp.array(dtype=int)):
    tid = wp.tid()
    if ee_mask[tid] == 0:
        E[tid] = target

@wp.kernel
def set_value_E_cloth_mask(E_cloth: wp.array(dtype=wp.vec3), target: wp.vec3, ee_mask: wp.array(dtype=int)):
    tid = wp.tid()
    if ee_mask[tid] == 0:
        E_cloth[tid] = target

@wp.kernel()
def set_value_friction(friction: wp.array(dtype=float), target: float):
    tid = wp.tid()
    friction[0] = target

@wp.kernel()
def set_value_friction_array(friction: wp.array(dtype=float), target: wp.array(dtype=float)):
    tid = wp.tid()
    friction[0] = target[0]

def wp_clamp(x, min, max):
    y = torch.clamp(wp.to_torch(x), min, max)
    wp.launch(kernel=set_value, dim=x.shape[0], inputs=[x, wp.from_torch(y)], device=x.device)


@wp.kernel()
def update_param_single(param: wp.array(dtype=float), grad: float, lr: float):
    tid = wp.tid()
    param[tid] = param[tid] - (lr * grad)

@wp.kernel()
def set_value(x: wp.array(dtype=float), y: wp.array(dtype=float)):
    tid = wp.tid()
    x[tid] = y[tid]

@wp.kernel()
def set_value_mask(x: wp.array(dtype=float), y: wp.array(dtype=float), ee_mask: wp.array(dtype=int)):
    tid = wp.tid()
    if ee_mask[tid] == 0:
        x[tid] = y[tid]

@wp.kernel()
def set_value_mask_logE(x: wp.array(dtype=float), y: wp.array(dtype=float), ee_mask: wp.array(dtype=int)):
    tid = wp.tid()
    if ee_mask[tid] == 0:
        x[tid] = wp.exp(y[tid])

@wp.kernel()
def set_value_mask_logE_single(E: wp.array(dtype=float), target: float, ee_mask: wp.array(dtype=int)):
    tid = wp.tid()
    if ee_mask[tid] == 0:
        E[tid] = wp.exp(target)

@wp.kernel
def set_E_cloth_from_log_stiffness(
    E_cloth: wp.array(dtype=wp.vec3),
    log_warp_stiffness: wp.array(dtype=float),
    log_weft_stiffness: wp.array(dtype=float),
    log_shear_stiffness: wp.array(dtype=float),
    ee_mask: wp.array(dtype=int)
):
    tid = wp.tid()
    if ee_mask[tid] == 0:
        warp_stiffness = wp.exp(log_warp_stiffness[tid])
        weft_stiffness = wp.exp(log_weft_stiffness[tid])
        shear_stiffness = wp.exp(log_shear_stiffness[tid])
        E_cloth[tid] = wp.vec3(warp_stiffness, weft_stiffness, shear_stiffness)

@wp.kernel
def set_E_cloth_from_log_stiffness_single(
    E_cloth: wp.array(dtype=wp.vec3),
    log_warp_stiffness: float,
    log_weft_stiffness: float,
    log_shear_stiffness: float,
    ee_mask: wp.array(dtype=int)
):
    tid = wp.tid()
    if ee_mask[tid] == 0:
        warp_stiffness = wp.exp(log_warp_stiffness)
        weft_stiffness = wp.exp(log_weft_stiffness)
        shear_stiffness = wp.exp(log_shear_stiffness)
        E_cloth[tid] = wp.vec3(warp_stiffness, weft_stiffness, shear_stiffness)


@wp.kernel()
def set_value_mask_density(x: wp.array(dtype=float), y: wp.array(dtype=float), ee_mask: wp.array(dtype=int)):
    tid = wp.tid()
    if ee_mask[tid] == 0:
        x[tid] = y[tid]

@wp.kernel()
def set_value_mask_logE_fps(x: wp.array(dtype=float),
                            y: wp.array(dtype=float),
                            ee_mask: wp.array(dtype=int),
                            nearest_indices: wp.array(dtype=int)):
    tid = wp.tid()
    if ee_mask[tid] == 0:
        x[tid] = wp.exp(y[nearest_indices[tid]])

@wp.kernel()
def update_param(param: wp.array(dtype=float), grad: wp.array(dtype=float), ee_mask: wp.array(dtype=wp.int32), lr: float):
    tid = wp.tid()
    if ee_mask[tid] == 1:
        return
    log_param = wp.log10(param[tid])
    log_param -= grad[tid] * lr
    # log_param = wp.clamp(log_param, -3., -0.4)
    log_param = wp.clamp(log_param, 2., 6.)
    param[tid] = wp.pow(10., log_param)

@wp.kernel()
def update_param_normal(param: wp.array(dtype=float), grad: wp.array(dtype=float), lr: float):
    tid = wp.tid()
    param_temp = param[tid]
    param_temp -= grad[tid] * lr
    param[tid] = param_temp

@wp.kernel()
def update_param_normal_global(param: wp.array(dtype=float), grad: float, lr: float):
    tid = wp.tid()
    param_temp = param[tid]
    param_temp -= grad * lr
    param[tid] = param_temp


@wp.kernel
def assign_particle_to_grid(
    pos: wp.array(dtype=wp.vec3),
    grid: wp.array(dtype=int),
    grid_n: int,
    grid_dx: float
):
    pi = wp.tid()
    p = pos[pi]
    i = int(wp.floor(p[0] / grid_dx))
    j = int(wp.floor(p[1] / grid_dx))
    k = int(wp.floor(p[2] / grid_dx))
    if 0 <= i < grid_n and 0 <= j < grid_n and 0 <= k < grid_n:
        index = i + j * grid_n + k * grid_n * grid_n
        wp.atomic_add(grid, index, 1)




@wp.kernel
def compute_particle_volume(
    pos: wp.array(dtype=wp.vec3),
    grid: wp.array(dtype=int),
    particle_vol: wp.array(dtype=float),
    grid_n: int,
    grid_dx: float
):
    pi = wp.tid()
    p = pos[pi]
    i = int(wp.floor(p[0] / grid_dx))
    j = int(wp.floor(p[1] / grid_dx))
    k = int(wp.floor(p[2] / grid_dx))
    if 0 <= i < grid_n and 0 <= j < grid_n and 0 <= k < grid_n:
        index = i + j * grid_n + k * grid_n * grid_n
        if grid[index] > 0:
            particle_vol[pi] = (grid_dx * grid_dx * grid_dx) / float(grid[index])

def get_particle_volume(pos, grid_n: int, grid_dx: float, uniform: bool = False,device = "cuda:0"):
    pos = pos.detach()

    pos_warp = wp.from_torch(pos.reshape(-1, 3), dtype=wp.vec3)
    grid = wp.zeros(shape=(grid_n * grid_n * grid_n), dtype=int, device=device)
    particle_vol = wp.zeros(shape=pos.shape[0], dtype=float, device=device)

    wp.launch(
        kernel=assign_particle_to_grid,
        dim=pos.shape[0],
        inputs=[pos_warp, grid, grid_n, grid_dx],
        device=device
    )
    wp.launch(
        kernel=compute_particle_volume,
        dim=pos.shape[0],
        inputs=[pos_warp, grid, particle_vol, grid_n, grid_dx],
        device=device
    )
    particle_vol_torch = wp.to_torch(particle_vol)

    if uniform:
        vol = torch.mean(particle_vol_torch).repeat(pos.shape[0])
        return vol
    else:
        return particle_vol_torch


