from typing import List
from jaxtyping import Float
import os
import torch
from torch import nn
from einops import rearrange, repeat, reduce
import torch.nn.functional as F
from torch import Tensor
from src.utils.project import ray_sample, get_ray_limits_box

class TriplaneNERF(nn.Module):
    def __init__(
        self,
        plane_size,
        plane_dim,      
        resolution,
        N_samples,
        query_transformer = None,
        resnet = None,
        init_scale=0.001,
        reduce_method='concat',
        implicit_network=None,
        rendering_network=None,
        density_network=None,
        perturb = False,
        density_scale = 1.,
        bg_color = "white"
    ):
        super().__init__()

        self.plane_size = plane_size
        self.plane_dim = plane_dim
        self.plane = nn.Parameter(torch.randn(3, plane_size, plane_size, plane_dim)*init_scale, requires_grad=True)
        self.resolution = resolution
        self.N_samples = N_samples
        self.query_transformer = query_transformer  
        self.resnet = resnet
        self.reduce_method = reduce_method

        self.implicit_network = implicit_network
        self.rendering_network = rendering_network
        self.density_network = density_network
        self.perturb = perturb
        self.density_scale = density_scale
        self.bg_color = bg_color
        # more effience sampling ---> sample by last stage density

    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op=None
    ) -> None:
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)

    def query_triplane(self,
        images: Float[Tensor, "B F C H W"],
        time_embeddings: Float[Tensor, "B F D"] = None,
    ):
        b, f, c, h, w = images.shape
        plane = self.plane
        plane = repeat(plane, 'n h w d -> b (n h w) d', b=b)
        images = rearrange(images, 'b f c h w -> b (f h w) c')
        volume = self.query_transformer(plane, images, time_embeddings).sample # b (n v v) d
        volume = rearrange(volume, 'b (n h w) d -> (b n) d h w', n=3, h=self.plane_size, w=self.plane_size)
        if self.resnet:
            volume = self.resnet(volume, time_embeddings)
        volume = rearrange(volume, '(b n) d h w -> b n d h w', n=3, h=self.plane_size, w=self.plane_size)

        return volume
    
    def volume_rendering(self, 
        density: Float[Tensor, "B ... S 1"],
        z_vals: Float[Tensor, "B ... S 1"],
        rgbs: Float[Tensor, "B ... S 3"],
        deltas: Float[Tensor, "B ... S 1"],
        ):
        density = self.density_network(density)
        alphas = 1 - torch.exp(-deltas * self.density_scale * density).squeeze(-1) # b f n s

        alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) 
        weights = (alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]).unsqueeze(-1) # b f n s 1
        # mask = weights > 1e-4 # hard coded # TODO: maybe need
        acc = weights.sum(dim=-2) # [N]
        depths = torch.sum(weights * z_vals, dim=-2) / (acc + 1e-8) # [N]
        images = torch.sum(weights * rgbs, dim=-2) # [N, 3], in [0, 1]

        # if self.bg_color == "white":
        #     bg_color = torch.ones_like(images[..., :1])
        #     images = images + (1 - acc)*bg_color # [N, 3]    

        return {
            "images": images,
            "depths": depths,
            "acc": acc
        }


    def render(self, 
        volume: Float[Tensor, "B N D V V"],
        target_cameras: Float[Tensor, "B F 32"],
        time_embeddings: Float[Tensor, "B F D"] = None
    ):
        b, f, _ = target_cameras.shape
        instrinsics = target_cameras[..., :16].reshape(-1, 4, 4)
        c2w = target_cameras[..., 16:].reshape(-1, 4, 4)
        ray_origins, ray_dirs = ray_sample(c2w, instrinsics[:, :3, :3], self.resolution)
        ray_origins = rearrange(ray_origins, "(b f) n c -> b f n c", f=f)
        ray_dirs = rearrange(ray_dirs, "(b f) n c -> b f n c", f=f)
    
        nears, fars = get_ray_limits_box(
            ray_origins, ray_dirs, 2
        )
        is_ray_valid = nears > fars
        
        device = ray_origins.device
        num_steps = self.N_samples

        if torch.any(is_ray_valid).item():
            nears[~is_ray_valid] = nears[is_ray_valid].min()
            fars[~is_ray_valid] = nears[is_ray_valid].max()

        z_vals = torch.linspace(0.0, 1.0, num_steps, device=device) # s
        z_vals = repeat(z_vals, 's -> b f 1 s 1', b = b, f=f) 
        nears = rearrange(nears, 'b f n 1 -> b f n 1 1')
        fars = rearrange(fars, 'b f n 1 -> b f n 1 1')
        z_vals = nears + (fars - nears) * z_vals # b x f x n x s x 1

        sample_dist = (fars - nears) / num_steps
        if self.perturb:
            z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist

        xyzs = ray_origins.unsqueeze(-2) + ray_dirs.unsqueeze(-2) * z_vals # b x f x n x s x 3

        xy, yz, xz = xyzs[..., [0, 1]], xyzs[..., [1, 2]] , xyzs[..., [0, 2]] 
        
        coords = torch.stack([xy, yz, xz], dim=-1) # b f n s c i
        coords = rearrange(coords, 'b f n s c i -> (b f i) n s c ')
        volume = repeat(volume, "1 i c h w -> (b f i) c h w", b=b, f=f)
        sampled_features = F.grid_sample(volume, coords, align_corners=True, mode='bilinear', padding_mode='zeros') # (BF) C H W
        if self.reduce_method == 'concat':
            sampled_features = rearrange(sampled_features, '(bf i) c n s -> (bf n) s (i c)', n=self.resolution**2, s=num_steps, i=3)
        elif self.reduce_method == 'mean':
            sampled_features = reduce(sampled_features, '(bf i) c n s -> (bf n) s c', 'mean', i=3) 
        sampled_features = rearrange(sampled_features, "(b f n) s c -> b f n s c", b=b, f=f, n=self.resolution**2)

        sdf, feature_vectors, normals = self.implicit_network.get_outputs(xyzs, sampled_features) # TODO: add timesteps support
        ray_dirs = repeat(ray_dirs, "b f n c -> b f n s c", s=num_steps)
        rgb = self.rendering_network(ray_dirs, feature_vectors) # TODO: add timesteps support
        # rgb: [B, H, W, 3]
        outputs = self.volume_rendering(sdf, z_vals, rgb, sample_dist)
        for k, v in outputs.items():
            outputs[k] = rearrange(v, "b f (h w) c -> b f c h w", h=self.resolution, w=self.resolution)
        return outputs

    def forward(self,
        images: Float[Tensor, "B F C H W"],
        target_cameras: Float[Tensor, "B F 32"],
        time_embeddings: Float[Tensor, "B F D"] = None
    ):
        volume = self.query_triplane(images, time_embeddings)
        output = self.render(volume, target_cameras, time_embeddings)
        # output: {"rgb": xx, "depth": xx, "mask": xx, "normal": xx}
        return output
        