import logging

import torch
import torch.utils.data

from unrealpose.datasets.yolo import YoloDataset
from unrealpose.pose.refine.check2d import parse_video_name
from .yolov3 import yolov3_inference, load_yolov3


logger = logging.getLogger(__name__)


def detection_wrapper(config,
                      video_dataset,
                      input_video_name,
                      output_dir):
    """
    Select different detection models based on video class.
    """
    record = {}
    if config.PIPELINE.DISABLE_PERSON_DETECTION:
        selected_person_model = ''
    else:
        selected_person_model = 'yolov3'
    input_video_basename, _, details = parse_video_name(input_video_name)

    # load model
    if selected_person_model == 'yolov3':
        model_dict = load_yolov3(config)
        yolo_dataset = YoloDataset(config.YOLO.IMG_SIZE)
        yolo_loader = torch.utils.data.DataLoader(
            yolo_dataset,
            batch_size=config.TEST.BATCH_SIZE,  # batch for one GPU
            shuffle=False,
            # num_workers=config.WORKERS,
            num_workers=1,  # since read video here
            pin_memory=True)
        record = yolov3_inference(config, yolo_loader, yolo_dataset,
                                  input_video_basename, output_dir, model_dict)
    else:
        pass

    # record.update(basketball_record)

    return record
