from __future__ import annotations

from typing import Generator

import numpy as np

from softmatcha import stopwatch
from softmatcha.embeddings import Embedding
from softmatcha.modules import softset_membership_batch
from softmatcha.struct import IndexInvertedFile, Pattern, SparseVector
from softmatcha.typing import Vector

from .base import Search, SearchIndex


class SearchIndexInvertedFile(SearchIndex):
    """SearchIndex class to find the given pattern from a text using the index.

    Args:
        index (Index): An index to be used for searching the text quickly.
        embedding (Embedding): Embeddings.
    """

    def __init__(self, index: IndexInvertedFile, embedding: Embedding) -> None:
        super().__init__(index, embedding)
        self.vocabulary_embeddings = embedding.embeddings[self.index.vocabulary]

    index: IndexInvertedFile

    def compute_exact_match(self, pattern: Pattern, embedding: Embedding) -> Vector:
        """Compute the similarity between pattern and vocabulary.

        Args:
            pattern (Pattern): Pattern token sets and their embeddings of shape `(P, D)`.
            embedding (Embedding): Embedding.

        Returns:
            Vector: Match matrix of shape `(P, V)`,
              where the matched element is set to 1, otherwise 0.
        """
        scores = np.zeros((len(pattern), len(embedding)), dtype=np.float32)
        with stopwatch.timers["membership"]:
            for i, p in enumerate(pattern.tokens):
                scores[i, list(p)] = 1.0
        return scores

    def compute_similarity(self, pattern: Pattern, embedding: Embedding) -> Vector:
        """Compute the similarity between pattern and vocabulary.

        Args:
            pattern (Pattern): Pattern token sets and their embeddings of shape `(P, D)`.
            embedding (Embedding): Embedding.

        Returns:
            Vector: Similarity matrix of shape `(P, V)`.
        """
        scores = np.zeros((len(pattern), len(embedding)), dtype=np.float32)
        with stopwatch.timers["membership"]:
            scores[:, self.index.vocabulary] = softset_membership_batch(
                pattern.embeddings, self.vocabulary_embeddings
            )
        return scores

    @stopwatch.timers("search", generator=True)
    def search(
        self, pattern: Pattern, threshold: float = 0.5, start: int = 0
    ) -> Generator[Search.Match]:
        """Search for the pattern from the given text.

        Args:
            pattern (Pattern): Pattern token set embeddings.
            threshold (float): Threshold for matched scores.
              The range of matched scores are `[0,1]`, and if `score > threshold`,
              the text token is regarded as matched the pattern token.
            start (int): Start position to be searched.

        Yields:
            Match: Yield the Match object when a subsequence of text matches the pattern.
        """
        pattern_length = len(pattern)

        # Compute pattern--vocabualry pairwise similarity.
        if threshold >= 1.0:
            scores = self.compute_exact_match(pattern, self.embedding)
        else:
            scores = self.compute_similarity(pattern, self.embedding)

        # Concatenate the matched token index vectors for each pattern.
        matched_vocabulary_indices = [[] for _ in range(pattern_length)]
        for pattern_idx, vocabulary_idx in zip(*(scores >= threshold).nonzero()):
            matched_vocabulary_indices[pattern_idx].append(vocabulary_idx)

        is_matched = [
            [self.index.inverted_lists.getrow(vocab_idx) for vocab_idx in vocab_idxs]
            for vocab_idxs in matched_vocabulary_indices
        ]

        # shift and calculate the intersection to obtain the matched pattern.
        rare_ordered_pattern_indices = np.array(
            [sum(map(len, m)) for m in is_matched]
        ).argsort()

        p = rare_ordered_pattern_indices[0]
        with stopwatch.timers["union"]:
            matches = SparseVector.union_from_iterable(is_matched[p])
        with stopwatch.timers["shift+and"]:
            matches <<= p

        for p in rare_ordered_pattern_indices[1:]:
            with stopwatch.timers["union"]:
                union = SparseVector.union_from_iterable(is_matched[p])
            with stopwatch.timers["shift+and"]:
                matches = (matches >> p & union) << p

        with stopwatch.timers["sort"]:
            matches = matches.sort()

        pattern_arange = np.arange(pattern_length)
        for begin in matches.indices.tolist():
            if begin < start:
                continue
            end = begin + pattern_length
            matched_tokens = self.index.tokens[begin:end]
            match_scores = scores[pattern_arange, matched_tokens]
            yield self.Match(begin, end, match_scores)
