import sys

sys.path.append("gaussian-splatting")
import random
import argparse
import math
import cv2
import torch
import os
import numpy as np
import json
from tqdm import tqdm
import itertools
import time

os.environ["CUDA_VISIBLE_DEVICES"] = '1'

from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt
import pyvista as pv

# Gaussian splatting dependencies
from utils.sh_utils import eval_sh
from scene.gaussian_model import GaussianModel
from diff_gaussian_rasterization import (
    GaussianRasterizationSettings,
    GaussianRasterizer,
)
from scene.cameras import Camera as GSCamera
from gaussian_renderer import render, GaussianModel
from utils.system_utils import searchForMaxIteration
from utils.graphics_utils import focal2fov

# MPM dependencies
from mpm_solver_warp.engine_utils import *
from mpm_solver_warp.mpm_solver_warp import MPM_Simulator_WARP
import warp as wp
# from mpm_utils import get_5_element, get_10_element

# Particle filling dependencies
from particle_filling.filling import *

# Utils
from utils.decode_param import *
from utils.transformation_utils import *
from utils.camera_view_utils import *
from utils.render_utils import *

wp.init()
wp.config.verify_cuda = True

ti.init(arch=ti.cuda, device_memory_GB=8.0)


# ti.init(arch=ti.cuda, device_memory_fraction=1.0)

def convert_one_gs_SH(pos_all, shs_all, rot):
    pos = torch.tensor([0.0, 0.0, 0.0]).to(device="cuda:0")
    index = 10000
    shs = shs_all
    # print(shs)
    camera_pos1 = torch.tensor([0.0, 5.0, 0.0]).to(device="cuda:0")
    camera_pos2 = torch.tensor([0.0, -5.0, 0.0]).to(device="cuda:0")
    camera_pos3 = torch.tensor([5.0, 0.0, 0.0]).to(device="cuda:0")
    camera_pos4 = torch.tensor([-5.0, 0.0, 0.0]).to(device="cuda:0")
    degree = 3
    color1 = convert_SH_wzy2(shs, camera_pos1, degree, pos)
    color2 = convert_SH_wzy2(shs, camera_pos2, degree, pos)
    color3 = convert_SH_wzy2(shs, camera_pos3, degree, pos)
    color4 = convert_SH_wzy2(shs, camera_pos4, degree, pos)
    color1 = torch.clamp(color1, 0.0, 1.0)
    color2 = torch.clamp(color2, 0.0, 1.0)
    color3 = torch.clamp(color3, 0.0, 1.0)
    color4 = torch.clamp(color4, 0.0, 1.0)
    print("index: ", index, " pos: ", pos_all[index])
    print("index: ", index, " view1: ", color1[index] * 255)
    print("index: ", index, " view2: ", color2[index] * 255)
    print("index: ", index, " view3: ", color3[index] * 255)
    print("index: ", index, " view4: ", color4[index] * 255)


class PipelineParamsNoparse:
    """Same as PipelineParams but without argument parser."""

    def __init__(self):
        self.convert_SHs_python = False
        self.compute_cov3D_python = False
        self.debug = False


def load_checkpoint(model_path, sh_degree=3, iteration=-1):
    # Find checkpoint
    checkpt_dir = os.path.join(model_path, "point_cloud")
    if iteration == -1:
        iteration = searchForMaxIteration(checkpt_dir)
    checkpt_path = os.path.join(
        checkpt_dir, f"iteration_{iteration}", "point_cloud.ply"
    )

    # Load guassians
    gaussians = GaussianModel(sh_degree)
    gaussians.load_ply(checkpt_path)
    return gaussians


def get_water_volume(particle_num, n_grid, grid_lim):
    dx = float(grid_lim / n_grid)
    volume = np.full(particle_num, (dx / 2.0) ** 3)
    print("water volume: ", volume[0])
    return volume


def get_pumpkin_volume(particle_num, n_grid, grid_lim):
    dx = float(grid_lim / n_grid)
    volume = np.full(particle_num, 6.4e-6)
    print("pumpkin volume: ", volume[0])
    return volume


