#!/usr/bin/env python3
import json
import logging
import re
import sys
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace
from collections import defaultdict
from typing import TypeVar

import numpy as np
import termcolor

from softmatcha import registry, stopwatch
from softmatcha.embeddings import GENSIM_PRETRAINED_MODELS, Embedding, get_embedding
from softmatcha.search import Search, SearchIndexInvertedFile
from softmatcha.struct import Pattern
from softmatcha.struct.index_inverted import (
    IndexInvertedFile,
    IndexInvertedFileCollection,
)
from softmatcha.tokenizers import Tokenizer, get_tokenizer

E = TypeVar("E", bound=Embedding)
S = TypeVar("S", bound=Search)


def parse_args() -> Namespace:
    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
    # fmt: off
    parser.add_argument("pattern", type=str,
                        help="Pattern string.")
    parser.add_argument("--index", type=str,
                        help="Index file.")
    parser.add_argument("--backend",
                        choices=registry.get_registry("embedding").keys(),
                        default="gensim")
    parser.add_argument("--model", default="glove-wiki-gigaword-300",
                        help="Model name or path.")
    parser.add_argument("--threshold", type=float, default=0.5,
                        help="Threshold for word matching.")
    parser.add_argument("--context-size", type=int, default=-1,
                        help="Context size to show.")
    parser.add_argument("--start-position", type=int, default=0,
                        help="Start position to be searched.")
    parser.add_argument("--mmap", action="store_true",
                        help="Load the index on disk. "
                        "This option will slow down the search time, "
                        "but will improve memory consumption.")
    # fmt: on
    output_parser = parser.add_argument_group("Output options")
    # fmt: off
    output_parser.add_argument("--json", action="store_true")
    output_parser.add_argument("--profile", action="store_true",
                               help="Profile the runtime statistics.")
    output_parser.add_argument("--log", nargs="?", const="-",
                               help="Output log to stderr or the specified file.")
    output_parser.add_argument("--line-number", "-n", action="store_true",
                               help="Print line number with output lines.")
    output_parser.add_argument("--only-matching", "-o", action="store_true",
                               help="Output only matched patterns.")
    output_parser.add_argument("--quiet", "-q", action="store_true",
                               help="No output.")
    # fmt: on
    parser.epilog = f"""
    Available pre-trained gensim models: {GENSIM_PRETRAINED_MODELS}
    """
    return parser.parse_args()


def search_lines(
    pattern: Pattern,
    file_path: str,
    file_index: IndexInvertedFile,
    searcher: SearchIndexInvertedFile,
    tokenizer: Tokenizer,
    threshold: float = 0.5,
    start_position: int = 0,
    line_number: bool = False,
    only_matching: bool = False,
):
    matched_lines = defaultdict(list)
    for matched in searcher.search(pattern, threshold=threshold, start=start_position):
        matched_lines[file_index.get_line_number(matched.begin)].append(matched)

    # for highlight
    matched_tokens: list[list[str]] = [
        [
            re.escape(tokenizer.tokens[token])
            for token in file_index.tokens[m.begin : m.end].tolist()
        ]
        for _, matches in matched_lines.items()
        for m in matches
    ]
    matched_re_pattern = re.compile(
        r"(\s*)".join(
            [
                "("
                + "|".join(
                    sorted(matched_token_set, key=lambda k: len(k), reverse=True)
                )
                + ")"
                for matched_token_set in map(set, zip(*matched_tokens))
            ]
        ),
        flags=re.IGNORECASE,
    )

    def _highlight(re_match: re.Match) -> str:
        return termcolor.colored(re_match.group(), "red")

    num_matched = 0
    with open(file_path) as f:
        for line_num, matches in matched_lines.items():
            prefix_string = ""
            if line_number:
                prefix_string += termcolor.colored(f"{line_num + 1}:", "green")

            f.seek(file_index.get_byte_offset(line_num))
            line = f.readline().rstrip()

            if only_matching:
                pos = 0
                for _ in matches:
                    if span := matched_re_pattern.search(line, pos=pos):
                        pos = span.end()
                        print(prefix_string + _highlight(span))
                    num_matched += 1
            else:
                line = matched_re_pattern.sub(_highlight, line)
                print(prefix_string + line)
                num_matched += 1

    return num_matched


