# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import yaml
from .swin import SwinTransformer
from .swin_adapters import SwinTransformerWithAdapters

from .swin_PHM import SwinTransformerPHM
from .swin_compacter import SwinTransformerCompacter
from .swin_LowRank import SwinTransformerLowRank
from .swin_LoRa import SwinTransformerLoRa

from .vit_adapters import  ViTwithAdapters
from .cvt import get_cls_model
from .cvt_adapters import get_cls_model_adapters
from .t2t import T2t_vit_14
from .t2t_adapters import T2t_vit_14_Adapters

from .resnet import ResNet50
from .vit import ViT
from .residual_resnet import ResidualResNet

import os

def build_model(config):
    model_type = config.MODEL.TYPE
    if model_type == 'swin':
        model = SwinTransformer(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.SWIN.PATCH_SIZE,
            in_chans=config.MODEL.SWIN.IN_CHANS,
            num_classes=config.MODEL.NUM_CLASSES,
            embed_dim=config.MODEL.SWIN.EMBED_DIM,
            depths=config.MODEL.SWIN.DEPTHS,
            num_heads=config.MODEL.SWIN.NUM_HEADS,
            window_size=config.MODEL.SWIN.WINDOW_SIZE,
            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
            qkv_bias=config.MODEL.SWIN.QKV_BIAS,
            qk_scale=config.MODEL.SWIN.QK_SCALE,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            ape=config.MODEL.SWIN.APE,
            rpe=config.MODEL.SWIN.RPE,
            patch_norm=config.MODEL.SWIN.PATCH_NORM,
            use_checkpoint=config.TRAIN.USE_CHECKPOINT,
            use_drloc=config.TRAIN.USE_DRLOC,
            sample_size=config.TRAIN.SAMPLE_SIZE,
            use_multiscale=config.TRAIN.USE_MULTISCALE,
            drloc_mode=False,
            use_abs=False,
            adapter_state=config.TRAIN.USE_ADAPTERS,
            hidden_size=config.TRAIN.SIZE_ADAPTERS)
    elif model_type == "cvt":
        with open(r'configs/cvt_13_224.yaml') as file:
            config_cvt = yaml.load(file, Loader=yaml.FullLoader)
            model = get_cls_model(config, config_cvt["MODEL"]["CVT"])
    elif model_type == "cvt_adapters":
        with open(r'configs/cvt_13_224.yaml') as file:
            config_cvt = yaml.load(file, Loader=yaml.FullLoader)
            model = get_cls_model_adapters(config, config_cvt["MODEL"]["CVT"])
    elif model_type == "t2t":
        model = T2t_vit_14(
            img_size=config.DATA.IMG_SIZE,
            num_classes=config.MODEL.NUM_CLASSES,
            use_drloc=config.TRAIN.USE_DRLOC,
            sample_size=config.TRAIN.SAMPLE_SIZE,
            drloc_mode=False,
            use_abs=False,
        )
    elif model_type == "t2t_adapters":
        model = T2t_vit_14_Adapters(
            img_size=config.DATA.IMG_SIZE,
            num_classes=config.MODEL.NUM_CLASSES,
            use_drloc=config.TRAIN.USE_DRLOC,
            sample_size=config.TRAIN.SAMPLE_SIZE,
            drloc_mode=False,
            use_abs=False,
            type_adapter=config.TRAIN.TYPE_ADAPTERS,
            param_ratio=config.TRAIN.SIZE_ADAPTERS
        )
    elif model_type == 'resnet50':
        model = ResNet50(
            num_classes=config.MODEL.NUM_CLASSES,
            use_drloc=config.TRAIN.USE_DRLOC,
            sample_size=config.TRAIN.SAMPLE_SIZE,
            drloc_mode=False,
            use_abs=False,
            pretrained_bool = config.MODEL.FINETUNE
        )
    elif model_type == "vit":
        model = ViT(model_type="sup_vitb16_224", img_size=224, num_classes=config.MODEL.NUM_CLASSES)
    elif model_type == "vit_adapters":
        model = ViTwithAdapters(
            model_type="sup_vitb16_224", img_size=224, num_classes=config.MODEL.NUM_CLASSES,
            type_adapter=config.TRAIN.TYPE_ADAPTERS,
            adapters_size=config.TRAIN.SIZE_ADAPTERS
        )
    elif model_type == "residualResnet26":
        current = os.getcwd()
        pre_model_path = os.path.join(current, "pretrained", "resnet26-timm.pth")
        model = ResidualResNet(
            num_classes     = config.MODEL.NUM_CLASSES,
            use_drloc       = config.TRAIN.USE_DRLOC,
            pretrained_path = pre_model_path,
            drloc_mode=False,
            use_abs=False,
            pretrained_bool = config.MODEL.FINETUNE
        )
    elif model_type == "swin_adapters":
        model = SwinTransformerWithAdapters(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.SWIN.PATCH_SIZE,
            in_chans=config.MODEL.SWIN.IN_CHANS,
            num_classes=config.MODEL.NUM_CLASSES,
            embed_dim=config.MODEL.SWIN.EMBED_DIM,
            depths=config.MODEL.SWIN.DEPTHS,
            num_heads=config.MODEL.SWIN.NUM_HEADS,
            window_size=config.MODEL.SWIN.WINDOW_SIZE,
            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
            qkv_bias=config.MODEL.SWIN.QKV_BIAS,
            qk_scale=config.MODEL.SWIN.QK_SCALE,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            ape=config.MODEL.SWIN.APE,
            rpe=config.MODEL.SWIN.RPE,
            patch_norm=config.MODEL.SWIN.PATCH_NORM,
            use_checkpoint=config.TRAIN.USE_CHECKPOINT,
            use_drloc=config.TRAIN.USE_DRLOC,
            sample_size=config.TRAIN.SAMPLE_SIZE,
            use_multiscale=config.TRAIN.USE_MULTISCALE,
            drloc_mode=False,
            use_abs=False,
            type_adapters=config.TRAIN.TYPE_ADAPTERS,
            ratio_param=config.TRAIN.SIZE_ADAPTERS,
            )
    elif model_type == "swin_compacter":
        model = SwinTransformerCompacter(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.SWIN.PATCH_SIZE,
            in_chans=config.MODEL.SWIN.IN_CHANS,
            num_classes=config.MODEL.NUM_CLASSES,
            embed_dim=config.MODEL.SWIN.EMBED_DIM,
            depths=config.MODEL.SWIN.DEPTHS,
            num_heads=config.MODEL.SWIN.NUM_HEADS,
            window_size=config.MODEL.SWIN.WINDOW_SIZE,
            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
            qkv_bias=config.MODEL.SWIN.QKV_BIAS,
            qk_scale=config.MODEL.SWIN.QK_SCALE,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            ape=config.MODEL.SWIN.APE,
            rpe=config.MODEL.SWIN.RPE,
            patch_norm=config.MODEL.SWIN.PATCH_NORM,
            use_checkpoint=config.TRAIN.USE_CHECKPOINT,
            use_drloc=config.TRAIN.USE_DRLOC,
            sample_size=config.TRAIN.SAMPLE_SIZE,
            use_multiscale=config.TRAIN.USE_MULTISCALE,
            drloc_mode=False,
            use_abs=False,
            type_adapters=config.TRAIN.TYPE_ADAPTERS,
            adapter_state=config.TRAIN.USE_ADAPTERS,
            ratio_param=config.TRAIN.SIZE_ADAPTERS,
            )
    elif model_type == "swin_PHM":
        model = SwinTransformerPHM(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.SWIN.PATCH_SIZE,
            in_chans=config.MODEL.SWIN.IN_CHANS,
            num_classes=config.MODEL.NUM_CLASSES,
            embed_dim=config.MODEL.SWIN.EMBED_DIM,
            depths=config.MODEL.SWIN.DEPTHS,
            num_heads=config.MODEL.SWIN.NUM_HEADS,
            window_size=config.MODEL.SWIN.WINDOW_SIZE,
            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
            qkv_bias=config.MODEL.SWIN.QKV_BIAS,
            qk_scale=config.MODEL.SWIN.QK_SCALE,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            ape=config.MODEL.SWIN.APE,
            rpe=config.MODEL.SWIN.RPE,
            patch_norm=config.MODEL.SWIN.PATCH_NORM,
            use_checkpoint=config.TRAIN.USE_CHECKPOINT,
            use_drloc=config.TRAIN.USE_DRLOC,
            sample_size=config.TRAIN.SAMPLE_SIZE,
            use_multiscale=config.TRAIN.USE_MULTISCALE,
            drloc_mode=config.TRAIN.DRLOC_MODE,
            use_abs=config.TRAIN.USE_ABS,
            type_adapters=config.TRAIN.TYPE_ADAPTERS,
            adapter_state=config.TRAIN.USE_ADAPTERS,
            ratio_param=config.TRAIN.SIZE_ADAPTERS,
            )
    elif model_type == "swin_LowRank":
        model = SwinTransformerLowRank(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.SWIN.PATCH_SIZE,
            in_chans=config.MODEL.SWIN.IN_CHANS,
            num_classes=config.MODEL.NUM_CLASSES,
            embed_dim=config.MODEL.SWIN.EMBED_DIM,
            depths=config.MODEL.SWIN.DEPTHS,
            num_heads=config.MODEL.SWIN.NUM_HEADS,
            window_size=config.MODEL.SWIN.WINDOW_SIZE,
            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
            qkv_bias=config.MODEL.SWIN.QKV_BIAS,
            qk_scale=config.MODEL.SWIN.QK_SCALE,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            ape=config.MODEL.SWIN.APE,
            rpe=config.MODEL.SWIN.RPE,
            patch_norm=config.MODEL.SWIN.PATCH_NORM,
            use_checkpoint=config.TRAIN.USE_CHECKPOINT,
            use_drloc=config.TRAIN.USE_DRLOC,
            sample_size=config.TRAIN.SAMPLE_SIZE,
            use_multiscale=config.TRAIN.USE_MULTISCALE,
            drloc_mode=False,
            use_abs=False,
            type_adapters=config.TRAIN.TYPE_ADAPTERS,
            adapter_state=config.TRAIN.USE_ADAPTERS,
            ratio_param=config.TRAIN.SIZE_ADAPTERS,
            )
    elif model_type == "swin_LoRa":
        model = SwinTransformerLoRa(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.SWIN.PATCH_SIZE,
            in_chans=config.MODEL.SWIN.IN_CHANS,
            num_classes=config.MODEL.NUM_CLASSES,
            embed_dim=config.MODEL.SWIN.EMBED_DIM,
            depths=config.MODEL.SWIN.DEPTHS,
            num_heads=config.MODEL.SWIN.NUM_HEADS,
            window_size=config.MODEL.SWIN.WINDOW_SIZE,
            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
            qkv_bias=config.MODEL.SWIN.QKV_BIAS,
            qk_scale=config.MODEL.SWIN.QK_SCALE,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            ape=config.MODEL.SWIN.APE,
            rpe=config.MODEL.SWIN.RPE,
            patch_norm=config.MODEL.SWIN.PATCH_NORM,
            use_checkpoint=config.TRAIN.USE_CHECKPOINT,
            use_drloc=False,
            sample_size=config.TRAIN.SAMPLE_SIZE,
            use_multiscale=config.TRAIN.USE_MULTISCALE,
            drloc_mode=config.TRAIN.DRLOC_MODE,
            use_abs=False
            )
    else:
        raise NotImplementedError(f"Unknown model: {model_type}")

    return model
