# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from rekognition_online_action_detection.utils.registry import Registry

TRAINERS = Registry()
INFERENCES = Registry()


def do_train(cfg,
             data_loaders,
             tet_data_loaders_i,
             model,
             criterion,
             optimizer,
             scheduler,
             device,
             checkpointer,
             logger,
             num_tasks,
             ):
    return TRAINERS[cfg.MODEL.MODEL_NAME](
        cfg,
        data_loaders,
        tet_data_loaders_i,
        model,
        criterion,
        optimizer,
        scheduler,
        device,
        checkpointer,
        logger,
        num_tasks)


def do_inference(cfg,
            data_loaders_i,
            model,
            device,
            logger,num_task,
                 inferr):

        mean_AP=INFERENCES[cfg.MODEL.MODEL_NAME](
        cfg,
        data_loaders_i,
        model,
        device,
        logger,num_task,
        inferr)

        return mean_AP
