"""Tools to compute and manipulate correlation maps.
"""

import math
import torch

from decorators import to_minibatch


@to_minibatch(expected_outputs=1)
def compute_correlation_maps(
    source_descriptors: torch.Tensor,
    target_features: torch.Tensor,
):
    """Compute dense correspondence maps.
    Args:
        * source_descriptors: The interpolated source keypoint descriptors.
        * target_features: The dense feature map of the target image.
    Returns:
        * correlation_map: The dense correlation map.
    """
    correlation_map = correlate(source_descriptors, target_features)
    return correlation_map


def cardinal_omega(features: torch.Tensor):
    """Compute Card(Omega) based on the target image feature size."""
    return features.shape[-1] * features.shape[-2] + 1


def relative_aspect_ratio(input_tensor: torch.Tensor, output_tensor: torch.Tensor):
    """Compute the relative aspect ratio between two tensors.
    Args:
        * input_tensor: The [..., Hi, Wi] input tensor.
        * output_tensor: The [..., Hj, Wj] output tensor.
    Returns:
        * ratio: The [2 x 1] relative aspect ratio.
    """
    i_height, i_width = input_tensor.shape[-2:]
    o_height, o_width = output_tensor.shape[-2:]
    return torch.tensor(
        [
            [
                float(o_width) / float(i_width),
                float(o_height) / float(i_height),
            ]
        ],
        dtype=torch.float32,
    )


def downsample_keypoints(
    keypoints: torch.Tensor, image: torch.Tensor, features: torch.Tensor
):
    """Downsample keypoints from image-space to feature-space.
    Args:
        * keypoints: The [N x 2] keypoints in image space, in the (x, y)
            order where x points right and y points down.
        * image: The image tensor.
        * features: The feature tensor.
    """
    return keypoints.float() * relative_aspect_ratio(image, features).to(
        keypoints.device
    )


def interpolate(dense_descriptors: torch.Tensor, keypoints: torch.Tensor):
    """Bilinearly interpolate sparse descriptors.
    Args:
        * dense_descriptors: The dense descriptor maps, of size [1 x C x H x W]
        * keypoints: The [N x 2] keypoint coordinates,
    Returns:
        * sparse_descriptors: The [N x C] sparse descriptors.
    """
    batch, channels, height, width = dense_descriptors.shape
    assert batch == 1
    scale = torch.tensor([width - 1, height - 1]).to(keypoints)
    keypoints = (keypoints / scale) * 2 - 1
    keypoints = keypoints.clamp(min=-2, max=2)
    sparse_descriptors = torch.nn.functional.grid_sample(
        dense_descriptors, keypoints[None, :, None], mode="bilinear", align_corners=True
    )
    return sparse_descriptors.view(channels, -1).transpose(-1, -2)


@torch.jit.script
def correlate(sparse_descriptors: torch.Tensor, dense_descriptors: torch.Tensor):
    """Compute dense correlation maps for every sparse descriptor.
    Args:
        * sparse_descriptors: A tensor of size [N x C]
        * dense_descriptors: A tensor of size [C x H x W]
    Returns:
        * correlation_map: A dense, unnormalized similarity map of size [N x H x W]
    """
    channels, height, width = dense_descriptors.shape[-3:]
    correlation_map = sparse_descriptors @ dense_descriptors.reshape(channels, -1)
    return correlation_map.reshape(-1, height, width).contiguous()


@torch.jit.script
def softmax(correlation_maps: torch.Tensor):
    """Applies a 2D spatial softmax operation on the dense correlation maps.
    Args:
        * correlation_maps: The batch of the dense correlation maps, of size [N x H x W]
    Returns:
        * The correspondence maps.
    """
    batch, height, width = correlation_maps.shape
    return correlation_maps.view(batch, -1).softmax(dim=1).view(batch, height, width)
