import os
import re
import cv2
import argparse
import numpy as np
import mitsuba as mi

# mi.set_variant("scalar_rgb")
mi.set_variant("cuda_ad_rgb")  # 'scalar_rgb', 'scalar_spectral', 'cuda_ad_rgb', 'llvm_ad_rgb'
from mitsuba import ScalarTransform4f as T
from mitsuba import Transform4f as DT
from mitsuba import Point3f, ScalarPoint3f
from mitsuba.util import write_bitmap, convert_to_bitmap
from glob import glob
import torch

def mitsuba_camera_position(elevations, azimuths, distances):
    """创建与Mitsuba坐标系对应的相机位置

    在Mitsuba中:
    - elevation是相对XY平面的角度（0为XY平面，90为正Z轴）
    - azimuth是在XY平面上从X轴开始的角度
    """
    # 创建网格以获取所有组合
    E, A, D = np.meshgrid(elevations, azimuths, distances, indexing='ij')

    # 将网格展平
    E_flat = E.flatten()
    A_flat = A.flatten()
    D_flat = D.flatten()

    # 转换为弧度并调整角度定义
    elevations_rad = np.radians(90 - E_flat)  # 调整为相对Z轴的角度
    azimuths_rad = np.radians(A_flat)

    # 批量计算相机位置
    x = D_flat * np.sin(elevations_rad) * np.cos(azimuths_rad)
    y = D_flat * np.sin(elevations_rad) * np.sin(azimuths_rad)
    z = D_flat * np.cos(elevations_rad)

    camera_positions = np.stack([x, y, z], axis=-1)
    camera_targets = np.zeros_like(camera_positions)
    camera_ups = np.array([[0, 0, 1]] * len(camera_positions), dtype=camera_positions.dtype)

    return camera_positions, camera_targets, camera_ups


def create_object_transform(translation, scale, rotation):
    """创建对象变换矩阵，按照Scale->Rotate->Translate的顺序应用"""
    transform = T()

    transform = transform.scale([scale, scale, scale])
    transform = transform.rotate([1, 0, 0], rotation[0]).rotate([0, 1, 0], rotation[1]).rotate([0, 0, 1], rotation[2])
    transform = transform.translate(translation)

    return transform


def create_scene_dict_all_meshes(obj_files, obj_idx_list, cameras, resolution, translation, scale, rotation):
    """创建包含所有物体的场景字典"""
    # 创建基本场景字典
    scene_dict = {
        'type': 'scene',
        'emitter': {
            'type': 'constant',
            'radiance': {
                'type': 'rgb',
                'value': 1.0,
            }
        },
        'integrator': {
            'type': 'aov',
            'aovs': 'nn:sh_normal,si:shape_index',
            'hide_emitters': True,
        }
    }

    for i, cam_pos, cam_target, cam_up in zip(range(len(cameras[0])), *cameras):
        scene_dict[f'sensor_{i}'] = {
            'type': 'orthographic',
            'near_clip': 0.01,
            'far_clip': 100,
            'to_world': T().look_at(
                origin=ScalarPoint3f(cam_pos),
                target=ScalarPoint3f(cam_target),
                up=ScalarPoint3f(cam_up)
            ) @ T().scale(ScalarPoint3f([1, 1, 1])),
            'film': {
                'type': 'hdrfilm',
                'pixel_format': 'rgb',
                'width': resolution[0],
                'height': resolution[1],
                'filter': {'type': 'box'}
            },
        }

    # 添加所有物体，使用diffuse材质
    for idx, obj_file in zip(obj_idx_list, obj_files):
        # 创建对象变换
        object_transform = create_object_transform(translation, scale, rotation)

        # 创建物体字典
        obj_key = f'obj_{idx}'
        scene_dict[obj_key] = {
            'type': 'obj',
            'filename': obj_file,
            'to_world': object_transform,
            'bsdf': {
                'type': 'diffuse',
                'reflectance': {
                    'type': 'rgb',
                    'value': [0.5, 0.5, 0.5]
                }
            }
        }

    return scene_dict


