from typing import Dict

import torch
import torch.nn as nn

from attention import TransformerBlock
from convolutional import CoarseModel
from positional import PositionalEncoding


def build_model(cfg: Dict):
    """Build the main model, consisting of a CNN + adap layers."""
    second_maxpool = (not hasattr(cfg, "no_maxpool")) or not cfg.no_maxpool
    layers, output_dimension = build_coarse_model(cfg.extraction_layer, second_maxpool)
    layers = layers + [
        nn.Conv2d(
            output_dimension,
            output_dimension // 2,
            kernel_size=1,
            bias=True,
        ),
    ]
    model = nn.Sequential(*layers)
    return model, output_dimension // 2


def build_attention(channel_dim: int, config: Dict, mode: str, kernel: str = "dot"):
    """Build the self-attention layers."""
    assert mode in ["self", "cross"], "Invalid mode."
    cfg = config.__getattr__(f"{mode}_attention")
    if cfg:
        block = TransformerBlock
        layers = []
        for i in range(cfg["num_layers"]):
            layers.append(block(channel_dim, cfg, "sparse_to_dense", kernel, i == 0))
            if mode == "cross":
                layers.append(block(channel_dim, cfg, "dense_to_sparse", kernel))
            else:
                layers.append(block(channel_dim, cfg, "dense_to_dense", kernel))
        return nn.Sequential(*layers)
    return nn.Identity()


def build_positional_encoding(channel_dim: int, config: Dict):
    """Build the positional encoding module."""
    if config.cross_attention or config.self_attention:
        return PositionalEncoding(channel_dim)
    return None


def build_coarse_model(extraction_layer: str, second_maxpool: bool = True):
    """Assemble the coarse model up to the provided extraction layer.
    Args:
        * extraction_layer: The name of the feature extraction layer.
    Returns:
        * truncated_model: The list of the coarse model layers.
        * output_dim: The output features channel size.
    """
    base_model = CoarseModel()
    truncated_model = []
    convolutional_layers = [
        "Conv2d_1a_3x3",
        "Conv2d_2a_3x3",
        "Conv2d_2b_3x3",
        "Conv2d_3b_1x1",
        "Conv2d_4a_3x3",
        "Mixed_5b",
        "Mixed_5c",
        "Mixed_5d",
        "Mixed_6a",
    ]
    for layer in convolutional_layers:
        truncated_model.append(base_model.__getattr__(layer))
        if layer == "Conv2d_2b_3x3" or (layer == "Conv2d_4a_3x3" and second_maxpool):
            truncated_model.append(
                torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            )
        if layer == extraction_layer:
            break
    return truncated_model, 768
