import math
from einops.einops import rearrange
import torch
from torch import nn

class PositionalEncoding3D(nn.Module):
        """Conduct 3D sine positional encoding."""

        def __init__(self, config):
                super().__init__()
                # make in config
                self.max_shape = config['max_shape']
                H, W = self.max_shape
                self.depth_num = config['depth_num']
                self.depth_max = config['depth_max']
                self.depth_min = config['depth_min']
                self.position_dim = self.depth_num * 3
                self.embed_dims = config['embed_dims']
                self.position_range = config['position_range']
                self.eps = 1e-5

                # define the 3d postion embedding
                self.position_encoder = nn.Sequential(
                        nn.Conv2d(self.position_dim, self.embed_dims * 3, kernel_size=1, stride=1, padding=0),
                        nn.ReLU(),
                        nn.Conv2d(self.embed_dims * 3, self.embed_dims, kernel_size=1, stride=1, padding=0)
                )

                # Note: need to rescale to image scale during forward
                coords_h = torch.arange(H).float()
                coords_w = torch.arange(W).float()

                index = torch.arange(start=0, end=self.depth_num, step=1).float()
                bin_size = (self.depth_max - self.depth_min) / (self.depth_num * (1 + self.depth_num))
                coords_d = self.depth_min + bin_size * index * (index + 1)

                D = coords_d.shape[0]
                coords = torch.stack(torch.meshgrid([coords_w, coords_h, coords_d]))
                coords = rearrange(coords, 'c w h d -> w h d c') # W H D 3
                # [u * d, v * d, d]
                coords[..., :2] = coords[..., :2] * torch.maximum(coords[..., 2:3], torch.ones_like(coords[..., 2:3]) * self.eps)
                self.register_buffer('coords', coords.unsqueeze(0), persistent=False) # 1 W H D 3
                self._reset_parameters()

        def _reset_parameters(self):
                for p in self.position_encoder.parameters():
                        if p.dim() > 1:
                                nn.init.xavier_uniform_(p)

        def forward(self, hw_in, hw_c, K_cam, K_rel=None):
                """
                Returns: pos (Tensor): position embedding with shape [N, 2D, h, w]
                """
                coords = torch.zeros((1, hw_c[1], hw_c[0], self.coords.shape[-2], self.coords.shape[-1]), device=self.coords.device)
                coords = coords + self.coords[:, :hw_c[1], :hw_c[0], :, :] # Fetch the part fitting size of coarse feat size
                coords[..., :1] = coords[..., :1] * hw_in[1] / hw_c[1] # Normalize W
                coords[..., 1:2] = coords[..., 1:2] * hw_in[0] / hw_c[0] # Normalize H
                
                B = K_cam.shape[0]
                _, W, H, D, _ = coords.shape
                coords = coords.repeat(B, 1, 1, 1, 1)
                coords = rearrange(coords, 'n w h d c -> n (w h d) c') # N L 3
                # unproject
                kpts_cam = K_cam.inverse() @ coords.transpose(2, 1) # N 3 L
                if K_rel is not None:
                        kpts_cam = K_rel[:, :3, :3] @ kpts_cam + K_rel[:, :3, [3]] # N 3 L
                # reshape
                kpts_cam = rearrange(kpts_cam, 'n c (w h d) -> n w h d c', w=W, h=H)
                # Normalize to certain scale
                kpts_cam[..., 0:1] = (kpts_cam[..., 0:1] - self.position_range[0]) / (self.position_range[3] - self.position_range[0])
                kpts_cam[..., 1:2] = (kpts_cam[..., 1:2] - self.position_range[1]) / (self.position_range[4] - self.position_range[1])
                kpts_cam[..., 2:3] = (kpts_cam[..., 2:3] - self.position_range[2]) / (self.position_range[5] - self.position_range[2])
                # encoding the coordinates
                kpts_cam = rearrange(kpts_cam, 'n w h d c -> n (d c) h w') # N 3d H W
                kpts_cam = torch.clamp(kpts_cam, 1e-5, 1-1e-5).float()
                kpts_cam = torch.log(kpts_cam/(1-kpts_cam))
                pos_embedding = self.position_encoder(kpts_cam)
                return pos_embedding