from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_

from .SwinTransformer import SwinTransformer
from .VisionTransformer import VisionTransformer
from .ResNet import resnet18
def build_supervised(config):
    if config.model.type == 'swin':
        encoder = SwinTransformer(
            img_size=config.dataset.image_size,#config.DATA.IMG_SIZE,
            patch_size=config.model.swin.patch_size,#config.MODEL.SWIN.PATCH_SIZE,
            in_chans=config.model.swin.in_chans, #MODEL.SWIN.IN_CHANS,
            num_classes=config.dataset.num_total_classes,
            embed_dim=config.model.swin.embed_dim,
            depths=config.model.swin.depths,
            num_heads=config.model.swin.num_heads,
            window_size=config.model.swin.window_size,
            mlp_ratio=config.model.swin.mlp_ratio,
            qkv_bias=config.model.swin.qkv_bias,
            qk_scale=config.model.swin.qk_scale,
            drop_rate=config.model.swin.drop_rate,
            drop_path_rate=config.model.swin.drop_path_rate,
            ape=config.model.swin.ape,
            patch_norm=config.model.swin.patch_norm,
            use_checkpoint=config.train.use_checkpoint)
        encoder_stride = 32
    elif config.model.type == 'vit':
        encoder = VisionTransformer(
            img_size=config.dataset.image_size,
            patch_size=config.model.vit.patch_size,
            in_chans=config.model.vit.in_chans,
            num_classes=config.dataset.num_total_classes,
            embed_dim=config.model.vit.embed_dim,
            depth=config.model.vit.depth,
            num_heads=config.model.vit.num_heads,
            mlp_ratio=config.model.vit.mlp_ratio,
            qkv_bias=config.model.vit.qkv_bias,
            drop_rate=config.model.vit.drop_rate,
            drop_path_rate=config.model.vit.drop_path_rate,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            init_values=config.model.vit.init_values,
            use_abs_pos_emb=config.model.vit.use_ape,
            use_rel_pos_bias=config.model.vit.use_rpb,
            use_shared_rel_pos_bias=config.model.vit.use_shared_rpb,
            use_mean_pooling=config.model.vit.use_mean_pooling)
        encoder_stride = 16
    elif config.model.type == 'resnet18':
        encoder = resnet18(num_classes=config.dataset.num_total_classes)
    else:
        raise NotImplementedError(f"Unknown pre-train model: {config.model_type}")
    return encoder
