import sys, os
sys.path.append('/root/autodl-tmp/PathTracing_LucidDreamer')
import torch
from torchvision.utils import save_image
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, OptimizationParams, GenerateCamParams, GuidanceParams

import numpy as np
import trimesh
from plyfile import PlyData, PlyElement
from scene import GaussianModel
from scene.gaussian_model import BasicPointCloud
from scene.dataset_readers import GenerateCircleCameras
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON, cameraList_from_RcamInfos
from gaussian_renderer import render
import yaml



def storePly(path, xyz, rgb):
    dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
             ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
             ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
    
    normals = np.zeros_like(xyz)
    elements = np.empty(xyz.shape[0], dtype=dtype)
    attributes = np.concatenate((xyz, normals, rgb), axis=1)
    elements[:] = list(map(tuple, attributes))
    vertex_element = PlyElement.describe(elements, 'vertex')
    ply_data = PlyData([vertex_element])
    ply_data.write(path)

# def load_mesh_from_file(file_path):
#     return trimesh.load(file_path)
def load_mesh_from_file(file_path):
    mesh = trimesh.load(file_path)
    if isinstance(mesh, trimesh.Scene):
        # If a scene is loaded, concatenate all meshes into a single mesh
        # mesh = trimesh.util.concatenate(
        #     [trimesh.Trimesh(vertices=g.vertices, faces=g.faces) for g in mesh.geometry.values()]
        # )
        mesh = trimesh.util.concatenate(
            [trimesh.Trimesh(vertices=g.vertices, faces=g.faces, vertex_colors=g.visual.vertex_colors) for g in mesh.geometry.values()]
        )

    return mesh

def SF3D(file_path):
    mesh = load_mesh_from_file(file_path)
    skip = 1
    coords = mesh.vertices
    vertex_colors = mesh.visual.vertex_colors
    vertex_colors = vertex_colors[:, :3] / 255.0
    rgb = np.concatenate([vertex_colors[:, None, 0], vertex_colors[:, None, 1], vertex_colors[:, None, 2]], axis=1)
    coords = coords[::skip]
    rgb = rgb[::skip]
    
    angle_x = np.radians(90)
    rotation_matrix = np.array([
        [1, 0, 0],
        [0, np.cos(angle_x), -np.sin(angle_x)],
        [0, np.sin(angle_x), np.cos(angle_x)]
    ])
    coords = coords @ rotation_matrix.T
    return coords, rgb, 0.8

def process_prompt(prompt, mesh_dir, save_dir, lp, pp, gcp, gp):
    file_path = os.path.join(mesh_dir, f'{prompt}/mesh.glb')
    prompt_save_dir = os.path.join(save_dir, prompt)
    os.makedirs(prompt_save_dir, exist_ok=True)
    
    xyz, rgb, scale = SF3D(file_path)
    num_pts = xyz.shape[0]
    num_pts = xyz.shape[0]

    print(f"Updated number of points for prompt '{prompt}':", num_pts)
    print(f"Shape of xyz for prompt '{prompt}':", xyz.shape)
    print(f"Shape of rgb for prompt '{prompt}':", rgb.shape)

    gaussians = GaussianModel(lp.sh_degree)
    pcd = BasicPointCloud(points=xyz, colors=rgb, normals=np.zeros((num_pts, 3)))
    ply_path = os.path.join(prompt_save_dir, 'point_cloud.ply')
    storePly(ply_path, xyz, rgb * 255)
    
    gaussians.create_from_pcd(pcd, 3.5)
    bg_color = [1, 1, 1] if lp._white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device=lp.data_device)

    test_cam_infos = GenerateCircleCameras(gcp, render45=gcp.render_45)
    test_cameras = {1.0: cameraList_from_RcamInfos(test_cam_infos, 1.0, gcp)}
    camera_list = test_cameras[1.0]
    for idx, viewpoint in enumerate(camera_list):
        render_out = render(viewpoint, gaussians, pp, background, test=True)
        rgb_render, depth = render_out["render"], render_out["depth"]
        if depth is not None:
            depth_norm = depth / depth.max()
            save_image(depth_norm, os.path.join(prompt_save_dir, f"render_depth_{viewpoint.uid}.png"))
        image = torch.clamp(rgb_render, 0.0, 1.0)
        save_image(image, os.path.join(prompt_save_dir, f"render_view_{viewpoint.uid}.png"))

def main():
    parser = ArgumentParser(description="Training script parameters")
    parser.add_argument('--opt', type=str, default=None)
    parser.add_argument('--ip', type=str, default="127.0.0.1")
    parser.add_argument('--port', type=int, default=6009)
    parser.add_argument('--debug_from', type=int, default=-1)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--detect_anomaly', action='store_true', default=False)
    parser.add_argument("--test_ratio", type=int, default=5)
    parser.add_argument("--save_ratio", type=int, default=2)
    parser.add_argument("--save_video", type=bool, default=False)
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
    parser.add_argument("--start_checkpoint", type=str, default=None)

    lp = ModelParams(parser)
    op = OptimizationParams(parser)
    pp = PipelineParams(parser)
    gcp = GenerateCamParams(parser)
    gp = GuidanceParams(parser)

    args = parser.parse_args(sys.argv[1:])
    if args.opt is not None:
        with open(args.opt) as f:
            opts = yaml.load(f, Loader=yaml.FullLoader)
        lp.load_yaml(opts.get('ModelParams', None))
        op.load_yaml(opts.get('OptimizationParams', None))
        pp.load_yaml(opts.get('PipelineParams', None))
        gcp.load_yaml(opts.get('GenerateCamParams', None))
        gp.load_yaml(opts.get('GuidanceParams', None))
        
        lp.opt_path = args.opt
        args.port = opts['port']
        args.save_video = opts.get('save_video', True)
        args.seed = opts.get('seed', 0)
        args.device = opts.get('device', 'cuda')

        gp.g_device = args.device
        lp.data_device = args.device
        gcp.device = args.device

    mesh_dir = '/root/autodl-tmp/PathTracing_LucidDreamer/output/exp7/all_prompts_mesh'
    # save_dir = '/root/autodl-tmp/PathTracing_LucidDreamer/output/test_render'
    save_dir = '/root/autodl-tmp/PathTracing_LucidDreamer/output/exp7/all_prompts_render_imgs'
    prompts = [d for d in os.listdir(mesh_dir) if os.path.isdir(os.path.join(mesh_dir, d))]
    
    for prompt in prompts:
        print(f"Processing prompt: {prompt}")
        process_prompt(prompt, mesh_dir, save_dir, lp, pp, gcp, gp)

if __name__ == '__main__':
    main()
