import numpy as np
import pytest

from softmatcha.embeddings import EmbeddingGensim, EmbeddingTransformers
from softmatcha.struct import Pattern, TokenEmbeddings
from softmatcha.tokenizers import TokenizerTransformers
from softmatcha.tokenizers.moses import TokenizerMoses

from .naive import SearchNaive


class TestSearchNaive:
    def test_search_fullmatch_glove(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove("Soft set operations.")
        text_embeddings = embed_glove(text_tokens)
        pattern = Pattern.build(
            [{token} for token in text_tokens],
            np.split(text_embeddings, len(text_embeddings)),
        )
        searcher = SearchNaive(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=1 - 1e-5)
        assert next(res)
        with pytest.raises(StopIteration):
            next(res)

    def test_search_fullmatch_bert(
        self, embed_bert: EmbeddingTransformers, tokenizer_bert: TokenizerTransformers
    ):
        text_tokens = tokenizer_bert("Soft set operations.")
        text_embeddings = embed_bert(text_tokens)
        pattern = Pattern.build(
            [{token} for token in text_tokens],
            np.split(text_embeddings, len(text_embeddings)),
        )
        searcher = SearchNaive(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=1 - 1e-5)
        assert next(res)
        with pytest.raises(StopIteration):
            next(res)

    def test_search_subseq_match(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove("Soft set operations based on linear subspace")
        text_embeddings = embed_glove(text_tokens)
        pattern_tokens = tokenizer_glove("set operations")
        pattern_embeddings = embed_glove(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchNaive(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=1 - 1e-5)
        assert next(res)
        with pytest.raises(StopIteration):
            next(res)

    def test_search_subseq_match_start_position(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove(
            "Soft set operations based on linear subspace and normal hard set operations"
        )
        text_embeddings = embed_glove(text_tokens)
        pattern_tokens = tokenizer_glove("set operations")
        pattern_embeddings = embed_glove(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchNaive(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=1 - 1e-5, start=2)
        matched = next(res)
        assert matched.begin == 10
        assert matched.end == 12
        with pytest.raises(StopIteration):
            next(res)

    def test_search_semantic_match_glove(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove("He watched the shooting star.")
        text_embeddings = embed_glove(text_tokens)
        pattern_tokens = tokenizer_glove("saw a shooting star")
        pattern_embeddings = embed_glove(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchNaive(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=0.5)
        assert next(res)
        with pytest.raises(StopIteration):
            next(res)

    def test_search_semantic_match_bert(
        self, embed_bert: EmbeddingTransformers, tokenizer_bert: TokenizerTransformers
    ):
        text_tokens = tokenizer_bert("He watched the shooting star.")
        text_embeddings = embed_bert(text_tokens)
        pattern_tokens = tokenizer_bert("watch a shooting star")
        pattern_embeddings = embed_bert(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchNaive(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=0.6)
        assert next(res)
        with pytest.raises(StopIteration):
            next(res)

    def test_search_semantic_no_match_glove(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove("He saw a television star.")
        text_embeddings = embed_glove(text_tokens)
        pattern_tokens = tokenizer_glove("saw a shooting star")
        pattern_embeddings = embed_glove(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchNaive(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=0.75)
        with pytest.raises(StopIteration):
            next(res)

    def test_search_long(
        self, embed_glove: EmbeddingGensim, tokenizer_glove: TokenizerMoses
    ):
        text_tokens = tokenizer_glove(
            "Did you see stars yesterday? "
            "Do you see stars everyday? "
            "I saw many stars and also "
            "I heard that he watched the television star and saw a shooting star yesterday."
        )
        text_embeddings = embed_glove(text_tokens)
        pattern_tokens = tokenizer_glove("saw a shooting star")
        pattern_embeddings = embed_glove(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchNaive(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=0.7)
        matched = next(res)
        assert matched is not None
        assert text_tokens[matched.begin : matched.end] == pattern_tokens

    def test_grep_semantic_no_match_bert(
        self, embed_bert: EmbeddingTransformers, tokenizer_bert: TokenizerTransformers
    ):
        text_tokens = tokenizer_bert("He saw a television star.")
        text_embeddings = embed_bert(text_tokens)
        pattern_tokens = tokenizer_bert("saw a shooting star")
        pattern_embeddings = embed_bert(pattern_tokens)
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(pattern_embeddings, len(pattern_embeddings)),
        )
        searcher = SearchNaive(pattern)
        text = TokenEmbeddings(text_tokens, text_embeddings)
        res = searcher.search(text, threshold=0.5)
        with pytest.raises(StopIteration):
            next(res)
