import trimesh
import numpy as np
import os
import argparse
import logging
import sys
from pathlib import Path


def calculate_global_normalization_params(scene):
    """Calculate global normalization parameters for the entire scene"""
    all_vertices = []
    
    # Collect all vertices from all geometries
    for name, geometry in scene.geometry.items():
        # Apply the transformation matrix if exists
        if name in scene.graph:
            transform = scene.graph[name][0]
            vertices = trimesh.transformations.transform_points(geometry.vertices, transform)
        else:
            vertices = geometry.vertices
        all_vertices.append(vertices)
    
    if not all_vertices:
        return np.array([0, 0, 0]), 1.0
    
    # Combine all vertices
    combined_vertices = np.vstack(all_vertices)
    
    # Calculate global centroid and scale
    centroid = np.mean(combined_vertices, axis=0)
    centered_vertices = combined_vertices - centroid
    scale = np.max(np.linalg.norm(centered_vertices, axis=1))
    
    if scale == 0:
        scale = 1.0
    
    return centroid, scale

def apply_normalization_params(mesh, centroid, scale):
    """Apply pre-calculated normalization parameters to a mesh"""
    mesh.vertices = (mesh.vertices - centroid) / scale
    return mesh

def apply_rotation(mesh, rotation):
    """Apply rotation matrix to mesh vertices"""
    mesh.vertices = np.dot(mesh.vertices, rotation.T)
    return mesh

def process_mesh(input_path, output_dir, rotation_matrix=None):
    """Main processing function"""
    try:
        print("-" * 100)
        print(f"Processing: {input_path}")
        # Load mesh/scene
        scene = trimesh.load(input_path, force='scene')
        
        # Calculate global normalization parameters BEFORE splitting
        global_centroid, global_scale = calculate_global_normalization_params(scene)
        print(f"Global centroid: {global_centroid}")
        print(f"Global scale: {global_scale}")
        
        # Process all geometries with same normalization params
        for i, (name, geometry) in enumerate(scene.geometry.items()):
            mesh = geometry.copy()
            
            # Apply global transformation if exists in scene graph
            if name in scene.graph:
                transform = scene.graph[name][0]
                mesh.vertices = trimesh.transformations.transform_points(mesh.vertices, transform)
            
            # Apply global normalization parameters
            mesh = apply_normalization_params(mesh, global_centroid, global_scale)
            
            # Apply additional rotation if specified
            if rotation_matrix is not None:
                mesh = apply_rotation(mesh, rotation_matrix)
            
            # Export each component
            output_path = Path(output_dir) / f"obj_{i}.obj"
            mesh.export(output_path)
            # remove the material.mtl file if it exists
            mtl_path = os.path.join(output_dir, "material.mtl")
            if os.path.exists(mtl_path):
                os.remove(mtl_path)
            logging.info(f"Exported: {output_path}")
            
        return True
    except Exception as e:
        logging.error(f"Processing failed: {str(e)}")
        return False

def process_directory(input_dir, output_dir, rotation_matrix=None):
    """Process all mesh files in a directory"""
    for filename in os.listdir(input_dir):
        if filename.endswith(('.obj', '.ply', '.glb')):
            input_path = os.path.join(input_dir, filename)
            mesh_output_dir = os.path.join(output_dir, os.path.splitext(filename)[0])
            os.makedirs(mesh_output_dir, exist_ok=True)
            success = process_mesh(input_path, mesh_output_dir, rotation_matrix)
            if not success:
                logging.error(f"Failed to process: {input_path}")

def main():
    # Configure logging
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    
    # Setup argument parser
    parser = argparse.ArgumentParser(description='Split and normalize 3D meshes')
    parser.add_argument('--input', help='Input mesh file (obj/ply/glb)')
    parser.add_argument('--output_dir', help='Output directory')
    parser.add_argument('--rotation', nargs=3, type=float, metavar=('X', 'Y', 'Z'),
                        help='Rotation angles in degrees (XYZ order)')
    args = parser.parse_args()

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Prepare rotation matrix
    rot_matrix = None
    if args.rotation:
        angles_rad = np.radians(args.rotation)
        rot_matrix = trimesh.transformations.euler_matrix(*angles_rad)[:3, :3]

    # Process input (file or directory)
    if os.path.isdir(args.input):
        process_directory(args.input, args.output_dir, rot_matrix)
    else:
        success = process_mesh(args.input, args.output_dir, rot_matrix)
        sys.exit(0 if success else 1)


if __name__ == "__main__":
    main()