"""Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V.

(MPG) is holder of all proprietary rights on this computer program. You can
only use this computer program if you have closed a license agreement with MPG
or you get the right to use the computer program from someone who is authorized
to grant you that right. Any use of the computer program without a valid
license is prohibited and liable to prosecution. Copyright 2019 Max-Planck-
Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of
its Max Planck Institute for Intelligent Systems and the Max Planck Institute
for Biological Cybernetics. All rights reserved. More information about VOCA is
available at http://voca.is.tue.mpg.de. For comments or questions, please email
us at voca@tue.mpg.de
"""

import cv2
import numpy as np
import os
import subprocess
import torch
from tqdm import tqdm


def render_meshes(
    mesh_vertices: torch.Tensor,
    faces: torch.Tensor,
    img_size: tuple = (256, 256),
    aa_factor: int = 1,
    vis_bs: int = 100,
):
    assert len(mesh_vertices) == len(faces) or len(faces) == 1
    if len(faces) == 1:
        faces = faces.repeat(len(mesh_vertices), 1, 1)

    x_max = y_max = mesh_vertices[..., 0:2].abs().max()
    from pytorch3d.renderer import (
        DirectionalLights,
        FoVOrthographicCameras,
        HardPhongShader,
        Materials,
        MeshRasterizer,
        MeshRenderer,
        RasterizationSettings,
        TexturesVertex,
    )
    from pytorch3d.renderer.blending import BlendParams
    from pytorch3d.renderer.materials import Materials
    from pytorch3d.structures import Meshes

    aa_factor = int(aa_factor)
    device = mesh_vertices.device
    verts_batch_lst = torch.split(mesh_vertices, vis_bs)
    faces_batch_lst = torch.split(faces, vis_bs)
    R = torch.eye(3).to(mesh_vertices.device)
    R[0, 0] = -1
    R[2, 2] = -1
    T = torch.zeros(3).to(mesh_vertices.device)
    T[2] = 10
    R, T = R[None], T[None]
    W, H = img_size
    cameras = FoVOrthographicCameras(
        device=device,
        R=R,
        T=T,
        znear=0.01,
        zfar=3,
        max_x=x_max * 1.2,
        min_x=-x_max * 1.2,
        max_y=y_max * 1.2,
        min_y=-y_max * 1.2,
    )
    # cameras = FoVPerspectiveCameras(
    # 	device=device,
    # 	R=R,
    # 	T=T,
    # 	znear=0.01,
    # 	zfar=3,
    # 	aspect_ratio=1,
    # 	fov=0.3,
    # 	degrees=False
    # 	)

    raster_settings = RasterizationSettings(
        image_size=(W * aa_factor, H * aa_factor),
        blur_radius=0.0,
        faces_per_pixel=1,
    )
    raster = MeshRasterizer(
        cameras=cameras,
        raster_settings=raster_settings,
    )
    blend_params = BlendParams(sigma=0.0, gamma=0.0, background_color=(0.0, 0.0, 0.0))
    lights = DirectionalLights(
        device=device,
        direction=((0, 0, 1),),
        ambient_color=((0.3, 0.3, 0.3),),
        diffuse_color=((0.6, 0.6, 0.6),),
        specular_color=((0.1, 0.1, 0.1),),
    )
    materias = Materials(
        ambient_color=((1, 1, 1),),
        diffuse_color=((1, 1, 1),),
        specular_color=((1, 1, 1),),
        shininess=15,
        device=device,
    )
    shader = HardPhongShader(
        device=device,
        cameras=cameras,
        lights=lights,
        materials=materias,
        blend_params=blend_params,
    )
    renderer = MeshRenderer(
        rasterizer=raster,
        shader=shader,
    )
    rendered_imgs = []
    for verts_batch, faces_batch in tqdm(
        zip(verts_batch_lst, faces_batch_lst), total=(len(verts_batch_lst))
    ):
        textures = TexturesVertex(verts_features=torch.ones_like(verts_batch))
        meshes = Meshes(verts=verts_batch, faces=faces_batch, textures=textures)
        with torch.no_grad():
            imgs = renderer(meshes)
        rendered_imgs.append(imgs.cpu())
    rendered_imgs = torch.cat(rendered_imgs, dim=0)
    if aa_factor > 1:
        rendered_imgs = rendered_imgs.permute(0, 3, 1, 2)  # NHWC -> NCHW
        rendered_imgs = torch.nn.functional.interpolate(
            rendered_imgs, scale_factor=1 / aa_factor, mode="bicubic"
        )
        rendered_imgs = rendered_imgs.permute(0, 2, 3, 1)  # NCHW -> NHWC
    return rendered_imgs


