from typing import Optional, Tuple
import torch

from .io import read_hwc, write_hwc
from .util import hwc2bchw, bchw2hwc, bchw2bhwc, bhwc2bchw, bhwc2hwc
from .draw import draw_bchw, draw_landmarks, draw_landmarks_only_eyes
from .show import show_bchw, show_bhw, get_bchw, get_bhw, get_bhw_no_contour, get_bchw_no_contour, test

from .face_detection import FaceDetector
from .face_parsing import FaceParser
from .face_alignment import FaceAlignment
from .face_attribute import FaceAttribute


def _split_name(name: str) -> Tuple[str, Optional[str]]:
    if '/' in name:
        detector_type, conf_name = name.split('/', 1)
    else:
        detector_type, conf_name = name, None
    return detector_type, conf_name


def face_detector(name: str, device: torch.device, **kwargs) -> FaceDetector:
    detector_type, conf_name = _split_name(name)
    if detector_type == 'retinaface':
        from .face_detection import RetinaFaceDetector
        return RetinaFaceDetector(conf_name, **kwargs).to(device)
    else:
        raise RuntimeError(f'Unknown detector type: {detector_type}')


def face_parser(name: str, device: torch.device, **kwargs) -> FaceParser:
    parser_type, conf_name = _split_name(name)
    if parser_type == 'farl':
        from .face_parsing import FaRLFaceParser
        return FaRLFaceParser(conf_name, device=device, **kwargs).to(device)
    else:
        raise RuntimeError(f'Unknown parser type: {parser_type}')


def face_aligner(name: str, device: torch.device, **kwargs) -> FaceAlignment:
    aligner_type, conf_name = _split_name(name)
    if aligner_type == 'farl':
        from .face_alignment import FaRLFaceAlignment
        return FaRLFaceAlignment(conf_name, device=device, **kwargs).to(device)
    else:
        raise RuntimeError(f'Unknown aligner type: {aligner_type}')

def face_attr(name: str, device: torch.device, **kwargs) -> FaceAttribute:
    attr_type, conf_name = _split_name(name)
    if attr_type == 'farl':
        from .face_attribute import FaRLFaceAttribute
        return FaRLFaceAttribute(conf_name, device=device, **kwargs).to(device)
    else:
        raise RuntimeError(f'Unknown attribute type: {attr_type}')