import math
import torch
import numpy as np
import torch.nn.functional as F
from arguments import OptimizationParams
from scene.cameras import Camera
from scene.gaussian_model import GaussianModel
from scene.derect_light_sh import DirectLightEnv
from utils.sh_utils import eval_sh
from utils.sh_utils_3_new import ProjectFunction,fibonacci_sphere_sampling,eval_sh_coef,ProjectFunction_diffuse
# from utils.graphics_utils import fibonacci_sphere_sampling
from .r3dg_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from utils.rotation_utils import *
from tqdm import tqdm
from bvh import RayTracer
import time
def read_light_sh_from_file(light_shs_file):
    # import pdb; pdb.set_trace()
    with open(light_shs_file, 'r') as f:
        c = f.read()
        l = eval(c)
        light_shs = torch.tensor(l)
    # import pdb; pdb.set_trace()
    return light_shs
def trace_visi(gaussians,sample_num,ray_d):
        raytracer = RayTracer(gaussians.get_xyz, gaussians.get_scaling,
                              gaussians.get_rotation)
        gaussians_xyz = gaussians.get_xyz
        gaussians_inverse_covariance = gaussians.get_inverse_covariance()
        gaussians_opacity = gaussians.get_opacity[:, 0]
        gaussians_normal = gaussians.get_normal
        incident_visibility_results = []
        chunk_size = gaussians_xyz.shape[0] // ((sample_num - 1) // 24 + 1)
        #chunk_size= gaussians_xyz.shape[0]
        for offset in tqdm(range(0, gaussians_xyz.shape[0], chunk_size),
                           "Precompute raytracing visibility"):
            incident_dirs = ray_d[offset:offset + chunk_size]
            print(incident_dirs.shape)
            #print(aaaaaa)
            trace_results = raytracer.trace_visibility(
                gaussians_xyz[offset:offset + chunk_size, None].expand_as(incident_dirs),
                incident_dirs,
                gaussians_xyz,
                gaussians_inverse_covariance,
                gaussians_opacity,
                gaussians_normal)
            incident_visibility = trace_results["visibility"]
            nan_idx=torch.where(incident_visibility!=incident_visibility)
            incident_visibility[nan_idx]=0.0
            incident_visibility_results.append(incident_visibility)
        incident_visibility_result = torch.cat(incident_visibility_results, dim=0)
        # print(incident_visibility_result.min())
        # # print("11")
        
        # print(incident_visibility_result.min())
        # print(aaaaaa)
        gaussians._visibility_tracing = incident_visibility_result
