from pathlib import Path

from .dataset import OneToManyDataset, SplitDataset, TripleDataset
from .ogb_dataset import OGBDataset, TripleDatasetWithNeg

__all__ = [
    "OGBDataset",
    "SplitDataset",
    "TripleDataset",
    "TripleDatasetWithNeg",
    "get_one_to_many_dataset",
    "get_triple_dataset",
]


def get_triple_dataset(
    dataset_name: str,
    data_folder: Path,
    *,
    add_inverse: bool = False,
) -> SplitDataset[TripleDataset] | OGBDataset:
    """Get a dataset from a name and data folder."""
    if dataset_name.startswith("ogb"):
        return OGBDataset(dataset_name, data_folder, add_inverse=add_inverse)
    return SplitDataset(
        train=TripleDataset.from_tsv(data_folder / dataset_name / "train.tsv", split="train"),
        valid=TripleDataset.from_tsv(data_folder / dataset_name / "valid.tsv", split="valid"),
        test=TripleDataset.from_tsv(data_folder / dataset_name / "test.tsv", split="test"),
        add_inverse=add_inverse,
    )


def get_one_to_many_dataset(
    dataset_name: str,
    num_entities: int,
    data_folder: Path,
    *,
    add_inverse: bool = False,
) -> SplitDataset[OneToManyDataset]:
    """Get a dataset from a name and data folder."""
    if dataset_name.startswith("ogb"):
        raise NotImplementedError("OGB datasets are not yet supported for 1-to-N datasets.")
    return SplitDataset(
        train=OneToManyDataset.from_tsv(
            data_folder / dataset_name / "train.tsv",
            split="train",
            num_entities=num_entities,
        ),
        valid=OneToManyDataset.from_tsv(
            data_folder / dataset_name / "valid.tsv",
            split="valid",
            num_entities=num_entities,
        ),
        test=OneToManyDataset.from_tsv(
            data_folder / dataset_name / "test.tsv",
            split="test",
            num_entities=num_entities,
        ),
        add_inverse=add_inverse,
    )
