import torch
from torch import nn
import itertools
import numpy as np
from typing import List, Tuple, Union, Optional, Dict
from torch import Tensor
from einops import rearrange, repeat


class ImplicitNetwork(nn.Module):
    def __init__(
            self,
            feature_vector_size,
            d_in,
            d_out,
            dims,
            geometric_init=True,
            bias=1.0,
            skip_in=(),
            weight_norm=True,
            sphere_scale=1.0,
            inside_outside=False,
            multires=6,
    ):
        super().__init__()

        self.sphere_scale = sphere_scale
        dims = [d_in] + dims + [d_out + feature_vector_size]

        # use cat architecture
        dims[0] =  feature_vector_size     
        print("implicit network architecture:", dims)                      
        skip_dim = skip_in[0] 
        dims[skip_dim] = dims[0]

        self.num_layers = len(dims) 
        self.skip_in = skip_in

        for l in range(0, self.num_layers - 1):

            out_dim = dims[l + 1]
            lin = nn.Linear(dims[l], out_dim)

            if geometric_init:
                if l == self.num_layers - 2:
                    if not inside_outside:
                        torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
                        torch.nn.init.constant_(lin.bias, -bias)
                    else:
                        torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
                        torch.nn.init.constant_(lin.bias, bias)

                elif multires > 0 and l == 0:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
                    torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
                elif multires > 0 and l in self.skip_in:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
                    torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
                else:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))

            if weight_norm:
                lin = nn.utils.weight_norm(lin)

            setattr(self, "lin" + str(l), lin)

        self.softplus = nn.Softplus(beta=100)

    # x y z 
    def forward(self,
        input: Float[Tensor, "B H W S 3"], 
        latent_feature: Float[Tensor, "B H W S latent_size"],
        ):

        input = self.embed_fn(input)                
        x = torch.cat([input, latent_feature], dim=-1)       
  
        for l in range(0, self.num_layers - 1):
            lin = getattr(self, "lin" + str(l))
            if l in self.skip_in:
                x = x + skip_feature
            if l == 0:
                skip_feature = x
            x = lin(x)
            if l < self.num_layers - 2:
                x = self.softplus(x)
        return x

    # def gradient(self, x, latent_feature):
    #     x.requires_grad_(True)
    #     y = self.forward(x, latent_feature)[:,:1]
    #     d_output = torch.ones_like(y, requires_grad=False, device=y.device)
    #     gradients = torch.autograd.grad(
    #         outputs=y,
    #         inputs=x,
    #         grad_outputs=d_output,
    #         create_graph=True,
    #         retain_graph=True,
    #         only_inputs=True)[0]
    #     return gradients

    def get_outputs(self, x, latent_feature):
        x.requires_grad_(True)
        with torch.set_grad_enabled(True):                   
            output = self.forward(x, latent_feature)
            sdf = output[...,:1]
            ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
            # if self.sdf_bounding_sphere > 0.0:
            #     sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
            #     sdf = torch.minimum(sdf, sphere_sdf)
            # if self.sdf_bounding_box > 0.0:
            #     # 计算点x到每个面的距离
            #     distance_to_faces = self.sdf_bounding_box - torch.abs(x)
            #     # 取最大的距离作为到立方体表面的SDF
            #     box_sdf = torch.min(distance_to_faces, dim=1, keepdim=True)[0]
            #     # 使用torch.minimum将原始的SDF和立方体的SDF进行比较，取较小的那个
            #     sdf = torch.minimum(sdf, box_sdf)
            feature_vectors = output[..., 1:]
            d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
            gradients = torch.autograd.grad(
                outputs=sdf,
                inputs=x,
                grad_outputs=d_output,
                create_graph=True,
                retain_graph=True,
                only_inputs=True)[0] 

        return sdf, feature_vectors, gradients

    # def get_sdf_vals(self, x, latent_feature):
    #     sdf = self.forward(x, latent_feature)[:,:1]
    #     ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
    #     if self.sdf_bounding_sphere > 0.0:
    #         sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
    #         sdf = torch.minimum(sdf, sphere_sdf)
    #     return sdf
    
class TriplaneSynthesizer(nn.Module):
    """
    Synthesizer that renders a triplane volume with planes and a camera.
    
    Reference:
    EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
    """

    def __init__(self, voxel_size: int, feature_vector_size: int):
        super().__init__()
        self.voxel_size = voxel_size
        # modules
        self.implicit_network = ImplicitNetwork(
            feature_vector_size=feature_vector_size,
            d_in=3,
            d_out=1,
            dims=[256, 256, 256, 256],
            skip_in=[],
            weight_norm=True,
            sphere_scale=1.0,
            inside_outside=False,
            multires=6,
        )

    def forward(self, planes, grid_size: int, chunk_size: int = 2**20):
        # planes: (N, 3, C, H, W)
        B = planes.shape[0]
        grid_points = torch.stack(torch.meshgrid(
            torch.linspace(-1, 1, grid_size, device=planes.device),
            torch.linspace(-1, 1, grid_size, device=planes.device),
            torch.linspace(-1, 1, grid_size, device=planes.device),
            indexing='ij',
        ), dim=-1).reshape(-1, 3) # (grid_size^3, 3)
        # repeat grid points for each plane
        xyzs = repeat(grid_points, 'n d -> b n d', b=B)
        xy, yz, xz = xyzs[..., [0, 1]], xyzs[..., [1, 2]] , xyzs[..., [0, 2]] 
        coordinates = torch.stack([xy, yz, xz], dim=-1) # (B, grid_size^3, 2, 3)
        coordinates = rearrange(coordinates, 'b n c i-> (b i) n 1 c')
        planes = rearrange(planes, 'b i c h w -> (b i) c h w')
        # query triplane in chunks
        N = grid_points.shape[1]
        outs = []
        for i in range(0, N, chunk_size):
            chunk_points = coordinates[:, i:i+chunk_size]
            # query triplane
            query_features = nn.functional.grid_sample(
                planes, chunk_points, align_corners=False, mode='bilinear', padding_mode='border') # (b i)x c x chunk_size x 1
            query_features = rearrange(query_features, '(b i) c n 1 -> b n (i c)') # b x chunk_size x 3C 
            chunk_out = self.implicit_network(chunk_points, query_features)
            outs.append(chunk_out)

        # concatenate the outputs
        point_features = {
            k: torch.cat([out[k] for out in outs], dim=1)
            for k in outs[0].keys()
        }
        # get sdf == 0
        sdf = point_features['sdf']
        sdf = sdf.reshape(B, N)


        return point_features

