# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

###########################################################################
# Example Sim Cloth
#
# Shows a simulation of an FEM cloth model colliding against a static
# rigid body mesh using the wp.sim.ModelBuilder().
#
###########################################################################

import math
import os
from enum import Enum

import numpy as np
from pxr import Usd, UsdGeom

import warp as wp
import warp.examples
import warp.sim
import warp.sim.render
import trimesh
global detection_time

class IntegratorType(Enum):
    EULER = "euler"
    XPBD = "xpbd"
    VBD = "vbd"

    def __str__(self):
        return self.value

def print_custom_report(results, indent=""):
    prefilter_time = 0
    soft_contact_time = 0

    for r in results:
        if "prefilter_particles_pim" in r.name and r.name.startswith("forward kernel"):
            prefilter_time += r.elapsed
        if "create_soft_contacts" in r.name and r.name.startswith("forward kernel"):
            soft_contact_time += r.elapsed

    print(f"{indent}prefilter kernels: {prefilter_time:.6f} ms")
    print(f"{indent}create soft contacts kernels: {soft_contact_time:.6f} ms")
    print(f"{indent}total: {prefilter_time+soft_contact_time:.6f} ms")
    detection_time.append(prefilter_time+soft_contact_time)


class Example:
    def __init__(
        self, stage_path="example_cloth.usd", integrator: IntegratorType = IntegratorType.EULER, height=32, width=64
    ):
        self.integrator_type = integrator

        self.sim_height = height
        self.sim_width = width
        self.radius = 0.01

        fps = 60
        self.sim_substeps = 64
        self.frame_dt = 1.0 / fps
        self.sim_dt = self.frame_dt / self.sim_substeps
        self.sim_time = 0.0
        self.profiler = {}

        builder = wp.sim.ModelBuilder()

        if self.integrator_type == IntegratorType.EULER:
            # builder.add_cloth_grid(
            #     pos=wp.vec3(-self.sim_width * 0.05, 2., -self.sim_height * 0.05),
            #     rot=wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), math.pi * 0.5),
            #     vel=wp.vec3(0.0, 0.0, 0.0),
            #     dim_x=self.sim_width,
            #     dim_y=self.sim_height,
            #     cell_x=0.1,
            #     cell_y=0.1,
            #     mass=0.1,
            #     fix_left=False,
            #     tri_ke=1.0e3,
            #     tri_ka=1.0e3,
            #     tri_kd=1.0e1,
            # )

            builder.add_cloth_grid(
                pos=wp.vec3(-self.sim_width * 0.05 / 2, 2.,
                            -self.sim_height * 0.05 / 2),
                rot=wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), math.pi * 0.5),
                vel=wp.vec3(0.0, 0.0, 0.0),
                dim_x=self.sim_width,
                dim_y=self.sim_height,
                cell_x=0.1 / 2,
                cell_y=0.1 / 2,
                particle_radius=self.radius,
                mass=0.1,
                fix_left=False,
                tri_ke=1.0e3 / 4,
                tri_ka=1.0e3 / 4,
                tri_kd=1.0e1 #/ 4,
            )
        elif self.integrator_type == IntegratorType.XPBD:
            builder.add_cloth_grid(
                pos=wp.vec3(-self.sim_width * 0.05 / 2, 2., -self.sim_height * 0.05 / 2),
                rot=wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), math.pi * 0.5),
                vel=wp.vec3(0.0, 0.0, 0.0),
                dim_x=self.sim_width,
                dim_y=self.sim_height,
                cell_x=0.05,
                cell_y=0.05,
                particle_radius=self.radius,
                mass=0.1,
                # fix_left=True,
                edge_ke=1.0e2,
                add_springs=True,
                spring_ke=1.0e3,
                spring_kd=1.0e1,
            )
        else:
            # VBD
            builder.add_cloth_grid(
                pos=wp.vec3(-self.sim_width * self.radius, self.sim_width * 2.0 * self.radius, -self.sim_height * self.radius),
                rot=wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), math.pi * 0.5),
                vel=wp.vec3(0.0, 0.0, 0.0),
                dim_x=self.sim_width,
                dim_y=self.sim_height,
                cell_x=self.radius * 2.0,
                cell_y=self.radius * 2.0,
                particle_radius=self.radius,
                mass=0.1,
                fix_left=True,
                tri_ke=1e4,
                tri_ka=1e4,
                tri_kd=1e-5,
                edge_ke=100,
            )

        vdb, sdf_np, mins, voxel_size, bg_value, coords = self.load_np_sdf('grid_sdf/koala.npz')

        mesh = trimesh.load("meshes/mesh_koala_collision_outer.obj", force='mesh')
        mesh_bb = trimesh.creation.box(bounds=mesh.bounds)
        if not mesh.is_volume:
            print("Warning: Mesh is not watertight. Physics collisions may behave oddly.")
            components = mesh.split(only_watertight=True)  # Set to True if you only want watertight components
            mesh = max(components, key=lambda m: len(m.faces))
        # mesh.show()
        vertices_np = np.asarray(mesh.vertices, dtype=np.float32)
        indices_np = np.asarray(mesh.faces, dtype=np.int32).flatten()
        print(vertices_np.min())
        self.vertices = vertices_np
        self.indices = indices_np

        sdf = wp.sim.SDF(volume=vdb)
        # sdf = wp.sim.SDF(volume=vdb, shell_vertices=mesh_bb.vertices, shell_indices=mesh_bb.faces)
        # sdf = wp.sim.SDF(volume=vdb, shell_vertices=mesh.vertices, shell_indices=mesh.faces)

        builder.add_shape_sdf(
            ke=1.0e2,
            kd=1.0e2,
            kf=1.0e1,
            # mu=0.5,
            sdf=sdf,
            body=-1,
            pos=wp.vec3(0., 0.56473401, 0.),
            rot=wp.quat(0.0, 0.0, 0.0, 1.0),
            scale=wp.vec3(1.0, 1.0, 1.0),
        )

        # usd_stage = Usd.Stage.Open(os.path.join(warp.examples.get_asset_directory(), "bunny.usd"))
        # usd_geom = UsdGeom.Mesh(usd_stage.GetPrimAtPath("/root/bunny"))

        # mesh_points = np.array(usd_geom.GetPointsAttr().Get())
        # mesh_indices = np.array(usd_geom.GetFaceVertexIndicesAttr().Get())
        #
        # mesh = wp.sim.Mesh(vertices_np, indices_np)
        #
        # builder.add_shape_mesh(
        #     body=-1,
        #     mesh=mesh,
        #     pos=wp.vec3(0., 0.56473401, 0.),
        #     rot=wp.quat(0.0, 0.0, 0.0, 1.0),
        #     scale=wp.vec3(1.0, 1.0, 1.0),
        #     ke=1.0e2,
        #     kd=1.0e2,
        #     kf=1.0e1,
        #     is_visible=False,
        # )

        if self.integrator_type == IntegratorType.VBD:
            builder.color()

        self.model = builder.finalize()
        self.model.ground = True
        self.model.soft_contact_ke = 1.0e4
        self.model.soft_contact_kd = 1.0e2

        if self.integrator_type == IntegratorType.EULER:
            self.integrator = wp.sim.SemiImplicitIntegrator()
        elif self.integrator_type == IntegratorType.XPBD:
            self.integrator = wp.sim.XPBDIntegrator(iterations=1)
        else:
            self.integrator = wp.sim.VBDIntegrator(self.model, iterations=1)

        self.state_0 = self.model.state()
        self.state_1 = self.model.state()

        if stage_path:
            self.renderer = wp.sim.render.SimRenderer(self.model, stage_path, scaling=1.0)
        else:
            self.renderer = None

        self.use_cuda_graph = wp.get_device().is_cuda
        self.use_cuda_graph = False
        if self.use_cuda_graph:
            with wp.ScopedCapture() as capture:
                self.simulate()
            self.graph = capture.graph

    def load_obj_mesh(self, filename):
        mesh = trimesh.load(filename, force='mesh')
        # mesh.simplify_quadric_decimation(face_count=500000)
        if not mesh.is_volume:
            print("Warning: Mesh is not watertight. Physics collisions may behave oddly.")
            # components = mesh.split(only_watertight=True)  # Set to True if you only want watertight components
            # mesh = max(components, key=lambda m: len(m.faces))
        # trimesh.smoothing.filter_taubin(mesh, lamb=0.5, nu=0.51, iterations=10)
        points = wp.array(np.asarray(mesh.vertices, dtype=np.float32), dtype=wp.vec3)
        indices = wp.array(np.asarray(mesh.faces, dtype=np.int32).flatten(), dtype=int)
        # print(points.shape, indices.shape)
        mesh = wp.Mesh(points, indices)
        return mesh

    def load_np_sdf(self, filename):
        data = np.load(filename)
        sdf_np = data['sdf']
        mins = (data['mins']).tolist()
        voxel_size = float(data['voxel_size'])
        bg_value = float(data['bg_value'])
        coords = data['coords']

        vdb = wp.Volume.load_from_numpy(sdf_np, min_world=mins, voxel_size=voxel_size, bg_value=bg_value)

        return vdb, sdf_np, mins, voxel_size, bg_value, coords


    def simulate(self):
        # wp.sim.prefilter_particles(self.model, self.state_0)
        # wp.sim.allocate_compact_soft_contacts(self.model)
        # wp.sim.collide(self.model, self.state_0, apply_particle_filter=True)

        for _ in range(self.sim_substeps):
            self.state_0.clear_forces()
            # wp.sim.prefilter_particles(self.model, self.state_0)
            # wp.sim.allocate_compact_soft_contacts(self.model)
            wp.sim.collide(self.model, self.state_0, apply_particle_filter=False)
            self.integrator.simulate(self.model, self.state_0, self.state_1, self.sim_dt)
            # swap states
            (self.state_0, self.state_1) = (self.state_1, self.state_0)

    def step(self):
        with wp.ScopedTimer("step", dict=self.profiler, cuda_filter=wp.TIMING_ALL, report_func=print_custom_report):
            if self.use_cuda_graph:
                wp.capture_launch(self.graph)
            else:
                self.simulate()
        self.sim_time += self.frame_dt

    def render(self):
        if self.renderer is None:
            return

        with wp.ScopedTimer("render"):
            self.renderer.begin_frame(self.sim_time)
            correction = wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), wp.radians(-90.0))

            self.renderer.render_ref(
                name="collision",
                path='meshes/koala.usd',
                pos=wp.vec3(0., 0.56473401, 0.),
                rot=correction,
                scale=wp.vec3(1., 1., 1.),
                # color=(0.35, 0.55, 0.9),
            )

            self.renderer.render(self.state_0)
            self.renderer.end_frame()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
    parser.add_argument(
        "--stage_path",
        type=lambda x: None if x == "None" else str(x),
        default="cloth_sdf_xpbd.usd",
        help="Path to the output USD file.",
    )
    parser.add_argument("--num_frames", type=int, default=200, help="Total number of frames.")
    parser.add_argument(
        "--integrator",
        help="Type of integrator",
        type=IntegratorType,
        choices=list(IntegratorType),
        default=IntegratorType.EULER,
    )
    parser.add_argument("--width", type=int, default=64, help="Cloth resolution in x.")
    parser.add_argument("--height", type=int, default=32, help="Cloth resolution in y.")

    args = parser.parse_known_args()[0]
    detection_time = []
    with wp.ScopedDevice(args.device):
        example = Example(stage_path=args.stage_path, integrator=args.integrator, height=args.height, width=args.width)

        for _i in range(args.num_frames):
            example.step()
            example.render()

        frame_times = example.profiler["step"]
        print(f"\nAverage frame sim time: {sum(frame_times) / len(frame_times):.2f} ms")

        if example.renderer:
            example.renderer.save()

    # np.savez_compressed('exp_results/cloth/time_koala_with_shell.npz', time=detection_time)
