import torch
from scene import Scene
import os
from os import makedirs
from gaussian_renderer import render, deform_render, d_render
from gaussian_renderer.mini_renderer import  vanilla_render, new_render
import random
from tqdm import tqdm
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams, SegParams
from gaussian_renderer import GaussianModel
import numpy as np
import open3d as o3d
import open3d.core as o3c

from crop import crop_mesh

id_mapping = {"103706":"blade", "102255": "foldchair", "10905":"fridge", "10211":"laptop", "101917":"oven"
              , "11100":"scissor", "103111":"stapler", "45135":"storage", "100109":"USB", "103776":"washer"}

scenes = {"blade", "foldchair", "fridge", "laptop", "oven" , "stapler", "storage", "washer", "USB", "scissor"}
id_mapping = {"103706":"blade", "102255": "foldchair", "10905":"fridge", "10211":"laptop", "101917":"oven"
              , "11100":"scissor", "103111":"stapler", "45135":"storage", "100109":"USB", "103776":"washer"}
id_motion = {"103706":"prismatic", "102255": "revolute", "10905":"revolute", "10211":"revolute", "101917":"revolute"
              , "11100":"revolute", "103111":"revolute", "45135":"prismatic", "100109":"revolute", "103776":"revolute"}
id_canonical ={"103706":False, "102255": False, "10905":False, "10211":True, "101917":False
              , "11100":True, "103111":False, "45135":False, "100109":True, "103776":False}
id_color_filter ={"103706":0.1, "102255": 0.06, "10905":0.1, "10211":0.06, "101917":0.1
                    , "11100":0.2, "103111":0.1, "45135":0.2, "100109":0.1, "103776":0.2}


def post_process_mesh(mesh, cluster_to_keep=1000):
    """
    Post-process a mesh to filter out floaters and disconnected parts
    """
    import copy
    print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep))
    mesh_0 = copy.deepcopy(mesh)
    with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
            triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles())

    triangle_clusters = np.asarray(triangle_clusters)
    cluster_n_triangles = np.asarray(cluster_n_triangles)
    cluster_area = np.asarray(cluster_area)
    n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep]
    n_cluster = max(n_cluster, 50) # filter meshes smaller than 50
    triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster
    mesh_0.remove_triangles_by_mask(triangles_to_remove)
    mesh_0.remove_unreferenced_vertices()
    mesh_0.remove_degenerate_triangles()
    print("num vertices raw {}".format(len(mesh.vertices)))
    print("num vertices post {}".format(len(mesh_0.vertices)))
    return mesh_0


def color_filter(mesh, render_path, fid, black_threshold = 0.1):
    import trimesh
    mesh = trimesh.Trimesh(
        vertices=np.asarray(mesh.vertices),
        faces=np.asarray(mesh.triangles),
        vertex_colors=np.asarray(mesh.vertex_colors)
    )
    vertex_colors = mesh.visual.vertex_colors[:, :3]  # 去除可能的alpha通道

    # 定义黑色阈值（对应0-255范围）
    black_threshold *= 255  # 相当于Open3D中的0.01阈值
    is_black = np.all(vertex_colors <= black_threshold, axis=1)

    # 创建非黑色顶点掩码
    non_black_mask = ~is_black

    # 检查每个面片的所有顶点是否都是非黑色
    valid_faces_mask = np.all(non_black_mask[mesh.faces], axis=1)

    # 直接通过面片掩码创建子网格
    filtered_mesh = mesh.submesh(
        [valid_faces_mask],  # 使用布尔掩码直接筛选
        append=True,  # 合并所有有效面片
        repair=False  # 关闭自动修复以提升性能
    )

    # 清理未使用的顶点（trimesh特有优化方法）
    filtered_mesh.remove_unreferenced_vertices()
    filtered_mesh_path = f"{render_path}/filtered_mesh_t{fid}.ply"
    #filtered_mesh_path = f"{render_path}/filtered_mesh_t{fid}_f{black_threshold}.ply"
    # 导出结果（自动保持顶点颜色）
    filtered_mesh.export(filtered_mesh_path)
    return filtered_mesh_path


