import open3d as o3d
import sys
import os
import numpy as np
import json
import math
import torch
import shutil
import cv2

def get_mask_rendering(rendering_path, img_path, mask_path):
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path, 0)
    mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
    mask = mask == 0
    #mask = np.repeat(mask, 3, axis= -1)
    img[mask] = 255
    img_name = "masked_" + img_path.split('/')[-1]
    cv2.imwrite(os.path.join(rendering_path, img_name), img)


def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
    Rt = np.zeros((4, 4))
    Rt[:3, :3] = R.transpose()
    Rt[:3, 3] = t
    Rt[3, 3] = 1.0

    C2W = np.linalg.inv(Rt)
    cam_center = C2W[:3, 3]
    cam_center = (cam_center + translate) * scale
    C2W[:3, 3] = cam_center
    Rt = np.linalg.inv(C2W)
    return np.float32(Rt)



def getProjectionMatrix(fovX, fovY, znear = 0.01, zfar =100.0):
    tanHalfFovY = math.tan((fovY / 2))
    tanHalfFovX = math.tan((fovX / 2))

    top = tanHalfFovY * znear
    bottom = -top
    right = tanHalfFovX * znear
    left = -right

    P = torch.zeros(4, 4)

    z_sign = 1.0

    P[0, 0] = 2.0 * znear / (right - left)
    P[1, 1] = 2.0 * znear / (top - bottom)
    P[0, 2] = (right + left) / (right - left)
    P[1, 2] = (top + bottom) / (top - bottom)
    P[3, 2] = z_sign
    P[2, 2] = z_sign * zfar / (zfar - znear)
    P[2, 3] = -(zfar * znear) / (zfar - znear)
    return P






def fov2focal(fov, pixels):
    return pixels / (2 * math.tan(fov / 2))

def focal2fov(focal, pixels):
    return 2*math.atan(pixels/(2*focal))


def read_json(json_file):
    with open(json_file) as f:
        camera = json.load(f)
    camera_transforms = sorted(camera.copy(), key=lambda x: x['img_name'])
    cams = []
    for cam_idx in range(len(camera_transforms)):
        camera_transform = camera_transforms[cam_idx]

        # Extrinsics
        rot = np.array(camera_transform['rotation'])
        pos = np.array(camera_transform['position'])

        W2C = np.zeros((4, 4))
        W2C[:3, :3] = rot
        W2C[:3, 3] = pos
        W2C[3, 3] = 1



        Rt = np.linalg.inv(W2C)
        # T = Rt[:3, 3]
        # R = Rt[:3, :3].transpose()
        #
        # extrinsic = np.zeros((4, 4))
        # extrinsic[:3, :3] = R.transpose()
        # extrinsic[:3, 3] = T
        # extrinsic[3, 3] = 1.0
        #
        # C2W = np.linalg.inv(extrinsic)
        # cam_center = C2W[:3, 3]
        # trans = np.array([0.0, 0.0, 0.0])
        # scale = 1.0
        # cam_center = (cam_center + trans) * scale
        # C2W[:3, 3] = cam_center
        # extrinsic = np.linalg.inv(C2W)
        #extrinsic = Rt


        TR = np.array([[1, 0, 0, 0],
                       [0, -1, 0, 0],
                       [0, 0, -1, 0],
                       [0, 0, 0, 1]]
                      )
        # 把读到的c2w做这两步操作
        #extrinsic = np.dot(W2C, TR)  # 关于x轴对称
        #extrinsic = TR @ W2C @ TR
        # Intrinsics
        width = camera_transform['width']
        height = camera_transform['height']
        fy = camera_transform['fy']
        fx = camera_transform['fx']
        fov_y = focal2fov(fy, height)
        fov_x = focal2fov(fx, width)
        P = np.array(getProjectionMatrix(fov_x, fov_y))
        extrinsic = P @ Rt

        img_num = camera_transform['img_name']
        cam = {"extrinsic": extrinsic,
               "fy": fy,
               "fx": fx,
               "cx": width/2-0.5,
               "cy": height/2-0.5,
               "id": int(img_num),
               "width": width,
               "height": height}
        cams.append(cam)

    return cams

def load_and_show_ply2(filepath):
    # Load the PLY file
    #mesh = o3d.io.read_point_cloud(filepath)

    mesh = o3d.io.read_triangle_mesh(filepath)
    mesh.compute_vertex_normals()
    """
    try:
        # Try loading as a point cloud first
        mesh = o3d.io.read_point_cloud(filepath)
    except:
        # If that fails, try loading as a triangle mesh
        mesh = o3d.io.read_triangle_mesh(filepath)
        mesh.compute_vertex_normals()
    """
    # Visualize the point cloud
    o3d.visualization.draw_geometries([mesh])

