import io
import re
import trimesh
import objaverse
import numpy as np
import open3d as o3d
from typing import List
from scipy.spatial import distance
from PIL import Image, ImageDraw, ImageFont



def filter_uids(lvis_class=None, tag_filter=None, name_substring=None):
    """Filter Objaverse UIDs based on LVIS class, tags, or name substring"""
    annotations = objaverse.load_annotations()
    lvis_annotations = objaverse.load_lvis_annotations()

    if lvis_class:
        uids = set(lvis_annotations.get(lvis_class, []))
    else:
        uids = set(annotations.keys())

    # Filter by tag (exact match of tag 'name')
    if tag_filter:
        tag_filter = set(tag_filter)
        uids = {
            uid for uid in uids
            if any(tag['name'] in tag_filter for tag in annotations[uid].get('tags', []))
        }

    # Filter by name substring
    if name_substring:
        uids = {
            uid for uid in uids
            if name_substring.lower() in annotations[uid]['name'].lower()
        }

    return list(uids)

def load_objaverse_mesh(uid: str) -> trimesh.Trimesh:
    mesh = objaverse.load_objects(uids=[uid])
    mesh  = trimesh.load(list(mesh.values())[0])
    mesh  = list(mesh.geometry.values())[0]
    return mesh

def load_objaverse_scene(uid: str) -> trimesh.Trimesh:
    scene = objaverse.load_objects(uids=[uid])
    scene = trimesh.load_mesh(scene[uid])
    return scene

def objects_to_prompts(uids: List[str]) -> List[str]:
    annotations = objaverse.load_annotations(uids)
    prompts = []
    for uid in annotations:
        attributes = [annotations[uid]['name']]
        for tag in annotations[uid]['tags']:
            attributes.append(tag['name'])
        attributes.append('3d asset')
        p = ', '.join(attributes)
        prompts.append(p)
    return prompts

def normalize_mesh(mesh: trimesh.Trimesh):
    vertices = mesh.vertices
    pairwise_distances = distance.pdist(vertices)
    max_dist = np.max(pairwise_distances)
    normalized_vertices = vertices / max_dist
    normalized_mesh = trimesh.Trimesh(vertices=normalized_vertices, faces=mesh.faces)
    return normalized_mesh

def sample_to_point_cloud(mesh: trimesh.Trimesh, num_points: int=1000):
    point_cloud, _ = trimesh.sample.sample_surface_even(mesh, num_points)
    return point_cloud

def trimesh_to_pcd_vertices(trimesh_mesh: trimesh.Trimesh, nsamples: int=5000) -> o3d.geometry.PointCloud:
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(trimesh_mesh.vertices)
    return pcd

def trimesh_to_pcd(mesh: trimesh.Trimesh, nsamples: int=5000) -> o3d.geometry.PointCloud:
    pcd = o3d.geometry.PointCloud()
    if isinstance(mesh, trimesh.Trimesh):
        pcd.points = o3d.utility.Vector3dVector(mesh.sample(nsamples))
    elif isinstance(mesh, o3d.geometry.TriangleMesh):
        pcd = mesh.sample_points_uniformly(number_of_points=nsamples)
    return pcd

def trimesh_to_o3d(trimesh_mesh: trimesh.Trimesh) -> o3d.geometry.TriangleMesh:
    o3d_mesh = o3d.geometry.TriangleMesh()
    o3d_mesh.vertices = o3d.utility.Vector3dVector(trimesh_mesh.vertices)
    o3d_mesh.triangles = o3d.utility.Vector3iVector(trimesh_mesh.faces)
    return o3d_mesh

def preprocess(mesh: trimesh.Trimesh, nsamples: int=5000, verbose=False):
    normalized_mesh = normalize_mesh(mesh)
    point_cloud = trimesh_to_pcd_vertices(normalized_mesh, nsamples)
    if verbose:
        print("Normalized mesh vertices:\n", normalized_mesh.vertices)
        print("Sampled points from normalized mesh:\n", point_cloud)
    return point_cloud