def skew(w: torch.Tensor) -> torch.Tensor:
    """
    构造旋转轴的反对称矩阵。
    Args:
        w: [N, 3] 单位旋转轴。
    Returns:
        W: [N, 3, 3] 反对称矩阵。
    """
    zeros = torch.zeros(w.shape[0], device=w.device)
    w_skew_list = [zeros, -w[:, 2], w[:, 1],
                   w[:, 2], zeros, -w[:, 0],
                   -w[:, 1], w[:, 0], zeros]
    return torch.stack(w_skew_list, dim=-1).reshape(-1, 3, 3)


def compute_4x4_rotation_matrix(axis: torch.Tensor, angle: torch.Tensor) -> torch.Tensor:
    """
    计算 4x4 的旋转矩阵。
    Args:
        axis: [N, 3] 单位旋转轴向量。
        angle: [N, 1] 旋转角（弧度）。
    Returns:
        T: [N, 4, 4] 齐次旋转矩阵。
    """
    # 构造反对称矩阵
    W = skew(axis)  # [N, 3, 3]
    I = torch.eye(3, device=axis.device).unsqueeze(0).repeat(axis.shape[0], 1, 1)

    # 计算 Rodrigues 公式
    W_sqr = torch.bmm(W, W)  # [N, 3, 3]
    R = I + torch.sin(angle).unsqueeze(-1) * W + (1 - torch.cos(angle).unsqueeze(-1)) * W_sqr  # [N, 3, 3]

    # 构造 4x4 齐次矩阵
    T = torch.eye(4, device=axis.device).unsqueeze(0).repeat(axis.shape[0], 1, 1)  # [N, 4, 4]
    T[:, :3, :3] = R  # 填入旋转矩阵

    return T


def tsdf_fusion(model_path, name, iteration, views, test_views, gaussians, pipeline, background, kernel_size,fid, black_threshold=0.1):
    render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "tsdf")

    makedirs(render_path, exist_ok=True)
    # o3d_device = o3d.core.Device("CUDA:0")
    o3d_device = o3d.core.Device("cpu:0")

    voxel_size = 0.004
    alpha_thres = 0.5

    max_depth = 5.0

    vbg = o3d.t.geometry.VoxelBlockGrid(
        attr_names=('tsdf', 'weight', 'color'),
        attr_dtypes=(o3c.float32, o3c.float32, o3c.float32),
        attr_channels=((1), (1), (3)),
        voxel_size=voxel_size,
        block_resolution=16,
        block_count=50000,
        device=o3d_device)

    with torch.no_grad():
        for _, view in enumerate(tqdm(views, desc="Rendering progress")):

            rendering = d_render(view, gaussians, pipeline, background, kernel_size=kernel_size, fid= fid)["render"]

            depth = rendering[6:7, :, :]
            alpha = rendering[7:8, :, :]
            rgb = rendering[:3, :, :]

            if view.gt_alpha_mask is not None:
                depth[(view.gt_alpha_mask < 0.5)] = 0

            depth[(alpha < alpha_thres)] = 0

            depth[depth > max_depth] = 0

            W = view.image_width
            H = view.image_height
            ndc2pix = torch.tensor([
                [W / 2, 0, 0, (W - 1) / 2],
                [0, H / 2, 0, (H - 1) / 2],
                [0, 0, 0, 1]]).float().cuda().T
            intrins = (view.projection_matrix @ ndc2pix)[:3, :3].T
            intrinsic = o3d.camera.PinholeCameraIntrinsic(
                width=W,
                height=H,
                cx=intrins[0, 2].item(),
                cy=intrins[1, 2].item(),
                fx=intrins[0, 0].item(),
                fy=intrins[1, 1].item()
            )

            extrinsic = np.asarray((view.world_view_transform.T).cpu().numpy())

            o3d_color = o3d.t.geometry.Image(np.asarray(rgb.permute(1, 2, 0).cpu().numpy(), order="C"))
            o3d_depth = o3d.t.geometry.Image(np.asarray(depth.permute(1, 2, 0).cpu().numpy(), order="C"))
            o3d_color = o3d_color.to(o3d_device)
            o3d_depth = o3d_depth.to(o3d_device)

            intrinsic = o3d.core.Tensor(intrinsic.intrinsic_matrix, o3d.core.Dtype.Float64)  # .to(o3d_device)
            extrinsic = o3d.core.Tensor(extrinsic, o3d.core.Dtype.Float64)  # .to(o3d_device)

            frustum_block_coords = vbg.compute_unique_block_coordinates(
                o3d_depth, intrinsic, extrinsic, 1.0, 6.0)

            vbg.integrate(frustum_block_coords, o3d_depth, o3d_color, intrinsic,
                          intrinsic, extrinsic, 1.0, 6.0)

        mesh = vbg.extract_triangle_mesh().to_legacy()

        # write mesh
        o3d.io.write_triangle_mesh(f"{render_path}/tsdf_t{fid}.ply", mesh)
        # o3d.io.write_triangle_mesh(f"{render_path}/tsdf.obj", mesh)


        filter_mesh_path = color_filter(mesh, render_path, fid, black_threshold=black_threshold)
        crop_mesh(filter_mesh_path, filter_mesh_path, axis_min=-2, axis_max=4, min_x=-2, min_y=-4, max_y=2, max_x=2)  # axis_max =4 stapler, storage:min_y=-2, AKB_cutter:min_y=-4


