import logging

import torch
import torch.utils.data

from unrealpose.pose.refine.check2d import parse_video_name
from .core import pose2d_inference, load_pose2d_model


logger = logging.getLogger(__name__)


def pose2d_wrapper(config,
                   video_dataset,
                   input_video_name,
                   output_dir,
                   detection_record=None):
    """
    Select different pose estimation models based on video class.

    """
    input_video_basename, _, details = parse_video_name(input_video_name)

    model_dict = load_pose2d_model(config)
    pose_dataset = video_dataset
    pose_loader = torch.utils.data.DataLoader(
        pose_dataset,
        batch_size=config.TEST.BATCH_SIZE,  # batch for one GPU
        shuffle=False,
        # num_workers=config.WORKERS,
        # num_workers=1,  # any positive numbers (e.g. 1) will start multiprocessing
        pin_memory=True)
    pose_record = pose2d_inference(config, pose_loader, pose_dataset,
                                   input_video_basename, output_dir, model_dict, detection_record)

    return pose_record
