import math
import torch
import numpy as np
import torch.nn.functional as F
from arguments import OptimizationParams
from scene.cameras import Camera
from scene.gaussian_model import GaussianModel
from scene.derect_light_sh import DirectLightEnv
from utils.sh_utils import eval_sh
from utils.sh_utils_3 import ProjectFunction,fibonacci_sphere_sampling,eval_sh_coef,ProjectFunction_diffuse
# from utils.graphics_utils import fibonacci_sphere_sampling
from .r3dg_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from utils.rotation_utils import *
import time
def read_light_sh_from_file(light_shs_file):
    # import pdb; pdb.set_trace()
    with open(light_shs_file, 'r') as f:
        c = f.read()
        l = eval(c)
        light_shs = torch.tensor(l)
    # import pdb; pdb.set_trace()
    return light_shs


def render_view(viewpoint_camera: Camera, pc: GaussianModel, pipe, bg_color: torch.Tensor,
                scaling_modifier=1.0, override_color=None, is_training=False, 
                dict_params=None, bake=False,light_transport=None,precompute=False,idx=0,light_shs_file='',light_deg=3,occ=None,indir=None):
    
    direct_light_env_light = dict_params.get("env_light")
    gamma_transform = dict_params.get("gamma")
    sample_num = dict_params.get("sample_num")

    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
    screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
    try:
        screenspace_points.retain_grad()
    except:
        pass

    # Set up rasterization configuration
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
    intrinsic = viewpoint_camera.intrinsics

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        cx=float(intrinsic[0, 2]),
        cy=float(intrinsic[1, 2]),
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=False,
        backward_geometry=True,
        computer_pseudo_normal=True,
        debug=pipe.debug
    )

    rasterizer = GaussianRasterizer(raster_settings=raster_settings)

    means3D = pc.get_xyz
    means2D = screenspace_points
    opacity = pc.get_opacity

    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
    # scaling / rotation by the rasterizer.
    scales = None
    rotations = None
    cov3D_precomp = None
    if pipe.compute_cov3D_python:
        cov3D_precomp = pc.get_covariance(scaling_modifier)
    else:
        scales = pc.get_scaling
        rotations = pc.get_rotation

    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
    shs = None
    colors_precomp = None
    if override_color is None:
        if pipe.compute_SHs_python:
            dir_pp_normalized = F.normalize(viewpoint_camera.camera_center.repeat(means3D.shape[0], 1) - means3D,
                                            dim=-1)
            shs_view = pc.get_shs.transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1) ** 2)
            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
        else:
            shs = pc.get_shs
    else:
        colors_precomp = override_color

    base_color = pc.get_base_color
    # print("basecolor",base_color.mean())
    #base_color=torch.ones_like(base_color)
    roughness = pc.get_roughness
    # print("roughness",roughness.mean())
    #roughness=torch.ones_like(roughness)*0.7
    metallic = pc.get_metallic
    # print("metal",metallic.mean())
    #metallic=torch.ones_like(metallic)*0.3
    normal = pc.get_normal
    visibility = pc.get_visibility
    incidents = pc.get_incidents
    viewdirs = F.normalize(viewpoint_camera.camera_center - means3D, dim=-1)

    # process by chunks to save memory
    # TODO: rewrite in CUDA
    chunk_size = base_color.shape[0] // ((sample_num - 1) // 8 + 1)
    brdf_color_chunks = []
    extra_results_chunks = []
    # for offset in range(0, base_color.shape[0], chunk_size):
    #     brdf_color, extra_results = rendering_equation_python(
    #         base_color[offset:offset+chunk_size],
    #         roughness[offset:offset+chunk_size],
    #         metallic[offset:offset+chunk_size],
    #         normal.detach()[offset:offset+chunk_size],
    #         viewdirs[offset:offset+chunk_size],
    #         incidents[offset:offset+chunk_size],
    #         is_training, direct_light_env_light,
    #         visibility[offset:offset+chunk_size], sample_num, bake,
    #         visibility_precompute=None if bake else pc._visibility_tracing[offset:offset+chunk_size])
        
    #     brdf_color_chunks.append(brdf_color)
    #     extra_results_chunks.append(extra_results)
    # time.time()    
    r = fromRotation(torch.tensor(idx * 0.003 * 2 * torch.pi), torch.tensor([0, 1, 0], dtype = torch.float32))
    # r = fromRotation(torch.tensor(0 * 0.003 * 2 * torch.pi), torch.tensor([0, 1, 0], dtype = torch.float32))
    idx += 1
    print(idx)
    r = mat4Matrix2mathMatrix(r)
    #light_shs = read_light_sh_from_file(light_shs_file)
    # light_shs=torch.tensor([[3.04949, 0.175166, 0.177748, 0.251558, -0.131606, -0.105654, -0.0541271, 0.0435226, -0.222103, 0.152418, -0.033816, 0.0772488, 0.0377921, -0.0922061, 0.0473335, -0.00317843],
