
class LearnedPooling(nn.Module):
    def __init__(self, pool_size, in_channels, upscale_factor = -1):
        super(SobelPooling, self).__init__()
        self.pool_size = pool_size

        upscale_factor = pool_size if upscale_factor < 1 else upscale_factor

        self.layers = MemoryEfficientUpscaling(in_channels, 1, upscale_factor)

    def forward(self, x, depth):
        weights = self.layers(x)
        weights = reduce(weights, 'b c (h h2) (w w2) -> b c h w', 'sum', h2=self.pool_size, w2=self.pool_size)
        weights = weights / th.maximum(weights, 1e-8 * th.ones_like(weights))
        weights = repeat(weights, 'b c h w -> b c (h h2) (w w2)', h2=self.pool_size, w2=self.pool_size)
        
        return reduce(depth * weights, 'b c (h h2) (w w2) -> b c h w', 'sum', h2=self.pool_size, w2=self.pool_size)

class Gaus3D(nn.Module):
    def __init__(self, size = None, position_limit = 1):
        super(Gaus3D, self).__init__()
        self.size = size
        self.position_limit = position_limit
        self.min_std = 0.1
        self.max_std = 0.5

        self.register_buffer("grid_x", th.zeros(1,1,1,1), persistent=False)
        self.register_buffer("grid_y", th.zeros(1,1,1,1), persistent=False)

        if size is not None:
            self.min_std = 1.0 / min(size)
            self.update_grid(size)

    def update_grid(self, size):

        if size != self.grid_x.shape[2:]:
            self.size    = size
            self.min_std = 1.0 / min(size)
            H, W = size

            self.grid_x = th.arange(W, device=self.grid_x.device)
            self.grid_y = th.arange(H, device=self.grid_x.device)

            self.grid_x = (self.grid_x / (W-1)) * 2 - 1
            self.grid_y = (self.grid_y / (H-1)) * 2 - 1

            self.grid_x = self.grid_x.view(1, 1, 1, -1).expand(1, 1, H, W).clone()
            self.grid_y = self.grid_y.view(1, 1, -1, 1).expand(1, 1, H, W).clone()

    def forward(self, position: th.Tensor, depth: th.Tensor):
        assert position.shape[2] == 5
        H, W = self.size

        x      = position[:,0:1]
        y      = position[:,1:2]
        z      = position[:,2:3]
        std_xy = position[:,3:4]
        std_z  = position[:,4:5]

        x      = th.clip(x, -self.position_limit, self.position_limit)
        y      = th.clip(y, -self.position_limit, self.position_limit)
        std_xy = th.clip(std, self.min_std, self.max_std)
        std_z  = th.clip(std, self.min_std, self.max_std)
            
        std_y = std.clone()
        std_x = std * (H / W)

        return th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2) + (depth - z)**2/(2 * std_z**2)))

class PositionSoftmax(nn.Module):
    def __init__(self, num_possitions):
        super(PositionSoftmax, self).__init__()
        self.num_positions = num_positions
		
		self.scale = nn.Parameter(th.ones(1) * 0.546)

    # approxiamte version of the bhattacharyya distance
    def bhattacharyya_distance(mu1, mu2, sigma_xy1, sigma_xy2, sigma_z1, sigma_z2):
        diff_sqr = (mu1 - mu2)**2

        sum_sigma_xy = sigma_xy1 + sigma_xy2 + 1e-8
        sum_sigma_z = sigma_z1 + sigma_z2 + 1e-8

        # Simplified Bhattacharyya distance for Gaussians with diagonal covariance matrices
        B_distance = diff_sq[..., :2].sum(-1) / sum_sigma_xy + diff_sq[..., 2] / sum_sigma_z

        return B_distance

	def distance_weights(positions):
		xyz = positions[:, :, :3]
		sigma_xy = positions[:, :, 3]
		sigma_z = positions[:, :, 4]

		# Expand dims to compute pairwise differences
		mu1 = xyz[:, :, None, :]
		mu2 = xyz[:, None, :, :]

		sigma_xy1 = sigma_xy[:, :, None]
		sigma_xy2 = sigma_xy[:, None, :]

		sigma_z1 = sigma_z[:, :, None]
		sigma_z2 = sigma_z[:, None, :]

		# Compute Bhattacharyya distance
		return bhattacharyya_distance(mu1, mu2, sigma_xy1, sigma_xy2, sigma_z1, sigma_z2)

    def forward(self, positions):
        positions = rearrange(positions, '(b n) c -> b n c', n = self.num_positions)

        # Compute Bhattacharyya distance
        B_distances = distance_weights(positions)

        # Compute weights
        weights = 1 - triu(th.exp(-B_distances * self.scale), diagonal=0)
        weights = reduce(weights, 'b i j -> b i', 'max')

        return rearrange(weights, 'b n -> (b n) 1 1 1')

class PositionAttention(nn.Module):
    def __init__(self, size, num_possitions, in_channels, out_channels):
        super(PositionAttention, self).__init__()
        num_possitions = num_possitions
        self.gaus3d = Gaus3D(size)

        self.base_std_z = nn.Parameter(th.zeros(1, out_channels, 1, 1))
        self.residual_std_z = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(out_channels, out_channels*4),
            nn.SiLU(),
            nn.Linear(out_channels*4, 1),
        )
        self.alpha = nn.Parameter(th.zeros(1, out_channels, 1, 1)+1e-16)

        self.possition_softmax = PositionSoftmax(num_possitions)

    def forward(self, feature_maps, depth_map, position):
        std_z         = th.sigmoid(self.residual_std_z(feature_maps) * self.alpha + self.base_std_z)
        position      = th.cat((position, std_z), dim=2)
        gaus_mask     = self.gaus3d(position, depth_map) 

        position_mask = self.possition_softmax(position)
        mask          = gaus_mask * position_mask

        return mask * feature_maps

class ResidualBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels):
        assert in_channels % out_channels == 0
        self.skip     = Reduce('b (n c) h w -> b n h w', c = in_channels // out_channels)
        self.residual = MemoryEfficientBottleneck(in_channels, out_channels)

    def forward(self, x):
        return self.skip(x) + self.residual(x)

class ResidualPatchDownscale(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor):
        assert out_channels % in_channels == 0
        self.skip     = Reduce('b c (h h2) (w w2) -> b (c n) h w', c = out_channels // in_channels, h2 = scale_factor, w2 = scale_factor)
        self.residual = MemoryEfficientPatchDownScale(in_channels, out_channels)

    def forward(self, x):
        return self.skip(x) + self.residual(x)


class PixelToPosition(nn.Module):
    def __init__(self, size, channels): # FIXME add update grid !!!
        super(PixelToPosition, self).__init__()

        self.register_buffer("grid_y", th.arange(size[0]), persistent=False)
        self.register_buffer("grid_x", th.arange(size[1]), persistent=False)

        self.grid_y = (self.grid_y / (size[0]-1)) * 2 - 1
        self.grid_x = (self.grid_x / (size[1]-1)) * 2 - 1

        self.grid_y = self.grid_y.view(1, 1, -1, 1).expand(1, 1, *size).clone()
        self.grid_x = self.grid_x.view(1, 1, 1, -1).expand(1, 1, *size).clone()

        self.size = size

        self.layers = nn.Sequential(
            ResidualBottleneck(channels, 1),
            nn.Softplus(),
        )

    def forward(self, input: th.Tensor):
        input     = self.layers(input)
        input_sum = th.sum(input, dim=(2,3), keepdim=True)
        input_normalized = input / th.mainputimum(input_sum, 1e-8 * th.ones_like(input_sum))

        x = th.sum(input * self.grid_x, dim=(2,3))
        y = th.sum(input * self.grid_y, dim=(2,3))

        return th.cat((x, y), dim=1)

class PixelToSTD(nn.Module):
    def __init__(self, channels):
        super(PixelToSTD, self).__init__()

        self.layers = nn.Sequential(
            ResidualBottleneck(channels, 1),
            nn.ReLU(),
        )

    def forward(self, x: th.Tensor):
        return (reduce(self.layers(x), 'b c h w -> b c', 'mean') + 1e-8)**0.5

class PixelToDepth(nn.Module):
    def __init__(self, size, channels):
        super(PixelToDepth, self).__init__()
        self.gaus3d = Gaus3D(size)

        self.base_std_z = nn.Parameter(th.zeros(1, out_channels, 1, 1))
        self.residual_std_z = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(channels, channels*4),
            nn.SiLU(),
            nn.Linear(channels*4, 1),
        )
        self.alpha_std = nn.Parameter(th.zeros(1, out_channels, 1, 1)+1e-16)

        self.residual_z = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(channels, channels*4),
            nn.SiLU(),
            nn.Linear(channels*4, 1),
        )
        self.alpha_z = nn.Parameter(th.zeros(1, out_channels, 1, 1)+1e-16)

    def forward(self, feature_maps, depth_map, position):
        std_z    = th.sigmoid(self.residual_std_z(feature_maps) * self.alpha_std + self.base_std_z)
        position = th.cat((position, std_z), dim=1)

        mask     = self.gaus3d(position, depth_map)
        mask_sum = th.sum(mask, dim=(2,3), keepdim=True)
        mask     = mask / th.maximum(mask_sum, 1e-8 * th.ones_like(mask_sum))

        base_z     = th.sum(mask * depth_map, dim=(2,3))
        residual_z = self.residual_z(feature_maps) * self.alpha_z

        return base_z + residual_z

class ProposalDownscale(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor):
        super(ProposalDownscale, self).__init__()

        self.downscale     = ResidualPatchDownscale(in_channels, out_channels, scale_factor)
        self.depth_pooling = LearnedPooling(scale_factor, out_channels)

        self.to_xy  = PixelToPosition(size, out_channels)
        self.to_z   = PixelToDepth(size, out_channels)
        self.to_std = PixelToSTD(out_channels)

        self.position_attention = PositionAttention(out_channels)
        self.possition_softmax  = PositionSoftmax()

    def forward(self, x, depth, position):
        x = self.downscale(x)
        depth = self.depth_pooling(x, depth)

        xy  = self.to_xy(x)
        z   = self.to_z(x, depth, position)
        std = self.to_std(x)

        position = th.cat((xy, z, std), dim=1)

        x = self.position_attention(x, position)
        x = self.possition_softmax(x, position)

        return x, depth, position

class ProposalLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ProposalDownscale, self).__init__()

        self.layer  = ConvNeXtBlock(in_channels, out_channels)
        self.to_xy  = PixelToPosition(size, out_channels)
        self.to_z   = PixelToDepth(size, out_channels)
        self.to_std = PixelToSTD(out_channels)

        self.position_attention = PositionAttention(out_channels)
        self.possition_softmax  = PositionSoftmax()

    def forward(self, x, depth, position):
        x   = self.layer(x)

        xy  = self.to_xy(x)
        z   = self.to_z(x, depth, position)
        std = self.to_std(x)

        position = th.cat((xy, z, std), dim=1)

        x = self.position_attention(x, position)
        x = self.possition_softmax(x, position)

        return x, depth, position
