import math
import torch
import numpy as np
import torch.nn.functional as F
from arguments import OptimizationParams
from scene.cameras import Camera
from scene.gaussian_model import GaussianModel
from scene.derect_light_sh import DirectLightEnv
from utils.sh_utils import eval_sh
from utils.sh_utils_3_cube import ProjectFunction,fibonacci_sphere_sampling,eval_sh_coef,ToVector
# from utils.graphics_utils import fibonacci_sphere_sampling
from .r3dg_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from utils.rotation_utils import *
from tqdm import tqdm
from bvh import RayTracer
import cv2
import time
def read_light_sh_from_file(light_shs_file):
    # import pdb; pdb.set_trace()
    with open(light_shs_file, 'r') as f:
        c = f.read()
        l = eval(c)
        light_shs = torch.tensor(l)
    # import pdb; pdb.set_trace()
    return light_shs
def trace_visi(gaussians,sample_num,ray_d):
        raytracer = RayTracer(gaussians.get_xyz, gaussians.get_scaling,
                              gaussians.get_rotation)
        gaussians_xyz = gaussians.get_xyz
        gaussians_inverse_covariance = gaussians.get_inverse_covariance()
        gaussians_opacity = gaussians.get_opacity[:, 0]
        gaussians_normal = gaussians.get_normal
        incident_visibility_results = []
        chunk_size = gaussians_xyz.shape[0] // ((sample_num - 1) // 24 + 1)
        #chunk_size= gaussians_xyz.shape[0]
        for offset in tqdm(range(0, gaussians_xyz.shape[0], chunk_size),
                           "Precompute raytracing visibility"):
            incident_dirs = ray_d[offset:offset + chunk_size]
            print(incident_dirs.shape)
            #print(aaaaaa)
            trace_results = raytracer.trace_visibility(
                gaussians_xyz[offset:offset + chunk_size, None].expand_as(incident_dirs),
                incident_dirs,
                gaussians_xyz,
                gaussians_inverse_covariance,
                gaussians_opacity,
                gaussians_normal)
            incident_visibility = trace_results["visibility"]
            nan_idx=torch.where(incident_visibility!=incident_visibility)
            incident_visibility[nan_idx]=0.0
            incident_visibility_results.append(incident_visibility)
        incident_visibility_result = torch.cat(incident_visibility_results, dim=0)
        # print(incident_visibility_result.min())
        # # print("11")
        
        # print(incident_visibility_result.min())
        # print(aaaaaa)
        gaussians._visibility_tracing = incident_visibility_result
def trace_visi_cube(gaussians,sample_num,ray_d,offests,numm):
        raytracer = RayTracer(gaussians.get_xyz, gaussians.get_scaling,
                              gaussians.get_rotation)
        gaussians_xyz = gaussians.get_xyz
        gaussians_inverse_covariance = gaussians.get_inverse_covariance()
        gaussians_opacity = gaussians.get_opacity[:, 0]
        gaussians_normal = gaussians.get_normal
        incident_visibility_results = []
        chunk_size = gaussians_xyz.shape[0] // ((sample_num - 1) // 24 + 1)
        chunk_size= numm
        for offset in tqdm(range(offests, gaussians_xyz.shape[0], chunk_size),
                           "Precompute raytracing visibility"):
            if offset>offests:
                break
            incident_dirs = ray_d[0:chunk_size]
            print(incident_dirs.shape)
            #print(aaaaaa)
            # print(offset)
            # print(aaaaa)
            trace_results = raytracer.trace_visibility(
                gaussians_xyz[offset:offset + chunk_size, None].expand_as(incident_dirs),
                incident_dirs,
                gaussians_xyz,
                gaussians_inverse_covariance,
                gaussians_opacity,
                gaussians_normal)
            incident_visibility = trace_results["visibility"]
            nan_idx=torch.where(incident_visibility!=incident_visibility)
            incident_visibility[nan_idx]=0.0
            incident_visibility_results.append(incident_visibility)
        incident_visibility_result = torch.cat(incident_visibility_results, dim=0)
        # print(incident_visibility_result.min())
        # # print("11")
        
        # print(incident_visibility_result.min())
        # print(aaaaaa)
        gaussians._visibility_tracing = incident_visibility_result