def rgb_to_hsv(rgb):
    """ Convert RGB to HSV. Assumes input is in [0, 255] range. """
    r, g, b = rgb
    r, g, b = r / 255.0, g / 255.0, b / 255.0  # Normalize to [0, 1]
    max_val = max(r, g, b)
    min_val = min(r, g, b)
    delta = max_val - min_val
    # Hue calculation
    if delta == 0:
        h = 0
    elif max_val == r:
        h = (60 * ((g - b) / delta) + 360) % 360
    elif max_val == g:
        h = (60 * ((b - r) / delta) + 120) % 360
    else:
        h = (60 * ((r - g) / delta) + 240) % 360
    # Saturation calculation
    if max_val == 0:
        s = 0
    else:
        s = delta / max_val
    # Value calculation
    v = max_val
    return h, s, v


def is_green(rgb):
    # h, s, v = rgb_to_hsv(rgb)
    # # Define the color range for green in HSV
    # min_hue = 60  # Green starts at 60° in HSV space
    # max_hue = 180  # Green ends at 180° in HSV space
    # min_saturation = 0.4  # Minimum saturation to ensure it’s a vibrant color, not grayish
    # min_value = 0.2  # Minimum brightness (value) to avoid black/dark colors
    #
    # # Check if the color is within the green range
    # if min_hue <= h <= max_hue and s >= min_saturation and v >= min_value:
    #     return True
    # return False
    r, g, b = rgb
    # r = rgb[:, 0]
    # g = rgb[:, 1]
    # b = rgb[:, 2]
    margin = 0.5
    threshold = 50.0
    # is_green = (g > r + margin) & (g > b + margin) & (g > threshold)
    # return is_green
    return g > r + margin and g > b + margin and g > threshold


def is_grey(rgb):
    r, g, b = rgb
    threshold = 30
    # 条件1：红色最高，绿色次之，蓝色最低
    condition1 = (r > g) & (g > b)

    # 条件2：红绿差距不大，避免纯红
    condition2 = np.abs(r - g) < threshold

    # 条件3：蓝色明显较低
    condition3 = b < (r + g) / 2

    # 条件4：亮度范围限制
    brightness = (r + g + b) / 3
    condition4 = (brightness > 100) & (brightness < 200)
    return condition1 & condition2 & condition3 & condition4


def get_branch_bool(gs_path):
    points_pv = pv.read(gs_path)
    points = points_pv.points

    # DBSCAN 聚类
    db = DBSCAN(eps=0.007, min_samples=10).fit(points)
    labels = db.labels_

    class_z_ranges = {}
    for label in np.unique(labels):
        # 获取当前类的所有点
        class_points = points[labels == label]

        # 获取该类的所有 z 值
        z_values = class_points[:, 2]

        # 计算该类的 z 值范围
        z_range = z_values.max() - z_values.min()

        # 存储该类的 z 值范围
        class_z_ranges[label] = z_range
    # 对类的 z 值范围进行排序，获取第二大的 z 值范围
    sorted_classes = sorted(class_z_ranges.items(), key=lambda item: item[1], reverse=True)

    # 获取第二大的类
    second_max_z_class = sorted_classes[1][0]
    print(f"The class with the second maximum z-value range is: {second_max_z_class}")
    is_label1 = labels == second_max_z_class
    return is_label1


