import torch
from model.dilated_graph import DilatedGraph

def get_optim(args, generative_model):
    optim = torch.optim.AdamW(
        generative_model.parameters(),
        lr=args.lr, amsgrad=True,
        weight_decay=1e-12)

    return optim

def get_model(args, device):
    net_dynamics = DilatedGraph(
        x_input_size=3,
        h_input_size=11,
        depth=args.depth,
        hidden_size=args.hidden_size,
        num_heads=16,
        # kernel_size=args.kernel_size,
        # dilation=args.dilation,
    )
    return net_dynamics.to(device)

