import numpy as np
import torch
import smplx
import trimesh
import os
import pickle
from typing import Union, Optional, List, Dict
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import chamfer_distance
from mld.transforms.joints2rots import config
import warnings
from human_body_prior.models.ik_engine import IK_Engine
from drag_dev.fit.ik_engine_utils import SourceKeyPoints
from drag_dev.shape_optimization.coap_selfpene_loss import COAPSelfPenetrationLoss
from drag_dev.shape_optimization.smpl_pose_optimizer import SMPLPoseOptimizer

_ik_engine = None
_source_pts = None

def get_ik_engine(device):
    """
    Get IK engine instance for inverse kinematics computation.
    """
    global _ik_engine, _source_pts
    
    if _ik_engine is None:
        vposer_expr_dir = '../human_body_prior/support_data/dowloads/V02_05'
        num_betas = 10
        data_loss = torch.nn.MSELoss(reduction='sum')
        
        stepwise_weights = [
            {'data': 10., 'poZ_body': .01, 'betas': .5},
        ]
        
        optimizer_args = {
            'type': 'LBFGS',
            'max_iter': 150,
            'lr': 1,
            'tolerance_change': 1e-4,
            'history_size': 200
        }
        
        _ik_engine = IK_Engine(
            vposer_expr_dir=vposer_expr_dir,
            verbosity=0,
            display_rc=(2, 2),
            data_loss=data_loss,
            num_betas=num_betas,
            stepwise_weights=stepwise_weights,
            optimizer_args=optimizer_args
        ).to(device)

        
        bm_fname = './smplh/neutral/model.npz'
        _source_pts = SourceKeyPoints(
            bm=bm_fname,
            n_joints=22,
            kpts_colors=None,
            num_betas=num_betas
        ).to(device)
    else:
        _ik_engine = _ik_engine.to(device)
        _source_pts = _source_pts.to(device)
    
    return _ik_engine, _source_pts

