import torch
from torch.cuda.amp import autocast
from threestudio.utils.ops import convert_pose,get_projection_matrix_gaussian
from typing import NamedTuple


class Camera(NamedTuple):
    FoVx: torch.Tensor
    FoVy: torch.Tensor
    camera_center: torch.Tensor
    image_width: int
    image_height: int
    world_view_transform: torch.Tensor
    full_proj_transform: torch.Tensor
    c2w: torch.Tensor

def get_cam_info_gaussian_2(c2w, fovx, fovy, znear, zfar):
    c2w = convert_pose(c2w)
    world_view_transform = torch.inverse(c2w)
    world_view_transform = world_view_transform.transpose(0, 1).cuda().float()
    projection_matrix = (
        get_projection_matrix_gaussian(znear=znear, zfar=zfar, fovX=fovx, fovY=fovy)
        .transpose(0, 1)
        .cuda()
    )
    full_proj_transform = (
        world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))
    ).squeeze(0)
    camera_center = world_view_transform.inverse()[3, :3]
    return c2w, world_view_transform, full_proj_transform, projection_matrix,camera_center


class FaceMeshBatchRender:
    def batch_forward(self, batch):
        # print(batch.keys())
        bs = batch["c2w"].shape[0]
        renders = []
        renders_kd = []
        full_proj_transforms = []
        normals = []
        pred_normals = []
        depths = []
        masks = []
        for batch_idx in range(bs):
            batch["batch_idx"] = batch_idx
            fovy = batch["fovy"][batch_idx]
            c2w, w2c, full_proj, _, cam_p = get_cam_info_gaussian_2(
                c2w=batch["c2w"][batch_idx], fovx=fovy, fovy=fovy, znear=0.1, zfar=100
            )
            c2w = c2w.cuda()

            # import pdb; pdb.set_trace()
            viewpoint_cam = Camera(
                FoVx=fovy,
                FoVy=fovy,
                image_width=batch["width"],
                image_height=batch["height"],
                world_view_transform=w2c,
                full_proj_transform=full_proj,
                camera_center=cam_p,
                c2w = c2w,
            )

            with autocast(enabled=False):
                render_pkg = self.forward(
                    viewpoint_cam, self.background_tensor, **batch
                )
                
                if render_pkg["render_kd"] is not None:
                    renders_kd.append(render_pkg["render_kd"])

                renders.append(render_pkg["render"])
                full_proj_transforms.append(full_proj)

                if render_pkg.__contains__("normal"):
                    normals.append(render_pkg["normal"])
                if (
                    render_pkg.__contains__("pred_normal")
                    and render_pkg["pred_normal"] is not None
                ):
                    pred_normals.append(render_pkg["pred_normal"])
                if render_pkg.__contains__("depth"):
                    depths.append(render_pkg["depth"])
                if render_pkg.__contains__("mask"):
                    masks.append(render_pkg["mask"])

        outputs = {
            "comp_rgb": torch.stack(renders, dim=0).permute(0, 2, 3, 1),
            "full_proj_transforms": full_proj_transforms,

        }
        if len(renders_kd)>0:
            outputs["comp_rgb_kd"] = torch.stack(renders_kd, dim=0).permute(0, 2, 3, 1) #YJ
        else:
            outputs["comp_rgb_kd"] = None

        if len(normals) > 0:
            outputs.update(
                {
                    "comp_normal": torch.stack(normals, dim=0).permute(0, 2, 3, 1),
                }
            )
        if len(pred_normals) > 0:
            outputs.update(
                {
                    "comp_pred_normal": torch.stack(pred_normals, dim=0).permute(
                        0, 2, 3, 1
                    ),
                }
            )
        if len(depths) > 0:
            outputs.update(
                {
                    "comp_depth": torch.stack(depths, dim=0).permute(0, 2, 3, 1),
                }
            )
        if len(masks) > 0:
            outputs.update(
                {
                    "comp_mask": torch.stack(masks, dim=0).permute(0, 2, 3, 1),
                }
            )
        return outputs
