from __future__ import annotations

import abc
from dataclasses import dataclass
from typing import Generator

from softmatcha.embeddings.base import Embedding
from softmatcha.struct import Pattern, TokenEmbeddings
from softmatcha.struct.index import Index
from softmatcha.typing import Vector


class Search(abc.ABC):
    """Search base class to find the given pattern from a text."""

    @dataclass
    class Match:
        """Match object.

        begin (int): Begin position of a matched span.
        end (int): End position of a matched span.
        scores (Vector): Match scores of shape `(pattern_len,)`.
        """

        begin: int
        end: int
        scores: Vector


class SearchScan(Search, metaclass=abc.ABCMeta):
    """Search scan class to find the given pattern from a text.

    Args:
        pattern (Pattern): Pattern token set embeddings.
    """

    def __init__(self, pattern: Pattern) -> None:
        self.pattern = pattern

    @abc.abstractmethod
    def search(
        self, text: TokenEmbeddings, threshold: float = 0.5, start: int = 0
    ) -> Generator[Search.Match]:
        """Search for the pattern from the given text.

        Args:
            text (TokenEmbeddings): Text token embeddings to be searched.
            threshold (float): Threshold for matched scores.
              The range of matched scores are `[0,1]`, and if `score > threshold`,
              the text token is regarded as matched the pattern token.
            start (int): Start position to be searched.

        Yields:
            Match: Yield the Match object when a subsequence of text matches the pattern.
        """


class SearchIndex(Search, metaclass=abc.ABCMeta):
    """SearchIndex class to find the given pattern from a text using the index.

    Args:
        index (Index): An index to be used for searching the text quickly.
        embedding (Embedding): Embeddings.
    """

    def __init__(self, index: Index, embedding: Embedding) -> None:
        self.index = index
        self.embedding = embedding

    @abc.abstractmethod
    def search(
        self, pattern: Pattern, threshold: float = 0.5, start: int = 0
    ) -> Generator[Search.Match]:
        """Search for the pattern from the given text.

        Args:
            pattern (Pattern): Pattern token set embeddings.
            threshold (float): Threshold for matched scores.
              The range of matched scores are `[0,1]`, and if `score > threshold`,
              the text token is regarded as matched the pattern token.
            start (int): Start position to be searched.

        Yields:
            Match: Yield the Match object when a subsequence of text matches the pattern.
        """