def is_triangle_mesh(filepath):
    if filepath.endswith(".obj"):
        return True

    # Check the contents of the PLY file for the 'element face' line
    try:
        with open(filepath, 'r', encoding='ISO-8859-1') as f:
            for line in f:
                if 'element face' in line:
                    return True
    except Exception as e:
        print(f"Error reading file: {e}")
    return False


def load_and_show_ply(filepath, json_path, rendering_path, render_mesh = False, dtu_mask = None):
    if is_triangle_mesh(filepath):
        mesh = o3d.io.read_triangle_mesh(filepath)
        if False:
            triangles = np.asarray(mesh.triangles)
            # flip triangles
            triangles = triangles[:, [0, 2, 1]]
            mesh = o3d.geometry.TriangleMesh(vertices=mesh.vertices, triangles=o3d.utility.Vector3iVector(triangles))
        # if mesh.has_vertex_colors():
        #     # 如果有颜色，清空 vertex_colors
        #     mesh.vertex_colors = o3d.utility.Vector3dVector([])


        mesh.compute_vertex_normals()



        print("read triangle")
    else:
        mesh = o3d.io.read_point_cloud(filepath)
        # if mesh.has_colors():
        #     # 如果有颜色，移除颜色
        #     mesh.colors = o3d.utility.Vector3dVector([])

        print("read point cloud")
    # Visualize the geometry (whether it's a point cloud or triangle mesh)
    # o3d.visualization.draw_geometries([mesh])
    print(f"vertices number is {mesh}")
    # Create a visualizer object
    vis = o3d.visualization.Visualizer()

    #axis_pcd = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1, origin=[0, 0, 0]) #  red x; green y; blue z

    # Create a window with the filename as the title
    #filename = os.path.basename(filepath)
    # vis.create_window(window_name=filepath)
    vis.create_window(window_name=filepath, width=1920, height=1080)

    # Add the geometry to the visualizer
    vis.add_geometry(mesh)
    #vis.add_geometry(axis_pcd)
    ctr = vis.get_view_control()
    ctr.set_constant_z_near(0.001)
    ctr.set_constant_z_far(1000)

    #o3d.visualization.draw_geometries((mesh, axis_pcd), window_name="Open3D2")

    # Run the visualizer
    # vis.run()
    #
    # # Close the visualizer window
    # vis.destroy_window()
    if render_mesh:
        #eval_split_interval = 4
        eval_split_interval = 4
        os.makedirs(rendering_path, exist_ok=True)
        opt = vis.get_render_option()
        opt.background_color = np.asarray([1, 1, 1])  # 白色背景


        cams = read_json(json_path)
        cams = sorted(cams, key=lambda x: x['id'])
        #ctr = vis.get_view_control()

        # 设置相机内参（这里用默认值，可以根据需要调整）
        intrinsic = o3d.camera.PinholeCameraIntrinsic()
        intrinsic.set_intrinsics(width=cams[0]["width"], height=cams[0]["height"], fx=cams[0]["fx"],
                                 fy=cams[0]["fy"], cx=cams[0]["cx"], cy=cams[0]["cy"])

        camera_params = o3d.camera.PinholeCameraParameters()
        camera_params.intrinsic = intrinsic


        # 设置相机外参（相机位姿）
        # 例如一个旋转矩阵和位移向量
        index = 0
        for cam in cams:
            if (cam["id"]-1) % eval_split_interval == 0:

                extrinsic = cam["extrinsic"]
                camera_params.extrinsic = extrinsic
                # 将相机位姿应用到可视化窗口
                ctr.convert_from_pinhole_camera_parameters(camera_params, True)

                # 渲染网格
                vis.poll_events()
                vis.update_renderer()

                # 保存渲染结果为图像
                vis.capture_screen_image(os.path.join(rendering_path, "ours_mesh_rendering_{}.png".format(cam["id"])))
                if dtu_mask:
                    #os.makedirs(rendering_path.replace("mesh_rendering", "mask_mesh_rendeing"))
                    get_mask_rendering(rendering_path, os.path.join(rendering_path, "ours_mesh_rendering_{}.png".format(cam["id"])),
                                       os.path.join(dtu_mask, f"{cam['id']:03d}.png"))
                index = index + 1

        vis.destroy_window()

    else:
        axis_pcd = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1, origin=[0, 0, 0])  # red x; green y; blue z
        vis.add_geometry(axis_pcd)



        vis.run()

        # Close the visualizer window
        vis.destroy_window()

