import torch
from threestudio.utils.ops import get_cam_info_gaussian
from torch.cuda.amp import autocast

from ..geometry.gaussian_base import BasicPointCloud, Camera


class GaussianBatchRenderer:
    def batch_forward(self, batch):
        bs = batch["c2w"].shape[0]
        renders = []
        viewspace_points = []
        visibility_filters = []
        radiis = []
        normals = []
        pred_normals = []
        depths = []
        masks = []
        # NEW:
        normals_raw = []
        alphas = []
        for batch_idx in range(bs):
            batch["batch_idx"] = batch_idx
            fovy = batch["fovy"][batch_idx]
            w2c, proj, cam_p = get_cam_info_gaussian(
                c2w=batch["c2w"][batch_idx], fovx=fovy, fovy=fovy, znear=0.1, zfar=100
            )

            # import pdb; pdb.set_trace()
            viewpoint_cam = Camera(
                FoVx=fovy,
                FoVy=fovy,
                image_width=batch["width"],
                image_height=batch["height"],
                world_view_transform=w2c,
                full_proj_transform=proj,
                camera_center=cam_p,
            )

            with autocast(enabled=False):
                render_pkg = self.forward(
                    viewpoint_cam, self.background_tensor, **batch
                )
                renders.append(render_pkg["render"])
                viewspace_points.append(render_pkg["viewspace_points"])
                visibility_filters.append(render_pkg["visibility_filter"])
                radiis.append(render_pkg["radii"])
                if render_pkg.__contains__("normal"):
                    normals.append(render_pkg["normal"])
                if "normal_raw" in render_pkg:
                    normals_raw.append(render_pkg["normal_raw"])  # [H,W,3] TRAIN
                if (
                    render_pkg.__contains__("pred_normal")
                    and render_pkg["pred_normal"] is not None
                ):
                    pred_normals.append(render_pkg["pred_normal"])
                if "alpha" in render_pkg:
                    alphas.append(render_pkg["alpha"])  # [H,W,1] TRAIN
                if render_pkg.__contains__("depth"):
                    depths.append(render_pkg["depth"])
                if render_pkg.__contains__("mask"):
                    masks.append(render_pkg["mask"])

        outputs = {
            "comp_rgb": torch.stack(renders, dim=0).permute(0, 2, 3, 1),
            "viewspace_points": viewspace_points,
            "visibility_filter": visibility_filters,
            "radii": radiis,
        }
        if len(normals) > 0:
            outputs["comp_normal_vis"] = torch.stack(normals, dim=0).permute(0, 2, 3, 1)  # [B,H,W,3] in [0,1]

        if len(normals_raw) > 0:
            outputs["comp_normal"] = torch.stack(normals_raw, dim=0)  # already [B,H,W,3] in [-1,1]

        if len(pred_normals) > 0:
            outputs["comp_pred_normal"] = torch.stack(pred_normals, dim=0).permute(0, 2, 3, 1)

        if len(depths) > 0:
            outputs["comp_depth"] = torch.stack(depths, dim=0)  # [B,H,W,1]

        if len(alphas) > 0:
            outputs["comp_alpha"] = torch.stack(alphas, dim=0)  # [B,H,W,1]

        if len(masks) > 0:
            outputs["comp_mask"] = torch.stack(masks, dim=0).permute(0, 2, 3, 1)

        return outputs