def render_view(viewpoint_camera: Camera, pc: GaussianModel, pipe, bg_color: torch.Tensor,
                scaling_modifier=1.0, override_color=None, is_training=False, 
                dict_params=None, bake=False,light_transport=None,precompute=False,idx=0,light_shs_file='',light_deg=3,occ=None,indir=None):
    
    direct_light_env_light = dict_params.get("env_light")
    gamma_transform = dict_params.get("gamma")
    sample_num = dict_params.get("sample_num")

    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
    screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
    try:
        screenspace_points.retain_grad()
    except:
        pass

    # Set up rasterization configuration
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
    intrinsic = viewpoint_camera.intrinsics

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        cx=float(intrinsic[0, 2]),
        cy=float(intrinsic[1, 2]),
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=False,
        backward_geometry=True,
        computer_pseudo_normal=True,
        debug=pipe.debug
    )

    rasterizer = GaussianRasterizer(raster_settings=raster_settings)

    means3D = pc.get_xyz
    means2D = screenspace_points
    opacity = pc.get_opacity

    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
    # scaling / rotation by the rasterizer.
    scales = None
    rotations = None
    cov3D_precomp = None
    if pipe.compute_cov3D_python:
        cov3D_precomp = pc.get_covariance(scaling_modifier)
    else:
        scales = pc.get_scaling
        rotations = pc.get_rotation

    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
    shs = None
    colors_precomp = None
    if override_color is None:
        if pipe.compute_SHs_python:
            dir_pp_normalized = F.normalize(viewpoint_camera.camera_center.repeat(means3D.shape[0], 1) - means3D,
                                            dim=-1)
            shs_view = pc.get_shs.transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1) ** 2)
            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
        else:
            shs = pc.get_shs
    else:
        colors_precomp = override_color

    base_color = pc.get_base_color
    # print("basecolor",base_color.mean())
    # base_color=torch.ones_like(base_color)
    roughness = pc.get_roughness
    # print("roughness",roughness.mean())
    # roughness=torch.ones_like(roughness)
    metallic = pc.get_metallic
    # print("metal",metallic.mean())
    # metallic=torch.zeros_like(metallic)
    normal = pc.get_normal
    visibility = pc.get_visibility
    incidents = pc.get_incidents
    viewdirs = F.normalize(viewpoint_camera.camera_center - means3D, dim=-1)

    # process by chunks to save memory
    # TODO: rewrite in CUDA
    chunk_size = base_color.shape[0] // ((sample_num - 1) // 8 + 1)
    brdf_color_chunks = []
    extra_results_chunks = []
    # for offset in range(0, base_color.shape[0], chunk_size):
    #     brdf_color, extra_results = rendering_equation_python(
    #         base_color[offset:offset+chunk_size],
    #         roughness[offset:offset+chunk_size],
    #         metallic[offset:offset+chunk_size],
    #         normal.detach()[offset:offset+chunk_size],
    #         viewdirs[offset:offset+chunk_size],
    #         incidents[offset:offset+chunk_size],
    #         is_training, direct_light_env_light,
    #         visibility[offset:offset+chunk_size], sample_num, bake,
    #         visibility_precompute=None if bake else pc._visibility_tracing[offset:offset+chunk_size])
        
    #     brdf_color_chunks.append(brdf_color)
    #     extra_results_chunks.append(extra_results)
    # time.time()    
    # r = fromRotation(torch.tensor(idx * 0.003 * 2 * torch.pi), torch.tensor([0, 1, 0], dtype = torch.float32))
    # r = fromRotation(torch.tensor(0 * 0.003 * 2 * torch.pi), torch.tensor([0, 1, 0], dtype = torch.float32))
    idx += 1
    print(idx)
    # r = mat4Matrix2mathMatrix(r)
    light_shs = read_light_sh_from_file(light_shs_file)
