"""PyTorch layer for extracting image features for the film_net interpolator.

The feature extractor implemented here converts an image pyramid into a pyramid
of deep features. The feature pyramid serves a similar purpose as U-Net
architecture's encoder, but we use a special cascaded architecture described in
Multi-view Image Fusion [1].

For comprehensiveness, below is a short description of the idea. While the
description is a bit involved, the cascaded feature pyramid can be used just
like any image feature pyramid.

Why cascaded architeture?
=========================
To understand the concept it is worth reviewing a traditional feature pyramid
first: *A traditional feature pyramid* as in U-net or in many optical flow
networks is built by alternating between convolutions and pooling, starting
from the input image.

It is well known that early features of such architecture correspond to low
level concepts such as edges in the image whereas later layers extract
semantically higher level concepts such as object classes etc. In other words,
the meaning of the filters in each resolution level is different. For problems
such as semantic segmentation and many others this is a desirable property.

However, the asymmetric features preclude sharing weights across resolution
levels in the feature extractor itself and in any subsequent neural networks
that follow. This can be a downside, since optical flow prediction, for
instance is symmetric across resolution levels. The cascaded feature
architecture addresses this shortcoming.

How is it built?
================
The *cascaded* feature pyramid contains feature vectors that have constant
length and meaning on each resolution level, except few of the finest ones. The
advantage of this is that the subsequent optical flow layer can learn
synergically from many resolutions. This means that coarse level prediction can
benefit from finer resolution training examples, which can be useful with
moderately sized datasets to avoid overfitting.

The cascaded feature pyramid is built by extracting shallower subtree pyramids,
each one of them similar to the traditional architecture. Each subtree
pyramid S_i is extracted starting from each resolution level:

image resolution 0 -> S_0
image resolution 1 -> S_1
image resolution 2 -> S_2
...

If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid
is constructed by concatenating features as follows (assuming subtree depth=3):

lvl
feat_0 = concat(                               S_0_0 )
feat_1 = concat(                         S_1_0 S_0_1 )
feat_2 = concat(                   S_2_0 S_1_1 S_0_2 )
feat_3 = concat(             S_3_0 S_2_1 S_1_2       )
feat_4 = concat(       S_4_0 S_3_1 S_2_2             )
feat_5 = concat( S_5_0 S_4_1 S_3_2                   )
   ....

In above, all levels except feat_0 and feat_1 have the same number of features
with similar semantic meaning. This enables training a single optical flow
predictor module shared by levels 2,3,4,5... . For more details and evaluation
see [1].

[1] Multi-view Image Fusion, Trinidad et al. 2019
"""
from typing import List

import torch
from torch import nn
from torch.nn import functional as F

from util import Conv2d


class SubTreeExtractor(nn.Module):
    """Extracts a hierarchical set of features from an image.

    This is a conventional, hierarchical image feature extractor, that extracts
    [k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels.
    Each level is followed by average pooling.
    """

    def __init__(self, in_channels=3, channels=64, n_layers=4):
        super().__init__()
        convs = []
        for i in range(n_layers):
            convs.append(nn.Sequential(
                Conv2d(in_channels, (channels << i), 3),
                Conv2d((channels << i), (channels << i), 3)
            ))
            in_channels = channels << i
        self.convs = nn.ModuleList(convs)

    def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]:
        """Extracts a pyramid of features from the image.

        Args:
          image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS.
          n: number of pyramid levels to extract. This can be less or equal to
           options.sub_levels given in the __init__.
        Returns:
          The pyramid of features, starting from the finest level. Each element
          contains the output after the last convolution on the corresponding
          pyramid level.
        """
        head = image
        pyramid = []
        for i, layer in enumerate(self.convs):
            head = layer(head)
            pyramid.append(head)
            if i < n - 1:
                head = F.avg_pool2d(head, kernel_size=2, stride=2)
        return pyramid


class FeatureExtractor(nn.Module):
    """Extracts features from an image pyramid using a cascaded architecture.
    """

    def __init__(self, in_channels=3, channels=64, sub_levels=4):
        super().__init__()
        self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels)
        self.sub_levels = sub_levels

    def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
        """Extracts a cascaded feature pyramid.

        Args:
          image_pyramid: Image pyramid as a list, starting from the finest level.
        Returns:
          A pyramid of cascaded features.
        """
        sub_pyramids: List[List[torch.Tensor]] = []
        for i in range(len(image_pyramid)):
            # At each level of the image pyramid, creates a sub_pyramid of features
            # with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor.
            # We use the same instance since we want to share the weights.
            #
            # However, we cap the depth of the sub_pyramid so we don't create features
            # that are beyond the coarsest level of the cascaded feature pyramid we
            # want to generate.
            capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels)
            sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels))
        # Below we generate the cascades of features on each level of the feature
        # pyramid. Assuming sub_levels=3, The layout of the features will be
        # as shown in the example on file documentation above.
        feature_pyramid: List[torch.Tensor] = []
        for i in range(len(image_pyramid)):
            features = sub_pyramids[i][0]
            for j in range(1, self.sub_levels):
                if j <= i:
                    features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
            feature_pyramid.append(features)
        return feature_pyramid
