#!/usr/bin/env python3
import logging
import os
import sys
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace

from softmatcha import stopwatch
from softmatcha.struct import IndexInvertedFileCollection
from softmatcha.tokenizers import get_tokenizer

logging.basicConfig(
    format="| %(asctime)s | %(levelname)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level="INFO",
    stream=sys.stderr,
)
logger = logging.getLogger("softmatcha.cli.build_inverted_index")


def parse_args() -> Namespace:
    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
    # fmt: off
    parser.add_argument("inputs", type=str, nargs="+",
                        help="Input file paths.")
    parser.add_argument("--index", "-o", type=str, required=True,
                        help="Output index filename.")
    parser.add_argument("--model", default="glove-wiki-gigaword-300",
                        help="Model name or path.")
    parser.add_argument("--backend",
                        choices=["gensim", "transformers", "fasttext"],
                        default="gensim")
    parser.add_argument("--jsonl-key", type=str,
                        help="Specify the JSONL key of texts to be indexed. "
                        "If not specified this option, the inputs will be treated as plain text.")
    parser.add_argument("--num-workers", type=int, default=8,
                        help="Number of workers.")
    parser.add_argument("--buffer-size", type=int, default=10000,
                        help="Buffer size.")
    parser.add_argument("--chunk-size", type=int, default=1024,
                        help="Chunk size of HDF5 storage.")
    # fmt: on
    return parser.parse_args()


def main(args: Namespace) -> None:
    logger.info(args)
    stopwatch.timers.reset(profile=True)

    input_paths = [os.path.abspath(input_path) for input_path in args.inputs]
    IndexInvertedFileCollection.build(
        args.index,
        input_paths,
        get_tokenizer(args.backend),
        args.model,
        jsonl_key=args.jsonl_key,
        num_workers=args.num_workers,
        buffer_size=args.buffer_size,
        chunk_size=args.chunk_size,
    )

    logger.info(f"Elapsed time: {stopwatch.timers.elapsed_time}")
    logger.info(f"Total time: {sum(stopwatch.timers.elapsed_time.values())}")
    logger.info(f"ncalls: {stopwatch.timers.ncalls}")


def cli_main() -> None:
    args = parse_args()
    main(args)


if __name__ == "__main__":
    cli_main()