def eval_best_viewpoint(camera_render_list, sensor_num, min_area_threshold=10):
    """
    使用综合评分方法查找最佳视角

    参数:
    camera_render_list: 包含(sensor_idx, obj_idx, mesh_mask, normal)元组的列表
    sensor_num: 传感器(视角)数量
    min_area_threshold: 考虑物体可见的最小像素面积阈值

    返回:
    best_sensor_idx: 最佳视角的传感器索引
    """
    # 创建数据结构来存储每个视角的信息
    viewpoint_info = {}
    for sensor_idx in range(sensor_num):
        viewpoint_info[sensor_idx] = {
            'visible_objects': 0,
            'area_by_object': {},
            'normal_diversity': 0,
            'total_visible_area': 0,
            'visibility_score': 0,
            'normal_by_object': {}  # 存储每个物体的法线
        }

    # 计算所有可能物体的总数
    all_objects = set()
    for _, obj_idx, _, _ in camera_render_list:
        all_objects.add(obj_idx)
    total_objects = len(all_objects)

    # 遍历所有渲染结果，统计每个视角的信息
    for sensor_idx, obj_idx, mesh_mask, normal in camera_render_list:
        # 计算当前物体在该视角的可见面积
        # 确保mesh_mask是二维的用于计算区域
        mask_2d = mesh_mask.squeeze()  # 从(1024, 1024, 1)变为(1024, 1024)
        visible_area = np.sum(mask_2d > 0)

        # 如果物体可见(面积大于阈值)
        if visible_area > min_area_threshold:
            viewpoint_info[sensor_idx]['area_by_object'][obj_idx] = visible_area
            viewpoint_info[sensor_idx]['total_visible_area'] += visible_area

            # 正确处理法线信息（调整索引方法）
            # 创建2D坐标索引，对应于mask_2d为True的位置
            y_coords, x_coords = np.where(mask_2d > 0)
            if len(y_coords) > 0:
                # 使用这些坐标获取法线值
                valid_normals = normal[y_coords, x_coords]
                viewpoint_info[sensor_idx]['normal_by_object'][obj_idx] = valid_normals

            # 更新可见物体数量
            viewpoint_info[sensor_idx]['visible_objects'] = len(viewpoint_info[sensor_idx]['area_by_object'])

    # 计算每个视角的综合评分
    for sensor_idx in viewpoint_info:
        info = viewpoint_info[sensor_idx]

        # 1. 可见物体比例分数 (0-1)
        object_coverage_score = info['visible_objects'] / max(1, total_objects)

        # 2. 计算法线多样性分数
        normal_diversity_score = 0
        if info['normal_by_object']:
            # 收集所有物体的法线
            all_normals_list = []
            for normals in info['normal_by_object'].values():
                if len(normals) > 0:
                    all_normals_list.append(normals)

            if all_normals_list:
                all_normals = np.vstack(all_normals_list)

                # 使用主成分分析(PCA)评估法线的分布多样性
                try:
                    from sklearn.decomposition import PCA
                    # 采样以减少计算量（如果法线数量很大）
                    max_samples = 10000
                    if len(all_normals) > max_samples:
                        indices = np.random.choice(len(all_normals), max_samples, replace=False)
                        sampled_normals = all_normals[indices]
                    else:
                        sampled_normals = all_normals

                    pca = PCA(n_components=3)
                    pca.fit(sampled_normals)
                    # 使用特征值之和作为多样性度量（归一化到0-1范围）
                    # 特征值和越大表示法线分布越分散
                    normal_diversity_score = min(1.0, sum(pca.explained_variance_) / 3.0)
                except Exception as e:
                    # 如果PCA失败，回退到简单方法
                    print(f"PCA failed: {e}, using fallback method")
                    # 计算法线的平均方向和分散程度
                    mean_normal = np.mean(all_normals, axis=0)
                    mean_normal = mean_normal / np.linalg.norm(mean_normal)
                    # 计算每个法线与平均法线的角度差异
                    dots = np.sum(all_normals * mean_normal, axis=1)
                    dots = np.clip(dots, -1.0, 1.0)  # 防止数值误差
                    angles = np.arccos(dots)
                    # 角度方差越大表示多样性越高
                    normal_diversity_score = min(1.0, np.var(angles) / (np.pi / 2))

        # 3. 计算面积分布均匀性 (使用标准差的倒数)
        area_distribution_score = 0
        if info['area_by_object']:
            areas = np.array(list(info['area_by_object'].values()))
            # 标准化面积
            total = np.sum(areas)
            if total > 0:
                normalized_areas = areas / total
                # 理想情况是所有物体面积均等
                ideal_area = 1.0 / len(areas)
                # 计算与理想分布的偏差
                deviation = np.mean(np.abs(normalized_areas - ideal_area))
                # 转换为0-1分数（偏差越小越好）
                area_distribution_score = 1.0 - min(1.0, deviation * len(areas))

        # 4. 总可见面积分数（归一化）
        max_possible_area = 1024 * 1024  # 假设渲染分辨率
        area_score = min(1.0, info['total_visible_area'] / max_possible_area)

        # 综合评分（可以调整权重）
        w1, w2, w3, w4 = 0.4, 0.3, 0.2, 0.1  # 权重
        info['visibility_score'] = (
                w1 * object_coverage_score +
                w2 * normal_diversity_score +
                w3 * area_distribution_score +
                w4 * area_score
        )

    # 选择得分最高的视角
    best_sensor_idx = max(viewpoint_info.keys(),
                          key=lambda idx: viewpoint_info[idx]['visibility_score'])

    return best_sensor_idx

