from .swinv2_model import (swin_transformer_v2_t, swin_transformer_v2_s, swin_transformer_v2_b, swin_transformer_v2_l,
                           swin_transformer_v2_h, swin_transformer_v2_g)
from .classification_wrapper import ClassificationModelWrapper
from .vit_model import ViT


def vit_t_patch16(image_size, num_classes, dropout=0.0, emb_dropout=0.0):
    return ViT(image_size=image_size, patch_size=16, num_classes=num_classes, dim=256, depth=6, heads=4, mlp_dim=1024,
               dropout=dropout, emb_dropout=emb_dropout)


def vit_mi_patch16(image_size, num_classes, dropout=0.0, emb_dropout=0.0):
    return ViT(image_size=image_size, patch_size=16, num_classes=num_classes, dim=1024, depth=6, heads=16, mlp_dim=4096,
               dropout=dropout, emb_dropout=emb_dropout)


def vit_s_patch16(image_size, num_classes, dropout=0.0, emb_dropout=0.0):
    return ViT(image_size=image_size, patch_size=16, num_classes=num_classes, dim=384, depth=12, heads=6, mlp_dim=1536,
               dropout=dropout, emb_dropout=emb_dropout)


def vit_b_patch16(image_size, num_classes, dropout=0.0, emb_dropout=0.0):
    return ViT(image_size=image_size, patch_size=16, num_classes=num_classes, dim=768, depth=12, heads=12, mlp_dim=3072,
               dropout=dropout, emb_dropout=emb_dropout)


def vit_l_patch16(image_size, num_classes, dropout=0.0, emb_dropout=0.0):
    return ViT(image_size=image_size, patch_size=16, num_classes=num_classes, dim=1024, depth=24, heads=16, mlp_dim=4096,
               dropout=dropout, emb_dropout=emb_dropout)


def vit_h_patch16(image_size, num_classes, dropout=0.0, emb_dropout=0.0):
    return ViT(image_size=image_size, patch_size=16, num_classes=num_classes, dim=1280, depth=32, heads=16, mlp_dim=5120,
               dropout=dropout, emb_dropout=emb_dropout)


def model_loader(model_name='swinv2_tiny', num_classes=1000, input_resolution=224):

    # try struct = {'tt', 'block_tt'}
    # try tt_dim = {2,3,4}
    # try tt_rank = {1,8,16}
    # for vit, use input_resolution = 224. for swin, use input_resolution = 256.

    if model_name == 'swinv2_tiny':
        feature_extractor = swin_transformer_v2_t((input_resolution, input_resolution))
        model = ClassificationModelWrapper(model=feature_extractor, number_of_classes=num_classes, output_channels=768)
    elif model_name == 'swinv2_small':
        feature_extractor = swin_transformer_v2_s((input_resolution, input_resolution))
        model = ClassificationModelWrapper(model=feature_extractor, number_of_classes=num_classes, output_channels=768)
    elif model_name == 'swinv2_base':
        feature_extractor = swin_transformer_v2_b((input_resolution, input_resolution))
        model = ClassificationModelWrapper(model=feature_extractor, number_of_classes=num_classes, output_channels=1024)
    elif model_name == 'swinv2_large':
        feature_extractor = swin_transformer_v2_l((input_resolution, input_resolution))
        model = ClassificationModelWrapper(model=feature_extractor, number_of_classes=num_classes, output_channels=1536)
    elif model_name == 'swinv2_huge':
        feature_extractor = swin_transformer_v2_h((input_resolution, input_resolution))
        model = ClassificationModelWrapper(model=feature_extractor, number_of_classes=num_classes, output_channels=2816)
    elif model_name == 'swinv2_giant':
        feature_extractor = swin_transformer_v2_g((input_resolution, input_resolution))
        model = ClassificationModelWrapper(model=feature_extractor, number_of_classes=num_classes, output_channels=4096)
    elif model_name == 'vit_tiny_patch16':
        model = vit_t_patch16(image_size=input_resolution, num_classes=num_classes)
    elif model_name == 'vit_mini_patch16':
        model = vit_mi_patch16(image_size=input_resolution, num_classes=num_classes)
    elif model_name == 'vit_small_patch16':
        model = vit_s_patch16(image_size=input_resolution, num_classes=num_classes)
    elif model_name == 'vit_base_patch16':
        model = vit_b_patch16(image_size=input_resolution, num_classes=num_classes)
    elif model_name == 'vit_large_patch16':
        model = vit_l_patch16(image_size=input_resolution, num_classes=num_classes)
    elif model_name == 'vit_huge_patch16':
        model = vit_h_patch16(image_size=input_resolution, num_classes=num_classes)
    else:
        raise Exception("Model not supported!")
    return model
