"""
Model construction.
"""

from argparse import ArgumentParser
from models.output_decoder import *


def make_decoder(args: ArgumentParser, embedding_model: nn.Module) -> nn.Module:
    r"""Make decoder layer for different dataset.
    Args:
        args (ArgumentParser): Arguments dict from argparser.
        embedding_model (nn.Module): Graph representation model, typically a gnn output node representation.
    """
    if args.dataset_name in ["ZINC", "ZINC_full", "StructureCounting", "QM9"]:
        model = GraphRegression(embedding_model, pooling_method=args.pooling_method)
    elif args.dataset_name in ["count_cycle", "count_graphlet"]:
        model = NodeRegression(embedding_model)
    else:
        model = GraphClassification(embedding_model, out_channels=args.out_channels, pooling_method=args.pooling_method)
    return model