def visualize_normal(normal):
    return (normal + 1.0) / 2.0


def set_seed(seed):
    """设置随机种子"""
    np.random.seed(seed)
    torch.manual_seed(seed)


def find_best_view(mesh_path, mesh_translation, mesh_scale, mesh_rotation, save_images=False):
    """查找最佳视角"""
    resolution = [1024, 1024]
    set_seed(0)
    spp = 1

    elevations = np.array([15, 30, 45, 60], dtype=np.float32)
    # azimuths = np.array([0, 15, 30, 45, 60, 300, 315, 330, 345], dtype=np.float32)
    azimuths = np.array([0, 15, 30, 45, 315, 330, 345], dtype=np.float32)
    distances = np.array([2], dtype=np.float32)

    cameras = mitsuba_camera_position(elevations, azimuths, distances)
    objs = sorted(glob(f"{mesh_path}/obj_*.obj"))
    obj_idx_list = [int(re.search(r'\d+', os.path.basename(obj_file)).group(0)) for obj_file in objs]

    camera_render_list = []

    # 创建包含所有mesh的场景，并加载
    scene_dict = create_scene_dict_all_meshes(objs, obj_idx_list, cameras, resolution, mesh_translation, mesh_scale,
                                              mesh_rotation)
    scene = mi.load_dict(scene_dict)

    if save_images:
        output_dir = "/home/swu/cyh/MasterGraduationProject/application/work2/output/test_sensor"
        os.makedirs(output_dir, exist_ok=True)

    # 对每个相机位置渲染一次，然后根据shape_index区分不同物体
    sensor_num = len(cameras[0])
    for sensor_idx in range(sensor_num):
        # 渲染完整场景
        image = mi.render(scene, sensor=sensor_idx, spp=spp, seed=0)

        # 提取信息
        shading_normal = np.nan_to_num(np.array(image[..., 0:3], dtype=np.float32), nan=0.0)
        # to_world = mi.traverse(scene)[f'sensor_{sensor_idx}.to_world']
        # to_world_inv = np.array(to_world.inverse_transpose, dtype=np.float32)
        # import pdb; pdb.set_trace()
        shading_normal = visualize_normal(shading_normal)
        shape_indices = (np.array(image[..., 3:4], dtype=np.float32)) - 1

        # 处理每个物体
        for obj_idx in obj_idx_list:
            # mesh_mask = (shape_indices == obj_idx).astype(np.float32)
            mesh_mask = np.isclose(shape_indices, obj_idx).astype(np.float32)
            obj_shading_normal = shading_normal * mesh_mask
            # 存储结果
            camera_render_list.append((sensor_idx, obj_idx, mesh_mask, shading_normal))

            if save_images:
                # 保存法线图
                save_path = os.path.join(output_dir, f"sensor_{sensor_idx}_obj_{obj_idx}_normal.png")
                cv2.imwrite(save_path, (obj_shading_normal * 255).astype(np.uint8))

                # 保存mask
                save_path = os.path.join(output_dir, f"sensor_{sensor_idx}_obj_{obj_idx}_mask.png")
                cv2.imwrite(save_path, mesh_mask.astype(np.uint8) * 255)

    # 查找最佳视角
    best_sensor_idx = eval_best_viewpoint(camera_render_list, sensor_num)
    best_sensor_dict = {
        'camera_idx': int(best_sensor_idx) if hasattr(best_sensor_idx, 'dtype') else best_sensor_idx,
        'camera_position': cameras[0][best_sensor_idx].tolist(),
        'camera_target': cameras[1][best_sensor_idx].tolist(),
        'camera_up': cameras[2][best_sensor_idx].tolist()
    }

    return best_sensor_dict


def get_random_view(mesh_path, mesh_translation, mesh_scale, mesh_rotation, save_images=False):

    elevations = np.array([-75, -60, -45, -30, 0, 30, 45, 60, 75], dtype=np.float32)
    azimuths = np.array([0, 30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330], dtype=np.float32)
    distances = np.array([2], dtype=np.float32)

    cameras = mitsuba_camera_position(elevations, azimuths, distances)

    best_sensor_idx = np.random.randint(0, len(cameras[0]))

    best_sensor_dict = {
        'camera_idx': int(best_sensor_idx) if hasattr(best_sensor_idx, 'dtype') else best_sensor_idx,
        'camera_position': cameras[0][best_sensor_idx].tolist(),
        'camera_target': cameras[1][best_sensor_idx].tolist(),
        'camera_up': cameras[2][best_sensor_idx].tolist()
    }

    return best_sensor_dict


