from __future__ import annotations

from typing import Generator

import numpy as np

from softmatcha import stopwatch
from softmatcha.struct import TokenEmbeddings

from .base import Search, SearchScan


class SearchNaive(SearchScan):
    """SearchNaive naively finds the given pattern from a text.

    Args:
        pattern (Pattern): Pattern token set embeddings.
    """

    @stopwatch.timers("search", generator=True)
    def search(
        self, text: TokenEmbeddings, threshold: float = 0.5, start: int = 0
    ) -> Generator[Search.Match]:
        """Search for the pattern from the given text.

        Args:
            text (TokenEmbeddings): Text token embeddings to be searched.
            threshold (float): Threshold for matched scores.
              The range of matched scores are `[0,1]`, and if `score > threshold`,
              the text token is regarded as matched the pattern token.
            start (int): Start position to be searched.

        Yields:
            Match: Yield the `Match` object when a subsequence of text matches the pattern.
        """
        pattern_length = len(self.pattern)

        i = start
        while i <= len(text) - pattern_length:
            match_scores = np.zeros(pattern_length, dtype=np.float32)
            for j in range(pattern_length):
                with stopwatch.timers["membership"]:
                    score = self.pattern.embeddings[j].membership(
                        text.embeddings[i + j]
                    )
                if score > threshold:
                    match_scores[j] = score
                else:
                    break
            if (match_scores > threshold).all():
                yield self.Match(i, i + pattern_length, match_scores)
            i += 1
