#!/usr/bin/env python3

import json
import logging
import sys
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace
from typing import TypeVar

import numpy as np

from softmatcha import registry, stopwatch
from softmatcha.embeddings import GENSIM_PRETRAINED_MODELS, Embedding, get_embedding
from softmatcha.search import SearchNaive, SearchQuick, SearchScan
from softmatcha.struct import Pattern, TokenEmbeddings
from softmatcha.tokenizers import Tokenizer, get_tokenizer
from softmatcha.utils import io as io_utils

E = TypeVar("E", bound=Embedding)
S = TypeVar("S", bound=SearchScan)

logging.basicConfig(
    format="| %(asctime)s | %(levelname)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level="INFO",
    stream=sys.stderr,
)
logger = logging.getLogger("search")


def parse_args() -> Namespace:
    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
    # fmt: off
    parser.add_argument("--model", default="glove-wiki-gigaword-300",
                        help="Model name or path.")
    parser.add_argument("--backend",
                        choices=registry.get_registry("embedding").keys(),
                        default="gensim")
    parser.add_argument("--search", choices=["naive", "quick"], default="quick",
                        help="Search method.")
    parser.add_argument("--threshold", type=float, default=0.55)
    parser.add_argument("--json", action="store_true")
    parser.add_argument("--profile", action="store_true")
    parser.add_argument("--start-position", type=int, default=0,
                        help="Start position to be searched.")
    parser.add_argument("pattern", type=str,
                        help="Pattern string.")
    parser.add_argument("text_file", type=str, nargs="?", default="-",
                        help="Input file.")
    parser.add_argument("--mmap", action="store_true",
                        help="Load the index on disk. "
                        "This option will slow down the search time, "
                        "but will improve memory consumption.")
    # fmt: on
    parser.epilog = f"""
    Available pre-trained gensim models: {GENSIM_PRETRAINED_MODELS}
    """
    return parser.parse_args()


def main(args: Namespace) -> None:
    stopwatch.timers.reset(profile=args.profile)

    logger.info(args)

    with stopwatch.timers["load/embedding"]:
        embedding: Embedding = get_embedding(args.backend).build(
            args.model, mmap=args.mmap
        )
    with stopwatch.timers["load/tokenizer"]:
        tokenizer: Tokenizer = get_tokenizer(args.backend).build(args.model)

    pattern_tokens = tokenizer(args.pattern)
    pattern_embeddings = embedding(pattern_tokens)
    pattern = Pattern.build(
        [{token} for token in pattern_tokens],
        np.split(pattern_embeddings, len(pattern_embeddings)),
    )

    searcher: SearchScan = {"naive": SearchNaive, "quick": SearchQuick}[args.search](
        pattern
    )

    num_tokens = 0
    num_lines = 0
    for i, line in enumerate(io_utils.read_lines(args.text_file)):
        with stopwatch.timers["tokenize"]:
            text_tokens = tokenizer(line)
        text_embeddings = embedding(text_tokens)
        num_tokens += len(text_embeddings)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        for matched in searcher.search(
            text, threshold=args.threshold, start=args.start_position
        ):
            if args.json:
                d = {
                    "line_number": i,
                    "original_line": line,
                    "matched_tokens": tokenizer.decode(
                        text.tokens[matched.begin : matched.end]
                    ),
                    "score": [float(f"{x:.4f}") for x in matched.scores.tolist()],
                }
                print(json.dumps(d, ensure_ascii=False))
            else:
                print(line)
            break
        num_lines += 1

    if args.profile:
        print(f"elapsed_time\t{stopwatch.timers.elapsed_time}", file=sys.stderr)
        print(f"ncalls\t{stopwatch.timers.ncalls}", file=sys.stderr)
        print(f"nlines\t{num_lines:,}", file=sys.stderr)
        print(f"ntokens\t{num_tokens:,}", file=sys.stderr)
        print(f"ntokens/sentence\t{num_tokens/num_lines:.1f}", file=sys.stderr)
        if isinstance(searcher, SearchQuick):
            print(f"table_size\t{len(searcher.shift_table)}", file=sys.stderr)


def cli_main() -> None:
    args = parse_args()
    main(args)


if __name__ == "__main__":
    cli_main()
