from typing import Optional, Callable
from pathlib import Path

from .sft_tvc.dataset import TvcgDataset


def get_dataset_name(
    ann_path: str,
):
    if "clevr" in ann_path.lower():
        return "clevr"
    elif "tvc" in ann_path.lower():
        return "tvc"
    else:
        raise NotImplementedError(f"Invalid ann_path data_type: {ann_path}")


def get_dataset(
    ann_path: str,
    image_dir: Optional[str] = None,
    baseline_opts: str = "",
    max_image_size: int = 672,
    postprocess_fn: Optional[Callable] = None,
    debug: bool = False,
):
    if image_dir is None:
        _image_dir = Path(ann_path).parent.parent / "images"
        assert _image_dir.is_dir(), f"invalid image dir: {_image_dir}"
        image_dir = str(_image_dir)
    if "tvc" in ann_path.lower():
        return TvcgDataset(
            ann_path,
            image_dir,
            max_image_size=max_image_size,
            baseline_opts=baseline_opts,
            postprocess_fn=postprocess_fn,
            debug=debug,
        )
    else:
        raise NotImplementedError(f"Invalid ann_path data_type: {ann_path}")