def show_gt_mesh(gt_mesh_path, json_path, rendering_path, tran_file, custom = False):
    if is_triangle_mesh(gt_mesh_path):
        mesh = o3d.io.read_triangle_mesh(gt_mesh_path)
        if False:
            triangles = np.asarray(mesh.triangles)
            # flip triangles
            triangles = triangles[:, [0, 2, 1]]
            mesh = o3d.geometry.TriangleMesh(vertices=mesh.vertices, triangles=o3d.utility.Vector3iVector(triangles))

        mesh.compute_vertex_normals()
        print("read triangle")
    else:
        mesh = o3d.io.read_point_cloud(gt_mesh_path)
        print("read point cloud")



    gt_trans = np.loadtxt(tran_file)
    gt_trans = np.linalg.inv(gt_trans)
    mesh = mesh.transform(gt_trans)
    vis = o3d.visualization.Visualizer()
    vis.create_window(window_name=filepath, width=1920, height=1080)

    # Add the geometry to the visualizer
    vis.add_geometry(mesh)


    # vis.add_geometry(axis_pcd)
    ctr = vis.get_view_control()
    ctr.set_constant_z_near(0.001)
    ctr.set_constant_z_far(1000)

    # o3d.visualization.draw_geometries((mesh, axis_pcd), window_name="Open3D2")
    if custom:
        #Run the visualizer
        vis.run()

        # Close the visualizer window
        vis.destroy_window()
    else:
        eval_split_interval = 8
        os.makedirs(rendering_path, exist_ok= True)
        opt = vis.get_render_option()
        opt.background_color = np.asarray([1, 1, 1])  # 白色背景


        cams = read_json(json_path)
        cams = sorted(cams, key=lambda x: x['id'])
        # ctr = vis.get_view_control()

        # 设置相机内参（这里用默认值，可以根据需要调整）
        intrinsic = o3d.camera.PinholeCameraIntrinsic()
        intrinsic.set_intrinsics(width=cams[0]["width"], height=cams[0]["height"], fx=cams[0]["fx"],
                                 fy=cams[0]["fy"], cx=cmesh_viewerams[0]["cx"], cy=cams[0]["cy"])

        camera_params = o3d.camera.PinholeCameraParameters()
        camera_params.intrinsic = intrinsic

        # 设置相机外参（相机位姿）
        # 例如一个旋转矩阵和位移向量
        index = 0
        for cam in cams:
            if (cam["id"]-1) % eval_split_interval == 0:
                extrinsic = cam["extrinsic"]
                camera_params.extrinsic = extrinsic
                # 将相机位姿应用到可视化窗口
                ctr.convert_from_pinhole_camera_parameters(camera_params, True)

                # 渲染网格
                vis.poll_events()
                vis.update_renderer()

                # 保存渲染结果为图像
                vis.capture_screen_image(os.path.join(rendering_path, "gt_rendering_{}.png".format(cam["id"])))
                index = index + 1


        vis.destroy_window()

def get_gt_img(img_path, json_path, save_path, eval_split_interval = 8):
    cams = read_json(json_path)
    cams = sorted(cams, key=lambda x: x['id'])
    for cam in cams:
        if (cam["id"]-1) % eval_split_interval == 0:
            # 确保你提供了正确的文件和文件夹路径
            source_file_path = img_path + "/" + f"{cam['id']:06d}" + ".jpg"
            #source_file_path = img_path + "/" + f"{cam['id']:04d}" + ".png"

            # 复制文件
            shutil.copy(source_file_path, save_path)

            # 检查文件是否成功复制
            copied_file_path = os.path.join(save_path, os.path.basename(source_file_path))
            if os.path.exists(copied_file_path):
                print(f"文件已成功复制到 {copied_file_path}")
            else:
                print("文件复制失败")



if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python script_name.py path_to_ply_file")
        exit()
    show_ply = True
    render_mesh = False
    render_gt_mesh = False
    need_gt_img = False
    gt_custom = True

    json_path = "../paris/load/sapien/fridge/10905/start/camera_test.json"
    #json_path = "../PGSR/exp/45135_end/cameras.json"
    rendering_path = "exp/10905/ours_s2_mesh_rendering"


    gt_mesh_path = "../TT/trainingdata/Ignatius/Ignatius.ply"
    #gt_mesh_tran_path = "../TT/trainingdata/Courthouse/Courthouse_trans.txt"
    gt_mesh_tran_path = "../TT/trainingdata/Ignatius/Ignatius_trans.txt"
    #gt_mesh_tran_path = "trans.txt"
    #img_path = "../TT/Caterpillar/images"
    img_path = "../TT/Ignatius/images"
    #json_path = "metrics/Barn/Barn/cameras.json"

    gt_img_save_path = "exp/Ignatius/gt_img"
    #dtu_mask = "../dtu/DTU/scan24/mask"
    dtu_mask = None

    filepath = sys.argv[1]
    if show_ply:
        load_and_show_ply(filepath, json_path, rendering_path, render_mesh, dtu_mask)
    if render_gt_mesh:
        show_gt_mesh(gt_mesh_path, json_path, rendering_path, gt_mesh_tran_path, gt_custom)
    if need_gt_img:
        os.makedirs(gt_img_save_path, exist_ok=True)
        get_gt_img(img_path, json_path, gt_img_save_path, eval_split_interval=4)

