import argparse
import os

import rembg
import torch
from torchvision.utils import save_image

import random
import numpy as np
import sys
from PIL import Image
from tqdm import tqdm
from dataclasses import dataclass, field
from typing import Optional
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, OptimizationParams, GenerateCamParams, GuidanceParams
import trimesh
from plyfile import PlyData, PlyElement
from scene import GaussianModel
from scene.gaussian_model import BasicPointCloud
from scene.dataset_readers import GenerateCircleCameras
from utils.camera_utils import cameraList_from_RcamInfos
from gaussian_renderer import render
import yaml

from stable_fast_3d.sf3d.system import SF3D
from stable_fast_3d.sf3d.utils import remove_background, resize_foreground

def SF3D_func(mesh):
    # mesh = load_mesh_from_file(file_path)
    skip = 1
    coords = mesh.vertices
    vertex_colors = mesh.visual.vertex_colors
    vertex_colors = vertex_colors[:, :3] / 255.0
    rgb = np.concatenate([vertex_colors[:, None, 0], vertex_colors[:, None, 1], vertex_colors[:, None, 2]], axis=1)
    coords = coords[::skip]
    rgb = rgb[::skip]
    
    angle_x = np.radians(90)
    rotation_matrix = np.array([
        [1, 0, 0],
        [0, np.cos(angle_x), -np.sin(angle_x)],
        [0, np.sin(angle_x), np.cos(angle_x)]
    ])
    coords = coords @ rotation_matrix.T
    return coords, rgb, 0.8

def process_prompt(file_name, dataset_dir, mesh, save_dir, lp, pp, gcp, gp):
    # file_path = os.path.join(mesh_dir, f'{prompt}/mesh.glb')
    # prompt_save_dir = os.path.join(save_dir, prompt)
    # os.makedirs(prompt_save_dir, exist_ok=True)
    
    xyz, rgb, scale = SF3D_func(mesh)
    # xyz[:,2] = xyz[:,2] + 0.15
    num_pts = xyz.shape[0]
    if num_pts < 10:
        return

    gaussians = GaussianModel(lp.sh_degree)
    pcd = BasicPointCloud(points=xyz, colors=rgb, normals=np.zeros((num_pts, 3)))
    # ply_path = os.path.join(prompt_save_dir, 'point_cloud.ply')
    # storePly(ply_path, xyz, rgb * 255)
    
    gaussians.create_from_pcd(pcd, 3.5)
    bg_color = [1, 1, 1] if lp._white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device=lp.data_device)

    test_cam_infos = GenerateCircleCameras(gcp, render45=gcp.render_45)
    test_cameras = {1.0: cameraList_from_RcamInfos(test_cam_infos, 1.0, gcp)}
    camera_list = test_cameras[1.0]
    # Define the view image selections and their associated labels
    view_selections = [
        (0, "front view"),
        (2, "side view"),
        (4, "back view"),
        (6, "side view")
    ]

    # Define the alternative view images
    alternative_views = [8, 10, 12, 14]
    random_number = random.randint(0, 3)
    (view_image, label_suffix) = view_selections[random_number]
    # for i, (view_image, label_suffix) in enumerate(view_selections):
    image_to_use = view_image
    if random.random() < 0.5:
        image_to_use = alternative_views[random_number]
    
    for idx, viewpoint in enumerate(camera_list):
        if viewpoint.uid != image_to_use:
            continue
        render_out = render(viewpoint, gaussians, pp, background, test=True)
        rgb_render, depth = render_out["render"], render_out["depth"]
        # if depth is not None:
        #     depth_norm = depth / depth.max()
        #     save_image(depth_norm, os.path.join(prompt_save_dir, f"render_depth_{viewpoint.uid}.png"))
        image = torch.clamp(rgb_render, 0.0, 1.0)
        save_image(image, os.path.join(save_dir, f"{file_name}.png"))

        txt_filepath = os.path.join(dataset_dir, f"{file_name}.txt")
        # todo get prompt from txt_filepath
        
        prompt = ''
        try:
            with open(txt_filepath, "r") as file:
                prompt = file.read().strip()
                if prompt.endswith('.'):
                    prompt = prompt[:-1]
        except FileNotFoundError:
            with open(os.path.join('backup/log/dataset_lora2_cap100k', f"{file_name}.txt"), "w") as txt_file:
                txt_file.write(f"{txt_filepath}")

        prompt_txt_path = os.path.join(save_dir, f"{file_name}.txt")
        with open(prompt_txt_path, "w") as txt_file:
            txt_file.write(f"{prompt}, {label_suffix}")

def func_main(mesh, file_name, save_dir, dataset_dir):
    parser = ArgumentParser(description="Training script parameters")
    parser.add_argument('--opt', type=str, default=None)
    parser.add_argument('--ip', type=str, default="127.0.0.1")
    parser.add_argument('--port', type=int, default=6009)
    parser.add_argument('--debug_from', type=int, default=-1)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--detect_anomaly', action='store_true', default=False)
    parser.add_argument("--test_ratio", type=int, default=5)
    parser.add_argument("--save_ratio", type=int, default=2)
    parser.add_argument("--save_video", type=bool, default=False)
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
    parser.add_argument("--start_checkpoint", type=str, default=None)

    lp = ModelParams(parser)
    op = OptimizationParams(parser)
    pp = PipelineParams(parser)
    gcp = GenerateCamParams(parser)
    gp = GuidanceParams(parser)

    args = parser.parse_args(sys.argv[1:])
    # if args.opt is not None:
    process_prompt(file_name, dataset_dir, mesh, save_dir, lp, pp, gcp, gp)

