"""NeurHal model.
"""
import logging
from typing import Dict

import pytorch_lightning as pl
import torch

from build import build_attention, build_model, build_positional_encoding
from math_tools import log
from nre import (
    cardinal_omega,
    compute_correlation_maps,
    downsample_keypoints,
    interpolate,
    relative_aspect_ratio,
    softmax,
)


class Net(pl.LightningModule):
    """The NeurHal architecture."""

    def __init__(self, **kwargs: Dict):
        """Initialize the network."""
        super(Net, self).__init__()
        self.save_hyperparameters()
        self.cnn, self.descriptor_dimension = build_model(self.hparams)
        self.self_attention = build_attention(
            self.descriptor_dimension, self.hparams, "self"
        )
        self.cross_attention = build_attention(
            self.descriptor_dimension, self.hparams, "cross"
        )
        self.positional_encoding = build_positional_encoding(
            self.descriptor_dimension, self.hparams
        )
        if hasattr(self.hparams, "outpainting") and self.hparams.outpainting > 0.0:
            self.register_parameter(
                "padding_vector",
                torch.nn.Parameter(torch.randn((1, self.descriptor_dimension, 1, 1))),
            )
        logging.info(
            f"Initialized {self.hparams.level} model "
            f"with output dimension {self.descriptor_dimension}"
        )

    def compute_cmaps(
        self,
        data: Dict,
    ):
        """Compute dense correspondence maps.
        Args:
            * data: Dictionary containing input data.
        Returns:
            * log_cmaps: The [N x H x W] log of correspondence maps tensor.
            * target_im2feat: The target image-to-feature downsampling ratio.
            * cardinal_omega: Card(Omega) of the maps.
            * padding: The padding values in target feature space.
        """
        # Compute dense feature maps
        (
            dense_source_features,
            dense_target_features,
            source_keypoints,
        ) = self.compute_dense_feature(data)

        # Compute the target feature-to-image upsampling ratio
        target_im2feat = relative_aspect_ratio(
            data["target_image_tensor"], dense_target_features
        )

        # (Optional) Pad target features for outpainting
        dense_target_features, padding = self.pad(dense_target_features)

        #  Apply attention
        sparse_source_features, dense_target_features = self.sparse_to_dense_attention(
            dense_source_features, dense_target_features, source_keypoints
        )

        # Compute dense correlation maps
        correlation_maps = compute_correlation_maps(
            source_descriptors=sparse_source_features,
            target_features=dense_target_features,
            num_points=len(source_keypoints),
            minibatch_size=self.hparams.minibatch_size,
        )

        # Normalize to obtain correspondence maps
        correspondence_maps = softmax(correlation_maps)
        card_omega = cardinal_omega(dense_target_features)

        # Return the log of the correspondence maps
        log_cmaps = log(correspondence_maps).squeeze(0)

        output = {
            "log_cmaps": log_cmaps,
            "target_im2feat": target_im2feat,
            "cardinal_omega": card_omega,
            "padding": padding,
        }
        return output

    def pad(self, feature_map: torch.Tensor):
        """Pad feature map to perform outpainting.
        Args:
            * feature_map: The [B, C, H, W] feature map to pad.
        Returns:
            * feature_map: The padded feature map.
            * pad: The padding values, in feature map space.
        """
        if not hasattr(self.hparams, "outpainting"):
            return feature_map, [0, 0, 0, 0]
        p = self.hparams.outpainting
        assert p >= 0.0 and p <= 1.0, "invalid outpainting value"
        if p == 0.0:
            return feature_map, [0, 0, 0, 0]
        h, w = feature_map.shape[-2:]
        pad = [
            max(1, int(w * p)),
            max(1, int(w * p)),
            max(1, int(h * p)),
            max(1, int(h * p)),
        ]
        padded_feature_map = self.padding_vector.repeat(
            (1, 1, h + pad[2] + pad[3], w + pad[0] + pad[1])
        )
        padded_feature_map[:, :, pad[2] : -pad[3], pad[0] : -pad[1]] = feature_map
        return padded_feature_map, pad

    def compute_dense_feature(self, data: Dict):
        """Compute dense feature maps."""
        # Compute dense CNN features
        dense_source_features = self.cnn(data["source_image_tensor"])
        dense_target_features = self.cnn(data["target_image_tensor"])

        # Divide by descriptor dimension for stability
        dense_source_features /= self.descriptor_dimension
        dense_target_features /= self.descriptor_dimension

        # Compute source keypoints in feature coordinate space
        source_keypoints = downsample_keypoints(
            keypoints=data["source_keypoints"].squeeze_(0),
            image=data["source_image_tensor"],
            features=dense_source_features,
        )
        return dense_source_features, dense_target_features, source_keypoints

    def sparse_to_dense_attention(
        self,
        dense_source_features: torch.Tensor,
        dense_target_features: torch.Tensor,
        source_keypoints: torch.Tensor,
    ):
        """Apply sparse-to-dense self and cross-attention."""

        # Apply 2D positional encoding
        pe = self.positional_encoding
        dense_source_features = pe(dense_source_features)
        dense_target_features = pe(dense_target_features)

        # Interpolate sparse source descriptors
        sparse_source_features = interpolate(dense_source_features, source_keypoints)

        # Unpack self + cross attention alternating layers
        idx_self, idx_cross = 0, 0
        n_self, n_cross = len(self.self_attention), len(self.cross_attention)
        while idx_self + idx_cross < n_self + n_cross:

            # Apply self-attention
            if idx_self < n_self:
                # Self (sparse_source, dense_source) -> sparse_source [sparse-to-dense]
                sparse_source_features = self.self_attention[idx_self](
                    sparse_source_features, dense_source_features
                )
                # Self (dense_source, sparse_source) -> dense_target [dense-to-sparse]
                dense_target_features = self.self_attention[idx_self + 1](
                    dense_target_features, dense_target_features
                )
                idx_self += 2

            # Apply cross-attention
            if idx_cross < n_cross:
                # Cross (sparse_source, dense_target) -> sparse_source [sparse-to-dense]
                sparse_source_features = self.cross_attention[idx_cross](
                    sparse_source_features, dense_target_features
                )
                # Cross (dense_target, sparse_source) -> dense_target [dense-to-sparse]
                dense_target_features = self.cross_attention[idx_cross + 1](
                    dense_target_features, sparse_source_features
                )
                idx_cross += 2

        return sparse_source_features, dense_target_features

    def forward(
        self,
        data: Dict,
        target_device: torch.device = "cpu",
    ):
        """Compute dense NRE loss maps.
        Args:
            * data: Input dictionary containing the source and target image tensors,
                as well as the source 2D keypoint coordinates (in image space).
            * target_device: The target device to move the dense NRE loss maps to.
        """
        output = self.compute_cmaps(data)
        output["log_cmaps"].neg_().clamp_(max=log(output["cardinal_omega"]))
        output["log_cmaps"] = output["log_cmaps"].to(target_device)
        return output
