#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the CC-BY-NC license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch.nn as nn


def create_compression_layer(
    embed_dim, final_spatial, after_compression_flat_size=2048,
    num_downsample_layers=0,
):
    down_sample_spatial = np.ceil(final_spatial / 2**num_downsample_layers)
    num_compression_channels = int(
        round(after_compression_flat_size / (down_sample_spatial**2))
    )
    nets = list()
    for _ in range(num_downsample_layers):
        nets.extend([
                nn.Conv2d(
                    embed_dim,
                    embed_dim,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    bias=False,
                ),
                nn.GroupNorm(1, embed_dim),
                nn.ReLU(True),
            ]
        )
    nets.extend([
        nn.Conv2d(
            embed_dim,
            num_compression_channels,
            kernel_size=3,
            padding=1,
            bias=False,
        ),
        nn.GroupNorm(1, num_compression_channels),
        nn.ReLU(True),
        nn.Flatten(),
    ])
    
    compression = nn.Sequential(*nets)
    final_spatial = int(
        np.ceil(final_spatial / (2**num_downsample_layers))
    )

    output_shape = (
        num_compression_channels,
        final_spatial,
        final_spatial,
    )
    output_size = np.prod(output_shape)

    return compression, output_shape, output_size