def render_view(viewpoint_camera: Camera, pc: GaussianModel, pipe, bg_color: torch.Tensor,
                scaling_modifier=1.0, override_color=None, is_training=False, 
                dict_params=None, bake=False,light_transport=None,precompute=False,idx=0,light_shs_file='',light_deg=3,occ=None,indir=None):
    
    direct_light_env_light = dict_params.get("env_light")
    gamma_transform = dict_params.get("gamma")
    sample_num = dict_params.get("sample_num")

    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
    screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
    try:
        screenspace_points.retain_grad()
    except:
        pass

    # Set up rasterization configuration
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
    intrinsic = viewpoint_camera.intrinsics

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        cx=float(intrinsic[0, 2]),
        cy=float(intrinsic[1, 2]),
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=False,
        backward_geometry=True,
        computer_pseudo_normal=True,
        debug=pipe.debug
    )

    rasterizer = GaussianRasterizer(raster_settings=raster_settings)

    means3D = pc.get_xyz
    means2D = screenspace_points
    opacity = pc.get_opacity

    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
    # scaling / rotation by the rasterizer.
    scales = None
    rotations = None
    cov3D_precomp = None
    if pipe.compute_cov3D_python:
        cov3D_precomp = pc.get_covariance(scaling_modifier)
    else:
        scales = pc.get_scaling
        rotations = pc.get_rotation

    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
    shs = None
    colors_precomp = None
    if override_color is None:
        if pipe.compute_SHs_python:
            dir_pp_normalized = F.normalize(viewpoint_camera.camera_center.repeat(means3D.shape[0], 1) - means3D,
                                            dim=-1)
            shs_view = pc.get_shs.transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1) ** 2)
            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
        else:
            shs = pc.get_shs
    else:
        colors_precomp = override_color

    base_color = pc.get_base_color
    # print("basecolor",base_color.mean())
    # base_color=torch.ones_like(base_color)
    roughness = pc.get_roughness
    # print("roughness",roughness.mean())
    # roughness=torch.ones_like(roughness)
    metallic = pc.get_metallic
    # print("metal",metallic.mean())
    # metallic=torch.zeros_like(metallic)
    normal = pc.get_normal
    visibility = pc.get_visibility
    incidents = pc.get_incidents
    viewdirs = F.normalize(viewpoint_camera.camera_center - means3D, dim=-1)
    r = fromRotation(torch.tensor(idx * 0.003 * 2 * torch.pi), torch.tensor([0, 1, 0], dtype = torch.float32))
    idx += 1
    print(idx)
    r = mat4Matrix2mathMatrix(r)
    light_shs = read_light_sh_from_file(light_shs_file)
    light_shs = light_shs.transpose(0,1)
    light_shs = getRotationPrecomputeL(light_shs, r)
    light_shs = light_shs.transpose(0,1)
    light_shs=light_shs.cuda()
    light_shs=light_shs.cuda()
    prt=True

    if prt==True:
        shs_view_direct=light_shs.unsqueeze(0).repeat(base_color.shape[0], 1, 1)
        shs_view_indir = incidents.transpose(1, 2).view(base_color.shape[0], 1, 3, -1).squeeze()
        light=shs_view_direct#+shs_view_indir
        light_deg=light_deg
        chunk_size2=3000
        diffuse=False
        # print(AAAAAA)
        print("idx",idx)
        if precompute:
            shs11=[]
            occ=[]
            d,a=fibonacci_sphere_sampling(normal, 15**2, random_rotate=True)
            theta = torch.acos(d[...,2])
            phi = torch.atan2(d[...,1], d[...,0])
            trace_visi(pc,225,d)
            print( pc._visibility_tracing.shape)
            visibility=pc._visibility_tracing
            for offset in range(0, base_color.shape[0], chunk_size2):
                print("persent",offset/ base_color.shape[0]*100)
                H,ssss,out_occ=ProjectFunction(pc.max_sh_degree+1,1,15,normal[offset:offset+chunk_size2],visibility[offset:offset+chunk_size2],
                                            viewdirs[offset:offset+chunk_size2],base_color[offset:offset+chunk_size2],
                                            metallic[offset:offset+chunk_size2],roughness[offset:offset+chunk_size2],
                                            light_deg,d[offset:offset+chunk_size2],theta[offset:offset+chunk_size2],phi[offset:offset+chunk_size2])
                shs11.append(H)
                occ.append(out_occ)
                torch.cuda.empty_cache()
            light_transport=torch.cat(shs11,dim=0)
            occ=torch.cat(occ,dim=0)
            # indir=ProjectFunction_diffuse(pc.max_sh_degree+1,1,10,normal,visibility,
            #                                 viewdirs,base_color,metallic,roughness,
            #                                 light_deg,pc)
            indir,indir_id=ProjectFunction_diffuse(pc.max_sh_degree+1,1,10,normal,visibility,
                                            viewdirs,base_color,metallic,roughness,

                                         light_deg,pc)
        if diffuse:
            memory_cached = torch.cuda.memory_cached()
            print("Memory cached on GPU1:", memory_cached/1024/1024, "mb")
            H,indir_id=ProjectFunction_diffuse(pc.max_sh_degree+1,1,10,normal,visibility,
                                            viewdirs,base_color,metallic,roughness,
                                            light_deg,pc)
            brdf_color=(H*shs_view_direct).sum(-1)
            print(brdf_color.mean())
        else:
            torch.cuda.synchronize()
            start = time.time()
            dir=eval_sh_coef(light_deg,viewdirs)
            # dir=dir.unsqueeze(1).unsqueeze(1).repeat(1,3,(pc.active_sh_degree+1)**2,1)
            dir=dir.unsqueeze(1).unsqueeze(1).repeat(1,3,(light_deg+1)**2,1)
            print(dir.shape)
            torch.cuda.synchronize()
            print("dir time:", time.time() - start)
            light_transport_v=(dir*light_transport).sum(-1)
            torch.cuda.synchronize()
            print("v time:", time.time() - start)
            print(light_transport_v.shape)
            print("occ",occ.shape)
            # light_transport_v+=indir
            print(light_deg)
            sh=False
            if sh:
                brdf_color=light_transport_v[:,:,0]
                # print("mean1",brdf_color.mean())
                red=torch.ones_like(brdf_color) * 0.8
                red[:,1:]=0.0
                blue=torch.ones_like(brdf_color) * 0.8
                blue[:,0]=0.0
                blue[:,1]=0.0
                brdf_color=torch.where(brdf_color.mean(dim=1, keepdim=True)>0,red,blue)
            brdf_color=(light_transport_v*light).sum(-1)+(indir*light).sum(-1)#=+((1-occ).unsqueeze(1).repeat(1,3)*base_color/torch.pi)*0.8
            brdf_color=brdf_color*2.0
            #brdf_color=(light_transport_v*light.transpose(1,2)).sum(-1)
        torch.cuda.synchronize()
        print("render time:", time.time() - start)
        extra_results = {
            "incident_lights": brdf_color,
            "local_incident_lights": brdf_color,
            "global_incident_lights": brdf_color,
            "incident_visibility": occ.unsqueeze(1),
        }
    features = torch.cat([brdf_color, normal, base_color, roughness, metallic,
                          extra_results["incident_lights"],
                          extra_results["local_incident_lights"],
                          extra_results["global_incident_lights"],
                          extra_results["incident_visibility"]], dim=-1)
    
    (num_rendered, num_contrib, rendered_image, rendered_opacity, rendered_depth,
     rendered_feature, rendered_pseudo_normal, rendered_surface_xyz, radii) = rasterizer(
        means3D=means3D,
        means2D=means2D,
        shs=shs,
        colors_precomp=colors_precomp,
        opacities=opacity,
        scales=scales,
        rotations=rotations,
        cov3D_precomp=cov3D_precomp,
        features=features,
    )
    feature_dict = {}
    rendered_pbr, rendered_normal, rendered_base_color, rendered_roughness, rendered_metallic, \
        rendered_light, rendered_local_light, rendered_global_light, rendered_visibility \
        = rendered_feature.split([3, 3, 3, 1, 1, 3, 3, 3, 1], dim=0)

    feature_dict.update({"base_color": rendered_base_color,
                         "roughness": rendered_roughness,
                         "metallic": rendered_metallic,
                         "lights": rendered_light,
                         "local lights": rendered_local_light,
                         "global lights": rendered_global_light,
                         "visibility": rendered_visibility,
                         })

    pbr = rendered_pbr
    rendered_pbr = pbr + (1 - rendered_opacity) * bg_color[:, None, None]
    
    # HDR out radiance to LDR
    val_gamma = 0
    if gamma_transform is not None:
        # print(aaaaa)
        rendered_pbr = gamma_transform.hdr2ldr(rendered_pbr)
        val_gamma = gamma_transform.gamma.item()
    # print(bbbb)
    results = {"render": rendered_image,
               "pbr": rendered_pbr,
               "normal": rendered_normal,
               "pseudo_normal": rendered_pseudo_normal,
               "surface_xyz": rendered_surface_xyz,
               "opacity": rendered_opacity,
               "depth": rendered_depth,
               "viewspace_points": screenspace_points,
               "visibility_filter": radii > 0,
               "radii": radii,
               "num_rendered": num_rendered,
               "num_contrib": num_contrib,
               "precompute_trans":light_transport,
               "idx":idx,
               "occ":occ,
               "indir":indir,}
    results.update(feature_dict)
    results["hdr"] = viewpoint_camera.hdr
    results["val_gamma"] = val_gamma

    if not is_training:
        directions = viewpoint_camera.get_world_directions()
        if isinstance(direct_light_env_light, DirectLightEnv):
            shs_view_direct = direct_light_env_light.get_env_shs.transpose(1, 2).unsqueeze(1)
            env = torch.clamp_min(eval_sh(direct_light_env_light.sh_degree, shs_view_direct, directions.permute(1, 2, 0)) + 0.5, 0).permute(2, 0, 1)
        else:
            env = direct_light_env_light.direct_light(directions.permute(1, 2, 0)).permute(2, 0, 1)
        results["render"] = rendered_image + (1 - rendered_opacity) * env
        env=torch.ones_like(env)
        results["pbr_env"] = pbr + (1 - rendered_opacity) * env

    return results


