import argparse
import logging
from pathlib import Path

import pandas as pd
from pykeen.datasets import Hetionet

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def download_and_split_hetionet(
    output_dir: str,
    seed: int = 42,
):
    """Download Hetionet and split it into train/validation/test sets.

    Args:
        output_dir: Directory to save the TSV files
        seed: Random seed for reproducibility

    """
    # Create output directory if it doesn't exist
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # Download and load Hetionet
    dataset = Hetionet(random_state=seed)
    train = dataset.training.mapped_triples.numpy()
    valid = dataset.validation.mapped_triples.numpy()
    test = dataset.testing.mapped_triples.numpy()
    train_df = pd.DataFrame(train, columns=["head", "relation", "tail"])
    valid_df = pd.DataFrame(valid, columns=["head", "relation", "tail"])
    test_df = pd.DataFrame(test, columns=["head", "relation", "tail"])

    train_df.to_csv(output_path / "train.tsv", sep="\t", index=False, header=False)
    valid_df.to_csv(output_path / "valid.tsv", sep="\t", index=False, header=False)
    test_df.to_csv(output_path / "test.tsv", sep="\t", index=False, header=False)

    logger.info(f"Splits saved to {output_dir}")
    logger.info(f"Number of entities: {dataset.num_entities}")
    logger.info(f"Number of relations: {dataset.num_relations}")
    logger.info(f"Train size: {len(train_df)}")
    logger.info(f"Validation size: {len(valid_df)}")
    logger.info(f"Test size: {len(test_df)}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Download and split Hetionet dataset")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="data/processed/Hetionet",
        help="Directory to save the TSV files",
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    args = parser.parse_args()

    download_and_split_hetionet(
        output_dir=args.output_dir,
        seed=args.seed,
    )