class MeshRenderer:
    """Render 3D meshes to 2D images for visualization"""
    
    def __init__(self, image_size=(200, 200), background_color=(1, 1, 1, 1)):
        self.image_size = image_size
        self.background_color = background_color

    def render_mesh_to_image(self, mesh: trimesh.Trimesh, uid: str = None) -> Image.Image:
        """Render a trimesh to a PIL Image"""
        try:
            # Normalize mesh
            mesh = mesh.copy()
            mesh.vertices -= mesh.centroid
            scale = np.max(np.linalg.norm(mesh.vertices, axis=1))
            if scale > 0:
                mesh.vertices /= scale
            
            # Create scene
            scene = trimesh.Scene([mesh])
            
            # Try different camera positions to find the best view
            camera_transforms = [
                # Isometric-like views
                trimesh.transformations.rotation_matrix(np.radians(45), [1, 0, 0]) @ 
                trimesh.transformations.rotation_matrix(np.radians(45), [0, 0, 1]),
                
                # Front view
                trimesh.transformations.translation_matrix([0, 0, 3]),
                
                # Side view  
                trimesh.transformations.rotation_matrix(np.radians(90), [0, 0, 1]) @
                trimesh.transformations.translation_matrix([0, 0, 3]),
                
                # Angled view
                trimesh.transformations.rotation_matrix(np.radians(30), [1, 0, 0]) @
                trimesh.transformations.rotation_matrix(np.radians(30), [0, 0, 1])
            ]
            
            best_image = None
            max_coverage = 0
            
            for transform in camera_transforms:
                try:
                    # Set camera transform
                    scene.camera_transform = transform
                    # print(scene.camera)
                    # print(scene.graph.nodes_geometry)
                    
                    # Render with basic parameters
                    png_data = scene.save_image(resolution=self.image_size)
                    
                    if png_data is None:
                        continue
                    
                    # Convert to PIL Image
                    image = Image.open(io.BytesIO(png_data))
                    
                    # Calculate coverage (non-background pixels)
                    img_array = np.array(image)
                    if len(img_array.shape) == 3:
                        # Check for non-white pixels (assuming white background)
                        non_bg_pixels = np.sum(np.any(img_array[:, :, :3] < 250, axis=2))
                        coverage = non_bg_pixels / (self.image_size[0] * self.image_size[1])
                        
                        if coverage > max_coverage:
                            max_coverage = coverage
                            best_image = image
                    else:
                        if best_image is None:
                            best_image = image
                
                except Exception as e:
                    print(f"Camera transform failed: {e}")
                    continue
            
            # If no good view found, try the simplest approach
            if best_image is None:
                try:
                    png_data = scene.save_image(resolution=self.image_size)
                    if png_data is not None:
                        best_image = Image.open(io.BytesIO(png_data))
                except Exception as e:
                    print(f"Simple render failed: {e}")
            
            if best_image is None:
                # Fallback: create a simple geometric representation
                best_image = self.create_fallback_image(uid)
            
            return best_image.convert('RGB')
            
        except Exception as e:
            print(f"Error rendering mesh {uid}: {e}")
            return self.create_fallback_image(uid)
        

    def render_mesh_to_image(self, mesh: trimesh.Trimesh, uid: str = None) -> Image.Image:
        """Render a trimesh to a PIL Image using matplotlib-based rendering"""
        try:
            # Normalize mesh
            mesh = mesh.copy()
            mesh.vertices -= mesh.centroid
            scale = np.max(np.linalg.norm(mesh.vertices, axis=1))
            if scale > 0:
                mesh.vertices /= scale
            
            # Try matplotlib 3D rendering approach
            try:
                return self.render_with_matplotlib(mesh, uid)
            except Exception as e:
                print(f"Matplotlib rendering failed for {uid}: {e}")
            
            # Fallback: Try trimesh scene rendering with basic setup
            try:
                scene = trimesh.Scene([mesh])
                png_data = scene.save_image(resolution=self.image_size)
                
                if png_data is not None:
                    image = Image.open(io.BytesIO(png_data))
                    return image.convert('RGB')
            except Exception as e:
                print(f"Trimesh rendering failed for {uid}: {e}")
            
            # Final fallback: create a simple geometric representation
            return self.create_fallback_image(uid)
            
        except Exception as e:
            print(f"Error rendering mesh {uid}: {e}")
            return self.create_fallback_image(uid)
    
    def render_with_matplotlib(self, mesh: trimesh.Trimesh, uid: str = None) -> Image.Image:
        """Render mesh using matplotlib 3D"""
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d.art3d import Poly3DCollection
        
        # Create figure with no axes for clean look
        fig = plt.figure(figsize=(4, 4), dpi=50)
        ax = fig.add_subplot(111, projection='3d')
        
        # Get mesh data
        vertices = mesh.vertices
        faces = mesh.faces
        
        # Create face collection
        face_collection = []
        for face in faces:
            face_vertices = vertices[face]
            face_collection.append(face_vertices)
        
        # Add mesh to plot
        poly3d = Poly3DCollection(face_collection, alpha=0.8, facecolor='lightblue', 
                                 edgecolor='black', linewidth=0.1)
        ax.add_collection3d(poly3d)
        
        # Set equal aspect ratio and clean appearance
        max_range = np.array([vertices[:, 0].max()-vertices[:, 0].min(),
                             vertices[:, 1].max()-vertices[:, 1].min(),
                             vertices[:, 2].max()-vertices[:, 2].min()]).max() / 2.0
        
        mid_x = (vertices[:, 0].max()+vertices[:, 0].min()) * 0.5
        mid_y = (vertices[:, 1].max()+vertices[:, 1].min()) * 0.5
        mid_z = (vertices[:, 2].max()+vertices[:, 2].min()) * 0.5
        
        ax.set_xlim(mid_x - max_range, mid_x + max_range)
        ax.set_ylim(mid_y - max_range, mid_y + max_range)
        ax.set_zlim(mid_z - max_range, mid_z + max_range)
        
        # Remove axes for cleaner look
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
        ax.grid(False)
        
        # Set viewing angle (isometric-like)
        ax.view_init(elev=30, azim=45)
        
        # Remove background
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
        ax.xaxis.pane.set_edgecolor('white')
        ax.yaxis.pane.set_edgecolor('white')
        ax.zaxis.pane.set_edgecolor('white')
        
        # Save to buffer
        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, 
                   facecolor='white', dpi=50)
        buf.seek(0)
        
        # Convert to PIL Image
        image = Image.open(buf)
        plt.close(fig)  # Important: close figure to free memory
        
        # Resize to target size
        image = image.resize(self.image_size, Image.Resampling.LANCZOS)
        
        return image.convert('RGB')
    
    def create_fallback_image(self, uid: str = None) -> Image.Image:
        """Create a fallback image when rendering fails"""
        img = Image.new('RGB', self.image_size, color='white')
        draw = ImageDraw.Draw(img)
        
        # Draw a simple placeholder
        center_x, center_y = self.image_size[0] // 2, self.image_size[1] // 2
        radius = min(self.image_size) // 4
        
        # Draw a circle
        draw.ellipse([
            center_x - radius, center_y - radius,
            center_x + radius, center_y + radius
        ], fill='lightgray', outline='gray')
        
        # Add UID text if available
        if uid:
            try:
                font = ImageFont.load_default()
                text = uid[:8] + "..." if len(uid) > 8 else uid
                draw.text((10, 10), text, fill='black', font=font)
            except:
                pass
        
        return img