def main(args: Namespace) -> None:
    stopwatch.timers.reset(profile=args.profile)

    if getattr(args, "log", None):
        logging_config_kwargs = {}
        if args.log == "-":
            logging_config_kwargs["stream"] = sys.stderr
        else:
            logging_config_kwargs["filename"] = args.log
        logging.basicConfig(
            format="| %(asctime)s | %(levelname)s | %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
            level="INFO",
            force=True,
            **logging_config_kwargs,
        )
    else:
        logging.disable(logging.ERROR)

    logger = logging.getLogger("search")

    logger.info(args)

    with stopwatch.timers["load/embedding"]:
        embedding: Embedding = get_embedding(args.backend).build(
            args.model, mmap=args.mmap
        )
    with stopwatch.timers["load/tokenizer"]:
        tokenizer: Tokenizer = get_tokenizer(args.backend).build(args.model)

    indexes = IndexInvertedFileCollection.load(args.index, mmap=args.mmap)
    for file_path, file_index in zip(indexes.paths, indexes.indexes):
        searcher = SearchIndexInvertedFile(file_index, embedding)
        logger.info(f"Search: {file_path}")

        def _search(pattern_str: str) -> int:
            pattern_tokens = tokenizer(pattern_str)
            pattern_embeddings = embedding(pattern_tokens)
            pattern = Pattern.build(
                [{token} for token in pattern_tokens],
                np.split(pattern_embeddings, len(pattern_embeddings)),
            )
            logger.info(f"Pattern length: {len(pattern):,}")
            num_matched = 0

            if args.quiet:
                for matched in searcher.search(
                    pattern, threshold=args.threshold, start=args.start_position
                ):
                    num_matched += 1
            elif args.json:
                for matched in searcher.search(
                    pattern, threshold=args.threshold, start=args.start_position
                ):
                    d = {
                        "context": tokenizer.decode(
                            file_index.tokens[
                                max(matched.begin - args.context_size, 0) : min(
                                    matched.end + args.context_size,
                                    len(file_index.tokens),
                                )
                            ].tolist()
                        ),
                        "matched_tokens": tokenizer.decode(
                            file_index.tokens[matched.begin : matched.end].tolist()
                        ),
                        "score": [float(f"{x:.4f}") for x in matched.scores.tolist()],
                    }
                    print(json.dumps(d, ensure_ascii=False))
                    num_matched += 1
            elif args.context_size >= 0:
                for matched in searcher.search(
                    pattern, threshold=args.threshold, start=args.start_position
                ):
                    matched_context = tokenizer.decode(
                        file_index.tokens[
                            max(matched.begin - args.context_size, 0) : min(
                                matched.end + args.context_size, len(file_index.tokens)
                            )
                        ].tolist()
                    )
                    print(" ".join(matched_context))
                    num_matched += 1
            else:
                num_matched = search_lines(
                    pattern,
                    file_path,
                    file_index,
                    searcher,
                    tokenizer,
                    threshold=args.threshold,
                    start_position=args.start_position,
                    line_number=args.line_number,
                    only_matching=args.only_matching,
                )

            return num_matched

        num_matched = _search(args.pattern)
        logger.info(f"Number of matched lines: {num_matched}")

    if args.profile:
        print(f"elapsed_time\t{stopwatch.timers.elapsed_time}", file=sys.stderr)
        print(f"ncalls\t{stopwatch.timers.ncalls}", file=sys.stderr)


def cli_main() -> None:
    args = parse_args()
    main(args)


if __name__ == "__main__":
    cli_main()
