import numpy as np
import torch
import smplx
import trimesh
import os
from typing import Union, Optional, List, Dict
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import chamfer_distance

class Motion:
    def __init__(
        self,
        joints: Optional[torch.Tensor] = None,
        smpl_params: Optional[Dict[str, torch.Tensor]] = None,
        smpl_model: Optional[smplx.SMPL] = None,
    ):
        """Motion class for handling SMPL-based motion data"""
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        if smpl_model is None:
            # Initialize SMPL model
            self.smpl_model = smplx.create(
                model_path='path/to/smpl/models',  # SMPL model path
                model_type='smpl',
                gender='neutral',
                use_face_contour=False,
                num_betas=10
            ).to(self.device)
        else:
            self.smpl_model = smpl_model
            
        if joints is not None:
            self.joints = joints.to(self.device)
            # Convert joints to SMPL parameters
            self.smpl_params = self.joints2smpl(joints)
        elif smpl_params is not None:
            self.smpl_params = {k: v.to(self.device) for k, v in smpl_params.items()}
            # Convert SMPL parameters to joints
            self.joints = self.smpl2joints()
        else:
            raise ValueError("Either 'joints' or 'smpl_params' must be provided")

    def joints2smpl(self, joints: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Convert joints to SMPL parameters"""
        # Convert joints to SMPL parameters (to be implemented)
        raise NotImplementedError("joints2smpl needs implementation")

    def smpl2joints(self) -> torch.Tensor:
        """Convert SMPL parameters to joints"""
        output = self.smpl_model(
            betas=self.smpl_params['betas'],
            body_pose=self.smpl_params['poses'][:, 3:],
            global_orient=self.smpl_params['poses'][:, :3]
        )
        return output.joints

    def check_penetration(self, threshold: float = 0.01) -> torch.Tensor:
        """Check for self-penetration in the motion"""
        vertices = self.get_vertices()
        penetration_mask = torch.zeros(len(vertices), dtype=torch.bool)
        
        for i in range(len(vertices)):
            mesh = trimesh.Trimesh(vertices[i].cpu().numpy(), self.smpl_model.faces)
            # Check if mesh is watertight and contains its centroid
            penetration_mask[i] = mesh.is_watertight and mesh.contains(mesh.centroid)
            
        return penetration_mask

    def optimize_self_intersection(self, max_iters: int = 100):
        """Optimize poses to reduce self-intersection"""
        # Initialize pose optimization
        poses = self.smpl_params['poses'].clone().requires_grad_(True)
        optimizer = torch.optim.Adam([poses], lr=0.01)
        
        for _ in range(max_iters):
            optimizer.zero_grad()
            
            # Forward pass through SMPL model
            output = self.smpl_model(
                betas=self.smpl_params['betas'],
                body_pose=poses[:, 3:],
                global_orient=poses[:, :3]
            )
            vertices = output.vertices
            
            # Compute self-intersection loss
            loss = self.compute_self_intersection_loss(vertices)
            
            loss.backward()
            optimizer.step()
            
        self.smpl_params['poses'] = poses.detach()
        self.joints = self.smpl2joints()

    def compute_self_intersection_loss(self, vertices: torch.Tensor) -> torch.Tensor:
        """Compute self-intersection loss using chamfer distance"""
        # Create meshes and sample points for loss computation
        meshes = Meshes(verts=vertices, faces=self.smpl_model.faces.expand(len(vertices), -1, -1))
        points = sample_points_from_meshes(meshes, num_samples=1000)
        loss, _ = chamfer_distance(points, points)
        return loss

    def plot_3d_motion(self, output_file: str = 'motion.gif'):
        """Plot 3D motion animation"""
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
        from matplotlib.animation import FuncAnimation
        
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection='3d')
        
        def update(frame):
            ax.clear()
            joints = self.joints[frame].cpu().numpy()
            ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2])
            
            
            
            ax.set_xlim(-2, 2)
            ax.set_ylim(-2, 2)
            ax.set_zlim(-2, 2)
            
        anim = FuncAnimation(fig, update, frames=len(self.joints), interval=50)
        anim.save(output_file)
        plt.close()

    def render_mesh_frames(self, output_dir: str):
        
        import pyrender
        import trimesh
        
        os.makedirs(output_dir, exist_ok=True)
        vertices = self.get_vertices()
        
        for i in range(len(vertices)):
            mesh = trimesh.Trimesh(vertices[i].cpu().numpy(), self.smpl_model.faces)
            scene = pyrender.Scene()
            scene.add(pyrender.Mesh.from_trimesh(mesh))
            
            
            camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0)
            light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=2.0)
            
            scene.add(camera, pose=np.eye(4))
            scene.add(light, pose=np.eye(4))
            
            r = pyrender.OffscreenRenderer(640, 480)
            color, _ = r.render(scene)
            
            import imageio
            imageio.imsave(os.path.join(output_dir, f'frame_{i:04d}.png'), color)

    def export_meshes(self, output_dir: str):
        
        os.makedirs(output_dir, exist_ok=True)
        vertices = self.get_vertices()
        
        for i in range(len(vertices)):
            mesh = trimesh.Trimesh(vertices[i].cpu().numpy(), self.smpl_model.faces)
            mesh.export(os.path.join(output_dir, f'frame_{i:04d}.obj'))

    def get_vertices(self) -> torch.Tensor:
        
        output = self.smpl_model(
            betas=self.smpl_params['betas'],
            body_pose=self.smpl_params['poses'][:, 3:],
            global_orient=self.smpl_params['poses'][:, :3]
        )
        return output.vertices 