def extract_mesh(dataset: ModelParams, opt: OptimizationParams, iteration: int, pipeline: PipelineParams, sp:SegParams, id, set_canonical = False,
                 only_extract_mesh = False, black_threshold= 0.1, revolute = False, prismatic = False):
    with torch.no_grad():

        dataset.load_time_camera = True
        dataset.sh_degree = 3
        dataset.white_background = False
        dataset.eval = False
        dataset.resolution = -1
        dataset.data_device = "cuda"

        gaussians = GaussianModel(dataset.sh_degree, revolute= revolute, prismatic= prismatic)
        scene = Scene(dataset, gaussians, load_iteration=iteration, load_root = True, shuffle=False)
        gaussians.old_deform_training_setup(opt)

        gaussians.use_canonical = set_canonical
        train_cameras = scene.getTrainCameras()
        if gaussians.revolute:
            gaussians.quaternions = torch.load(os.path.join(dataset.model_path, "point_cloud", "iteration_40000", "quaternions.pt")).cuda()
            gaussians.axis_o = torch.load(os.path.join(dataset.model_path, "point_cloud", "iteration_40000", "axis_o.pt")).cuda()
        else:
            gaussians.dir = torch.load(os.path.join(dataset.model_path, "point_cloud", "iteration_40000", "dir.pt")).cuda()
            gaussians.dist = torch.load(os.path.join(dataset.model_path, "point_cloud", "iteration_40000", "dist.pt")).cuda()

        gaussians.load_ply(os.path.join(dataset.model_path,"point_cloud", f"iteration_{iteration}", "point_cloud.ply")) # TODO

        fid = dataset.extract_fid

        mask_path = os.path.join(dataset.model_path, "point_cloud", "iteration_40000", "dynamic_part_mask.npy")
        dynamic_part_mask = np.load(mask_path)

        gaussians.dynamic_part_mask = dynamic_part_mask


        bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
        #kernel_size = dataset.kernel_size
        kernel_size = 0.0
        cams = train_cameras
        test_views = scene.getTestCameras()

        tsdf_fusion(dataset.model_path, "test", iteration, cams, test_views, gaussians, pipeline, background, kernel_size, fid, black_threshold=black_threshold)


if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    opt = OptimizationParams(parser)
    sp = SegParams(parser)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=40000, type=int)
    #parser.add_argument("--use_canonical", action="store_true")
    parser.add_argument("--quiet", action="store_true")


    parser.add_argument("--id", required= True, type=str, help='test category ID')
    args = get_combined_args(parser)

    args.train_ours = True
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.set_device(torch.device("cuda:0"))
    motion = id_motion[args.id]
    canonical = id_canonical[args.id]
    black_threshold = id_color_filter[args.id]
    if motion == 'revolute':
        revolute = True
        prismatic = False
    else:
        revolute = False
        prismatic = True
    cat = id_mapping[args.id]
    args.source_path = os.path.abspath(f"load/{cat}/{args.id}")
    args.model_path = f"exp/{args.id}"
    print("Rendering " + args.model_path)

    extract_mesh(model.extract(args), opt.extract(args), args.iteration, pipeline.extract(args), sp.extract(args), args.id, set_canonical = canonical,
                 black_threshold = black_threshold, revolute = revolute, prismatic= prismatic)
