from typing import *
from torch import Tensor

import torch
from einops import rearrange, repeat

from src.models.gs_render.deferred_bp import deferred_bp
from src.models.gs_render.gs_util import GaussianModel, render
from src.options import Options
from src.utils import unproject_depth


class GaussianRenderer:
    def __init__(self, opt: Options):
        self.opt = opt

        self.scale_activation = lambda x: \
            self.opt.scale_min * x + self.opt.scale_max * (1. - x)  # [0, 1] -> [s_min, s_max]

    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def render(self,
        model_outputs: Dict[str, Tensor],
        input_C2W: Tensor, input_fxfycxcy: Tensor,
        C2W: Tensor, fxfycxcy: Tensor,
        height: Optional[float] = None,
        width: Optional[float] = None,
        bg_color: Tuple[float, float, float] = (1., 1., 1.),
        scaling_modifier: float = 1.,
        opacity_threshold: float = 0.,
        input_normalized: bool = True,
        in_image_format: bool = True,
    ):
        if not in_image_format:
            assert height is not None and width is not None
            assert "xyz" in model_outputs  # depth must be in image format

        rgb, scale, rotation, opacity = model_outputs["rgb"], model_outputs["scale"], model_outputs["rotation"], model_outputs["opacity"]
        depth = model_outputs.get("depth", None)
        xyz = model_outputs.get("xyz", None)
        # Only one of `depth` and `xyz` should be None
        assert (depth is not None or xyz is not None) and not (depth is not None and xyz is not None)

        # Rendering resolution could be different from input resolution
        H = height if height is not None else rgb.shape[-2]
        W = width if width is not None else rgb.shape[-1]

        # Reshape for rendering
        if in_image_format:
            rgb = rearrange(rgb, "b v c h w -> b (v h w) c")
            scale = rearrange(scale, "b v c h w -> b (v h w) c")
            rotation = rearrange(rotation, "b v c h w -> b (v h w) c")
            opacity = rearrange(opacity, "b v c h w -> b (v h w) c")

        # Prepare XYZ for rendering
        if xyz is None:
            if input_normalized:
                depth = depth + torch.norm(input_C2W[:, :, :3, 3], p=2, dim=2, keepdim=True)[..., None, None]  # [-1, 1] -> image plane + [-1, 1]
            xyz = unproject_depth(depth.squeeze(2), input_C2W, input_fxfycxcy)  # [-1, 1]
        xyz = xyz + model_outputs.get("offset", torch.zeros_like(xyz))
        if in_image_format:
            xyz = rearrange(xyz, "b v c h w -> b (v h w) c")

        # From [-1, 1] to valid values
        if input_normalized:
            rgb = rgb * 0.5 + 0.5  # [-1, 1] -> [0, 1]

            # need comment out for small scale
            scale = self.scale_activation(scale * 0.5 + 0.5)  # [-1, 1] -> [0, 1] -> [s_min, s_max]
            rotation = rotation  # not changed; already L2 normalized
            opacity = opacity * 0.5 + 0.5  # [-1, 1] -> [0, 1]

        # print('【small scale to show the sculpture】')
        # scale = scale * 0.001 # 

        # Filter by opacity
        opacity = (opacity > opacity_threshold) * opacity

        (B, V), device = C2W.shape[:2], C2W.device  # `HR`/`WR` meight be different from `H`/`W`

        
        # dtype_ = C2W.dtype 
        dtype_ = torch.float32
        images = torch.zeros(B, V, 3, H, W, dtype=dtype_, device=device)
        alphas = torch.zeros(B, V, 1, H, W, dtype=dtype_, device=device)
        depths = torch.zeros(B, V, 1, H, W, dtype=dtype_, device=device)
        normals = torch.zeros(B, V, 3, H, W, dtype=dtype_, device=device)
        

        
        pcs = []
        for i in range(B):
            pcs.append(GaussianModel().set_data(xyz[i], rgb[i], scale[i], rotation[i], opacity[i]))

        if self.opt.render_type == "defered":
            images, alphas, depths, normals = deferred_bp(
                xyz, rgb, scale, rotation, opacity,
                H, W, C2W, fxfycxcy,
                self.opt.deferred_bp_patch_size, GaussianModel(),
                self.opt.znear, self.opt.zfar,
                bg_color,
                scaling_modifier,
                self.opt.coord_weight > 0. or self.opt.normal_weight > 0. or \
                    self.opt.vis_coords or self.opt.vis_normals,  # whether render depth & normal
            )
        else:  # default
            for i in range(B):
                pc = pcs[i]
                for j in range(V):
                    render_results = render(
                        pc, H, W, C2W[i, j], fxfycxcy[i, j],
                        self.opt.znear, self.opt.zfar,
                        bg_color,
                        scaling_modifier,
                        self.opt.coord_weight > 0. or self.opt.normal_weight > 0. or \
                            self.opt.vis_coords or self.opt.vis_normals,  # whether render depth & normal
                    )
                    images[i, j] = render_results["image"]
                    alphas[i, j] = render_results["alpha"]
                    depths[i, j] = render_results["depth"]
                    normals[i, j] = render_results["normal"]

        if not isinstance(bg_color, Tensor):
            bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=device)
        bg_color = repeat(bg_color, "c -> b v c h w", b=B, v=V, h=H, w=W)

        coords = (unproject_depth(depths.squeeze(2), C2W, fxfycxcy)
            * 0.5 + 0.5) * alphas + (1. - alphas) * bg_color
        C2W = C2W.to(torch.float32)
        normals_ = (torch.einsum("bvrc,bvchw->bvrhw", C2W[:, :, :3, :3], normals)
            * 0.5 + 0.5) * alphas + (1. - alphas) * bg_color

        return {
            "image": images,
            "alpha": alphas,
            "coord": coords,
            "normal": normals_,
            "raw_depth": depths,
            "raw_normal": normals,
            "pc": pcs,
        }