# [3.19935, 0.161815, 0.120789, 0.196896, -0.138798, -0.0851454, -0.0399412, 0.0383304, -0.204134, 0.151009, -0.0210227, 0.0591032, 0.0268158, -0.0759584, 0.0309626, 0.0094788],
# [3.30395, 0.11395, 0.0817311, 0.149406, -0.10666, -0.0531691, -0.0112636, 0.0340628, -0.151267, 0.11851, -0.0159299, 0.0304058, 0.0191498, -0.0549724, 0.0260181, 0.0148034]])
    light_shs=torch.tensor([[0.588103, 0.000529265, -0.000265849, -0.848459, -0.00109119, -0.000445126, -0.364183, 0.000479468, 0.629345],
[0.588103, 0.000529265, -0.000265849, -0.848459, -0.00109119, -0.000445126, -0.364183, 0.000479468, 0.629345],
[0.588103, 0.000529265, -0.000265849, -0.848459, -0.00109119, -0.000445126, -0.364183, 0.000479468, 0.629345],
],device="cpu")   
    # print("in")
    # print(aaaaaa) 
    light_shs = light_shs.transpose(0,1)
    light_shs = getRotationPrecomputeL(light_shs, r)
    light_shs = light_shs.transpose(0,1)
    light_shs=light_shs.cuda()
    # light_shs=torch.zeros_like(light_shs)
    # light_shs[:,0]=1/0.282
    # brdf_color = torch.cat(brdf_color_chunks, dim=0)
    # extra_results = {k: torch.cat([x[k] for x in extra_results_chunks], dim=0) for k in extra_results_chunks[0].keys()}
    prt=True
    # base_color=torch.ones_like(base_color)
    # metallic=torch.ones_like(metallic)*0.0
    # roughness=torch.ones_like(roughness)
    torch.cuda.synchronize()
    start = time.time()
    if prt==True:
        # shs_view_direct = direct_light_env_light.get_env_shs.transpose(1, 2).unsqueeze(1)
        shs_view_direct=light_shs.unsqueeze(0).repeat(base_color.shape[0], 1, 1)
        shs_view_indir = incidents.transpose(1, 2).view(base_color.shape[0], 1, 3, -1).squeeze()
        # print( shs_view_direct.shape)
        # print( shs_view_indir.shape)
        light=shs_view_direct#+shs_view_indir


        # print(visibility.shape)
        #print(aaaaa)
        #precompute=True
        light_deg=light_deg
        chunk_size2=3000
        # print(light_deg)
        # print(viewdirs)
        # print(aaaaaa)
        # precompute=False
        diffuse=False
        if precompute:
            shs11=[]
            occ=[]
            for offset in range(0, base_color.shape[0], chunk_size2):
                print("persent",offset/ base_color.shape[0]*100)
                H,d,out_occ=ProjectFunction(pc.max_sh_degree+1,1,15,normal[offset:offset+chunk_size2],visibility[offset:offset+chunk_size2],
                                            viewdirs[offset:offset+chunk_size2],base_color[offset:offset+chunk_size2],
                                            metallic[offset:offset+chunk_size2],roughness[offset:offset+chunk_size2],light_deg)
                shs11.append(H)
                occ.append(out_occ)
                torch.cuda.empty_cache()
            light_transport=torch.cat(shs11,dim=0)
            occ=torch.cat(occ,dim=0)
            # indir,indir_id=ProjectFunction_diffuse(pc.max_sh_degree+1,1,10,normal,visibility,
            #                                 viewdirs,base_color,metallic,roughness,
            #    
            #                              light_deg,pc)
        if diffuse:
            memory_cached = torch.cuda.memory_cached()
            print("Memory cached on GPU1:", memory_cached/1024/1024, "mb")
            H,indir_id=ProjectFunction_diffuse(pc.max_sh_degree+1,1,10,normal,visibility,
                                            viewdirs,base_color,metallic,roughness,
                                            light_deg,pc)
            brdf_color=(H*shs_view_direct).sum(-1)
            print(brdf_color.mean())
            # indir_id=torch.where(indir_id==-1,0,1)
            # print(indir_id.float().mean())
            # print(aaaaaa)
            # brdf_color=indir_id.float().mean(dim=1,keepdim=True).repeat(1,3)
            # memory_cached = torch.cuda.memory_cached()
            # print("Memory cached on GPU2:", memory_cached/1024/1024, "mb")
            # torch.cuda.empty_cache()
            # dir=eval_sh_coef(pc.active_sh_degree,  viewdirs)
            # dir=dir.unsqueeze(1).unsqueeze(1).repeat(1,3,(light_deg+1)**2,1)
            # # print(dir.shape)
            # light_transport=H
            # print(light_transport.shape)
            # print(dir.shape)
            # light_transport_v=(dir*light_transport).sum(-1)
            # light_transport=None
            # brdf_color=(light_transport_v*light).sum(-1)
        else:
            dir=eval_sh_coef(light_deg,viewdirs)
            # dir=dir.unsqueeze(1).unsqueeze(1).repeat(1,3,(pc.active_sh_degree+1)**2,1)
            dir=dir.unsqueeze(1).unsqueeze(1).repeat(1,3,(light_deg+1)**2,1)
            print(dir.shape)
            light_transport_v=(dir*light_transport).sum(-1)
            print(light_transport_v.shape)
            print("occ",occ.shape)
            # light_transport_v+=indir
            print(light_deg)
            print("light_shape",light.shape)
            print
            brdf_color=(light_transport_v*light).sum(-1)*2.0#+((1-occ).unsqueeze(1).repeat(1,3)*base_color/torch.pi)*0.8
            # brdf_color=(light_transport_v*light.transpose(1,2)).sum(-1)
        torch.cuda.synchronize()
        print("render time:", time.time() - start)
        extra_results = {
            "incident_lights": brdf_color,
            "local_incident_lights": brdf_color,
            "global_incident_lights": brdf_color,
            "incident_visibility": brdf_color[:,0].unsqueeze(1),
        }
    features = torch.cat([brdf_color, normal, base_color, roughness, metallic,
                          extra_results["incident_lights"],
                          extra_results["local_incident_lights"],
                          extra_results["global_incident_lights"],
                          extra_results["incident_visibility"]], dim=-1)
    
    (num_rendered, num_contrib, rendered_image, rendered_opacity, rendered_depth,
     rendered_feature, rendered_pseudo_normal, rendered_surface_xyz, radii) = rasterizer(
        means3D=means3D,
        means2D=means2D,
        shs=shs,
        colors_precomp=colors_precomp,
        opacities=opacity,
        scales=scales,
        rotations=rotations,
        cov3D_precomp=cov3D_precomp,
        features=features,
    )
    feature_dict = {}
    rendered_pbr, rendered_normal, rendered_base_color, rendered_roughness, rendered_metallic, \
        rendered_light, rendered_local_light, rendered_global_light, rendered_visibility \
        = rendered_feature.split([3, 3, 3, 1, 1, 3, 3, 3, 1], dim=0)

    feature_dict.update({"base_color": rendered_base_color,
                         "roughness": rendered_roughness,
                         "metallic": rendered_metallic,
                         "lights": rendered_light,
                         "local lights": rendered_local_light,
                         "global lights": rendered_global_light,
                         "visibility": rendered_visibility,
                         })

    pbr = rendered_pbr
    rendered_pbr = pbr + (1 - rendered_opacity) * bg_color[:, None, None]
    
    # HDR out radiance to LDR
    val_gamma = 0
    if gamma_transform is not None:
        rendered_pbr = gamma_transform.hdr2ldr(rendered_pbr)
        val_gamma = gamma_transform.gamma.item()

    results = {"render": rendered_image,
               "pbr": rendered_pbr,
               "normal": rendered_normal,
               "pseudo_normal": rendered_pseudo_normal,
               "surface_xyz": rendered_surface_xyz,
               "opacity": rendered_opacity,
               "depth": rendered_depth,
               "viewspace_points": screenspace_points,
               "visibility_filter": radii > 0,
               "radii": radii,
               "num_rendered": num_rendered,
               "num_contrib": num_contrib,
               "precompute_trans":light_transport,
               "idx":idx,
               "occ":occ,
               "indir":indir,}
    results.update(feature_dict)
    results["hdr"] = viewpoint_camera.hdr
    results["val_gamma"] = val_gamma

    if not is_training:
        directions = viewpoint_camera.get_world_directions()
        if isinstance(direct_light_env_light, DirectLightEnv):
            shs_view_direct = direct_light_env_light.get_env_shs.transpose(1, 2).unsqueeze(1)
            env = torch.clamp_min(eval_sh(direct_light_env_light.sh_degree, shs_view_direct, directions.permute(1, 2, 0)) + 0.5, 0).permute(2, 0, 1)
        else:
            env = direct_light_env_light.direct_light(directions.permute(1, 2, 0)).permute(2, 0, 1)
        results["render"] = rendered_image + (1 - rendered_opacity) * env
        results["pbr_env"] = pbr + (1 - rendered_opacity) * env

    return results


