import torch
import torch.optim as optim

from .utility import *
from .loader import *
from .train import *
from ..models.ours import MASH

def select_model(args, A, X, Y):
    n_sample, n_node, n_feat = X.shape

    model = MASH(n_node=n_node,
                c_in=n_feat,
                c_hid=args.hid,
                n_class=torch.max(Y).item() + 1,
                dropout=args.dr)
    
    return model
        
def select_optimizer(args, model):
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    
    return optimizer

def select_trainer(args, model, optimizer, dl_tr, dl_te):
    trainer = MASH_Trainer(args, model, optimizer, dl_tr, dl_te, dl_te)

    return trainer