from typing import List
from jaxtyping import Float
import os
import torch
from torch import nn
from einops import rearrange, repeat, reduce
import torch.nn.functional as F
from torch import Tensor
import nerfacc
import numpy as np
from src.utils.project import ray_sample


class TriplaneNERF(nn.Module):
    def __init__(
        self,
        plane_size,
        plane_dim,      
        resolution,
        N_samples,
        query_transformer = None,
        resnet = None,
        init_scale=0.001,
        reduce_method='concat',
        implicit_network=None,
        rendering_network=None,
        density_network=None,
        perturb = False,
        density_scale = 1.,
        bg_color = "white"
    ):
        super().__init__()

        self.plane_size = plane_size
        self.plane_dim = plane_dim
        self.plane = nn.Parameter(torch.randn(3, plane_size, plane_size, plane_dim)*init_scale, requires_grad=True)
        self.resolution = resolution
        self.N_samples = N_samples
        self.query_transformer = query_transformer  
        self.resnet = resnet
        self.reduce_method = reduce_method

        self.implicit_network = implicit_network
        self.rendering_network = rendering_network
        self.density_network = density_network
        self.perturb = perturb
        self.density_scale = density_scale
        self.bg_color = bg_color
        self.aabb_box = torch.as_tensor([-1, -1, -1, 1, 1, 1], dtype=torch.float32)
        self.estimator = nerfacc.OccGridEstimator(
            resolution=resolution,
            roi_aabb=self.aabb_box,
            levels=1
        )
        self.estimator.occs.fill_(True)
        self.estimator.binaries.fill_(True)
        self.render_step_size = np.sqrt(3) * 2 / self.N_samples
        # more effience sampling ---> sample by last stage density

    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op=None
    ) -> None:
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)

    def query_triplane(self,
        images: Float[Tensor, "B F C H W"],
        time_embeddings: Float[Tensor, "B F D"] = None,
    ):
        b, f, c, h, w = images.shape
        plane = self.plane
        plane = repeat(plane, 'n h w d -> b (n h w) d', b=b)
        images = rearrange(images, 'b f c h w -> b (f h w) c')
        volume = self.query_transformer(plane, images, time_embeddings).sample # b (n v v) d
        volume = rearrange(volume, 'b (n h w) d -> (b n) d h w', n=3, h=self.plane_size, w=self.plane_size)
        if self.resnet:
            volume = self.resnet(volume, time_embeddings)
        volume = rearrange(volume, '(b n) d h w -> b n d h w', n=3, h=self.plane_size, w=self.plane_size)

        return volume
    
    def volume_rendering(self, 
        density,
        normals,
        rgbs,
        t_starts,
        t_ends,
        t_positions,
        ray_indices,
        n_rays
        ):
        density = self.density_network(density)
        weights_, trans_, _ = nerfacc.render_weight_from_density(
            t_starts[..., 0],
            t_ends[..., 0],
            density[..., 0],
            ray_indices=ray_indices,
            n_rays=n_rays,
        )

        weights = weights_[..., None]

        acc = nerfacc.accumulate_along_rays(
            weights[..., 0], values=None, ray_indices=ray_indices, n_rays=n_rays
        )
        depth = nerfacc.accumulate_along_rays(
            weights[..., 0], values=t_positions, ray_indices=ray_indices, n_rays=n_rays
        )
        color = nerfacc.accumulate_along_rays(
            weights[..., 0], values=rgbs, ray_indices=ray_indices, n_rays=n_rays
        )
        normals = nerfacc.accumulate_along_rays(
            weights[..., 0], values=normals, ray_indices=ray_indices, n_rays=n_rays
        )
        if self.bg_color == "white":
            bg_color = torch.ones_like(color)
            
        color = color + (1.0 - acc) * bg_color
        
        return {
            "rgb": color,
            "depth": depth,
            "mask": acc,
            "normal": normals
        }


    def render(self, 
        volume: Float[Tensor, "B N D V V"],
        target_cameras: Float[Tensor, "B F 32"],
        time_embeddings: Float[Tensor, "B F D"] = None
    ):
        b, f, _ = target_cameras.shape
        instrinsics = target_cameras[..., :16].reshape(-1, 4, 4)
        c2w = target_cameras[..., 16:].reshape(-1, 4, 4)
        ray_origins, ray_dirs = ray_sample(c2w, instrinsics[:, :3, :3], self.resolution)
        ray_origins_flattened = ray_origins.reshape(-1, 3)
        ray_dirs_flattened = ray_dirs.reshape(-1, 3)
        n_rays = ray_origins_flattened.shape[0]
        with torch.no_grad():
            ray_indices, t_starts, t_ends = self.estimator.sampling(
                ray_origins_flattened, ray_dirs_flattened, 
                render_step_size=self.render_step_size,
            )

        t_starts, t_ends = t_starts[..., None], t_ends[..., None]
        t_origins = ray_origins_flattened[ray_indices]
        t_dirs = ray_dirs_flattened[ray_indices]
        t_positions = (t_starts + t_ends) / 2.0

        xyzs = t_origins + t_dirs * t_positions

        xy, yz, xz = xyzs[..., [0, 1]], xyzs[..., [1, 2]] , xyzs[..., [0, 2]] 

        coords = torch.stack([xy, yz, xz], dim=-1) 
        coords = rearrange(coords, 'n c i -> i 1 n c')
        volume = rearrange(volume, "1 i c h w -> i c h w")

        sampled_features = F.grid_sample(volume, coords, align_corners=True, mode='bilinear', padding_mode='zeros') 
        if self.reduce_method == 'concat':
            sampled_features = rearrange(sampled_features, 'i c 1 n -> n (i c)')
        elif self.reduce_method == 'mean':
            sampled_features = reduce(sampled_features, 'i c 1 n -> n c')
        
        density, feature_vectors, normals = self.implicit_network.get_outputs(xyzs, sampled_features) # TODO: add timesteps support

        rgb = self.rendering_network(t_dirs, feature_vectors) # TODO: add timesteps support
        # rgb: [B, H, W, 3]
        outputs = self.volume_rendering(density, normals, rgb, t_starts, t_ends, t_positions, ray_indices, n_rays)
        for k, v in outputs.items():
            outputs[k] = rearrange(v, "(b f h w) c -> b f c h w", h=self.resolution, w=self.resolution, f=f)
        return outputs

    def forward(self,
        images: Float[Tensor, "B F C H W"],
        target_cameras: Float[Tensor, "B F 32"],
        time_embeddings: Float[Tensor, "B F D"] = None
    ):
        volume = self.query_triplane(images, time_embeddings)
        volume = rearrange(self.plane, 'i h w c -> 1 i c h w')
        output = self.render(volume, target_cameras, time_embeddings)
        # output: {"rgb": xx, "depth": xx, "mask": xx, "normal": xx}
        return output
        