"""The final fusion stage for the film_net frame interpolator.

The inputs to this module are the warped input images, image features and
flow fields, all aligned to the target frame (often midway point between the
two original inputs). The output is the final image. FILM has no explicit
occlusion handling -- instead using the abovementioned information this module
automatically decides how to best blend the inputs together to produce content
in areas where the pixels can only be borrowed from one of the inputs.

Similarly, this module also decides on how much to blend in each input in case
of fractional timestep that is not at the halfway point. For example, if the two
inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1,
it often makes most sense to favor the first input. However, this is not
always the case -- in particular in occluded pixels.

The architecture of the Fusion module follows U-net [1] architecture's decoder
side, e.g. each pyramid level consists of concatenation with upsampled coarser
level output, and two 3x3 convolutions.

The upsampling is implemented as 'resize convolution', e.g. nearest neighbor
upsampling followed by 2x2 convolution as explained in [2]. The classic U-net
uses max-pooling which has a tendency to create checkerboard artifacts.

[1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image
    Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf
[2] https://distill.pub/2016/deconv-checkerboard/
"""
from typing import List

import torch
from torch import nn
from torch.nn import functional as F

from util import Conv2d

_NUMBER_OF_COLOR_CHANNELS = 3


def get_channels_at_level(level, filters):
    n_images = 2
    channels = _NUMBER_OF_COLOR_CHANNELS
    flows = 2

    return (sum(filters << i for i in range(level)) + channels + flows) * n_images


class Fusion(nn.Module):
    """The decoder."""

    def __init__(self, n_layers=4, specialized_layers=3, filters=64):
        """
        Args:
            m: specialized levels
        """
        super().__init__()

        # The final convolution that outputs RGB:
        self.output_conv = nn.Conv2d(filters, 3, kernel_size=1)

        # Each item 'convs[i]' will contain the list of convolutions to be applied
        # for pyramid level 'i'.
        self.convs = nn.ModuleList()

        # Create the convolutions. Roughly following the feature extractor, we
        # double the number of filters when the resolution halves, but only up to
        # the specialized_levels, after which we use the same number of filters on
        # all levels.
        #
        # We create the convs in fine-to-coarse order, so that the array index
        # for the convs will correspond to our normal indexing (0=finest level).
        # in_channels: tuple = (128, 202, 256, 522, 512, 1162, 1930, 2442)

        in_channels = get_channels_at_level(n_layers, filters)
        increase = 0
        for i in range(n_layers)[::-1]:
            num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
            convs = nn.ModuleList([
                Conv2d(in_channels, num_filters, size=2, activation=None),
                Conv2d(in_channels + (increase or num_filters), num_filters, size=3),
                Conv2d(num_filters, num_filters, size=3)]
            )
            self.convs.append(convs)
            in_channels = num_filters
            increase = get_channels_at_level(i, filters) - num_filters // 2

    def forward(self, pyramid: List[torch.Tensor]) -> torch.Tensor:
        """Runs the fusion module.

        Args:
          pyramid: The input feature pyramid as list of tensors. Each tensor being
            in (B x H x W x C) format, with finest level tensor first.

        Returns:
          A batch of RGB images.
        Raises:
          ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in
            the constructor.
        """

        # As a slight difference to a conventional decoder (e.g. U-net), we don't
        # apply any extra convolutions to the coarsest level, but just pass it
        # to finer levels for concatenation. This choice has not been thoroughly
        # evaluated, but is motivated by the educated guess that the fusion part
        # probably does not need large spatial context, because at this point the
        # features are spatially aligned by the preceding warp.
        net = pyramid[-1]

        # Loop starting from the 2nd coarsest level:
        # for i in reversed(range(0, len(pyramid) - 1)):
        for k, layers in enumerate(self.convs):
            i = len(self.convs) - 1 - k
            # Resize the tensor from coarser level to match for concatenation.
            level_size = pyramid[i].shape[2:4]
            net = F.interpolate(net, size=level_size, mode='nearest')
            net = layers[0](net)
            net = torch.cat([pyramid[i], net], dim=1)
            net = layers[1](net)
            net = layers[2](net)
        net = self.output_conv(net)
        return net
