# Copyright 2022 CircuitNet. All rights reserved.

import models
import torch

def build_model(args, in_feature_dim, out_feature_dim):
    if args.model == "congestion":
        model = models.__dict__["GPDL"](in_feature_dim, out_feature_dim)
    elif args.model == "DRC":
        model = models.__dict__["RouteNet"](in_feature_dim, out_feature_dim)
    elif args.model == "IR_drop":
        model = models.__dict__["MAVI"](in_feature_dim, out_feature_dim)
    else:
        model = models.__dict__[args.model](in_feature_dim, out_feature_dim)
    return model
