from .DNN import myDNN
from .LSTM import myLSTM
from .GPT import myGPT
from .GPT_lightly import myGPT_lightly
from .GPT_specific import myGPT_specific
from .GPT_separate_attn_proj import myGPT_separate_attn_proj
from .GPT_separate_attn_proj_specific import myGPT_separate_attn_proj_specific
from .GPT_softmax10x import myGPT_softmax10x
from .GPT_OneHotEmb import myGPT_OneHotEmb
from .GPT_custom_attn import myGPT_custom_attn
from .GPT_condense import myGPT_condense
from .GPT_condense_attn_only import myGPT_condense_attn_only
from .GPT_condense_attn_only_without_QK import myGPT_condense_attn_only_without_QK
from .GPT_linear_regime import myGPT_linear
from .GPT_linear_regime_only_with_attention_para import myGPT_linear_for_attn
from .GPT_condense_regime_only_with_attention_para import myGPT_condense_for_attn
from .GPT_kaiming_normal import myGPT_kaiming_normal
from .GPT_normal_init import myGPT_normal_init
from .GPT_specific_test_for_resnet import myGPT_specific_test_for_resnet
from .GPT_normal_init_for_emb import myGPT_normal_init_for_emb
from .GPT2_init import myGPT2_init
from .GPT2 import myGPT2
from .GPT2_init_prenorm import myGPT2_init_prenorm
from .DNN_averaged import myDNN_averaged
from .DNN_simplified import myDNN_simplified
from .GPT2_init_exact_value import myGPT2_init_exact_value
from .GPT2_init_exact_value_test_value import myGPT2_init_exact_value_test_value
from .GPT2_init_for_attn import myGPT2_init_for_attn
from .GPT2_prenorm_init_for_MLP import myGPT2_prenorm_init_for_MLP
from .GPT2_init_for_diff_part import myGPT2_init_for_diff_part
from .GPT2_init_for_diff_part_prenorm import myGPT2_init_for_diff_part_prenorm
from .GPT_sandwitch import GPT_sandwitchLN

def get_model(args, device):
    if args.model == 'LSTM':
        model = myLSTM(args, device).to(device)
    elif args.model == 'GPT':
        model = myGPT(args, device).to(device)
    elif args.model == 'DNN':
        model = myDNN(args, device).to(device)
    elif args.model == 'GPT_lightly':
        model = myGPT_lightly(args, device).to(device)
    elif args.model == 'GPT_specific':
        model = myGPT_specific(args, device).to(device)
    elif args.model == 'GPT_separate_attn_proj':
        model = myGPT_separate_attn_proj(args, device).to(device)
    elif args.model == 'GPT_separate_attn_proj_specific':
        model = myGPT_separate_attn_proj_specific(args, device).to(device)
    elif args.model == 'GPT_softmax10x':
        model = myGPT_softmax10x(args, device).to(device)
    elif args.model == 'GPT_OneHotEmb':
        model = myGPT_OneHotEmb(args, device).to(device)
    elif args.model == 'GPT_custom_attn':
        model = myGPT_custom_attn(args, device).to(device)
    elif args.model == 'GPT_condense':
        model = myGPT_condense(args, device).to(device)
    elif args.model == 'GPT_condense_attn_only':
        model = myGPT_condense_attn_only(args, device).to(device)
    elif args.model == 'GPT_condense_attn_only_without_QK':
        model = myGPT_condense_attn_only_without_QK(args, device).to(device)
    elif args.model == 'GPT_linear_regime':
        model = myGPT_linear(args, device).to(device)
    elif args.model == 'GPT_linear_regime_only_with_attention_para':
        model = myGPT_linear_for_attn(args, device).to(device)
    elif args.model == 'GPT_condense_regime_only_with_attention_para':
        model = myGPT_condense_for_attn(args, device).to(device)
    elif args.model == 'GPT_specific_test_for_resnet':
        model = myGPT_specific_test_for_resnet(args, device).to(device)
    elif args.model == 'GPT_kaiming_normal':
        model = myGPT_kaiming_normal(args, device).to(device)
    elif args.model == 'GPT_normal_init':
        model= myGPT_normal_init(args, device).to(device)
    elif args.model == 'GPT_normal_init_for_emb':
        model = myGPT_normal_init_for_emb(args, device).to(device)
    elif args.model == 'GPT2_init':
        model = myGPT2_init(args, device).to(device)
    elif args.model == 'GPT2':
        model = myGPT2(args, device).to(device)
    elif args.model == 'GPT2_init_prenorm':
        model = myGPT2_init_prenorm(args, device).to(device)
    elif args.model == 'GPT2_init_for_diff_part':
        model = myGPT2_init_for_diff_part(args, device).to(device)
    elif args.model == 'GPT2_init_exact_value':
        model = myGPT2_init_exact_value(args, device).to(device)
    elif args.model == 'GPT2_init_exact_value_test_value':
        model = myGPT2_init_exact_value_test_value(args, device).to(device)
    elif args.model == 'DNN_averaged':
        model = myDNN_averaged(args, device).to(device)
    elif args.model == 'DNN_simplified':
        model = myDNN_simplified(args, device).to(device)
    elif args.model == 'GPT2_prenorm_init_for_MLP':
        model = myGPT2_prenorm_init_for_MLP(args, device).to(device)
    elif args.model == 'GPT2_init_for_attn':
        model = myGPT2_init_for_attn(args, device).to(device)
    elif args.model == 'GPT2_init_for_diff_part_prenorm':
        model = myGPT2_init_for_diff_part_prenorm(args, device).to(device)
    elif args.model == 'GPT_sandwitchLN':
        model = GPT_sandwitchLN(args, device).to(device)

    return model