def render_neilf_composite_prt(viewpoint_camera: Camera, pc: GaussianModel, pipe, bg_color: torch.Tensor,
                 scaling_modifier=1.0, override_color=None, opt: OptimizationParams = False,
                 is_training=False, dict_params=None, bake=False,light_transport=None,precompute=False,
                 idx=0,light_shs_file='',light_deg=3,occ=None,indir=None):
    """
    Render the scene.
    Background tensor (bg_color) must be on GPU!
    """
    results = render_view(viewpoint_camera, pc, pipe, bg_color,
                          scaling_modifier, override_color,
                          is_training, dict_params, bake,light_transport,precompute, idx,light_shs_file,light_deg,occ,indir)

    return results


def rendering_equation_python(base_color, roughness, metallic, normals, viewdirs, incidents, 
                              is_training=False, direct_light_env_light=None, visibility=None, 
                              sample_num=24, bake=False, visibility_precompute=None):
    
    incident_dirs, incident_areas = sample_incident_rays(normals, is_training, sample_num)

    base_color = base_color.unsqueeze(-2).contiguous()
    roughness = roughness.unsqueeze(-2).contiguous()
    metallic = metallic.unsqueeze(-2).contiguous()
    normals = normals.unsqueeze(-2).contiguous()
    viewdirs = viewdirs.unsqueeze(-2).contiguous()

    deg = int(np.sqrt(visibility.shape[1]) - 1)
    incident_dirs_coef = eval_sh_coef(deg, incident_dirs).unsqueeze(2)
    shs_view = incidents.transpose(1, 2).view(base_color.shape[0], 1, 3, -1)
    
    shs_visibility = visibility.transpose(1, 2).view(base_color.shape[0], 1, 1, -1)
    local_incident_lights = torch.clamp_min((incident_dirs_coef[..., :shs_view.shape[-1]] * shs_view).sum(-1), 0)
    if direct_light_env_light is not None:
        if isinstance(direct_light_env_light, DirectLightEnv):
            shs_view_direct = direct_light_env_light.get_env_shs.transpose(1, 2).unsqueeze(1)
            global_incident_lights = torch.clamp_min(
                (incident_dirs_coef[..., :shs_view_direct.shape[-1]] * shs_view_direct).sum(-1) + 0.5, 0)
        else:
            global_incident_lights = direct_light_env_light.direct_light(incident_dirs)
    else:
        global_incident_lights = torch.zeros_like(local_incident_lights, requires_grad=False)

    if bake:
        incident_visibility = torch.clamp(
            (incident_dirs_coef[..., :shs_visibility.shape[-1]] * shs_visibility).sum(-1) + 0.5, 0, 1)
    else:
        if visibility_precompute is not None:
            incident_visibility = visibility_precompute
        else:
            raise ValueError("visibility should be pre-computed.")

    global_incident_lights = global_incident_lights * incident_visibility
    incident_lights = local_incident_lights + global_incident_lights

    def _dot(a, b):
        return (a * b).sum(dim=-1, keepdim=True)  # [H, W, 1, 1]

    def _f_diffuse(base_color, metallic):
        return (1 - metallic) * base_color / np.pi  # [H, W, 1, 3]

    def _f_specular(h_d_n, h_d_o, n_d_i, n_d_o, base_color, roughness, metallic):
        # used in SG, wrongly normalized
        def _d_sg(r, cos):
            r2 = (r * r).clamp(min=1e-7)
            amp = 1 / (r2 * np.pi)
            sharp = 2 / r2
            return amp * torch.exp(sharp * (cos - 1))

        D = _d_sg(roughness, h_d_n)
       
        # Fresnel term F
        F_0 = 0.04 * (1 - metallic) + base_color * metallic  # [H, W, 1, 3]
        F = F_0 + (1.0 - F_0) * ((1.0 - h_d_o) ** 5)  # [H, W, S, 3]

        # geometry term V, we use V = G / (4 * cos * cos) here
        def _v_schlick_ggx(r, cos):
            r2 = ((1 + r) ** 2) / 8
            return 0.5 / (cos * (1 - r2) + r2).clamp(min=1e-7)

        V = _v_schlick_ggx(roughness, n_d_i) * _v_schlick_ggx(roughness, n_d_o)  # [H, W, S, 1]

        return D * F * V

    # half vector and all cosines
    half_dirs = incident_dirs + viewdirs
    half_dirs = F.normalize(half_dirs, dim=-1)

    h_d_n = _dot(half_dirs, normals).clamp(min=0)
    h_d_o = _dot(half_dirs, viewdirs).clamp(min=0)
    n_d_i = _dot(normals, incident_dirs).clamp(min=0)
    n_d_o = _dot(normals, viewdirs).clamp(min=0)
    
    f_d = _f_diffuse(base_color, metallic)
    f_s = _f_specular(h_d_n, h_d_o, n_d_i, n_d_o, base_color, roughness, metallic)

    transport = incident_lights * incident_areas * n_d_i
    rgb_d = (f_d * transport).mean(dim=-2)
    rgb_s = (f_s * transport).mean(dim=-2)
    rgb = rgb_d + rgb_s

    extra_results = {
        "incident_lights": incident_lights.mean(dim=-2),
        "local_incident_lights": local_incident_lights.mean(dim=-2),
        "global_incident_lights": global_incident_lights.mean(dim=-2),
        "incident_visibility": incident_visibility.mean(dim=-2),
    }
    
    return rgb, extra_results


def sample_incident_rays(normals, is_training=False, sample_num=24):
    if is_training:
        incident_dirs, incident_areas = fibonacci_sphere_sampling(
            normals, sample_num, random_rotate=True)
    else:
        incident_dirs, incident_areas = fibonacci_sphere_sampling(
            normals, sample_num, random_rotate=False)

    return incident_dirs, incident_areas  # [N, S, 3], [N, S, 1]