from torch import nn
from typing import List

from utils.utils import NestedTensor
from models.position_encoding import build_position_encoding


class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)
        out: List = []
        pos = []
        for name, x in xs.items():
            out.append(x)
            # position encoding
            pos.append(self[1](x).to(x.tensors.dtype))

        return out, pos


def build_backbone(args):
    train_backbone = args.lr_backbone > 0
    return_interm_layers = args.return_interm

    if "resnet" in args.backbone_arch:
        from models.resnet import resnet_model

        backbone = resnet_model(
            args.backbone_arch, train_backbone, return_interm_layers, args.dilation
        )
    elif args.backbone_arch == "dinov2":
        from models.dino import dino_model

        backbone = dino_model(-1 * args.enc_output_layer, return_interm_layers)

    elif args.backbone_arch == "radio":
        from models.radio import radio_model

        backbone = radio_model(args.enc_output_layer)

    elif args.backbone_arch == "eradio":
        from models.radio import eradio_model

        backbone = eradio_model(args.enc_output_layer)

    elif args.backbone_arch == "radio-h":
        from models.radio import radio_model_h

        backbone = radio_model_h(args.enc_output_layer)

    elif args.backbone_arch == "dinov2_q":
        from models.dino import dino_model_with_hooks

        backbone = dino_model_with_hooks(
            -1 * args.enc_output_layer, return_interm_layers
        )
    elif args.backbone_arch == "dinov2_q_large":
        from models.dino import dino_model_with_hooks_large

        backbone = dino_model_with_hooks_large(
            -1 * args.enc_output_layer, return_interm_layers
        )
    elif args.backbone_arch == "clip":
        from models.clip import clip_model

        backbone = clip_model(-1 * args.enc_output_layer, return_interm_layers)
        num_channels = backbone.num_channels
        position_embedding = build_position_encoding(
            args.position_embedding, args.hidden_dim // 2
        )
        model = Joiner(backbone, position_embedding)
        model.num_channels = num_channels
        return model

    num_channels = backbone.num_channels

    position_embedding = build_position_encoding(
        args.position_embedding, args.hidden_dim // 2
    )
    model = Joiner(backbone, position_embedding)
    model.num_channels = num_channels

    return model