def read_obj(in_path):
    with open(in_path, "r") as obj_file:
        # Read the lines of the OBJ file
        lines = obj_file.readlines()

    # Initialize empty lists for vertices and faces
    verts = []
    faces = []
    for line in lines:
        line = line.strip()  # Remove leading/trailing whitespace
        elements = line.split()  # Split the line into elements

        if len(elements) == 0:
            continue  # Skip empty lines

        # Check the type of line (vertex or face)
        if elements[0] == "v":
            # Vertex line
            x, y, z = map(float, elements[1:4])  # Extract the vertex coordinates
            verts.append((x, y, z))  # Add the vertex to the list
        elif elements[0] == "f":
            # Face line
            face_indices = [
                int(index.split("/")[0]) for index in elements[1:]
            ]  # Extract the vertex indices
            faces.append(face_indices)  # Add the face to the list
    return np.array(verts), np.array(faces)


def pad_for_libx264(image_array):
    """Pad zeros if width or height of image_array is not divisible by 2.
    Otherwise you will get.

    \"[libx264 @ 0x1b1d560] width not divisible by 2 \"

    Args:
            image_array (np.ndarray):
                    Image or images load by cv2.imread().
                    Possible shapes:
                    1. [height, width]
                    2. [height, width, channels]
                    3. [images, height, width]
                    4. [images, height, width, channels]

    Returns:
            np.ndarray:
                    A image with both edges divisible by 2.
    """
    if image_array.ndim == 2 or (image_array.ndim == 3 and image_array.shape[2] == 3):
        hei_index = 0
        wid_index = 1
    elif image_array.ndim == 4 or (image_array.ndim == 3 and image_array.shape[2] != 3):
        hei_index = 1
        wid_index = 2
    else:
        return image_array
    hei_pad = image_array.shape[hei_index] % 2
    wid_pad = image_array.shape[wid_index] % 2
    if hei_pad + wid_pad > 0:
        pad_width = []
        for dim_index in range(image_array.ndim):
            if dim_index == hei_index:
                pad_width.append((0, hei_pad))
            elif dim_index == wid_index:
                pad_width.append((0, wid_pad))
            else:
                pad_width.append((0, 0))
        values = 0
        image_array = np.pad(
            image_array, pad_width, mode="constant", constant_values=values
        )
    return image_array