class Motion:
    def __init__(
        self,
        joints: Optional[torch.Tensor] = None,
        smpl_params: Optional[Dict[str, torch.Tensor]] = None,
        smpl_model: Optional[smplx.SMPL] = None,
        device: Optional[torch.device] = None
    ):
        """
        Initialize Motion object with joints or SMPL parameters.
        """
        self.device = device if device is not None else torch.device('cuda:0')
        
        if smpl_model is None:
            self.smpl_model = smplx.create(
                model_path=config.SMPL_MODEL_DIR,
                model_type='smpl',
                gender='neutral',
                use_face_contour=False,
                num_betas=10
            ).to(self.device)
        else:
            self.smpl_model = smpl_model.to(self.device)
            
        if joints is not None:
            self.joints = joints.to(self.device)
            if self.joints.shape[1] == 45:
                warnings.warn("joints contains 45 joints, only first 22 will be used")
                self.joints = self.joints[:, :22, :]
            self.smpl_params = self.joints2smpl(self.joints)
        elif smpl_params is not None:
            self.smpl_params = {k: v.to(self.device) for k, v in smpl_params.items()}
            self.joints = self.smpl2joints()
        else:
            raise ValueError("Either 'joints' or 'smpl_params' must be provided")

    @classmethod
    def from_pkl(cls, pkl_path: str, device: Optional[torch.device] = None, save_back: bool = True):
        """
        Load Motion object from pickle file.
        """
        print(f"Loading Motion from pkl: {pkl_path}")
        
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
        
        if 'pose' in data and 'betas' in data:
            print("Using existing smpl_params in pkl")
            smpl_params = {
                'poses': torch.tensor(data['pose']),
                'betas': torch.tensor(data['betas']),
                'trans': torch.tensor(data.get('trans', data['joints'][:, 0, :]))
            }
            motion = cls(smpl_params=smpl_params, device=device)
        else:
            print("No smpl_params in pkl, running IK ...")
            joints = torch.tensor(data['joints']).float()
            motion = cls(joints=joints, device=device)
            
            if save_back:
                print("Saving computed smpl_params back to pkl")
                data['pose'] = motion.smpl_params['poses'].detach().cpu().numpy()
                data['betas'] = motion.smpl_params['betas'].detach().cpu().numpy()
                data['trans'] = motion.smpl_params['trans'].detach().cpu().numpy()
                with open(pkl_path, 'wb') as f:
                    pickle.dump(data, f)
                print(f"Saved smpl_params to: {pkl_path}")
        
        return motion

    def joints2smpl(self, joints: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Convert joint positions to SMPL parameters using inverse kinematics.
        """
        ik_engine, source_pts = get_ik_engine(self.device)
        
        root_joint = joints[:, 0:1, :]
        rel_joints = joints - root_joint
        root = root_joint.squeeze(1)
        
        ik_res = ik_engine(source_pts, rel_joints, {})
        
        
        body_pose = ik_res['pose_body']
        betas = ik_res['betas']
        global_orient = ik_res['root_orient']
        
        
        padded_body_pose = torch.zeros((*body_pose.shape[:-1], 69), device=body_pose.device)
        padded_body_pose[..., :63] = body_pose
        
        
        poses = torch.cat([global_orient, padded_body_pose], dim=-1)
        
        return {
            'poses': poses,
            'betas': betas,
            'trans': root
        }

    def smpl2joints(self) -> torch.Tensor:
        
        output = self.smpl_model(
            betas=self.smpl_params['betas'],
            body_pose=self.smpl_params['poses'][:, 3:],
            global_orient=self.smpl_params['poses'][:, :3],
            transl=self.smpl_params['trans']
        )
        
        return output.joints[:, :22, :]

    def check_penetration(self, threshold: float = 0.01, batch_size: int = 8) -> Dict[str, float]:
        
        
        with torch.no_grad():
            output = self.smpl_model(
                betas=self.smpl_params['betas'],
                body_pose=self.smpl_params['poses'][:, 3:],
                global_orient=self.smpl_params['poses'][:, :3],
                transl=self.smpl_params['trans'],
                return_verts=True,
                return_full_pose=True
            )
        
        
        loss_computer = COAPSelfPenetrationLoss(
            smpl_model=self.smpl_model,
            device=self.device,
            batch_size=1  
        )
        
        
        class AttrDict(dict):
            def __getattr__(self, name):
                if name in self:
                    return self[name]
                raise AttributeError(f"'AttrDict' object has no attribute '{name}'")
            def __setattr__(self, name, value):
                self[name] = value
        
        
        total_frames = output.vertices.shape[0]
        all_penetration_losses = []
        penetration_count = 0
        
        for i in range(total_frames):
            
            frame_output = AttrDict({
                'vertices': output.vertices[i:i+1],
                'joints': output.joints[i:i+1],
                'full_pose': output.full_pose[i:i+1],
                'global_orient': output.global_orient[i:i+1],
                'body_pose': output.body_pose[i:i+1],
                'betas': output.betas[i:i+1],
            })
            
            penetration_loss = loss_computer.model.coap.self_collision_loss(frame_output, ret_samples=False)
            all_penetration_losses.append(penetration_loss.detach().cpu())
            
            if penetration_loss > threshold:
                penetration_count += 1
            
            
            torch.cuda.empty_cache()
        
        all_penetration_losses = torch.cat(all_penetration_losses)
        penetration_rate = penetration_count / total_frames
        
        
        del loss_computer
        torch.cuda.empty_cache()
        
        return {
            'avg_penetration': all_penetration_losses.mean().item(),
            'max_penetration': all_penetration_losses.max().item(),
            'penetration_rate': penetration_rate,
            'penetration_frames': penetration_count,
            'total_frames': total_frames,
            'all_penetration_losses': all_penetration_losses
        }

    def optimize_self_intersection(self, max_iters: int = 100, learning_rate: float = 0.001, 
                                 lambda_pen: float = 1.0, lambda_pose: float = 0.01,
                                 batch_size: int = 32):
        
        import time
        
        
        start_time = time.time()
        
        
        optimizer = SMPLPoseOptimizer(self.smpl_model, device=self.device)
        
        
        optimized_poses = optimizer.optimize_poses(
            poses=self.smpl_params['poses'],
            trans=self.smpl_params['trans'],
            betas=self.smpl_params['betas'],
            num_iterations=max_iters,
            learning_rate=learning_rate,
            lambda_pen=lambda_pen,
            lambda_pose=lambda_pose,
            batch_size=batch_size
        )
        
        
        end_time = time.time()
        total_time = end_time - start_time
        
        
        self.smpl_params['poses'] = torch.tensor(optimized_poses, device=self.device)
        
        
        self.joints = self.smpl2joints()
        
        
        del optimizer
        torch.cuda.empty_cache()
        
        
        seq_len = self.smpl_params['poses'].shape[0]
        avg_time_per_frame = total_time / seq_len if seq_len > 0 else 0
        fps = seq_len / total_time if total_time > 0 else 0
        
        time_stats = {
            'total_optimization_time': total_time,
            'sequence_length': seq_len,
            'avg_time_per_frame': avg_time_per_frame,
            'optimization_fps': fps,
            'max_iterations': max_iters,
            'learning_rate': learning_rate,
            'lambda_pen': lambda_pen,
            'lambda_pose': lambda_pose,
            'batch_size': batch_size
        }
        
        return time_stats

    def optimize_interaction(self, other_motion, max_iters: int = 100,
                            learning_rate: float = 0.001, 
                            lambda_pen: float = 1.0,
                            lambda_inter: float = 1.0,
                            lambda_pose: float = 0.01,
                            batch_size: int = 32):
        
        import time
        
        
        start_time = time.time()
        
        
        optimizer = SMPLPoseOptimizer(self.smpl_model, device=self.device)
        
        
        optimized_poses1, optimized_poses2 = optimizer.optimize_poses_with_interaction(
            poses1=self.smpl_params['poses'],
            poses2=other_motion.smpl_params['poses'],
            trans1=self.smpl_params['trans'],
            trans2=other_motion.smpl_params['trans'],
            betas1=self.smpl_params['betas'],
            betas2=other_motion.smpl_params['betas'],
            num_iterations=max_iters,
            learning_rate=learning_rate,
            lambda_pen=lambda_pen,
            lambda_inter=lambda_inter,
            lambda_pose=lambda_pose,
            batch_size=batch_size
        )
        
        
        end_time = time.time()
        total_time = end_time - start_time
        
        
        self.smpl_params['poses'] = torch.tensor(optimized_poses1, device=self.device)
        other_motion.smpl_params['poses'] = torch.tensor(optimized_poses2, device=self.device)
        
        
        self.joints = self.smpl2joints()
        other_motion.joints = other_motion.smpl2joints()
        
        
        del optimizer
        torch.cuda.empty_cache()
        
        
        seq_len = self.smpl_params['poses'].shape[0]
        avg_time_per_frame = total_time / seq_len if seq_len > 0 else 0
        fps = seq_len / total_time if total_time > 0 else 0
        
        time_stats = {
            'total_optimization_time': total_time,
            'sequence_length': seq_len,
            'avg_time_per_frame': avg_time_per_frame,
            'optimization_fps': fps,
            'max_iterations': max_iters,
            'learning_rate': learning_rate,
            'lambda_pen': lambda_pen,
            'lambda_inter': lambda_inter,
            'lambda_pose': lambda_pose,
            'batch_size': batch_size
        }
        
        return time_stats

    def plot_3d_motion(self, output_file: str = 'motion.gif'):
        
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
        from matplotlib.animation import FuncAnimation
        
        
        connections = [
            (0, 1), (0, 2), (0, 3),  
            (1, 4), (4, 7), (7, 10),  
            (2, 5), (5, 8), (8, 11),  
            (3, 6), (6, 9), (9, 12),  
            (12, 13), (12, 14),  
            (13, 16), (16, 18), (18, 20),  
            (14, 17), (17, 19), (19, 21),  
        ]
        
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection='3d')
        
        def update(frame):
            ax.clear()
            
            joints = self.joints[frame].cpu().numpy()
            
            
            
            theta = np.pi/2
            rot_mat = np.array([
                [1, 0, 0],
                [0, np.cos(theta), -np.sin(theta)],
                [0, np.sin(theta), np.cos(theta)]
            ])
            
            
            joints_transformed = np.dot(joints, rot_mat.T)
            
            
            ax.scatter(joints_transformed[:, 0], 
                      joints_transformed[:, 1], 
                      joints_transformed[:, 2], 
                      c='b', marker='o')
            
            
            for connection in connections:
                start = joints_transformed[connection[0]]
                end = joints_transformed[connection[1]]
                ax.plot([start[0], end[0]], 
                       [start[1], end[1]], 
                       [start[2], end[2]], 'r-')
            
            
            ax.set_xlim(-2, 2)
            ax.set_ylim(-2, 2)
            ax.set_zlim(-2, 2)
            
            
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')
            
            
            ax.set_title(f'Frame {frame}')
            
        anim = FuncAnimation(fig, update, frames=len(self.joints), interval=50)
        anim.save(output_file)
        plt.close()

    def _convert_zup_to_yup(self, vertices: torch.Tensor) -> torch.Tensor:
        
        
        
        theta = -np.pi/2
        rot_mat = np.array([
            [1, 0, 0],
            [0, np.cos(theta), -np.sin(theta)],
            [0, np.sin(theta), np.cos(theta)]
        ])
        
        
        vertices_np = vertices.cpu().numpy()
        vertices_transformed = np.dot(vertices_np, rot_mat.T)
        return torch.from_numpy(vertices_transformed).to(vertices.device)

    def _is_zup(self, vertices: torch.Tensor) -> bool:
        
        
        y_var = torch.var(vertices[..., 1])
        z_var = torch.var(vertices[..., 2])
        
        
        return z_var > y_var * 1.5

    def _setup_rendering_imports(self):
        
        import pyrender
        import trimesh
        import imageio
        import numpy as np
        import os
        from tqdm import tqdm
        from collections import UserDict
        return pyrender, trimesh, imageio, np, os, tqdm, UserDict

    def _create_temp_directory(self, output_file):
        
        import os
        temp_dir = os.path.join(os.path.dirname(output_file), 'temp_frames')
        os.makedirs(temp_dir, exist_ok=True)
        return temp_dir

    def _get_smpl_output(self, return_full_pose=False):
        
        with torch.no_grad():
            output = self.smpl_model(
                betas=self.smpl_params['betas'],
                body_pose=self.smpl_params['poses'][:, 3:],
                global_orient=self.smpl_params['poses'][:, :3],
                transl=self.smpl_params['trans'],
                return_verts=True,
                return_full_pose=return_full_pose
            )
        vertices = output.vertices.cpu().numpy()  # [N, 6890, 3]
        faces = self.smpl_model.faces
        return output, vertices, faces

    def _setup_ground(self, vertices):
        
        import numpy as np
        
        if len(vertices) > 0:
            
            all_vertices = vertices.reshape(-1, 3)  
            
            
            min_x = np.min(all_vertices[:, 0])  
            max_x = np.max(all_vertices[:, 0])  
            min_z = np.min(all_vertices[:, 2])  
            max_z = np.max(all_vertices[:, 2])  
            min_y = np.min(all_vertices[:, 1])  
            
            
            ground_y = min_y - 0.05  
            
            
            margin = 1.0  
            ground_x_min = min_x - margin
            ground_x_max = max_x + margin
            ground_z_min = min_z - margin
            ground_z_max = max_z + margin
            
            print(f"Activity range analysis:")
            print(f"  X range: [{min_x:.3f}, {max_x:.3f}]")
            print(f"  Z range: [{min_z:.3f}, {max_z:.3f}]")
            print(f"  Y range: [{min_y:.3f}, {np.max(all_vertices[:, 1]):.3f}]")
            print(f"Ground:")
            print(f"  Y: {ground_y:.3f}")
            print(f"  XZ range: [{ground_x_min:.3f}, {ground_x_max:.3f}] x [{ground_z_min:.3f}, {ground_z_max:.3f}]")
        else:
            ground_y = -0.5  
            ground_x_min = -3.0
            ground_x_max = 3.0
            ground_z_min = -3.0
            ground_z_max = 3.0
        
        return ground_y, ground_x_min, ground_x_max, ground_z_min, ground_z_max

    def _create_ground_mesh(self, ground_y, ground_x_min, ground_x_max, ground_z_min, ground_z_max):
        
        import pyrender
        import trimesh
        import numpy as np
        
        
        
        ground_width = ground_x_max - ground_x_min
        ground_length = ground_z_max - ground_z_min
        ground_center_x = (ground_x_min + ground_x_max) / 2
        ground_center_z = (ground_z_min + ground_z_max) / 2
        
        
        plane = trimesh.creation.box(extents=(ground_width, 0.1, ground_length))
        plane.apply_translation((ground_center_x, ground_y, ground_center_z))  
        
        
        plane.visual.vertex_colors = np.ones((len(plane.vertices), 4)) * [0.7, 0.7, 0.7, 1.0]
        
        ground_pyrender = pyrender.Mesh.from_trimesh(plane, smooth=False)
        return ground_pyrender

    def _setup_scene_and_camera(self, ground_pyrender, camera_distance=3.0):
        
        import pyrender
        import numpy as np
        
        
        scene = pyrender.Scene(ambient_light=[0.4]*3, bg_color=[1.0]*3)
        
        
        scene.add(ground_pyrender)
        
        
        camera = pyrender.PerspectiveCamera(yfov=np.pi / 4.0)  
        camera_pose = np.eye(4)
        camera_pose[1, 3] = 0.9  
        camera_pose[2, 3] = camera_distance  
        scene.add(camera, pose=camera_pose)
        
        
        
        light = pyrender.DirectionalLight(color=[1.0]*3, intensity=1.5)
        scene.add(light, pose=camera_pose)
        
        
        point_light = pyrender.PointLight(color=[0.3]*3, intensity=0.3)
        point_pose = np.eye(4)
        point_pose[0, 3] = 0.0
        point_pose[1, 3] = 2.0
        point_pose[2, 3] = 0.0
        scene.add(point_light, pose=point_pose)
        
        
        fill_light = pyrender.DirectionalLight(color=[0.8, 0.8, 1.0], intensity=0.8)
        fill_pose = np.eye(4)
        fill_pose[0, 3] = 2.0  
        fill_pose[1, 3] = 1.0
        fill_pose[2, 3] = 2.0
        scene.add(fill_light, pose=fill_pose)
        
        return scene, camera_pose

    def _clear_scene_meshes(self, scene):
        
        import numpy as np
        import pyrender
        
        for node in scene.get_nodes():
            if isinstance(node.mesh, pyrender.Mesh):
                
                if hasattr(node.mesh, 'primitives') and len(node.mesh.primitives) > 0:
                    primitive = node.mesh.primitives[0]
                    if hasattr(primitive, 'positions') and len(primitive.positions) < 100:
                        
                        continue  
                scene.remove_node(node)

    def render_mesh_frames(self, output_file: str = 'motion.mp4'):
        
        pyrender, trimesh, imageio, np, os, tqdm, UserDict = self._setup_rendering_imports()
        
        
        temp_dir = self._create_temp_directory(output_file)
        
        
        output, vertices, faces = self._get_smpl_output()
        
        
        ground_y, ground_x_min, ground_x_max, ground_z_min, ground_z_max = self._setup_ground(vertices)
        ground_pyrender = self._create_ground_mesh(ground_y, ground_x_min, ground_x_max, ground_z_min, ground_z_max)
        scene, camera_pose = self._setup_scene_and_camera(ground_pyrender, camera_distance=3.0)
        
        
        r = pyrender.OffscreenRenderer(1280, 960)  
        
        
        frames = []
        for i in tqdm(range(len(vertices)), desc="Rendering frames"):
            
            mesh = trimesh.Trimesh(vertices=vertices[i], faces=faces)
            material = pyrender.MetallicRoughnessMaterial(
                baseColorFactor=[0.3, 0.5, 0.8, 1.0],
                metallicFactor=0.2,
                roughnessFactor=0.8
            )
            mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
            
            
            self._clear_scene_meshes(scene)
            
            
            scene.add(mesh)
            
            
            color, _ = r.render(scene)
            frames.append(color)
        
        
        imageio.mimsave(output_file, frames, fps=30, quality=9)  
        
        
        r.delete()
        import shutil
        shutil.rmtree(temp_dir)
        print(f"Video saved to: {output_file}")

    def render_mesh_frames_with_collision(self, output_file: str = 'motion.mp4', collision_radius: float = 0.03, batch_size: int = 8, quality: str = 'high'):
        
        pyrender, trimesh, imageio, np, os, tqdm, UserDict = self._setup_rendering_imports()
        
        
        quality_settings = {
            'low': (640, 480),
            'medium': (960, 720),
            'high': (1280, 960)
        }
        width, height = quality_settings.get(quality, (1280, 960))
        
        
        class AttrDict(UserDict):
            def __getattr__(self, name):
                try:
                    return self.data[name]
                except KeyError:
                    raise AttributeError(name)
            def __setattr__(self, name, value):
                if name == 'data':
                    super().__setattr__(name, value)
                else:
                    self.data[name] = value
        
        
        temp_dir = self._create_temp_directory(output_file)
        
        
        output, vertices, faces = self._get_smpl_output(return_full_pose=True)
        
        
        loss_computer = COAPSelfPenetrationLoss(
            smpl_model=self.smpl_model,
            device=self.device,
            batch_size=batch_size
        )
        
        
        ground_y, ground_x_min, ground_x_max, ground_z_min, ground_z_max = self._setup_ground(vertices)
        ground_pyrender = self._create_ground_mesh(ground_y, ground_x_min, ground_x_max, ground_z_min, ground_z_max)
        scene, camera_pose = self._setup_scene_and_camera(ground_pyrender, camera_distance=4.0)
        
        
        r = pyrender.OffscreenRenderer(width, height)
        
        
        collision_frames = 0
        total_collision_points = 0
        
        
        frames = []
        for i in tqdm(range(len(vertices)), desc="Rendering frames"):
            
            frame_output = AttrDict({
                'vertices': output.vertices[i:i+1],
                'joints': output.joints[i:i+1],
                'full_pose': output.full_pose[i:i+1],
                'global_orient': output.global_orient[i:i+1],
                'body_pose': output.body_pose[i:i+1],
                'betas': output.betas[i:i+1],
            })
            _, collision_samples = loss_computer.model.coap.self_collision_loss(frame_output, ret_samples=True)
            collision_samples = collision_samples[0]
            
            
            if collision_samples is not None and len(collision_samples) > 0:
                collision_frames += 1
                total_collision_points += len(collision_samples)
                print(f"frame {i}: detected {len(collision_samples)} collision points")
            
            
            vertex_colors = self._get_collision_colors(vertices[i], collision_samples, collision_radius)
            
            
            mesh = trimesh.Trimesh(vertices=vertices[i], faces=faces)
            mesh.visual.vertex_colors = vertex_colors
            
            
            self._clear_scene_meshes(scene)
            
            
            
            pyrender_mesh = pyrender.Mesh.from_trimesh(mesh)
            
            
            scene.add(pyrender_mesh)
            
            
            color, _ = r.render(scene)
            frames.append(color)
        
        
        print(f"Collision stats:")
        print(f"  total frames: {len(vertices)}")
        print(f"  frames with collisions: {collision_frames}")
        print(f"  total collision points: {total_collision_points}")
        print(f"  collision frame rate: {collision_frames/len(vertices)*100:.1f}%")
        print(f"Render settings:")
        print(f"  resolution: {width}x{height}")
        print(f"  quality: {quality}")
        
        
        imageio.mimsave(output_file, frames, fps=30, quality=9)  
        
        
        r.delete()
        del loss_computer
        torch.cuda.empty_cache()
        import shutil
        shutil.rmtree(temp_dir)
        print(f"Video saved to: {output_file}")

    def _get_collision_colors(self, vertices, collision_samples, collision_radius=0.03):
        
        
        vertex_colors = np.ones((vertices.shape[0], 4))
        vertex_colors[:, :3] = [0.3, 0.5, 0.8]  
        vertex_colors[:, 3] = 1.0
        
        
        if collision_samples is not None and len(collision_samples) > 0:
            collision_points = collision_samples.cpu().numpy()
            for point in collision_points:
                dists = np.linalg.norm(vertices - point, axis=1)
                nearby = dists < collision_radius
                vertex_colors[nearby, :3] = [1.0, 0.0, 0.0]  
        
        return vertex_colors

    def export_meshes(self, output_dir: str):
        
        os.makedirs(output_dir, exist_ok=True)
        vertices = self.get_vertices()
        
        for i in range(len(vertices)):
            mesh = trimesh.Trimesh(vertices[i].cpu().numpy(), self.smpl_model.faces)
            mesh.export(os.path.join(output_dir, f'frame_{i:04d}.obj'))

    def get_vertices(self, with_grad: bool = False) -> torch.Tensor:
        
        with torch.set_grad_enabled(with_grad):
            output = self.smpl_model(
                betas=self.smpl_params['betas'],
                body_pose=self.smpl_params['poses'][:, 3:],
                global_orient=self.smpl_params['poses'][:, :3],
                transl=self.smpl_params['trans']
            )
        return output.vertices

    def reset_smpl_params(self, smpl_params: Dict[str, torch.Tensor]) -> None:
        
        
        self.smpl_params = {k: v.to(self.device) for k, v in smpl_params.items()}
        
        
        self.joints = self.smpl2joints()

    def check_interaction(self, other_motion, threshold: float = 0.01, batch_size: int = 8) -> Dict[str, float]:
        
        
        with torch.no_grad():
            output1 = self.smpl_model(
                betas=self.smpl_params['betas'],
                body_pose=self.smpl_params['poses'][:, 3:],
                global_orient=self.smpl_params['poses'][:, :3],
                transl=self.smpl_params['trans'],
                return_verts=True,
                return_full_pose=True
            )
            
            output2 = self.smpl_model(
                betas=other_motion.smpl_params['betas'],
                body_pose=other_motion.smpl_params['poses'][:, 3:],
                global_orient=other_motion.smpl_params['poses'][:, :3],
                transl=other_motion.smpl_params['trans'],
                return_verts=True,
                return_full_pose=True
            )
        
        
        loss_computer = COAPSelfPenetrationLoss(
            smpl_model=self.smpl_model,
            device=self.device,
            batch_size=batch_size
        )
        
        total_frames = output1.vertices.shape[0]
        all_interaction_losses = []
        interaction_count = 0
        
        
        class DictObject(dict):
            def __getattr__(self, name):
                if name in self:
                    return self[name]
                raise AttributeError(f"'DictObject' object has no attribute '{name}'")
        
        
        for i in range(0, total_frames, batch_size):
            end_idx = min(i + batch_size, total_frames)
            
            
            batch_output1 = DictObject({
                'vertices': output1.vertices[i:end_idx],
                'joints': output1.joints[i:end_idx],
                'full_pose': output1.full_pose[i:end_idx],
                'global_orient': output1.global_orient[i:end_idx],
                'body_pose': output1.body_pose[i:end_idx],
                'betas': output1.betas[i:end_idx],
            })
            
            batch_output2 = DictObject({
                'vertices': output2.vertices[i:end_idx],
                'joints': output2.joints[i:end_idx],
                'full_pose': output2.full_pose[i:end_idx],
                'global_orient': output2.global_orient[i:end_idx],
                'body_pose': output2.body_pose[i:end_idx],
                'betas': output2.betas[i:end_idx],
            })
            
            
            inter_loss1, _ = loss_computer.model.coap.collision_loss(batch_output2.vertices, batch_output1)
            inter_loss2, _ = loss_computer.model.coap.collision_loss(batch_output1.vertices, batch_output2)
            inter_loss = (inter_loss1 + inter_loss2) / 2
            
            all_interaction_losses.append(inter_loss.detach().cpu())
            interaction_count += torch.sum(inter_loss > threshold).item()
            
            
            torch.cuda.empty_cache()
        
        all_interaction_losses = torch.cat(all_interaction_losses)
        interaction_rate = interaction_count / total_frames
        
        
        del loss_computer
        torch.cuda.empty_cache()
        
        return {
            'avg_interaction': all_interaction_losses.mean().item(),
            'max_interaction': all_interaction_losses.max().item(),
            'interaction_rate': interaction_rate,
            'interaction_frames': interaction_count,
            'total_frames': total_frames
        }

    def export_meshes_with_collision_color(self, output_dir: str, collision_radius: float = 0.03, batch_size: int = 8):
        
        import numpy as np
        import trimesh
        from collections import UserDict
        os.makedirs(output_dir, exist_ok=True)
        
        
        class AttrDict(UserDict):
            def __getattr__(self, name):
                try:
                    return self.data[name]
                except KeyError:
                    raise AttributeError(name)
            def __setattr__(self, name, value):
                if name == 'data':
                    super().__setattr__(name, value)
                else:
                    self.data[name] = value
        
        
        output, vertices, faces = self._get_smpl_output(return_full_pose=True)
        
        
        loss_computer = COAPSelfPenetrationLoss(
            smpl_model=self.smpl_model,
            device=self.device,
            batch_size=batch_size
        )
        
        
        for i in range(vertices.shape[0]):
            
            frame_output = AttrDict({
                'vertices': output.vertices[i:i+1],
                'joints': output.joints[i:i+1],
                'full_pose': output.full_pose[i:i+1],
                'global_orient': output.global_orient[i:i+1],
                'body_pose': output.body_pose[i:i+1],
                'betas': output.betas[i:i+1],
            })
            _, collision_samples = loss_computer.model.coap.self_collision_loss(frame_output, ret_samples=True)
            collision_samples = collision_samples[0]
            
            
            vertex_colors = np.ones((vertices.shape[1], 4))
            vertex_colors[:, :3] = [0.3, 0.5, 0.8]
            vertex_colors[:, 3] = 1.0
            
            
            if collision_samples is not None and len(collision_samples) > 0:
                collision_points = collision_samples.cpu().numpy()
                for point in collision_points:
                    dists = np.linalg.norm(vertices[i] - point, axis=1)
                    nearby = dists < collision_radius
                    vertex_colors[nearby, :3] = [1.0, 0.0, 0.0]
            
            mesh = trimesh.Trimesh(vertices=vertices[i], faces=faces)
            mesh.visual.vertex_colors = vertex_colors
            mesh.export(os.path.join(output_dir, f'frame_{i:04d}.ply'), encoding='ascii')
        
        del loss_computer
        torch.cuda.empty_cache()
