#https://github.com/DaLi-Jack/SSR-code/blob/main/ssr/model/mlp.py
import torch
from torch import nn
import numpy as np

from einops import rearrange, repeat, reduce
import torch.nn.functional as F
from src.utils.typing import *

class HarmonicEmbedding(torch.nn.Module):
    def __init__(
        self,
        n_harmonic_functions: int = 6,
        omega_0: float = 1.0,
        logspace: bool = True,
        append_input: bool = True,
    ) -> None:
        super().__init__()

        if logspace:
            frequencies = 2.0 ** torch.arange(
                n_harmonic_functions,
                dtype=torch.float32,
            )
        else:
            frequencies = torch.linspace(
                1.0,
                2.0 ** (n_harmonic_functions - 1),
                n_harmonic_functions,
                dtype=torch.float32,
            )

        self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
        self.append_input = append_input

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: tensor of shape [..., dim]
        Returns:
            embedding: a harmonic embedding of `x`
                of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
        """
        embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
        embed = torch.cat(
            (embed.sin(), embed.cos(), x)
            if self.append_input
            else (embed.sin(), embed.cos()),
            dim=-1,
        )
        return embed

    @staticmethod
    def get_output_dim_static(
        input_dims: int,
        n_harmonic_functions: int,
        append_input: bool,
    ) -> int:
        """
        Utility to help predict the shape of the output of `forward`.
        Args:
            input_dims: length of the last dimension of the input tensor
            n_harmonic_functions: number of embedding frequencies
            append_input: whether or not to concat the original
                input to the harmonic embedding
        Returns:
            int: the length of the last dimension of the output tensor
        """
        return input_dims * (2 * n_harmonic_functions + int(append_input))

    def get_output_dim(self, input_dims: int = 3) -> int:
        """
        Same as above. The default for input_dims is 3 for 3D applications
        which use harmonic embedding for positional encoding,
        so the input might be xyz.
        """
        return self.get_output_dim_static(
            input_dims, len(self._frequencies), self.append_input
        )
    
class ImplicitNetwork(nn.Module):
    def __init__(
            self,
            feature_vector_size,
            d_in,
            d_out,
            dims,
            geometric_init=True,
            bias=1.0,
            skip_in=(),
            weight_norm=True,
            sphere_scale=1.0,
            inside_outside=False,
            multires=6,
    ):
        super().__init__()

        self.sphere_scale = sphere_scale
        dims = [d_in] + dims + [d_out + feature_vector_size]

        self.embed_fn = HarmonicEmbedding(multires, append_input=True)
        input_ch = HarmonicEmbedding.get_output_dim_static(d_in, multires, True)
        # use cat architecture
        dims[0] = input_ch + feature_vector_size     
        print("implicit network architecture:", dims)         
        if len(skip_in) > 0:             
            skip_dim = skip_in[0] 
            dims[skip_dim] = dims[0]

        self.num_layers = len(dims) 
        self.skip_in = skip_in

        for l in range(0, self.num_layers - 1):

            out_dim = dims[l + 1]
            lin = nn.Linear(dims[l], out_dim)

            if geometric_init:
                if l == self.num_layers - 2:
                    if not inside_outside:
                        torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
                        torch.nn.init.constant_(lin.bias, -bias)
                    else:
                        torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
                        torch.nn.init.constant_(lin.bias, bias)

                elif multires > 0 and l == 0:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
                    torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
                elif multires > 0 and l in self.skip_in:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
                    torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
                else:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))

            if weight_norm:
                lin = nn.utils.weight_norm(lin)

            setattr(self, "lin" + str(l), lin)

        self.softplus = nn.Softplus(beta=100)

    # x y z 
    def forward(self,
        input: Float[Tensor, "B H W S 3"], 
        latent_feature: Float[Tensor, "B H W S latent_size"],
        ):

        input = self.embed_fn(input)                
        x = torch.cat([input, latent_feature], dim=-1)       
  
        for l in range(0, self.num_layers - 1):
            lin = getattr(self, "lin" + str(l))
            if l in self.skip_in:
                x = x + skip_feature
            if l == 0:
                skip_feature = x
            x = lin(x)
            if l < self.num_layers - 2:
                x = self.softplus(x)
        return x

    # def gradient(self, x, latent_feature):
    #     x.requires_grad_(True)
    #     y = self.forward(x, latent_feature)[:,:1]
    #     d_output = torch.ones_like(y, requires_grad=False, device=y.device)
    #     gradients = torch.autograd.grad(
    #         outputs=y,
    #         inputs=x,
    #         grad_outputs=d_output,
    #         create_graph=True,
    #         retain_graph=True,
    #         only_inputs=True)[0]
    #     return gradients

    def get_outputs(self, x, latent_feature):
        x.requires_grad_(True)
        with torch.set_grad_enabled(True):                   
            output = self.forward(x, latent_feature)
            sdf = output[...,:1]
            ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
            # if self.sdf_bounding_sphere > 0.0:
            #     sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
            #     sdf = torch.minimum(sdf, sphere_sdf)
            # if self.sdf_bounding_box > 0.0:
            #     # 计算点x到每个面的距离
            #     distance_to_faces = self.sdf_bounding_box - torch.abs(x)
            #     # 取最大的距离作为到立方体表面的SDF
            #     box_sdf = torch.min(distance_to_faces, dim=1, keepdim=True)[0]
            #     # 使用torch.minimum将原始的SDF和立方体的SDF进行比较，取较小的那个
            #     sdf = torch.minimum(sdf, box_sdf)
            feature_vectors = output[..., 1:]
            d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
            gradients = torch.autograd.grad(
                outputs=sdf,
                inputs=x,
                grad_outputs=d_output,
                create_graph=True,
                retain_graph=True,
                only_inputs=True)[0] 

        return sdf, feature_vectors, gradients

    # def get_sdf_vals(self, x, latent_feature):
    #     sdf = self.forward(x, latent_feature)[:,:1]
    #     ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
    #     if self.sdf_bounding_sphere > 0.0:
    #         sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
    #         sdf = torch.minimum(sdf, sphere_sdf)
    #     return sdf


class RenderingNetwork(nn.Module):
    def __init__(
            self,
            feature_vector_size,
            d_in,
            d_out,
            dims,
            weight_norm=True,
            multires_view=0,
            add_normals = False,
            squeeze_out=True
    ):
        super().__init__()
        dims = [feature_vector_size] + dims + [d_out]

        self.embedview_fn = HarmonicEmbedding(multires_view, append_input=True)
        input_ch = HarmonicEmbedding.get_output_dim_static(d_in, multires_view, True)
        dims[0] += input_ch 
        if add_normals:
            dims[0] += 3
        
        print("rendering network architecture:")
        print(dims)
        
        self.num_layers = len(dims)

        for l in range(0, self.num_layers - 1):
            out_dim = dims[l + 1]
            lin = nn.Linear(dims[l], out_dim)

            if weight_norm:
                lin = nn.utils.weight_norm(lin)

            setattr(self, "lin" + str(l), lin)

        self.relu = nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        self.squeeze_out= squeeze_out

    def forward(self, view_dirs, feature_vectors, normals=None):
        view_dirs = self.embedview_fn(view_dirs)
        if normals is not None:
            x = torch.cat([view_dirs, feature_vectors, normals], dim=-1)
        else:
            x = torch.cat([view_dirs, feature_vectors], dim=-1)
        for l in range(0, self.num_layers - 1):
            lin = getattr(self, "lin" + str(l))
            x = lin(x)
            if l < self.num_layers - 2:
                x = self.relu(x)
        if self.squeeze_out:
            x = self.sigmoid(x)
        return x

class Triplane(nn.Module):
    def __init__(self, plane_size, plane_dim, mode='concat', init_scale=0.001):
        super().__init__()
        self.plane_size = plane_size
        self.plane = nn.Parameter(torch.randn(3, plane_size, plane_size, plane_dim)*init_scale, requires_grad=True)
        self.mode = mode
    def sample_features(self, xyzs, plane = None):
        # plane: 3 x h x w x c
        volume = self.plane if plane is None else plane
        xy, yz, xz = xyzs[..., [0, 1]], xyzs[..., [1, 2]] , xyzs[..., [0, 2]] 

        coords = torch.stack([xy, yz, xz], dim=-1) 
        coords = rearrange(coords, 'n c i -> i 1 n c')

        sampled_features = F.grid_sample(volume, coords, align_corners=True, mode='bilinear', padding_mode='zeros') 
        if self.reduce_method == 'concat':
            sampled_features = rearrange(sampled_features, 'i c 1 n -> n (i c)')
        elif self.reduce_method == 'mean':
            sampled_features = reduce(sampled_features, 'i c 1 n -> n c')
        
        return sampled_features
        
class Volume(nn.Module):
    def __init__(self, volume_size, volume_dim, init_scale = 0.001):
        super().__init__()
        self.volume_size = volume_size
        self.volume = nn.Parameter(torch.randn(volume_size, volume_size, volume_size, volume_dim)*init_scale, requires_grad=True)

    def sample_features(self, xyzs, volume = None):
        volume = self.volume if volume is None else volume
        if len(volume.shape) == 4:
            volume = rearrange(volume, 'd h w c -> 1 c d h w')
        coords = rearrange(xyzs, 'n c -> 1 1 1 n c')
        sampled_features = F.grid_sample(volume, coords, align_corners=True, mode='bilinear', padding_mode='zeros') 
        sampled_features = rearrange(sampled_features, '1 c 1 1 n -> n c')
        return sampled_features
        

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    implici_network = ImplicitNetwork(
        feature_vector_size=64,
        d_in=3,
        d_out=1,
        dims=[ 256, 256, 256, 256, 256, 256, 256 ],
        geometric_init=True,
        skip_in=[4],
        weight_norm=True, 
        multires=6
    ).to(device)
    rendering_network = RenderingNetwork(
        feature_vector_size=64,
        d_in=3,
        d_out=3,
        dims=[ 256, 256],
        weight_norm=True, 
        multires_view=6
    ).to(device)
    ray_origins = torch.randn(1, 64, 64, 64, 3).to(device)
    ray_dirs = torch.randn(1, 64, 64, 64, 3).to(device)
    latent_feature = torch.randn(1, 64, 64, 64, 64).to(device)
    sdf, feature_vectors, gradients = implici_network.get_outputs(ray_origins, latent_feature)
    print(sdf.shape, feature_vectors.shape, gradients.shape)
    output = rendering_network(ray_origins, gradients, ray_dirs, feature_vectors)
    print(output.shape)
    