def SF3D_img2mesh(dataset_dir, output_dir):
    @dataclass
    class Config:
        pretrained_model: str = "stabilityai/stable-fast-3d"
        device: str = "cuda:0"
        image: Optional[str] = "./"
        output_dir: Optional[str] = "./"
        foreground_ratio: float = 0.85
        batch_size: int = 1
        texture_resolution: int = 1024
        remesh_option: str = "none"
    cfg: Config
    # implement Config 
    cfg = Config(
        image=[dataset_dir],
        output_dir=output_dir
    )
    # Ensure args.device contains cuda
    if "cuda" not in cfg.device:
        raise ValueError(
            "CUDA device is required for baking and hence running the method."
        )
    cfg.image = [dataset_dir]
    cfg.output_dir = output_dir
    # os.makedirs(output_dir, exist_ok=True)

    device = cfg.device
    if not torch.cuda.is_available():
        device = "cpu"

    model = SF3D.from_pretrained(
        cfg.pretrained_model,
        config_name="config.yaml",
        weight_name="model.safetensors",
    )
    model.to(device)
    model.eval()

    rembg_session = rembg.new_session()
    images = []
    images_name = []
    idx = 0

    for image_path in cfg.image:

        def handle_image(image_path, idx):
            print(f"handle image_path: {image_path}")
            # 提取文件名（不包括扩展名）
            file_name = os.path.splitext(os.path.basename(image_path))[0]
            image = remove_background(
                Image.open(image_path).convert("RGBA"), rembg_session
            )
            image = resize_foreground(image, cfg.foreground_ratio)
            # os.makedirs(os.path.join(output_dir, file_name), exist_ok=True)
            # image.save(os.path.join(output_dir, file_name, "input.png"))
            return [image], file_name
            # images.append(image)
            
            # images_name.append(file_name)

        if os.path.isdir(image_path):
            all_files = os.listdir(image_path)
            all_files.sort()
            num_files = len(all_files)

            partition = 4
            partition_size = (num_files + 3) // 4
            start = (partition-1) * partition_size
            end = min(partition * partition_size, num_files)
            # start=48598
            subset_files = all_files[start: end] 

            image_paths = [
                os.path.join(image_path, f)
                for f in subset_files #os.listdir(image_path) #
                if f.endswith((".png", ".jpg", ".jpeg"))
            ]
            for image_path in image_paths:
                # test whether the image has already been processed
                temp_file_name = os.path.splitext(os.path.basename(image_path))[0]
                temp_path = os.path.join(output_dir, f"{temp_file_name}.png")
                if os.path.isfile(temp_path):
                    print(f"This file has been generated {temp_path}")
                    continue
                # generate
                image, file_name = handle_image(image_path, idx)
                idx += 1
                torch.cuda.reset_peak_memory_stats()
                with torch.no_grad():
                    with torch.autocast(device_type="cuda", dtype=torch.float16):
                        mesh, glob_dict = model.run_image(
                            image,
                            bake_resolution=cfg.texture_resolution,
                            remesh=cfg.remesh_option,
                        )
                # print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
                func_main(mesh, file_name, output_dir, dataset_dir)
                # if len(image) == 1:
                #     out_mesh_path = os.path.join(output_dir, f"{file_name}.glb")
                #     mesh.export(out_mesh_path, include_normals=True)
                # else:
                #     for j in range(len(mesh)):
                #         out_mesh_path = os.path.join(output_dir, f"{file_name}.glb")
                #         print(f"out_mesh_path: {out_mesh_path}")
                #         mesh[j].export(out_mesh_path, include_normals=True)
        else:
            handle_image(image_path, idx)
            idx += 1
    return 
    # for i in tqdm(range(0, len(images), cfg.batch_size)):
    #     image = images[i : i + cfg.batch_size]
    #     torch.cuda.reset_peak_memory_stats()
    #     with torch.no_grad():
    #         with torch.autocast(device_type="cuda", dtype=torch.float16):
    #             mesh, glob_dict = model.run_image(
    #                 image,
    #                 bake_resolution=cfg.texture_resolution,
    #                 remesh=cfg.remesh_option,
    #             )
    #     print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")

    #     if len(image) == 1:
    #         out_mesh_path = os.path.join(output_dir, images_name[i], "mesh.glb")
    #         mesh.export(out_mesh_path, include_normals=True)
    #     else:
    #         for j in range(len(mesh)):
    #             out_mesh_path = os.path.join(output_dir, images_name[i+j], "mesh.glb")
    #             print(f"out_mesh_path: {out_mesh_path}")
    #             mesh[j].export(out_mesh_path, include_normals=True)

    # return 