from .vit import ViT
from .cait import CaiT
from .pit import PiT
from .swin import SwinTransformer
from .t2t import T2T_ViT
from .two_stream_transformer import TwoStreamTransformer
from .swinv2 import SwinTransformerV2
from .two_stream_transformer_v2 import TwoStreamTransformerV2
from .two_stream_transformer_v3 import TwoStreamTransformerV3

def create_model(img_size, n_classes, args):
    if args.model == 'vit':
        patch_size = 4 if img_size == 32 else 4
        model = ViT(img_size=img_size, patch_size = patch_size, num_classes=n_classes, dim=384, 
                    mlp_dim_ratio=2, depth=8, heads=12, dim_head=384//12,
                    stochastic_depth=args.sd, is_SPT=args.is_SPT, is_LSA=args.is_LSA)

    elif args.model == 'cait':       
        patch_size = 4 if img_size == 32 else 4
        model = CaiT(img_size=img_size, patch_size = patch_size, num_classes=n_classes, stochastic_depth=args.sd, 
                     is_LSA=args.is_LSA, is_SPT=args.is_SPT)
        
    elif args.model == 'pit':
        patch_size = 2 if img_size == 32 else 4    
        args.channel = 96
        args.heads = (2, 4, 8)
        args.depth = (2, 6, 4)
        dim_head = args.channel // args.heads[0]
        
        model = PiT(img_size=img_size, patch_size = patch_size, num_classes=n_classes, dim=args.channel, 
                    mlp_dim_ratio=2, depth=args.depth, heads=args.heads, dim_head=dim_head, 
                    stochastic_depth=args.sd, is_SPT=args.is_SPT, is_LSA=args.is_LSA)

    elif args.model =='t2t':
        model = T2T_ViT(img_size=img_size, num_classes=n_classes, drop_path_rate=args.sd, is_SPT=args.is_SPT, is_LSA=args.is_LSA)
        
    elif args.model =='swin':
        depths = [2, 6, 4]
        num_heads = [3, 6, 12]
        mlp_ratio = 2
        window_size = 4
        patch_size = 2 if img_size == 32 else 4
            
        model = SwinTransformerV2(input_resolution=(img_size, img_size),
                             window_size=window_size,
                             in_channels=3,
                             use_checkpoint=False,
                             embedding_channels=96,
                             depths=depths,
                             number_of_heads=num_heads, num_classes=n_classes).cuda()

    elif args.model == "two_stream":
        model = TwoStreamTransformer(
                 image_size = img_size,
                 outer_patch_size = 4,
                 inner_patch_size = 4,
                 num_classes = n_classes,
                 dim = 384,
                 depth = 3,
                 heads = 4,
                 mlp_dim = 384 * 2,
                 channels = 3,
                 p = 0.1,
                 num_workspace_slots = 5,
                 inner_depth = 2,
                 num_templates = 4)
    elif args.model == "two_stream_v2":
        model = TwoStreamTransformerV2(
                 image_size = img_size,
                 outer_patch_size = 4,
                 inner_patch_size = 4,
                 num_classes = n_classes,
                 dim = 192,
                 depth = 3,
                 heads = 12,
                 mlp_dim = 192 * 2,
                 channels = 3,
                 p = 0.1,
                 num_workspace_slots = 5,
                 inner_depth = 2,
                 num_templates = 4)
    elif args.model == "two_stream_v3":
        model = TwoStreamTransformerV3(
                 image_size = img_size,
                 outer_patch_size = 4,
                 inner_patch_size = 4,
                 num_classes = n_classes,
                 dim = 192,
                 depth = 3,
                 heads = 12,
                 mlp_dim = 192 * 2,
                 channels = 3,
                 p = 0.1,
                 num_workspace_slots = 5,
                 inner_depth = 2,
                 num_templates = 4)
        
    return model
