import os
import trimesh
import numpy as np
import re
import json
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt

class PointCloudProcessor:
    def __init__(self, main_partnet_dir, output_base_dir, total_points=2048, precision=5):
        self.main_partnet_dir = main_partnet_dir
        self.output_base_dir = output_base_dir
        self.total_points_to_sample = total_points
        self.json_float_precision = precision
        
        self.output_gt_dir = os.path.join(output_base_dir, "point_clouds")
        self.output_bbox_dir_base = os.path.join(output_base_dir, "bboxes")
        self.output_json_dir = os.path.join(output_base_dir, "json_questions")

        os.makedirs(self.output_gt_dir, exist_ok=True)
        os.makedirs(self.output_bbox_dir_base, exist_ok=True)
        os.makedirs(self.output_json_dir, exist_ok=True)
        
        self.global_semantic_to_id, self.global_semantic_names = self._get_global_semantic_mapping(main_partnet_dir)
        self.num_semantic_classes = len(self.global_semantic_names) if self.global_semantic_names else 0

    def _get_global_semantic_mapping(self, partnet_mobility_dataset_path):
        semantics_files = []
        for item in os.listdir(partnet_mobility_dataset_path):
            item_path = os.path.join(partnet_mobility_dataset_path, item)
            if os.path.isdir(item_path) and item.isdigit():
                semantics_file = os.path.join(item_path, "semantics.txt")
                if os.path.isfile(semantics_file):
                    semantics_files.append(semantics_file)
        
        all_semantic_names = set()
        for file_path in semantics_files:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 3:
                        semantic_name = parts[2]
                        if semantic_name and semantic_name.lower() != 'none':
                            all_semantic_names.add(semantic_name)

        if not all_semantic_names:
            return None, None
            
        unique_semantic_names = sorted(list(all_semantic_names))
        semantic_to_global_id = {name: i for i, name in enumerate(unique_semantic_names)}
        return semantic_to_global_id, unique_semantic_names

    def _create_object_link_to_global_id_mapping(self, semantics_path):
        if not os.path.exists(semantics_path) or self.global_semantic_to_id is None:
            return None
        link_to_global_id = {}
        with open(semantics_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 3:
                    link_name, _, semantic_name = parts[0], parts[1], parts[2]
                    if semantic_name in self.global_semantic_to_id:
                        link_to_global_id[link_name] = self.global_semantic_to_id[semantic_name]
        return link_to_global_id

    def _get_link_semantic_descriptions(self, semantics_path):
        if not os.path.exists(semantics_path):
            return None, None
        link_id_to_semantic_name = {}
        link_name_to_semantic_name = {}
        pattern = r'^link_(\d+)$'
        with open(semantics_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 3:
                    link_name, _, semantic_name = parts[0], parts[1], parts[2]
                    match = re.match(pattern, link_name)
                    if match:
                        numeric_link_id = int(match.group(1))
                        link_id_to_semantic_name[numeric_link_id] = semantic_name
                        link_name_to_semantic_name[link_name] = semantic_name
                    elif link_name == 'base':
                        link_name_to_semantic_name[link_name] = semantic_name
        return link_id_to_semantic_name, link_name_to_semantic_name

    def _parse_urdf_for_joints(self, urdf_path):
        if not os.path.exists(urdf_path):
            return None
        joints_data = []
        tree = ET.parse(urdf_path)
        root = tree.getroot()
        for joint in root.findall('joint'):
            joint_info = {
                'id': joint.get('name'),
                'type': joint.get('type'),
                'parent': joint.find('parent').get('link') if joint.find('parent') is not None else None,
                'child': joint.find('child').get('link') if joint.find('child') is not None else None
            }
            origin = joint.find('origin')
            if origin is not None:
                xyz_str = origin.get('xyz', "0.0 0.0 0.0")
                rpy_str = origin.get('rpy', "0.0 0.0 0.0")
                joint_info['origin'] = {'xyz': [float(x) for x in xyz_str.split()], 'rpy': [float(r) for r in rpy_str.split()]}
            else:
                joint_info['origin'] = {'xyz': [0.0, 0.0, 0.0], 'rpy': [0.0, 0.0, 0.0]}
            
            axis = joint.find('axis')
            if axis is not None:
                joint_info['axis'] = [float(a) for a in axis.get('xyz', "1.0 0.0 0.0").split()]
            elif joint_info['type'] != 'fixed':
                joint_info['axis'] = [1.0, 0.0, 0.0]
            else:
                joint_info['axis'] = [0.0, 0.0, 0.0]
            joints_data.append(joint_info)
        return joints_data

    def _get_object_name_from_result_json(self, result_json_path):
        if not os.path.exists(result_json_path):
            return "UnknownObject"
        with open(result_json_path, 'r') as f:
            data = json.load(f)
        obj_data = data[0] if isinstance(data, list) and data else data
        if isinstance(obj_data, dict):
            return obj_data.get('text', obj_data.get('name', "UnknownObject"))
        return "UnknownObject"

    def _round_floats(self, o, precision):
        if isinstance(o, float): return round(o, precision)
        if isinstance(o, dict): return {k: self._round_floats(v, precision) for k, v in o.items()}
        if isinstance(o, (list, tuple)): return [self._round_floats(x, precision) for x in o]
        if isinstance(o, (np.float32, np.float64)): return round(float(o), precision)
        return o

    def _map_labels_to_colors(self, labels, cmap_name='viridis'):
        if not isinstance(labels, np.ndarray): labels = np.array(labels)
        if labels.ndim != 1 or self.num_semantic_classes <= 0 or len(labels) == 0:
            return np.empty((0, 4), dtype=np.uint8)
        
        cmap = plt.get_cmap(cmap_name)
        if self.num_semantic_classes == 1:
            class_colors_float = cmap(0.5) 
        else:
            class_colors_float = cmap(np.linspace(0, 1, self.num_semantic_classes))
        class_colors_uint8 = (np.array(class_colors_float) * 255).astype(np.uint8)
        if class_colors_uint8.ndim == 1: class_colors_uint8 = class_colors_uint8.reshape(1, 4)
            
        mapped_colors = np.full((len(labels), 4), [128, 128, 128, 255], dtype=np.uint8)
        valid_mask = (labels >= 0) & (labels < self.num_semantic_classes)
        valid_labels = labels[valid_mask].astype(int)
        mapped_colors[valid_mask] = class_colors_uint8[valid_labels]
        return mapped_colors

    def sample_points_and_labels(self, obj_subdirectory):
        target_dir = os.path.join(self.main_partnet_dir, obj_subdirectory)
        semantics_path = os.path.join(target_dir, "semantics.txt")
        link_to_global_id_mapping = self._create_object_link_to_global_id_mapping(semantics_path)
        
        if not os.path.isdir(target_dir) or link_to_global_id_mapping is None:
            return np.empty((0, 3)), np.empty((0,), dtype=int), {}

        mesh_data, total_area = [], 0.0
        pattern = r'^(link_\d+)_combined_mesh\.obj$'
        for file in os.listdir(target_dir):
            match = re.fullmatch(pattern, file)
            if match:
                link_name = match.group(1)
                if link_name in link_to_global_id_mapping:
                    file_path = os.path.join(target_dir, file)
                    mesh = trimesh.load(file_path, force='mesh', process=True)
                    if isinstance(mesh, trimesh.Scene):
                         mesh = trimesh.util.concatenate([m for m in mesh.geometry.values() if isinstance(m, trimesh.Trimesh)])
                    if isinstance(mesh, trimesh.Trimesh) and mesh.vertices.size > 0 and mesh.faces.size > 0 and mesh.area > 1e-6:
                         mesh_data.append({'mesh': mesh, 'area': mesh.area, 'global_semantic_id': link_to_global_id_mapping[link_name], 'link_name': link_name})
                         total_area += mesh.area

        if not mesh_data: return np.empty((0, 3)), np.empty((0,), dtype=int), {}
        
        if total_area <= 1e-6:
            points_per_mesh = self.total_points_to_sample // len(mesh_data)
            for i, data in enumerate(mesh_data): data['points_to_sample'] = points_per_mesh + (1 if i < self.total_points_to_sample % len(mesh_data) else 0)
        else:
            running_total = 0
            for data in mesh_data:
                points = int(round((data['area'] / total_area) * self.total_points_to_sample))
                data['points_to_sample'] = points
                running_total += points
            
            difference = self.total_points_to_sample - running_total
            mesh_data.sort(key=lambda x: x['area'], reverse=(difference > 0))
            for i in range(abs(difference)):
                mesh_data[i % len(mesh_data)]['points_to_sample'] += 1 if difference > 0 else -1

        all_sampled_points, all_labels = [], []
        bbox_info = {}
        for data in mesh_data:
            if data['points_to_sample'] > 0:
                points, _ = trimesh.sample.sample_surface_even(data['mesh'], data['points_to_sample'])
                if points.shape[0] > 0:
                    all_sampled_points.append(points)
                    all_labels.append(np.full((points.shape[0],), data['global_semantic_id'], dtype=int))
                    link_id = int(data['link_name'].split('_')[1])
                    bbox_info[link_id] = {"bounds": data['mesh'].bounds.tolist()}

        if not all_sampled_points: return np.empty((0, 3)), np.empty((0,), dtype=int), {}
            
        final_points = np.vstack(all_sampled_points)
        final_labels = np.concatenate(all_labels)
        
        current_total = final_points.shape[0]
        if current_total != self.total_points_to_sample and current_total > 0:
            if current_total > self.total_points_to_sample:
                indices = np.random.choice(current_total, self.total_points_to_sample, replace=False)
                final_points, final_labels = final_points[indices], final_labels[indices]
            else:
                num_needed = self.total_points_to_sample - current_total
                indices_to_duplicate = np.random.choice(current_total, num_needed, replace=True)
                final_points = np.vstack((final_points, final_points[indices_to_duplicate]))
                final_labels = np.concatenate((final_labels, final_labels[indices_to_duplicate]))
                shuffle = np.arange(self.total_points_to_sample); np.random.shuffle(shuffle)
                final_points, final_labels = final_points[shuffle], final_labels[shuffle]

        return final_points, final_labels, bbox_info

    def generate_and_save_json_variations(self, obj_subdirectory, object_name, link_name_to_semantic, joints_data, bbox_info):
        num_parts = len([k for k in bbox_info.keys() if k != 'base'])
        joints_data_rounded = self._round_floats(joints_data, self.json_float_precision)
        
        links_answer_seg = {}
        sorted_link_names = sorted(link_name_to_semantic.keys(), key=lambda x: (-1 if x == 'base' else int(x.split('_')[-1])))
        for link_name in sorted_link_names:
            if link_name.startswith("link_"):
                links_answer_seg[link_name] = f"{link_name_to_semantic.get(link_name, link_name)}[SEG]"

        def save_json(data, dirname, filename_suffix=""):
            dirpath = os.path.join(self.output_json_dir, dirname)
            os.makedirs(dirpath, exist_ok=True)
            filename = f"{obj_subdirectory}{'_' + filename_suffix if filename_suffix else ''}.json"
            filepath = os.path.join(dirpath, filename)
            with open(filepath, 'w', encoding='utf-8') as f:
                json.dump(self._round_floats(data, self.json_float_precision), f, indent=4, ensure_ascii=False)

        point_cloud_all = {name: link_name_to_semantic.get(name, name) for name in sorted_link_names}
        bbox_question_parts = [f"{f'link_{lid}'} with bbox {self._round_floats(bbox_info[lid].get('bounds'), 2)}" for lid in sorted(bbox_info.keys())]
        bbox_result_string = ", ".join(bbox_question_parts) + "."
        
        q_bbox = f"This articulated object {object_name} consists of {num_parts} parts. {bbox_result_string} Predict all joint parameters in JSON format, including type, origin, axis, parent, and child."
        q_nobbox = f"This articulated object {object_name} consists of {num_parts} parts. Predict all joint parameters in JSON format, including type, origin, axis, parent, and child."
        q_all_seg_bbox = f"{q_bbox} Segment each link in JSON format."
        q_all_seg_nobbox = f"{q_nobbox} Segment each link in JSON format."

        save_json({"point_cloud": point_cloud_all, "question": q_bbox, "answer": {"joints": joints_data_rounded}}, "joint_all_parameters_bbox")
        save_json({"point_cloud": point_cloud_all, "question": q_nobbox, "answer": {"joints": joints_data_rounded}}, "joint_all_parameters_nobbox")
        save_json({"point_cloud": point_cloud_all, "question": q_all_seg_bbox, "answer": {"joints": joints_data_rounded, "links": links_answer_seg}}, "joint_all_parameters_all_seg_bbox")
        save_json({"point_cloud": point_cloud_all, "question": q_all_seg_nobbox, "answer": {"joints": joints_data_rounded, "links": links_answer_seg}}, "joint_all_parameters_all_seg_nobbox")

    def process_object(self, obj_subdirectory):
        current_object_dir = os.path.join(self.main_partnet_dir, obj_subdirectory)
        semantics_path = os.path.join(current_object_dir, "semantics.txt")
        urdf_path = os.path.join(current_object_dir, "new.urdf")
        result_json_path = os.path.join(current_object_dir, "result.json")

        if not all(map(os.path.exists, [semantics_path, urdf_path])):
            return

        sampled_points, sampled_labels, bbox_info = self.sample_points_and_labels(obj_subdirectory)
        
        if sampled_points.size == 0 or sampled_labels.size == 0:
            return

        object_name = self._get_object_name_from_result_json(result_json_path)
        _, link_name_to_semantic = self._get_link_semantic_descriptions(semantics_path)
        joints_data = self._parse_urdf_for_joints(urdf_path)

        if not joints_data or not link_name_to_semantic:
             return
        
        self.generate_and_save_json_variations(obj_subdirectory, object_name, link_name_to_semantic, joints_data, bbox_info)

        colors = self._map_labels_to_colors(sampled_labels)
        if colors.size > 0:
            ply_output_path = os.path.join(self.output_gt_dir, f"{obj_subdirectory}.ply")
            point_cloud_obj = trimesh.PointCloud(vertices=sampled_points, colors=colors)
            point_cloud_obj.export(ply_output_path)

    def run(self, urdf_list_path):
        with open(urdf_list_path, 'r') as f:
            subdirs_to_process = [line.strip() for line in f if line.strip().isdigit()]
        
        for subdir in sorted(subdirs_to_process, key=int):
            self.process_object(subdir)
