


from rekognition_online_action_detection.utils.registry import Registry

TRAINERS = Registry()
INFERENCES = Registry()


def do_train(cfg,
             batch_size,
             memory_video_ratio,
             memory_frame_ratio,
             data_loaders,
             tet_data_loaders_i,
             model,
             criterion,
             optimizer,
             scheduler,
             device,
             checkpointer,
             logger,
             num_tasks,
             m,
             get_memory
             ):
    if get_memory==False:
        return TRAINERS[cfg.MODEL.MODEL_NAME](
            cfg,
            batch_size,
            memory_video_ratio,
            memory_frame_ratio,
            data_loaders,
            tet_data_loaders_i,
            model,
            criterion,
            optimizer,
            scheduler,
            device,
            checkpointer,
            logger,
            num_tasks,
            m,
            get_memory)
    else:
        selected_features= TRAINERS[cfg.MODEL.MODEL_NAME](
            cfg,
            batch_size,
            memory_video_ratio,
            memory_frame_ratio,
            data_loaders,
            tet_data_loaders_i,
            model,
            criterion,
            optimizer,
            scheduler,
            device,
            checkpointer,
            logger,
            num_tasks,
            m,
            get_memory)
        return selected_features


def do_inference(cfg,
            data_loaders_i,
            model,
            device,
            logger,num_task,
                 inferr):

        mean_AP=INFERENCES[cfg.MODEL.MODEL_NAME](
        cfg,
        data_loaders_i,
        model,
        device,
        logger,num_task,
        inferr)

        return mean_AP
