from __future__ import annotations

from dataclasses import dataclass
from typing import Generator

import numpy as np

from softmatcha import stopwatch
from softmatcha.struct import Pattern, TokenEmbeddings

from .base import Search, SearchScan


class SearchQuick(SearchScan):
    """SearchQuick quickly finds the given pattern from a text.

    Args:
        pattern (Pattern): Pattern token set embeddings.
    """

    def __init__(self, pattern: Pattern) -> None:
        super().__init__(pattern)
        self.shift_table: dict[int, SearchQuick.Shift] = {}

    @dataclass
    class Shift:
        """Shift object.

        shift (int): The number of shift.
        score (float): Match score.
        """

        shift: int
        score: float

    @stopwatch.timers("search", generator=True)
    def search(
        self, text: TokenEmbeddings, threshold: float = 0.5, start: int = 0
    ) -> Generator[Search.Match]:
        """Search for the pattern from the given text.

        Args:
            text (TokenEmbeddings): Text token embeddings to be searched.
            threshold (float): Threshold for matched scores.
              The range of matched scores are `[0,1]`, and if `score > threshold`,
              the text token is regarded as matched the pattern token.
            start (int): Start position to be searched.

        Yields:
            Match: Yield the `Match` object when a subsequence of text matches the pattern.
        """
        pattern_length = len(self.pattern)
        text_length = len(text)

        i = start
        while i <= text_length - pattern_length:
            match_scores = np.zeros(pattern_length, dtype=np.float32)

            # Match the pattern.
            for j in reversed(range(pattern_length)):
                if (
                    text.tokens[i + j] in self.shift_table
                    and (pattern_length - self.shift_table[text.tokens[i + j]].shift)
                    == j
                ):
                    # Look-up the score if it has been calculated.
                    match_scores[j] = self.shift_table[text.tokens[i + j]].score
                else:
                    # Otherwise, calculate the score.
                    with stopwatch.timers["membership"]:
                        score = self.pattern.embeddings[j].membership(
                            text.embeddings[i + j]
                        )

                    if score > threshold:
                        match_scores[j] = score
                    else:
                        break
            else:
                if (match_scores > threshold).all():
                    yield self.Match(i, i + pattern_length, match_scores)

            next_t = i + pattern_length
            if next_t >= text_length:
                break

            # Memorization
            if text.tokens[next_t] not in self.shift_table:
                for j in reversed(range(pattern_length)):
                    with stopwatch.timers["membership"]:
                        score = self.pattern.embeddings[j].membership(
                            text.embeddings[next_t]
                        )
                    if score > threshold:
                        self.shift_table[text.tokens[next_t]] = self.Shift(
                            pattern_length - j, score
                        )
                        break
                else:
                    # When any pattern token does not matched.
                    self.shift_table[text.tokens[next_t]] = self.Shift(
                        pattern_length + 1, 0.0
                    )
            # Shift pattern tokens.
            i += self.shift_table[text.tokens[next_t]].shift