def array_to_video(
    image_array: np.ndarray,
    output_path: str,
    fps=30,
    resolution=None,
    disable_log: bool = False,
) -> None:
    """Convert an array to a video directly, gif not supported.

    Args:
            image_array (np.ndarray): shape should be (f * h * w * 3).
            output_path (str): output video file path.
            fps (Union[int, float, optional): fps. Defaults to 30.
            resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
                    optional): (height, width) of the output video.
                    Defaults to None.
            disable_log (bool, optional): whether close the ffmepg command info.
                    Defaults to False.
    Raises:
            FileNotFoundError: check output path.
            TypeError: check input array.

    Returns:
            None.
    """
    if not isinstance(image_array, np.ndarray):
        raise TypeError("Input should be np.ndarray.")
    assert image_array.ndim == 4
    assert image_array.shape[-1] == 3
    if resolution:
        height, width = resolution
        width += width % 2
        height += height % 2
    else:
        image_array = pad_for_libx264(image_array)
        height, width = image_array.shape[1], image_array.shape[2]
    command = [
        "/usr/bin/ffmpeg",
        "-y",  # (optional) overwrite output file if it exists
        "-f",
        "rawvideo",
        "-s",
        f"{int(width)}x{int(height)}",  # size of one frame
        "-pix_fmt",
        "bgr24",
        "-r",
        f"{fps}",  # frames per second
        "-loglevel",
        "error",
        "-threads",
        "4",
        "-i",
        "-",  # The input comes from a pipe
        "-vcodec",
        "libx264",
        "-an",  # Tells FFMPEG not to expect any audio
        output_path,
    ]
    if not disable_log:
        print(f'Running "{" ".join(command)}"')
    process = subprocess.Popen(
        command,
        stdin=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    if process.stdin is None or process.stderr is None:
        raise BrokenPipeError("No buffer received.")
    index = 0
    while True:
        if index >= image_array.shape[0]:
            break
        process.stdin.write(image_array[index].tobytes())
        index += 1
    process.stdin.close()
    process.stderr.close()
    process.wait()


def get_obj_faces(annot_type: str):
    template_obj_path_dict = {
        "3DETF_blendshape_weight": "checkpoints/unitalker/resources/obj_template/3DETF_blendshape_weight.obj",
        "FLAME_5023_vertices": "checkpoints/unitalker/resources/obj_template/FLAME_5023_vertices.obj",
        "BIWI_23370_vertices": "checkpoints/unitalker/resources/obj_template/BIWI_23370_vertices.obj",
        "flame_params_from_dadhead": "checkpoints/unitalker/resources/obj_template/flame_params_from_dadhead.obj",
        "inhouse_blendshape_weight": "checkpoints/unitalker/resources/obj_template/inhouse_blendshape_weight.obj",
        "meshtalk_6172_vertices": "checkpoints/unitalker/resources/obj_template/meshtalk_6172_vertices.obj",
    }

    faces = np.array(read_obj(template_obj_path_dict[annot_type])[1])
    if faces.min() == 1:
        faces = faces - 1
    return faces


def render_vertices_video(
    vertices,
    annot_type: str,
    img_size=(512, 512),
    aa_factor: int = 1,
    vis_bs: int = 100,
    device: str = None,
    return_float: bool = False,
):
    """
    将一段顶点序列渲染为 RGB 视频数组（T,H,W,3）。

    Args:
        vertices: np.ndarray[... ,3] 或 torch.Tensor[...,3]
                  - 支持 (V,3) 单帧，或 (T,V,3) 多帧
        annot_type: 与 get_obj_faces 对应的标注类型字符串
        img_size: (W,H) 输出分辨率
        aa_factor: 超采样因子（>1 会先大图渲染再双三次下采样）
        vis_bs: 分批渲染的 batch size，降低显存/内存峰值
        device: "cuda" 或 "cpu"；默认自动选择
        return_float: 若为 True，返回 float32、范围[0,1]；否则返回 uint8 0–255

    Returns:
        np.ndarray, shape (T, H, W, 3)，RGB
    """
    # 设备
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    # 顶点张量，形状规范到 (T,V,3)
    if isinstance(vertices, np.ndarray):
        verts = torch.from_numpy(vertices)
    elif isinstance(vertices, torch.Tensor):
        verts = vertices
    else:
        raise TypeError(f"Unsupported type for vertices: {type(vertices)}")

    if verts.ndim == 2 and verts.shape[-1] == 3:
        verts = verts.unsqueeze(0)  # (V,3) -> (1,V,3)
    assert verts.ndim == 3 and verts.shape[-1] == 3, "vertices must be (T,V,3) or (V,3)"
    verts = verts.to(device=device, dtype=torch.float32)

    # faces（模板面片）并放到同一设备
    faces_np = get_obj_faces(annot_type)  # (F,3)
    faces = torch.from_numpy(faces_np).long().to(device)
    if faces.ndim == 2:
        faces = faces.unsqueeze(0)  # (1,F,3)，render_meshes 会按需要重复

    # 渲染：得到 (T,H,W,4) 的 RGBA
    imgs_rgba = render_meshes(
        verts, faces, img_size=img_size, aa_factor=int(aa_factor), vis_bs=int(vis_bs)
    )  # float, [0,1], (T,H,W,4)

    # 取 RGB，裁剪到 [0,1]
    imgs_rgb = imgs_rgba[..., :3].clamp(0.0, 1.0)

    if return_float:
        return imgs_rgb.detach().cpu().numpy().astype(np.float32)  # (T,H,W,3), [0,1]
    else:
        return (imgs_rgb.detach().cpu().numpy() * 255.0).astype(
            np.uint8
        )  # (T,H,W,3), uint8
