# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

###########################################################################
# Example Sim Rigid Contact
#
# Shows how to set up free rigid bodies with different shape types falling
# and colliding against each other and the ground using wp.sim.ModelBuilder().
#
###########################################################################

import math
import os
import trimesh

import numpy as np
from pxr import Usd, UsdGeom

import warp as wp
import warp.examples
import warp.sim
import warp.sim.render

@wp.kernel
def compute_shape_world_transforms(
    body_q: wp.array(dtype=wp.transform),
    shape_body: wp.array(dtype=int),
    shape_transform: wp.array(dtype=wp.transform),
    shape_world_transform: wp.array(dtype=wp.transform),
):
    tid = wp.tid()
    b = shape_body[tid]
    if b == -1:
        shape_world_transform[tid] = shape_transform[tid]
    else:
        shape_world_transform[tid] = wp.transform_multiply(body_q[b], shape_transform[tid])


class Example:
    def __init__(self, stage_path="example_rigid_contact.usd"):
        builder = wp.sim.ModelBuilder()

        self.sim_time = 0.0
        fps = 60
        self.frame_dt = 1.0 / fps

        self.sim_substeps = 64
        self.sim_dt = self.frame_dt / self.sim_substeps

        self.num_bodies = 2
        self.scale = 1.
        self.ke = 1.0e5
        self.kd = 250.0
        self.kf = 500.0


        # sim_mehses = [self.load_mesh(f'/home/ /PycharmProjects/ray-casting/meshes/mesh_{c}_mid.obj') for c in 'SIGGRAPH']
        # sim_mehses = [self.load_mesh(f'/home/ /PycharmProjects/ray-casting/meshes/mesh_cat_0lvl.obj') for c in 'SI']
        sim_mehses = [self.load_mesh(f'/home/ /Downloads/letters_3d/{c}_3d.obj') for c in 'SIGGRAPH']
        self.render_meshes = [trimesh.load(f'/home/ /Downloads/letters_3d/{c}_3d.obj') for c in 'SIGGRAPH']
        # self.render_meshes = [trimesh.load(f'/home/ /PycharmProjects/ray-casting/meshes/mesh_cat_0lvl.obj') for c in 'SI']
        for m in self.render_meshes:
            print(m.vertices.shape)
        for i in range(self.num_bodies):
            b = builder.add_body(
                origin=wp.transform(
                    (i * 0.8 * self.scale, 1.5 + i * 2.0 * self.scale, 4.5 + i * 0.8 * self.scale),
                    wp.quat_from_axis_angle(wp.vec3(0.0, 1.0, 0.0), math.pi * 0.1 * i),
                )
            )

            builder.add_shape_mesh(
                body=b,
                mesh=sim_mehses[i],
                pos=wp.vec3(0.0, 0.0, 0.0),
                scale=wp.vec3(self.scale, self.scale, self.scale),
                ke=self.ke,
                kd=self.kd,
                kf=self.kf,
                density=1e3,
                is_visible=False,
            )



        # finalize model
        self.model = builder.finalize()
        self.model.ground = True

        self.shape_world_transforms = wp.empty(len(self.model.shape_body), dtype=wp.transform, device=self.model.device)

        self.integrator = wp.sim.SemiImplicitIntegrator()

        if stage_path:
            self.renderer = wp.sim.render.SimRenderer(self.model, stage_path, scaling=1.)
        else:
            self.renderer = None

        self.state_0 = self.model.state()
        self.state_1 = self.model.state()

        wp.sim.eval_fk(self.model, self.model.joint_q, self.model.joint_qd, None, self.state_0)

        self.use_cuda_graph = wp.get_device().is_cuda
        if self.use_cuda_graph:
            with wp.ScopedCapture() as capture:
                self.simulate()
            self.graph = capture.graph

    # Mesh loader (simplified via trimesh)
    def load_mesh(self, path):
        mesh = trimesh.load(path)
        components = mesh.split(only_watertight=True)  # Set to True if you only want watertight components
        # mesh = max(components, key=lambda m: len(m.faces))
        vertices = np.array(mesh.vertices, dtype=np.float32)
        faces = np.array(mesh.faces, dtype=np.int32)
        return wp.sim.Mesh(vertices, faces)

    def load_np_sdf(self, filename):
        data = np.load(filename)
        sdf_np = data['sdf']
        mins = (data['mins']).tolist()
        voxel_size = float(data['voxel_size'])
        bg_value = float(data['bg_value'])
        coords = data['coords']

        vdb = wp.Volume.load_from_numpy(sdf_np, min_world=mins, voxel_size=voxel_size, bg_value=bg_value)

        return vdb, sdf_np, mins, voxel_size, bg_value, coords


    def simulate(self):
        for _ in range(self.sim_substeps):
            self.state_0.clear_forces()
            wp.sim.collide(self.model, self.state_0)
            self.integrator.simulate(self.model, self.state_0, self.state_1, self.sim_dt)
            self.state_0, self.state_1 = self.state_1, self.state_0

    def step(self):
        with wp.ScopedTimer("step", active=True):
            if self.use_cuda_graph:
                wp.capture_launch(self.graph)
            else:
                self.simulate()
        self.sim_time += self.frame_dt

    # def render(self):
    #     if self.renderer is None:
    #         return
    #
    #     with wp.ScopedTimer("render", active=True):
    #         self.renderer.begin_frame(self.sim_time)
    #         self.renderer.render(self.state_0)
    #         self.renderer.end_frame()

    def render(self):
        if self.renderer is None:
            return

        with wp.ScopedTimer("render", active=True):
            self.renderer.begin_frame(self.sim_time)

            # Compute world-space shape transforms using a Warp kernel
            wp.launch(
                compute_shape_world_transforms,
                dim=len(self.model.shape_body),
                inputs=[
                    self.state_0.body_q,
                    self.model.shape_body,
                    self.model.shape_transform,
                    self.shape_world_transforms,
                ],
                device=self.model.device,
            )

            # Transfer transform data to host
            shape_body_host = self.model.shape_body.numpy()
            shape_xf_host = self.shape_world_transforms.numpy()
            second_g = False
            for i in range(len(shape_xf_host)):
                if shape_body_host[i] != -1:
                    xf = shape_xf_host[i]
                    pos = wp.transform_get_translation(xf)
                    rot = wp.transform_get_rotation(xf)
                    name = 'SIGGRAPH'[i]
                    if name == 'G' and not second_g:
                        second_g = True
                    elif name == 'G' and second_g:
                        name = 'G2'
                    self.renderer.render_mesh(
                        name=name,  # optional: make this unique if needed
                        points=self.render_meshes[i].vertices,
                        indices=self.render_meshes[i].faces,
                        pos=pos,
                        rot=rot,
                    )

            self.renderer.end_frame()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
    parser.add_argument(
        "--stage_path",
        type=lambda x: None if x == "None" else str(x),
        default="letters_3d.usd",
        help="Path to the output USD file.",
    )
    parser.add_argument("--num_frames", type=int, default=300, help="Total number of frames.")

    args = parser.parse_known_args()[0]

    with wp.ScopedDevice(args.device):
        example = Example(stage_path=args.stage_path)

        for _ in range(args.num_frames):
            example.step()
            example.render()

        if example.renderer:
            example.renderer.save()