#     light_shs=torch.tensor([[1.64739, -0.0567535, 0.2088, -0.20922, 0.0535251, -0.174331, 0.0266558, -0.0111866, -0.0418862, 0.0393093, 0.0519994, -0.028159, -0.20907, 0.02357, 0.0794556, -0.00706448],
# [1.25846, -0.0364914, 0.24987, -0.214639, 0.0697918, -0.169005, -0.0187954, -0.0305843, -0.0153261, 0.0358524, 0.0682489, 0.00219614, -0.14748, 0.0408841, 0.075493, -0.00253895],
# [0.774436, -0.0158676, 0.227027, -0.208651, 0.0766489, -0.126428, -0.0166017, -0.0436586, 0.0366098, 0.0128161, 0.0841806, 0.012524, -0.087762, 0.0587182, 0.0788554, 0.00676332]],device="cuda")
    # light_shs=torch.tensor([[3.04949, 0.175166, 0.177748, 0.251558, -0.131606, -0.105654, -0.0541271, 0.0435226, -0.222103, 0.152418, -0.033816, 0.0772488, 0.0377921, -0.0922061, 0.0473335, -0.00317843],
# [3.19935, 0.161815, 0.120789, 0.196896, -0.138798, -0.0851454, -0.0399412, 0.0383304, -0.204134, 0.151009, -0.0210227, 0.0591032, 0.0268158, -0.0759584, 0.0309626, 0.0094788],
# [3.30395, 0.11395, 0.0817311, 0.149406, -0.10666, -0.0531691, -0.0112636, 0.0340628, -0.151267, 0.11851, -0.0159299, 0.0304058, 0.0191498, -0.0549724, 0.0260181, 0.0148034]])
    light_shs=torch.tensor([[1.64739, -0.2088, -0.0567535, 0.20922, -0.0111866, 0.174331, 0.0229471, -0.0535251, -0.0440266],
[1.25846, -0.24987, -0.0364914, 0.214639, -0.0305843, 0.169005, 0.0226703, -0.0697917, 0.00861447],
[0.774436, -0.227027, -0.0158676, 0.208651, -0.0436586, 0.126428, -0.0234041, -0.0766489, 0.0326816]]
,device="cuda")    
#     light_shs=torch.tensor([[1.64739, -0.0567535, 0.2088, -0.20922, 0.0535251, -0.174331, 0.0266558, -0.0111866, -0.0418862, 0.0393093, 0.0519994, -0.028159, -0.20907, 0.02357, 0.0794556, -0.00706448],
# [1.25846, -0.0364914, 0.24987, -0.214639, 0.0697918, -0.169005, -0.0187954, -0.0305843, -0.0153261, 0.0358524, 0.0682489, 0.00219614, -0.14748, 0.0408841, 0.075493, -0.00253895],
# [0.774436, -0.0158676, 0.227027, -0.208651, 0.0766489, -0.126428, -0.0166017, -0.0436586, 0.0366098, 0.0128161, 0.0841806, 0.012524, -0.087762, 0.0587182, 0.0788554, 0.00676332]],device="cuda")
    # light_shs = light_shs.transpose(0,1)
    # light_shs = getRotationPrecomputeL(light_shs, r)
    # light_shs = light_shs.transpose(0,1)
    light_shs=light_shs.cuda()/0.28
    # light_shs=torch.zeros_like(light_shs)
    # light_shs[:,0]=1/0.282
    # brdf_color = torch.cat(brdf_color_chunks, dim=0)
    # extra_results = {k: torch.cat([x[k] for x in extra_results_chunks], dim=0) for k in extra_results_chunks[0].keys()}
    prt=True
    # base_color=torch.ones_like(base_color)
    # metallic=torch.ones_like(metallic)*0.0
    # roughness=torch.ones_like(roughness)

    if prt==True:
        # shs_view_direct = direct_light_env_light.get_env_shs.transpose(1, 2).unsqueeze(1)
        shs_view_direct=light_shs.unsqueeze(0).repeat(base_color.shape[0], 1, 1)
        shs_view_indir = incidents.transpose(1, 2).view(base_color.shape[0], 1, 3, -1).squeeze()
        # print( shs_view_direct.shape)
        # print( shs_view_indir.shape)
        light=shs_view_direct#+shs_view_indir


        # print(visibility.shape)
        #print(aaaaa)
        #precompute=True
        light_deg=light_deg
        chunk_size2=3000
        # print(light_deg)
        # print(viewdirs)
        # print(aaaaaa)
        # precompute=False
        diffuse=False
        # print(AAAAAA)
        print("idx",idx)
        if precompute:
            shs11=[]
            occ=[]
            d,a=fibonacci_sphere_sampling(normal, 15**2, random_rotate=True)
            #print(aaaaaax)
            theta = torch.acos(d[...,2])
            phi = torch.atan2(d[...,1], d[...,0])
            # d,a=fibonacci_sphere_sampling(normal[0:10], 800**2, random_rotate=True)
            # #print(aaaaaax)
            # theta = torch.acos(d[...,2])
            # phi = torch.atan2(d[...,1], d[...,0])
            # print(theta[0])
            # print(phi[0])
            #num_samples = 100
            # offset=127699
            # numm=100
            # sample_side=1600
            # sample_count=sample_side**2
            # rngh = torch.rand((sample_count,1),device="cuda").float()
            # rngt = torch.rand((sample_count,1),device="cuda").float()
            # t = torch.arange(sample_count,device="cuda").unsqueeze(0).float().unsqueeze(2)
            # t_mod=t%sample_side
            # alpha = ((t -t_mod)/sample_side+rngt) / sample_side
            # beta = (t_mod + rngh) / sample_side
            # phi = 2.0 * 3.1415926535 * beta
            # theta = (2.0 * alpha - 1.0).acos()
            # torch.cuda.synchronize()
            # # print("p1_sh run time:", time.time() - proj_start_time)
            # d=ToVector(phi,theta).squeeze()
            # d=d.unsqueeze(0).repeat(numm,1,1)
            # #print(aaaaaa)
            # trace_visi_cube(pc,sample_side**2,d[0:numm],offset,numm)
            # print( pc._visibility_tracing.shape)
            # visibility=pc._visibility_tracing
            # print(normal.shape)
            # cos=(torch.mul(d,normal[offset:offset+numm].unsqueeze(1).repeat(1,d.shape[1],1)).mean(dim=2)*3)
            # print("cos",cos.shape)
            # print(visibility.shape)
            # zero = torch.zeros_like(cos)
            # cos1 = torch.where(cos <= 0.0, zero,1.0).unsqueeze(2)
            # # print(normal[0:100])
            # # print(cos1)
            # # print(aaaaaa)
            # visibility*=cos1
            # for i in range(100):
            #     print()
            #     img=visibility[i].reshape(sample_side,sample_side)
            #     img=img.cpu().numpy()*255
            #     cv2.imwrite("./img/img_"+str(i)+".png",img)
            # print(aaaaaa)
            chunk_size2=1
            for offset in range(0, base_color.shape[0], chunk_size2):
                print("persent",offset/ base_color.shape[0]*100)
                H,ssss,out_occ=ProjectFunction(pc.max_sh_degree+1,1,15,normal[offset:offset+chunk_size2],visibility[offset:offset+chunk_size2],
                                            viewdirs[offset:offset+chunk_size2],base_color[offset:offset+chunk_size2],
                                            metallic[offset:offset+chunk_size2],roughness[offset:offset+chunk_size2],
                                            light_deg,d[offset:offset+chunk_size2],theta[offset:offset+chunk_size2],phi[offset:offset+chunk_size2])
                shs11.append(H)
                occ.append(out_occ)
                torch.cuda.empty_cache()
            light_transport=torch.cat(shs11,dim=0)
            occ=torch.cat(occ,dim=0)
            # indir=ProjectFunction_diffuse(pc.max_sh_degree+1,1,10,normal,visibility,
            #                                 viewdirs,base_color,metallic,roughness,
            #                                 light_deg,pc)
            # indir,indir_id=ProjectFunction_diffuse(pc.max_sh_degree+1,1,10,normal,visibility,
            #                                 viewdirs,base_color,metallic,roughness,
            #    
            #                              light_deg,pc)
        if diffuse:
            memory_cached = torch.cuda.memory_cached()
            print("Memory cached on GPU1:", memory_cached/1024/1024, "mb")
            H,indir_id=ProjectFunction_diffuse(pc.max_sh_degree+1,1,10,normal,visibility,
                                            viewdirs,base_color,metallic,roughness,
                                            light_deg,pc)
            brdf_color=(H*shs_view_direct).sum(-1)
            print(brdf_color.mean())
            # indir_id=torch.where(indir_id==-1,0,1)
            # print(indir_id.float().mean())
            # print(aaaaaa)
            # brdf_color=indir_id.float().mean(dim=1,keepdim=True).repeat(1,3)
            # memory_cached = torch.cuda.memory_cached()
            # print("Memory cached on GPU2:", memory_cached/1024/1024, "mb")
            # torch.cuda.empty_cache()
            # dir=eval_sh_coef(pc.active_sh_degree,  viewdirs)
            # dir=dir.unsqueeze(1).unsqueeze(1).repeat(1,3,(light_deg+1)**2,1)
            # # print(dir.shape)
            # light_transport=H
            # print(light_transport.shape)
            # print(dir.shape)
            # light_transport_v=(dir*light_transport).sum(-1)
            # light_transport=None
            # brdf_color=(light_transport_v*light).sum(-1)
        else:
            torch.cuda.synchronize()
            start = time.time()
            dir=eval_sh_coef(light_deg,viewdirs)
            # dir=dir.unsqueeze(1).unsqueeze(1).repeat(1,3,(pc.active_sh_degree+1)**2,1)
            dir=dir.unsqueeze(1).unsqueeze(1).repeat(1,3,(light_deg+1)**2,1)
            print(dir.shape)
            torch.cuda.synchronize()
            print("dir time:", time.time() - start)
            light_transport_v=(dir*light_transport).sum(-1)
            torch.cuda.synchronize()
            print("v time:", time.time() - start)
            print(light_transport_v.shape)
            print("occ",occ.shape)
            # light_transport_v+=indir
            print(light_deg)
            # print("light_shape",light.shape)
            sh=False
            if sh:
                brdf_color=light_transport_v[:,:,0]
                # print("mean1",brdf_color.mean())
                red=torch.ones_like(brdf_color) * 0.8
                red[:,1:]=0.0
                blue=torch.ones_like(brdf_color) * 0.8
                blue[:,0]=0.0
                blue[:,1]=0.0
            # brdf_color=brdf_color.mean(dim=1)
            # print(brdf_color.min())
            # print(aaaaa)
            # import pdb; pdb.set_trace()
                brdf_color=torch.where(brdf_color.mean(dim=1, keepdim=True)>0,red,blue)
            # print("mean1",brdf_color.mean())
            # print(brdf_color.shape)
            # print(aaaaaa)
            brdf_color=(light_transport_v*light).sum(-1)#+(indir*light).sum(-1)#+((1-occ).unsqueeze(1).repeat(1,3)*base_color/torch.pi)*0.8
            # if viewpoint_camera.colmap_id>-1:
            #     idx1=torch.where((brdf_color[:,0]>0.9)&(brdf_color[:,1]>0.9)&(brdf_color[:,2]>0.9) )
            #     brdf_color[idx1]=0.0
            #     # print(aaaaa)
            brdf_color=brdf_color
            #brdf_color=(light_transport_v*light.transpose(1,2)).sum(-1)
        torch.cuda.synchronize()
        print("render time:", time.time() - start)
        extra_results = {
            "incident_lights": brdf_color,
            "local_incident_lights": brdf_color,
            "global_incident_lights": brdf_color,
            "incident_visibility": occ.unsqueeze(1),
        }
    features = torch.cat([brdf_color, normal, base_color, roughness, metallic,
                          extra_results["incident_lights"],
                          extra_results["local_incident_lights"],
                          extra_results["global_incident_lights"],
                          extra_results["incident_visibility"]], dim=-1)
    
    (num_rendered, num_contrib, rendered_image, rendered_opacity, rendered_depth,
     rendered_feature, rendered_pseudo_normal, rendered_surface_xyz, radii) = rasterizer(
        means3D=means3D,
        means2D=means2D,
        shs=shs,
        colors_precomp=colors_precomp,
        opacities=opacity,
        scales=scales,
        rotations=rotations,
        cov3D_precomp=cov3D_precomp,
        features=features,
    )
    feature_dict = {}
    rendered_pbr, rendered_normal, rendered_base_color, rendered_roughness, rendered_metallic, \
        rendered_light, rendered_local_light, rendered_global_light, rendered_visibility \
        = rendered_feature.split([3, 3, 3, 1, 1, 3, 3, 3, 1], dim=0)

    feature_dict.update({"base_color": rendered_base_color,
                         "roughness": rendered_roughness,
                         "metallic": rendered_metallic,
                         "lights": rendered_light,
                         "local lights": rendered_local_light,
                         "global lights": rendered_global_light,
                         "visibility": rendered_visibility,
                         })

    pbr = rendered_pbr
    rendered_pbr = pbr + (1 - rendered_opacity) * bg_color[:, None, None]
    
    # HDR out radiance to LDR
    val_gamma = 0
    if gamma_transform is not None:
        # print(aaaaa)
        rendered_pbr = gamma_transform.hdr2ldr(rendered_pbr)
        val_gamma = gamma_transform.gamma.item()
    # print(bbbb)
    results = {"render": rendered_image,
               "pbr": rendered_pbr,
               "normal": rendered_normal,
               "pseudo_normal": rendered_pseudo_normal,
               "surface_xyz": rendered_surface_xyz,
               "opacity": rendered_opacity,
               "depth": rendered_depth,
               "viewspace_points": screenspace_points,
               "visibility_filter": radii > 0,
               "radii": radii,
               "num_rendered": num_rendered,
               "num_contrib": num_contrib,
               "precompute_trans":light_transport,
               "idx":idx,
               "occ":occ,
               "indir":indir,}
    results.update(feature_dict)
    results["hdr"] = viewpoint_camera.hdr
    results["val_gamma"] = val_gamma

    if not is_training:
        directions = viewpoint_camera.get_world_directions()
        if isinstance(direct_light_env_light, DirectLightEnv):
            shs_view_direct = direct_light_env_light.get_env_shs.transpose(1, 2).unsqueeze(1)
            env = torch.clamp_min(eval_sh(direct_light_env_light.sh_degree, shs_view_direct, directions.permute(1, 2, 0)) + 0.5, 0).permute(2, 0, 1)
        else:
            env = direct_light_env_light.direct_light(directions.permute(1, 2, 0)).permute(2, 0, 1)
        results["render"] = rendered_image + (1 - rendered_opacity) * env
        results["pbr_env"] = pbr + (1 - rendered_opacity) * env

    return results


