import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from einops import rearrange
import math
from pytorch3d.utils.camera_conversions import cameras_from_opencv_projection
import nerfacc
from src.utils.project import ray_sample
import numpy as np

class RenderModule(nn.Module):
    def __init__(self) -> None:
        super(RenderModule, self).__init__()
        self.render_feat_dim = 3
        # feature and density layers
        self.density_head = nn.Sequential(
            nn.ConvTranspose3d(128, 32, 4, stride=2, padding=1),
            #nn.LayerNorm([32,64,64,64]),
            nn.BatchNorm3d(32),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(32, 8, 3, padding=1),
            #nn.LayerNorm([8,64,64,64]),
            nn.BatchNorm3d(8),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(8, 1, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Softplus(beta=100)
        )
        self.features_head = nn.Sequential(
            nn.ConvTranspose3d(128, 32, 4, stride=2, padding=1),
            #nn.LayerNorm([32,64,64,64]),
            nn.BatchNorm3d(32),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(32, self.render_feat_dim, 3, padding=1),
            #nn.LayerNorm([self.render_feat_dim,64,64,64]),
            nn.BatchNorm3d(self.render_feat_dim),
            nn.LeakyReLU(inplace=True),
            nn.Sigmoid(),
        )
        self.upsample = nn.ModuleList(
            [
                nn.Sequential(
                    nn.ConvTranspose3d(self.render_feat_dim, self.render_feat_dim, 4, stride=2, padding=1),
                    #nn.LayerNorm([self.render_feat_dim,128,128,128]),
                    nn.BatchNorm3d(self.render_feat_dim),
                    nn.LeakyReLU(inplace=True),
                )
                for _ in range(2) 
            ]
        )
        self.bg_color = "white"
        N_samples = 64
        resolution = 64
        self.aabb_box = torch.as_tensor([-1, -1, -1, 1, 1, 1], dtype=torch.float32)
        self.estimator = nerfacc.OccGridEstimator(
            resolution=resolution,
            roi_aabb=self.aabb_box,
            levels=1
        )
        self.estimator.occs.fill_(True)
        self.estimator.binaries.fill_(True)
        self.N_samples = N_samples
        self.render_step_size = np.sqrt(3) * 2 / self.N_samples
        self.resolution = resolution


    def volume_rendering(self, density, rgbs, t_starts, t_ends, t_positions, ray_indices, n_rays):
        weights_, trans_, _ = nerfacc.render_weight_from_density(
            t_starts[..., 0],
            t_ends[..., 0],
            density[..., 0],
            ray_indices=ray_indices,
            n_rays=n_rays,
        )

        weights = weights_[..., None]
        acc = nerfacc.accumulate_along_rays(
            weights[..., 0], values=None, ray_indices=ray_indices, n_rays=n_rays
        )
        depth = nerfacc.accumulate_along_rays(
            weights[..., 0], values=t_positions, ray_indices=ray_indices, n_rays=n_rays
        )
        color = nerfacc.accumulate_along_rays(
            weights[..., 0], values=rgbs, ray_indices=ray_indices, n_rays=n_rays
        )
        if self.bg_color == "white":
            bg_color = torch.ones_like(color)
        color = color + (1.0 - acc) * bg_color
        
        return {
            "rgb": color,
            "depth": depth,
            "mask": acc,
        }

    def forward(self, features, target_cameras):
        '''
        feat3d: in shape [b,C,D,H,W]
        '''
        b,C,D,H,W = features.shape
        device = features.device

        # get neural volume for NeRF rendering
        densities = self.density_head(features).clip(min=0.0, max=1.0-1e-5)      # [b,1,D2,H2,W2]
        features = self.features_head(features)                                  # [b,C2=16,D2,H2,W2]
        _,C2,D2,H2,W2 = features.shape
        
        b, f, _ = target_cameras.shape
        instrinsics = target_cameras[..., :16].reshape(-1, 4, 4)
        c2w = target_cameras[..., 16:].reshape(-1, 4, 4)
        ray_origins, ray_dirs = ray_sample(c2w, instrinsics[:, :3, :3], self.resolution)
        ray_origins_flattened = ray_origins.reshape(-1, 3)
        ray_dirs_flattened = ray_dirs.reshape(-1, 3)
        n_rays = ray_origins_flattened.shape[0]
        with torch.no_grad():
            ray_indices, t_starts, t_ends = self.estimator.sampling(
                ray_origins_flattened, ray_dirs_flattened, 
                render_step_size=self.render_step_size,
            )

        t_starts, t_ends = t_starts[..., None], t_ends[..., None]
        t_origins = ray_origins_flattened[ray_indices]
        t_dirs = ray_dirs_flattened[ray_indices]
        t_positions = (t_starts + t_ends) / 2.0

        xyzs = t_origins + t_dirs * t_positions # [N, 3]

        coords = rearrange(xyzs, 'n c -> 1 1 1 n c')
        density = F.grid_sample(densities, coords, align_corners=True, mode='bilinear', padding_mode='zeros') 
        rgb = F.grid_sample(features, coords, align_corners=True, mode='bilinear', padding_mode='zeros')

        density = rearrange(density, '1 c 1 1 n -> n c')
        rgb = rearrange(rgb, '1 c 1 1 n -> n c')

        # rgb: [B, H, W, 3]
        outputs = self.volume_rendering(density, rgb, t_starts, t_ends, t_positions, ray_indices, n_rays)
        for k, v in outputs.items():
            outputs[k] = rearrange(v, "(b f h w) c -> b f c h w", h=self.resolution, w=self.resolution, f=f)

        return outputs




        


