from typing import List
from jaxtyping import Float
import os
import torch
from torch import nn
from einops import rearrange, repeat, reduce
import torch.nn.functional as F
from torch import Tensor
import nerfacc
import numpy as np
from src.utils.project import ray_sample


class VolumeNERF(nn.Module):
    def __init__(
        self,
        plane_size,
        plane_dim,      
        resolution,
        N_samples,
        query_transformer = None,
        resnet = None,
        init_scale=0.001,
        reduce_method='concat',
        implicit_network=None,
        rendering_network=None,
        density_network=None,
        perturb = False,
        density_scale = 1.,
        bg_color = "white"
    ):
        super().__init__()

        self.plane_size = plane_size
        self.plane_dim = plane_dim
        self.plane = nn.Parameter(torch.randn(3, plane_size, plane_size, plane_dim)*init_scale, requires_grad=True)
        self.resolution = resolution
        self.N_samples = N_samples
        self.query_transformer = query_transformer  
        self.resnet = resnet
        self.reduce_method = reduce_method

        self.implicit_network = implicit_network
        self.rendering_network = rendering_network
        self.density_network = density_network
        self.perturb = perturb
        self.density_scale = density_scale
        self.bg_color = bg_color
        self.aabb_box = torch.as_tensor([-1, -1, -1, 1, 1, 1], dtype=torch.float32)
        self.estimator = nerfacc.OccGridEstimator(
            resolution=resolution,
            roi_aabb=self.aabb_box,
            levels=1
        )
        self.estimator.occs.fill_(True)
        self.estimator.binaries.fill_(True)
        self.render_step_size = np.sqrt(3) * 2 / self.N_samples
        # more effience sampling ---> sample by last stage density

    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op=None
    ) -> None:
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)

    def query_triplane(self,
        images: Float[Tensor, "B F C H W"],
        time_embeddings: Float[Tensor, "B F D"] = None,
    ):
        b, f, c, h, w = images.shape
        plane = repeat(self.plane, 'd h w c -> b (d h w) c', b=b)
        images = rearrange(images, 'b f c h w -> b (f h w) c')
        volume = self.query_transformer(plane, images, time_embeddings).sample # b (n v v) d
        volume = rearrange(volume, 'b (d h w) c -> b c d h w', d=self.plane_size, h=self.plane_size, w=self.plane_size)
        if self.resnet:
            volume = self.resnet(volume, time_embeddings)
        volume = rearrange(volume, 'b c d h w -> b d h w c')
        return volume
    
    def volume_rendering(self, 
        sdf,
        normals,
        rgbs,
        t_starts,
        t_ends,
        t_positions,
        ray_indices,
        n_rays
        ):
        density = self.density_network(sdf)
        weights_, trans_, _ = nerfacc.render_weight_from_density(
            t_starts[..., 0],
            t_ends[..., 0],
            density[..., 0],
            ray_indices=ray_indices,
            n_rays=n_rays,
        )

        weights = weights_[..., None]

        acc = nerfacc.accumulate_along_rays(
            weights[..., 0], values=None, ray_indices=ray_indices, n_rays=n_rays
        )
        depth = nerfacc.accumulate_along_rays(
            weights[..., 0], values=t_positions, ray_indices=ray_indices, n_rays=n_rays
        )
        color = nerfacc.accumulate_along_rays(
            weights[..., 0], values=rgbs, ray_indices=ray_indices, n_rays=n_rays
        )
        normals = nerfacc.accumulate_along_rays(
            weights[..., 0], values=normals, ray_indices=ray_indices, n_rays=n_rays
        )
        normals = F.normalize(normals, p=2, dim=-1)
        if self.bg_color == "white":
            bg_color = torch.ones_like(color)
        color = color + (1.0 - acc) * bg_color
        # color to [-1, 1]
        # color = color * 2 - 1
        return {
            "rgb": color,
            "depth": depth,
            "mask": acc,
            "normal": normals
        }


    def render(self, 
        volume: Float[Tensor, "B C D H W"],
        target_cameras: Float[Tensor, "B F 32"],
        time_embeddings: Float[Tensor, "B F D"] = None
    ):
        b, f, _ = target_cameras.shape
        instrinsics = target_cameras[..., :16].reshape(-1, 4, 4)
        c2w = target_cameras[..., 16:].reshape(-1, 4, 4)
        ray_origins, ray_dirs = ray_sample(c2w, instrinsics[:, :3, :3], self.resolution)
        ray_origins_flattened = ray_origins.reshape(-1, 3)
        ray_dirs_flattened = ray_dirs.reshape(-1, 3)
        n_rays = ray_origins_flattened.shape[0]
        with torch.no_grad():
            ray_indices, t_starts, t_ends = self.estimator.sampling(
                ray_origins_flattened, ray_dirs_flattened, 
                render_step_size=self.render_step_size,
            )

        t_starts, t_ends = t_starts[..., None], t_ends[..., None]
        t_origins = ray_origins_flattened[ray_indices]
        t_dirs = ray_dirs_flattened[ray_indices]
        t_positions = (t_starts + t_ends) / 2.0

        xyzs = t_origins + t_dirs * t_positions # [N, 3]

        coords = rearrange(xyzs, 'n c -> 1 1 1 n c')

        sampled_features = F.grid_sample(volume, coords, align_corners=True, mode='bilinear', padding_mode='zeros') 

        sampled_features = rearrange(sampled_features, '1 c 1 1 n -> n c')
        
        density, feature_vectors, grads = self.implicit_network.get_outputs(xyzs, sampled_features) # TODO: add timesteps support
        
        normals = grads / (grads.norm(2, -1, keepdim=True) + 1e-6)

        rgb = self.rendering_network(t_dirs, feature_vectors, normals) # TODO: add timesteps support
        # rgb: [B, H, W, 3]
        outputs = self.volume_rendering(density, normals, rgb, t_starts, t_ends, t_positions, ray_indices, n_rays)

        for k, v in outputs.items():
            outputs[k] = rearrange(v, "(b f h w) c -> b f c h w", h=self.resolution, w=self.resolution, f=f)
        outputs.update(
            {
                "sdf_grad": grads
            }
        )
        return outputs

    def forward(self,
        images: Float[Tensor, "B F C H W"],
        target_cameras: Float[Tensor, "B F 32"],
        time_embeddings: Float[Tensor, "B F D"] = None
    ):
        # volume = self.query_triplane(images, time_embeddings)
        volume = rearrange(self.plane, 'd h w c -> 1 c d h w')
        output = self.render(volume, target_cameras, time_embeddings)
        # output: {"rgb": xx, "depth": xx, "mask": xx, "normal": xx}
        return output
        