def render_neilf_composite_prt(viewpoint_camera: Camera, pc: GaussianModel, pipe, bg_color: torch.Tensor,
                 scaling_modifier=1.0, override_color=None, opt: OptimizationParams = False,
                 is_training=False, dict_params=None, bake=False,light_transport=None,precompute=False,
                 idx=0,light_shs_file='',light_deg=3,occ=None,indir=None):
    """
    Render the scene.
    Background tensor (bg_color) must be on GPU!
    """
    results = render_view(viewpoint_camera, pc, pipe, bg_color,
                          scaling_modifier, override_color,
                          is_training, dict_params, bake,light_transport,precompute, idx,light_shs_file,light_deg,occ,indir)

    return results


def rendering_equation_python(base_color, roughness, metallic, normals, viewdirs, incidents, 
                              is_training=False, direct_light_env_light=None, visibility=None, 
                              sample_num=24, bake=False, visibility_precompute=None):
    
    incident_dirs, incident_areas = sample_incident_rays(normals, is_training, sample_num)

    base_color = base_color.unsqueeze(-2).contiguous()
    roughness = roughness.unsqueeze(-2).contiguous()
    metallic = metallic.unsqueeze(-2).contiguous()
    normals = normals.unsqueeze(-2).contiguous()
    viewdirs = viewdirs.unsqueeze(-2).contiguous()

    deg = int(np.sqrt(visibility.shape[1]) - 1)
    incident_dirs_coef = eval_sh_coef(deg, incident_dirs).unsqueeze(2)
    shs_view = incidents.transpose(1, 2).view(base_color.shape[0], 1, 3, -1)
    
    shs_visibility = visibility.transpose(1, 2).view(base_color.shape[0], 1, 1, -1)
    local_incident_lights = torch.clamp_min((incident_dirs_coef[..., :shs_view.shape[-1]] * shs_view).sum(-1), 0)
    if direct_light_env_light is not None:
        if isinstance(direct_light_env_light, DirectLightEnv):
            shs_view_direct = direct_light_env_light.get_env_shs.transpose(1, 2).unsqueeze(1)
            global_incident_lights = torch.clamp_min(
                (incident_dirs_coef[..., :shs_view_direct.shape[-1]] * shs_view_direct).sum(-1) + 0.5, 0)
        else:
            global_incident_lights = direct_light_env_light.direct_light(incident_dirs)
    else:
        global_incident_lights = torch.zeros_like(local_incident_lights, requires_grad=False)

    if bake:
        incident_visibility = torch.clamp(
            (incident_dirs_coef[..., :shs_visibility.shape[-1]] * shs_visibility).sum(-1) + 0.5, 0, 1)
    else:
        if visibility_precompute is not None:
            incident_visibility = visibility_precompute
        else:
            raise ValueError("visibility should be pre-computed.")

    global_incident_lights = global_incident_lights * incident_visibility
    incident_lights = local_incident_lights + global_incident_lights

    def _dot(a, b):
        return (a * b).sum(dim=-1, keepdim=True)  # [H, W, 1, 1]

    def _f_diffuse(base_color, metallic):
        return (1 - metallic) * base_color / np.pi  # [H, W, 1, 3]

    def _f_specular(h_d_n, h_d_o, n_d_i, n_d_o, base_color, roughness, metallic):
        # used in SG, wrongly normalized
        def _d_sg(r, cos):
            r2 = (r * r).clamp(min=1e-7)
            amp = 1 / (r2 * np.pi)
            sharp = 2 / r2
            return amp * torch.exp(sharp * (cos - 1))

        D = _d_sg(roughness, h_d_n)
       
        # Fresnel term F
        F_0 = 0.04 * (1 - metallic) + base_color * metallic  # [H, W, 1, 3]
        F = F_0 + (1.0 - F_0) * ((1.0 - h_d_o) ** 5)  # [H, W, S, 3]

        # geometry term V, we use V = G / (4 * cos * cos) here
        def _v_schlick_ggx(r, cos):
            r2 = ((1 + r) ** 2) / 8
            return 0.5 / (cos * (1 - r2) + r2).clamp(min=1e-7)

        V = _v_schlick_ggx(roughness, n_d_i) * _v_schlick_ggx(roughness, n_d_o)  # [H, W, S, 1]

        return D * F * V

    # half vector and all cosines
    half_dirs = incident_dirs + viewdirs
    half_dirs = F.normalize(half_dirs, dim=-1)

    h_d_n = _dot(half_dirs, normals).clamp(min=0)
    h_d_o = _dot(half_dirs, viewdirs).clamp(min=0)
    n_d_i = _dot(normals, incident_dirs).clamp(min=0)
    n_d_o = _dot(normals, viewdirs).clamp(min=0)
    
    f_d = _f_diffuse(base_color, metallic)
    f_s = _f_specular(h_d_n, h_d_o, n_d_i, n_d_o, base_color, roughness, metallic)

    transport = incident_lights * incident_areas * n_d_i
    rgb_d = (f_d * transport).mean(dim=-2)
    rgb_s = (f_s * transport).mean(dim=-2)
    rgb = rgb_d + rgb_s

    extra_results = {
        "incident_lights": incident_lights.mean(dim=-2),
        "local_incident_lights": local_incident_lights.mean(dim=-2),
        "global_incident_lights": global_incident_lights.mean(dim=-2),
        "incident_visibility": incident_visibility.mean(dim=-2),
    }
    
    return rgb, extra_results


def sample_incident_rays(normals, is_training=False, sample_num=24):
    if is_training:
        incident_dirs, incident_areas = fibonacci_sphere_sampling(
            normals, sample_num, random_rotate=True)
    else:
        incident_dirs, incident_areas = fibonacci_sphere_sampling(
            normals, sample_num, random_rotate=False)

    return incident_dirs, incident_areas  # [N, S, 3], [N, S, 1]