import numpy as np
import pytest

from softmatcha.embeddings import EmbeddingGensim
from softmatcha.struct import Pattern
from softmatcha.struct.token_embeddings import TokenEmbeddings
from softmatcha.tokenizers.moses import TokenizerMoses

from .quick import SearchQuick


class TestSearchQuick:
    def test_search_fullmatch_glove(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove("Soft set operations.")
        text_embeddings = embed_glove(text_tokens)
        pattern = Pattern.build(
            [{token} for token in text_tokens],
            np.split(text_embeddings, len(text_embeddings)),
        )
        searcher = SearchQuick(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=1 - 1e-5)
        assert next(res)
        with pytest.raises(StopIteration):
            next(res)

    def test_search_subseq_match(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove("Soft set operations based on linear subspace")
        text_embeddings = embed_glove(text_tokens)
        pattern_tokens = tokenizer_glove("set operations")
        pattern_embeddings = embed_glove(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchQuick(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=1 - 1e-5)
        assert next(res)
        with pytest.raises(StopIteration):
            next(res)

    def test_search_subseq_match_start_position(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove(
            "Soft set operations based on linear subspace and normal hard set operations"
        )
        text_embeddings = embed_glove(text_tokens)
        pattern_tokens = tokenizer_glove("set operations")
        pattern_embeddings = embed_glove(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchQuick(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=1 - 1e-5, start=2)
        matched = next(res)
        assert matched.begin == 10
        assert matched.end == 12
        with pytest.raises(StopIteration):
            next(res)

    def test_search_semantic_match_glove(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove("He watched the shooting star.")
        text_embeddings = embed_glove(text_tokens)
        pattern_tokens = tokenizer_glove("saw a shooting star")
        pattern_embeddings = embed_glove(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchQuick(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=0.5)
        assert next(res)
        with pytest.raises(StopIteration):
            next(res)

    def test_search_semantic_no_match_glove(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove("He saw a television star.")
        text_embeddings = embed_glove(text_tokens)
        pattern_tokens = tokenizer_glove("saw a shooting star")
        pattern_embeddings = embed_glove(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchQuick(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=0.5)
        with pytest.raises(StopIteration):
            next(res)

    def test_search_memorization(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove(
            "Did you see stars yesterday? "
            "Do you see stars everyday? "
            "I saw many stars and also "
            "I heard that he watched the television star and saw a shooting star yesterday."
        )
        text_embeddings = embed_glove(text_tokens)
        pattern_tokens = tokenizer_glove("saw a shooting star")
        pattern_embeddings = embed_glove(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchQuick(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=0.7)
        matched = next(res)
        assert matched is not None
        assert text_tokens[matched.begin : matched.end] == pattern_tokens
