"""Module containing functionality to build different parts of the model adapted from https://github.com/bwittmann/transoar."""

from organ_detr.models.matcher import HungarianMatcher
from organ_detr.models.criterion import OrganDetrCriterion
from organ_detr.models.backbones.attn_fpn.attn_fpn import AttnFPN
from organ_detr.models.necks.def_detr_transformer import DeformableTransformer
from organ_detr.models.position_encoding import PositionEmbeddingSine3D, PositionEmbeddingLearned3D

from organ_detr.models.backbones.resnet3d import ResNet3D
from organ_detr.models.backbones.fpn import FPN
from organ_detr.models.backbones.swin_unetr import Swin_UNETR

def build_backbone(config):
    if config['name'].lower() in ['attn_fpn']:
        return AttnFPN(config)
    elif config['name'].lower() in ['fpn']:
        return FPN(config)
    elif config['name'].lower() in ['resnet']:
        return ResNet3D(config)
    elif config['name'].lower() in ['swin_unetr']:
        return Swin_UNETR(config)


def build_neck(config):
    model = DeformableTransformer(
        d_model=config['hidden_dim'],
        nhead=config['nheads'],
        num_encoder_layers=config['enc_layers'],
        num_decoder_layers=config['dec_layers'],
        dim_feedforward=config['dim_feedforward'],
        dropout=config['dropout'],
        activation="relu",
        return_intermediate_dec=True,
        dec_n_points=config['dec_n_points'],
        enc_n_points=config['enc_n_points'],
        use_cuda=config['use_cuda'],
        use_encoder=config['use_encoder'],
        num_feature_levels=config['num_feature_levels'],
        use_dab=config.get('use_dab', False),
        two_stage=config.get('two_stage', False),
        two_stage_num_proposals=config['num_queries'],
        num_classes=config['num_classes'],
        dn=config.get('dn', {}).get('enabled', False),
    ) 

    return model

def build_criterion(config):
    qs = config.get('class_matching_query_split', [])
    if qs != [] and config.get('class_matching', False):
        assert sum(qs) == config['neck']['num_queries'], "query split doesn't match num_queries"

    if config.get('hybrid_dense_matching', False):
        hybrid_dense_matcher = HungarianMatcher(
            cost_class=config['set_cost_class'],
            cost_bbox=config['set_cost_bbox'],
            cost_giou=config['set_cost_giou'],
            dense_q_matching=True,
            dense_q_matching_lambda=config.get('hybrid_dense_matching_lambda', 0.1)
        )
        hybrid_dense_criterion = OrganDetCriterion(
            num_classes=config['neck']['num_classes'],
            matcher=hybrid_dense_matcher,
            seg_proxy=config['backbone']['use_seg_proxy_loss'] and not config['backbone'].get('use_msa_seg_loss', False),
            seg_fg_bg=config['backbone']['fg_bg'],
            seg_msa=config['backbone'].get('use_msa_seg_loss', False),
            focal_loss=config.get('focal_loss', False),
        )

    matcher = HungarianMatcher(
        cost_class=config['set_cost_class'],
        cost_bbox=config['set_cost_bbox'],
        cost_giou=config['set_cost_giou'],
        dense_q_matching=config.get('dense_q_matching', False),
        dense_q_matching_lambda=config.get('dense_q_matching_lambda', 0.1)
    )

    criterion = OrganDetrCriterion(
        num_classes=config['neck']['num_classes'],
        matcher=matcher,
        seg_proxy=config['backbone']['use_seg_proxy_loss'] and not config['backbone'].get('use_msa_seg_loss', False),
        seg_fg_bg=config['backbone']['fg_bg'],
        seg_msa=config['backbone'].get('use_msa_seg_loss', False),
        focal_loss=config.get('focal_loss', False),
    )
    if config.get('hybrid_dense_matching', False):
        return criterion, hybrid_dense_criterion
    else:
        return criterion

def build_pos_enc(config):
    channels = config['hidden_dim']
    if config['pos_encoding'] == 'sine':
        return PositionEmbeddingSine3D(channels=channels)
    elif config['pos_encoding'] == 'learned':
        return PositionEmbeddingLearned3D(channels=channels)
    else:
        raise ValueError('Please select a implemented pos. encoding.')
