from libs.fttransformer import FTTransformer

def getmodel_ad(tasktype, dataset, input_dim, output_dim, args, device):

    # default 
    params = {
        ## tokenizer
        'token_bias': True,   
    
        ## transformer
        'n_layers': 6, 
        'd_token': 24,
        'n_heads': 8,  
        'd_ffn_factor': 4 / 3,  
        'attention_dropout': 0.2,
        'ffn_dropout': 0.1,
        'residual_dropout': 0.0,
        'activation': 'reglu',  
        'prenormalization': True,
        'initialization': 'kaiming',
    
        ## linformer 
        'kv_compression': None,
        'kv_compression_sharing': None,
    
        ## optimizer
        'learning_rate': 1e-4, 
        'lr_scheduler': False,
        'weight_decay': 1e-5, 
        'n_epochs': 300,
        'early_stopping_rounds': 20, #20
        'optimizer': 'AdamW'}

    params['batch_size']= 1 # 1 -> use get_batch_size from data.py
    params['save_model_path'] = args.save_model
    params['fusion_num'] = args.fusion_num
    params['fusion_method'] = args.fusion_method
    params['dataname'] = args.dataname
    params['aug'] = args.aug
    params['k_aug'] = args.k_aug
    params['score_method'] = args.score_method

    if args.dataname == 'census':
        params['d_token'] = 8 
        params['n_heads'] = 4
    if args.dataname == 'InternetAds':
        params['d_token'] = 8 
        params['n_heads'] = 4
        params['batch_size']= 64

    model = FTTransformer(params, tasktype, dataset.X_num, dataset.X_cat_cardinality, input_dim=input_dim, output_dim=output_dim, device=device)
    
    return model