def get_vasedeck_bool(points):
    # DBSCAN 聚类
    db = DBSCAN(eps=0.025, min_samples=10).fit(points)
    labels = db.labels_

    is_label1 = labels == 0
    return is_label1


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, default=None)
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--output_ply", action="store_true")
    parser.add_argument("--output_h5", action="store_true")
    parser.add_argument("--render_img", action="store_true")
    parser.add_argument("--compile_video", action="store_true")
    parser.add_argument("--white_bg", action="store_true")
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    gs_path = os.path.join(args.model_path, "point_cloud", "iteration_60000", "point_cloud.ply")
    # gs_path = r"/home/swu/wzy/PhysGaussian/model/ficus_gic/point_cloud/iteration_40000/point_cloud.ply"
    # gs_path = r"/home/swu/wzy/PhysGaussian/model/ficus_gic_dizuo_zhongkong/point_cloud/iteration_1000/point_cloud.ply"
    # gs_path = r"/home/swu/wzy/PhysGaussian/model/ficus_gic_not_zhongkong_interpolated/point_cloud/iteration_1000/point_cloud.ply"
    # gs_path = r"/home/swu/wzy/PhysGaussian/model/ficus-bottom-surface/point_cloud/iteration_40000/point_cloud.ply"

    pumpkin_path = r"/home/swu/wzy/PhysGaussian/model/pumpkin"
    scale = [1]
    E = [1e6]
    nu = [0.39]
    # beta = [0.5, 0.8, 1.0, 1.5, 2.0]
    beta = [2]
    Jp = [0.94]
    kesai = [3]
    density = [5]

    maxlogJp = 0
    remove_big_logJp = False
    only_branch = False
    low_height = True
    ficus_fill = False
    # create_water_box_gaussian()
    remove_chair_points = False
    vasedeck_clip = False
    use_water = False
    additional_particle = False
    distinguish = False
    get_branch = True
    add_pumpkin = False
    remove_internal_points = False

    if not os.path.exists(args.model_path):
        raise AssertionError("Scene model does not exist!")
    if not os.path.exists(args.config):
        raise AssertionError("Scene config does not exist!")
    if args.output_path and not os.path.exists(args.output_path):
        os.makedirs(args.output_path, exist_ok=True)


    # load scene config
    print("Loading scene config...")
    (
        material_params,
        bc_params,
        time_params,
        preprocessing_params,
        camera_params,
        layer_num,
        layer_points
    ) = decode_param_json_water(args.config)

    # load gaussians
    print("Loading gaussians...")
    model_path = args.model_path
    gaussians = load_checkpoint(model_path)
    radius = gaussians.get_scaling.squeeze()
    pipeline = PipelineParamsNoparse()
    pipeline.compute_cov3D_python = True

    pipeline_water = PipelineParamsNoparse()
    pipeline_water.compute_cov3D_python = True

    background = (
        torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
        if args.white_bg
        else torch.tensor([0, 0, 0], dtype=torch.float32, device="cuda")
    )

    # init the scene
    print("Initializing scene and pre-processing...")
    # use_gic_gaussian = True
    # if use_gic_gaussian:
    # gaussians._scaling = gaussians._scaling.repeat(1, 3)

    # gaussians.save_ply("model/pumpkin/pumpkin-gs-pic.ply")
    params = load_params_from_gs(gaussians, pipeline)

    init_pos = params["pos"]
    init_cov = params["cov3D_precomp"]
    init_screen_points = params["screen_points"]
    init_opacity = params["opacity"]
    init_shs = params["shs"]
    # print("max opacity: ", max(init_opacity))
    # print("min opacity: ", min(init_opacity))

    solid_num = init_pos.shape[0]

    use_throw_opacity = False
    # throw away low opacity kernels
    if use_throw_opacity:
        mask = init_opacity[:, 0] > preprocessing_params["opacity_threshold"]
        init_pos = init_pos[mask, :]
        init_cov = init_cov[mask, :]
        init_opacity = init_opacity[mask, :]
        init_screen_points = init_screen_points[mask, :]
        init_shs = init_shs[mask, :]

    # rorate and translate object
    if args.debug:
        if not os.path.exists("./log"):
            os.makedirs("./log")
        particle_position_tensor_to_ply(
            init_pos,
            "./log/init_particles.ply",
        )
    rotation_matrices = generate_rotation_matrices(
        torch.tensor(preprocessing_params["rotation_degree"]),
        preprocessing_params["rotation_axis"],
    )
    rotated_pos = apply_rotations(init_pos, rotation_matrices)
    # print(rotated_pos[999])
    if args.debug:
        particle_position_tensor_to_ply(rotated_pos, "./log/rotated_particles.ply")

    # select a sim area and save params of unselected particles
    unselected_pos, unselected_cov, unselected_opacity, unselected_shs = (
        None,
        None,
        None,
        None,
    )

    if preprocessing_params["sim_area"] is not None:
        boundary = preprocessing_params["sim_area"]
        assert len(boundary) == 6
        mask = torch.ones(rotated_pos.shape[0], dtype=torch.bool).to(device="cuda")
        for i in range(3):
            mask = torch.logical_and(mask, rotated_pos[:, i] > boundary[2 * i])
            mask = torch.logical_and(mask, rotated_pos[:, i] < boundary[2 * i + 1])

        unselected_pos = init_pos[~mask, :]
        unselected_cov = init_cov[~mask, :]
        unselected_opacity = init_opacity[~mask, :]
        unselected_shs = init_shs[~mask, :]

        rotated_pos = rotated_pos[mask, :]
        init_cov = init_cov[mask, :]
        init_opacity = init_opacity[mask, :]
        init_shs = init_shs[mask, :]
        init_screen_points = init_screen_points[mask, :]

    transformed_pos, scale_origin, original_mean_pos = transform2origin(rotated_pos)
    # particle_position_tensor_to_ply(
    #     transformed_pos,
    #     "./log/2.13-ficus_normalized.ply",
    # )
    norm_pos = transformed_pos
    particle_type = np.zeros((norm_pos.shape[0]))
    if get_branch:
        print("begin cluster")
        branch_bool = get_branch_bool(gs_path)
        # leaf_and_branch = mpm_init_pos[mpm_init_pos[:, 2] > 0.3]
        # leaf_and_branch = leaf_and_branch[leaf_and_branch[:, 2] < 1.5]
        branch_bool = torch.from_numpy(branch_bool).to(device="cuda:0")
        if additional_particle:
            False_rows = torch.zeros(pumpkin_pos.shape[0], dtype=torch.bool).to(device="cuda:0")
            branch_bool = torch.cat([branch_bool, False_rows]).to(device="cuda:0")
        # leaf_and_branch_bool = (mpm_init_pos[:, 2] > 0.75) & (mpm_init_pos[:, 2] < 1.5)
        # branch_up_bool = leaf_and_branch_bool & branch_bool

        # leaf_bool = leaf_and_branch_bool ^ branch_up_bool
        # # leaf_bool = leaf_and_branch_bool ^ branch_bool
        # up_bool = (mpm_init_pos[:, 2] > 0.95) & (mpm_init_pos[:, 2] < 1.5)
        # leaf_up_bool = leaf_bool & up_bool
        # leaf_down_bool = leaf_bool ^ up_bool
        # branch_all_bool = branch_bool | leaf_down_bool
        leaf_and_branch_bool = (norm_pos[:, 2] > -0.256) & (norm_pos[:, 2] < 0.5)
        branch_up_bool = leaf_and_branch_bool & branch_bool
        leaf_bool = leaf_and_branch_bool ^ branch_up_bool
        # leaf_bool = leaf_and_branch_bool ^ branch_bool
        up_bool = (norm_pos[:, 2] > -0.056) & (norm_pos[:, 2] < 0.5)
        leaf_up_bool = leaf_bool & up_bool
        leaf_down_bool = leaf_bool ^ up_bool
        branch_all_bool = branch_bool | leaf_down_bool

        # leaf_bool = torch.any((mpm_init_pos[:, None, :] == leaf).all(dim=2), dim=1)
        # branch_bool = torch.any((mpm_init_pos[:, None, :] == branch).all(dim=2), dim=1)
        # print(leaf_bool)
        leaf_test = norm_pos[leaf_up_bool]
        branch_test = norm_pos[branch_all_bool]
        leaf_up_bool = leaf_up_bool.detach().cpu().numpy()
        branch_all_bool = branch_all_bool.detach().cpu().numpy()
        particle_type[leaf_up_bool] = 1
        particle_type[branch_all_bool] = 2
        type0_test = norm_pos[particle_type == 0]
        # particle_position_tensor_to_ply(
        #     leaf_test,
        #     "./log/2.13-gic_leaf_particle.ply",
        # )
        # particle_position_tensor_to_ply(
        #     branch_test,
        #     "./log/2.13-gic_branch_particle.ply",
        # )
        # particle_position_tensor_to_ply(
        #     type0_test,
        #     "./log/1.22_type0_gic_particle.ply",
        # )
        # print("done")

    # 移动到[1，1，1]
    # transformed_pos = shift2center111(transformed_pos)
    transformed_pos = shift2center1106(transformed_pos)  # 1 1 0.6
    # 移动到[2，2，1]
    # transformed_pos = shift2center221(transformed_pos)

    # modify covariance matrix accordingly
    init_cov = apply_cov_rotations(init_cov, rotation_matrices)
    init_cov = scale_origin * scale_origin * init_cov
    # print(init_cov)

    if args.debug:
        particle_position_tensor_to_ply(
            transformed_pos,
            "./log/transformed_particles.ply",
        )

    # fill particles if needed
    gs_num = transformed_pos.shape[0]
    device = "cuda:0"
    filling_params = preprocessing_params["particle_filling"]

    mpm_init_pos = transformed_pos.to(device=device)

    n_gaussians = mpm_init_pos.shape[0]
    print("n_gaussians = ", str(n_gaussians))

    # init the mpm solver
    print("Initializing MPM solver and setting up boundary conditions...")
    # volume compute
    # mpm_init_vol = ((4 / 3) * torch.pi * radius ** 3).to(device=device)
    mpm_init_vol = get_particle_volume(
        mpm_init_pos,
        material_params["n_grid"],
        material_params["grid_lim"] / material_params["n_grid"],
        unifrom=material_params["material"] == "sand",
    ).to(device=device)
    print("max volume: ", torch.max(mpm_init_vol))
    print("min volume: ", torch.min(mpm_init_vol))

    if filling_params is not None and filling_params["visualize"] == True:
        shs, opacity, mpm_init_cov = init_filled_particles(
            mpm_init_pos[:gs_num],
            init_shs,
            init_cov,
            init_opacity,
            mpm_init_pos[gs_num:],
        )
        gs_num = mpm_init_pos.shape[0]
    else:
        mpm_init_cov = torch.zeros((mpm_init_pos.shape[0], 6), device=device)
        mpm_init_cov[:gs_num] = init_cov
        shs = init_shs
        opacity = init_opacity

    if args.debug:
        print("check *.ply files to see if it's ready for simulation")

    print("gs_num: ", str(gs_num))

    particle_position_tensor_to_ply(
        mpm_init_pos,
        "./log/2.18-mpm_init_pos.ply",
    )

    # set up the mpm solver
    mpm_solver = MPM_Simulator_WARP(10)
    water_num = 0
    # if additional_particle:
    #     mpm_solver.add_initial_data_from_torch(water_pos, water_vol, water_cov)
    combinations = list(itertools.product(E, nu, beta, Jp, kesai, density))
    if only_branch:
        mpm_init_pos = mpm_init_pos[particle_type != 1]
        mpm_init_vol = mpm_init_vol[particle_type != 1]
        mpm_init_cov = mpm_init_cov[particle_type != 1]
        shs = shs[particle_type != 1]
        opacity = opacity[particle_type != 1]
        init_screen_points = init_screen_points[particle_type != 1]
        particle_type = particle_type[particle_type != 1]
        gs_num = mpm_init_pos.shape[0]
        print("gs_num: ", str(gs_num))
    for i, combo in enumerate(tqdm(combinations)):
        print(combo)
        result_path = args.output_path + "/branch_E_{}_nu_{}_beta_{}_Jp_{}_kesai_{}_density_{}/".format(
            combo[0], combo[1], combo[2], combo[3], combo[4], combo[5]
        )
        if not os.path.exists(result_path+"crack_bool"):
            os.makedirs(result_path+"crack_bool")
        params = dict()
        params["E"] = combo[0]
        params["nu"] = combo[1]
        params["beta"] = combo[2]
        params["Jp"] = combo[3]
        params["kesai"] = combo[4]
        params["density"] = combo[5]

        if not os.path.exists(result_path):
            os.makedirs(result_path)
        with open(result_path + "parameters.txt", "w") as f:
            f.write("E: {}\n".format(params["E"]))
            f.write("nu: {}\n".format(params["nu"]))
            f.write("beta: {}\n".format(params["beta"]))
            f.write("Jp: {}\n".format(params["Jp"]))
            f.write("kesai: {}\n".format(params["kesai"]))
            f.write("density: {}\n".format(params["density"]))
        mpm_solver.load_initial_data_from_torch(
            mpm_init_pos,
            mpm_init_vol,
            mpm_init_cov,
            n_grid=material_params["n_grid"],
            grid_lim=material_params["grid_lim"],
        )
        mpm_solver.set_parameters_dict(material_params)
        branch_dict = params
        leaf_params = dict()
        mpm_solver.set_ficus_type(particle_type, branch_dict, leaf_params)
        # mpm_solver.reset_phys_params(params)

        # mpm_solver.set_particle_type(particle_type)
        # branch_dict = dict()

        branch_type = mpm_init_pos[:, 2] > 0.255
        branch_type = branch_type.detach().cpu().numpy()
        # particle_type = np.zeros((mpm_init_pos.shape[0]))
        # particle_type[branch_type] = 2
        branch_pos = mpm_init_pos[branch_type]
        # particle_position_tensor_to_ply(
        #     mpm_init_pos,
        #     "./log/2.11-remove_leaf.ply",
        # )
        #
        # particle_position_tensor_to_ply(
        #     branch_pos,
        #     "./log/2.11-only-branch.ply",
        # )
        # print("done")i
        # mpm_solver.set_branch_type(particle_type, branch_dict)

        # if only_branch:
        #     mpm_solver.set_branch_type(particle_type, branch_dict)
        # mpm_solver.set_ficus_type(particle_type, branch_dict, leaf_params)
        # mpm_solver.set_ficus_type(particle_type, branch_dict, leaf_params)

        print("setting parameters complete")

        # mpm_solver.set_interpolation(preprocessing_params)

        # Note: boundary conditions may depend on mass, so the order cannot be changed!
        set_boundary_conditions(mpm_solver, bc_params, time_params)

        mpm_solver.finalize_mu_lam()

        # camera setting
        mpm_space_viewpoint_center = (
            torch.tensor(camera_params["mpm_space_viewpoint_center"]).reshape((1, 3)).cuda()
        )
        mpm_space_vertical_upward_axis = (
            torch.tensor(camera_params["mpm_space_vertical_upward_axis"])
            .reshape((1, 3))
            .cuda()
        )
        (
            viewpoint_center_worldspace,
            observant_coordinates,
        ) = get_center_view_worldspace_and_observant_coordinate(
            mpm_space_viewpoint_center,
            mpm_space_vertical_upward_axis,
            rotation_matrices,
            scale_origin,
            original_mean_pos,
        )

        # run the simulation
        if args.output_ply or args.output_h5:
            directory_to_save = os.path.join(result_path, "simulation_ply")
            if not os.path.exists(directory_to_save):
                os.makedirs(directory_to_save)
            save_data_at_frame(
                mpm_solver,
                directory_to_save,
                0,
                save_to_ply=args.output_ply,
                save_to_h5=args.output_h5,
            )

        substep_dt = time_params["substep_dt"]
        frame_dt = time_params["frame_dt"]
        frame_num = time_params["frame_num"]
        step_per_frame = int(frame_dt / substep_dt)
        opacity_render = opacity
        shs_render = shs
        height = None
        width = None
        crack_bool = None
        for frame in tqdm(range(frame_num)):
            current_camera = get_camera_view(
                model_path,
                default_camera_index=camera_params["default_camera_index"],
                center_view_world_space=viewpoint_center_worldspace,
                observant_coordinates=observant_coordinates,
                show_hint=camera_params["show_hint"],
                init_azimuthm=camera_params["init_azimuthm"],
                init_elevation=camera_params["init_elevation"],
                init_radius=camera_params["init_radius"],
                move_camera=camera_params["move_camera"],
                current_frame=frame,
                delta_a=camera_params["delta_a"],
                delta_e=camera_params["delta_e"],
                delta_r=camera_params["delta_r"],
            )
            rasterize = initialize_resterize(
                current_camera, gaussians, pipeline, background
            )

            # mpm_solver.get_5_and_10_mu()

            # mpm_solver.get_5_and_10_kappa()
            # print("Begin p2g2p")
            try:
                mass = mpm_solver.mpm_state.particle_mass.numpy()
                v = mpm_solver.mpm_state.particle_v.numpy()
                E = mpm_solver.mpm_model.E.numpy()
                density = mpm_solver.mpm_state.particle_density.numpy()
                beta = mpm_solver.mpm_model.beta.numpy()
                kesai = mpm_solver.mpm_model.kesai.numpy()
                logJp = mpm_solver.mpm_model.logJp.numpy()

                print("E: ", E[particle_type == 2])
                print("density: ", density[particle_type == 2])
                print("beta: ", beta[particle_type == 2])
                print("kesai: ", kesai[particle_type == 2])
                print("logJp: ", logJp[particle_type == 2])
                # print("solid mass: ", mass[100])
                # print("v0: ", v[0])
                for step in range(step_per_frame):
                    # mpm_solver.get_5_and_10_v()
                    mpm_solver.p2g2p(frame, step, substep_dt, device=device)
                    # mpm_solver.get_5_and_10_F()

                # output ply including gs and interpolation
                if args.output_ply or args.output_h5:
                    save_data_at_frame(
                        mpm_solver,
                        directory_to_save,
                        frame + 1,
                        save_to_ply=args.output_ply,
                        save_to_h5=args.output_h5,
                    )

                logJp = mpm_solver.mpm_model.logJp.numpy()
                crack_points = mpm_solver.mpm_crack_points.numpy()
                crack_bool = crack_points[:] > maxlogJp
                no_crack_bool = crack_points[:] <= maxlogJp
                print("logJp: ", logJp)
                print("max logJp: ", np.max(logJp))
                

                np.savetxt(result_path + 'crack_bool/bool_array_{}.txt'.format(frame), crack_bool, fmt='%i')

                gs_bool = np.full(mpm_init_pos.shape[0], False, dtype=bool)
                gs_bool[:gs_num] = True

                logJp_bool = np.logical_and(gs_bool, no_crack_bool)
                if args.render_img:
                    if remove_big_logJp:
                        pos = mpm_solver.export_particle_x_to_torch()[logJp_bool].to(device)
                        cov3D = mpm_solver.export_particle_cov_to_torch()
                        rot = mpm_solver.export_particle_R_to_torch()
                        cov3D = cov3D.view(-1, 6)[logJp_bool].to(device)
                        rot = rot.view(-1, 3, 3)[logJp_bool].to(device)
                        init_screen_points_render = init_screen_points[logJp_bool]
                        opacity_render = opacity[logJp_bool]
                        shs_render = shs[logJp_bool]
                    else:

                        if preprocessing_params["particle_interpolation"]:
                            pos = mpm_solver.export_particle_x_gs_to_torch()[:gs_num].to(device)  # export only x_gs
                        else:
                            pos = mpm_solver.export_particle_x_to_torch()[:gs_num].to(device)
                        cov3D = mpm_solver.export_particle_cov_to_torch()
                        rot = mpm_solver.export_particle_R_to_torch()
                        cov3D = cov3D.view(-1, 6)[:gs_num].to(device)
                        rot = rot.view(-1, 3, 3)[:gs_num].to(device)
                        init_screen_points_render = init_screen_points[:gs_num]
                        opacity_render = opacity[:gs_num]
                        shs_render = shs[:gs_num]

                    # water_cov3D = water_cov
                    # water_rot = water_rot
                    # water_pos = mpm_solver.export_particle_x_to_torch()[gs_num-water_num:gs_num].to(device)

                    pos = apply_inverse_rotations(
                        undotransform2origin(
                            undoshift2center111(pos), scale_origin, original_mean_pos
                        ),
                        rotation_matrices,
                    )
                    # water_pos = apply_inverse_rotations(
                    #     undotransform2origin(
                    #         undoshift2center111(water_pos), scale_origin, original_mean_pos
                    #     ),
                    #     rotation_matrices,
                    # )
                    # pos = apply_inverse_rotations(
                    #     undotransform2origin(
                    #         undoshift2center221(pos), scale_origin, original_mean_pos
                    #     ),
                    #     rotation_matrices,
                    # )

                    cov3D = cov3D / (scale_origin * scale_origin)
                    cov3D = apply_inverse_cov_rotations(cov3D, rotation_matrices)
                    # opacity = opacity_render
                    # shs = shs_render
                    # if preprocessing_params["sim_area"] is not None:
                    #     pos = torch.cat([pos, unselected_pos], dim=0)
                    #     cov3D = torch.cat([cov3D, unselected_cov], dim=0)
                    #     opacity = torch.cat([opacity_render, unselected_opacity], dim=0)
                    #     shs = torch.cat([shs_render, unselected_shs], dim=0)

                    colors_precomp = convert_SH(shs_render, current_camera, gaussians, pos, rot)
                    # colors_precomp_water = convert_SH(water_shs, current_camera, gaussians, water_pos, water_rot)
                    # print(colors_precomp.min())
                    # print(colors_precomp.max())
                    # colors_precomp = torch.clamp(colors_precomp, 0.0, 1.0)  # 确保范围在 [0, 1]

                    # colors_208920 = convert_SH(shs[208920], current_camera, gaussians, pos[208920], rot[208920])
                    # convert_one_gs_SH(init_pos, shs, rot)

                    # cov3D[gs_num-water_num:] = water_cov3D
                    # print("pos shape: ", pos.shape)
                    # print("cov3D shape: ", cov3D.shape)
                    # print("init_screen_points shape: ", init_screen_points.shape)
                    # print("colors_precomp shape: ", colors_precomp.shape)
                    # print("opacity shape: ", opacity.shape)

                    start_time = time.time()
                    rendering, raddi = rasterize(
                        means3D=pos,
                        means2D=init_screen_points_render,
                        shs=None,
                        colors_precomp=colors_precomp,
                        opacities=opacity_render,
                        scales=None,
                        rotations=None,
                        cov3D_precomp=cov3D,
                    )
                    end_time = time.time()
                    print("render time: ", end_time - start_time)
                    cv2_img = rendering.permute(1, 2, 0).detach().cpu().numpy()
                    cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
                    if height is None or width is None:
                        height = cv2_img.shape[0] // 2 * 2
                        width = cv2_img.shape[1] // 2 * 2
                    assert args.output_path is not None
                    cv2.imwrite(
                        os.path.join(result_path, f"{frame}.png".rjust(8, "0")),
                        255 * cv2_img,
                    )

            except Exception as e:
                print(e)


        crack_mpm_points = mpm_init_pos[crack_bool]
        particle_position_tensor_to_ply(
            crack_mpm_points,
            "./log/2.18-ficus-crack_big_than_{}_init_points.ply".format(str(maxlogJp)),
        )
        if args.render_img and args.compile_video:
            fps = int(0.01 / time_params["frame_dt"])
            fps = 15
            os.system(
                f"ffmpeg -framerate {fps} -i {result_path}/%04d.png -c:v libx264 -s {width}x{height} -y -pix_fmt yuv420p {result_path}/output.mp4"
            )