def render_neilf_composite_prt(viewpoint_camera: Camera, pc: GaussianModel, pipe, bg_color: torch.Tensor,
                 scaling_modifier=1.0, override_color=None, opt: OptimizationParams = False,
                 is_training=False, dict_params=None, bake=False,light_transport=None,precompute=False,
                 idx=0,light_shs_file='',light_deg=3,occ=None,indir=None):
    """
    Render the scene.
    Background tensor (bg_color) must be on GPU!
    """
    results = render_view(viewpoint_camera, pc, pipe, bg_color,
                          scaling_modifier, override_color,
                          is_training, dict_params, bake,light_transport,precompute, idx,light_shs_file,light_deg,occ,indir)

    return results


def rendering_equation_python(base_color, roughness, metallic, normals, viewdirs, incidents, 
                              is_training=False, direct_light_env_light=None, visibility=None, 
                              sample_num=24, bake=False, visibility_precompute=None):
    
    incident_dirs, incident_areas = sample_incident_rays(normals, is_training, sample_num)

    base_color = base_color.unsqueeze(-2).contiguous()
    roughness = roughness.unsqueeze(-2).contiguous()
    metallic = metallic.unsqueeze(-2).contiguous()
    normals = normals.unsqueeze(-2).contiguous()
    viewdirs = viewdirs.unsqueeze(-2).contiguous()

    deg = int(np.sqrt(visibility.shape[1]) - 1)
    incident_dirs_coef = eval_sh_coef(deg, incident_dirs).unsqueeze(2)
    shs_view = incidents.transpose(1, 2).view(base_color.shape[0], 1, 3, -1)
    
    shs_visibility = visibility.transpose(1, 2).view(base_color.shape[0], 1, 1, -1)
    local_incident_lights = torch.clamp_min((incident_dirs_coef[..., :shs_view.shape[-1]] * shs_view).sum(-1), 0)
    if direct_light_env_light is not None:
        if isinstance(direct_light_env_light, DirectLightEnv):
            shs_view_direct = direct_light_env_light.get_env_shs.transpose(1, 2).unsqueeze(1)
            global_incident_lights = torch.clamp_min(
                (incident_dirs_coef[..., :shs_view_direct.shape[-1]] * shs_view_direct).sum(-1) + 0.5, 0)
        else:
            global_incident_lights = direct_light_env_light.direct_light(incident_dirs)
    else:
        global_incident_lights = torch.zeros_like(local_incident_lights, requires_grad=False)

    if bake:
        incident_visibility = torch.clamp(
            (incident_dirs_coef[..., :shs_visibility.shape[-1]] * shs_visibility).sum(-1) + 0.5, 0, 1)
    else:
        if visibility_precompute is not None:
            incident_visibility = visibility_precompute
        else:
            raise ValueError("visibility should be pre-computed.")

    global_incident_lights = global_incident_lights * incident_visibility
    incident_lights = local_incident_lights + global_incident_lights

    def _dot(a, b):
        return (a * b).sum(dim=-1, keepdim=True)  # [H, W, 1, 1]

    def _f_diffuse(base_color, metallic):
        return (1 - metallic) * base_color / np.pi  # [H, W, 1, 3]

    def _f_specular(h_d_n, h_d_o, n_d_i, n_d_o, base_color, roughness, metallic):
        # used in SG, wrongly normalized
        def _d_sg(r, cos):
            r2 = (r * r).clamp(min=1e-7)
            amp = 1 / (r2 * np.pi)
            sharp = 2 / r2
            return amp * torch.exp(sharp * (cos - 1))

        D = _d_sg(roughness, h_d_n)
       
        # Fresnel term F
        F_0 = 0.04 * (1 - metallic) + base_color * metallic  # [H, W, 1, 3]
        F = F_0 + (1.0 - F_0) * ((1.0 - h_d_o) ** 5)  # [H, W, S, 3]

        # geometry term V, we use V = G / (4 * cos * cos) here
        def _v_schlick_ggx(r, cos):
            r2 = ((1 + r) ** 2) / 8
            return 0.5 / (cos * (1 - r2) + r2).clamp(min=1e-7)

        V = _v_schlick_ggx(roughness, n_d_i) * _v_schlick_ggx(roughness, n_d_o)  # [H, W, S, 1]

        return D * F * V

    # half vector and all cosines
    half_dirs = incident_dirs + viewdirs
    half_dirs = F.normalize(half_dirs, dim=-1)

    h_d_n = _dot(half_dirs, normals).clamp(min=0)
    h_d_o = _dot(half_dirs, viewdirs).clamp(min=0)
    n_d_i = _dot(normals, incident_dirs).clamp(min=0)
    n_d_o = _dot(normals, viewdirs).clamp(min=0)
    
    f_d = _f_diffuse(base_color, metallic)
    f_s = _f_specular(h_d_n, h_d_o, n_d_i, n_d_o, base_color, roughness, metallic)

    transport = incident_lights * incident_areas * n_d_i
    rgb_d = (f_d * transport).mean(dim=-2)
    rgb_s = (f_s * transport).mean(dim=-2)
    rgb = rgb_d + rgb_s

    extra_results = {
        "incident_lights": incident_lights.mean(dim=-2),
        "local_incident_lights": local_incident_lights.mean(dim=-2),
        "global_incident_lights": global_incident_lights.mean(dim=-2),
        "incident_visibility": incident_visibility.mean(dim=-2),
    }
    
    return rgb, extra_results


def sample_incident_rays(normals, is_training=False, sample_num=24):
    if is_training:
        incident_dirs, incident_areas = fibonacci_sphere_sampling(
            normals, sample_num, random_rotate=True)
    else:
        incident_dirs, incident_areas = fibonacci_sphere_sampling(
            normals, sample_num, random_rotate=False)

    return incident_dirs, incident_areas  # [N, S, 3], [N, S, 1]