import numpy as np
import h5py
import os
import sys
import warp as wp
import torch


def save_data_at_frame(mpm_solver, dir_name, frame, save_to_ply=True, save_to_h5=False, save_water=False):
    os.umask(0)
    os.makedirs(dir_name, 0o777, exist_ok=True)

    # fullfilename = dir_name + "/sim_" + str(frame).zfill(10) + ".h5"
    fullfilename = dir_name + "/sim_" + str(frame) + ".h5"

    if save_to_ply:
        if save_water:
            particle_and_water_position_to_ply(mpm_solver, fullfilename[:-2] + "ply")
        else:
            # particle_position_to_ply(mpm_solver, fullfilename[:-2] + "ply")
            mpm_solver.particle_position_and_displacement_to_ply(fullfilename[:-2] + "ply")

    if save_to_h5:

        if os.path.exists(fullfilename):
            os.remove(fullfilename)
        newFile = h5py.File(fullfilename, "w")

        x_np = (
            mpm_solver.mpm_state.particle_x.numpy().transpose()
        )  # x_np has shape (3, n_particles)
        newFile.create_dataset("x", data=x_np)  # position

        currentTime = np.array([mpm_solver.time]).reshape(1, 1)
        newFile.create_dataset("time", data=currentTime)  # current time

        f_tensor_np = (
            mpm_solver.mpm_state.particle_F.numpy().reshape(-1, 9).transpose()
        )  # shape = (9, n_particles)
        newFile.create_dataset("f_tensor", data=f_tensor_np)  # deformation grad

        v_np = (
            mpm_solver.mpm_state.particle_v.numpy().transpose()
        )  # v_np has shape (3, n_particles)
        newFile.create_dataset("v", data=v_np)  # particle velocity

        C_np = (
            mpm_solver.mpm_state.particle_C.numpy().reshape(-1, 9).transpose()
        )  # shape = (9, n_particles)
        newFile.create_dataset("C", data=C_np)  # particle C
        print("save siumlation data at frame ", frame, " to ", fullfilename)

def save_data_at_frame_output2ply(mesh_num, mpm_solver, dir_name, frame, save_to_ply=True, save_to_h5=False, save_water=False):
    os.umask(0)
    os.makedirs(dir_name, 0o777, exist_ok=True)
    os.makedirs(dir_name + "/mesh", 0o777, exist_ok=True)
    os.makedirs(dir_name + "/water", 0o777, exist_ok=True)


    # fullfilename = dir_name + "/sim_" + str(frame).zfill(10) + ".h5"
    # fullfilename = dir_name + "/sim_" + str(frame) + ".h5"
    fullfilename_mesh = dir_name + "/mesh/sim_" + str(frame) + ".h5"
    fullfilename_water = dir_name + "/water/sim_" + str(frame) + ".h5"

    if save_to_ply:
        if save_water:
            particle_and_water_position_to_ply(mpm_solver, fullfilename[:-2] + "ply")
        else:
            particle_position_to_ply_mesh(mesh_num, mpm_solver, fullfilename_mesh[:-2] + "ply")
            particle_position_to_ply_water(mesh_num, mpm_solver, fullfilename_water[:-2] + "ply")


def particle_position_to_ply_mesh(mesh_num, mpm_solver, filename):
    # position is (n,3)
    if os.path.exists(filename):
        os.remove(filename)
    position = mpm_solver.mpm_state.particle_x.numpy()
    position = position[:mesh_num]
    num_particles = (position).shape[0]
    position = position.astype(np.float32)
    import os
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, "wb") as f:  # write binary
        header = f"""ply
format binary_little_endian 1.0
element vertex {num_particles}
property float x
property float y
property float z
end_header
"""
        f.write(str.encode(header))
        f.write(position.tobytes())
        print("write", filename)

def particle_position_to_ply_water(mesh_num, mpm_solver, filename):
    # position is (n,3)
    if os.path.exists(filename):
        os.remove(filename)
    position = mpm_solver.mpm_state.particle_x.numpy()
    position = position[mesh_num:]
    position = position[position[:, 2] <= 1.5]
    num_particles = (position).shape[0]
    position = position.astype(np.float32)
    with open(filename, "wb") as f:  # write binary
        header = f"""ply
format binary_little_endian 1.0
element vertex {num_particles}
property float x
property float y
property float z
end_header
"""
        f.write(str.encode(header))
        f.write(position.tobytes())
        print("write", filename)



def particle_position_to_ply(mpm_solver, filename):
    # position is (n,3)
    if os.path.exists(filename):
        os.remove(filename)
    position = mpm_solver.mpm_state.particle_x.numpy()
    num_particles = (position).shape[0]
    position = position.astype(np.float32)
    with open(filename, "wb") as f:  # write binary
        header = f"""ply
format binary_little_endian 1.0
element vertex {num_particles}
property float x
property float y
property float z
end_header
"""
        f.write(str.encode(header))
        f.write(position.tobytes())
        print("write", filename)

def particle_and_water_position_to_ply(mpm_solver, filename):
    # position is (n,3)
    if os.path.exists(filename):
        os.remove(filename)
    position = mpm_solver.mpm_state.particle_x.numpy()
    water_position = mpm_solver.mpm_watersys.particle_water_x.numpy()
    position = np.vstack((position, water_position))
    num_particles = (position).shape[0]
    position = position.astype(np.float32)
    with open(filename, "wb") as f:  # write binary
        header = f"""ply
format binary_little_endian 1.0
element vertex {num_particles}
property float x
property float y
property float z
end_header
"""
        f.write(str.encode(header))
        f.write(position.tobytes())
        print("write", filename)


def particle_position_tensor_to_ply(position_tensor, filename):
    # position is (n,3)
    if os.path.exists(filename):
        os.remove(filename)
    position = position_tensor.clone().detach().cpu().numpy()
    num_particles = (position).shape[0]
    position = position.astype(np.float32)
    with open(filename, "wb") as f:  # write binary
        header = f"""ply
format binary_little_endian 1.0
element vertex {num_particles}
property float x
property float y
property float z
end_header
"""
        f.write(str.encode(header))
        f.write(position.tobytes())
        print("write", filename)
