from pathlib import Path
from typing import Generator

import h5py
import numpy as np
import pytest

from softmatcha.embeddings import EmbeddingGensim
from softmatcha.struct import IndexInvertedFile, Pattern
from softmatcha.tokenizers.moses import TokenizerMoses

from .index_inverted import SearchIndexInvertedFile


@pytest.fixture
def index(tmp_path: Path, tokenizer_glove: TokenizerMoses) -> IndexInvertedFile:
    text = []
    text += tokenizer_glove("I like the jazz music.")
    text += tokenizer_glove("I have a pen.")

    index_path = str(tmp_path / "index.bin")
    index_root = h5py.File(index_path, mode="w")
    index_group = index_root.create_group("index")
    index_group.create_dataset("tokens", data=np.array(text, dtype=np.int32))
    index_group.create_dataset("line_offsets", data=np.array([0]))
    index_group.create_dataset("byte_offsets", data=np.array([0]))

    return IndexInvertedFile.build(index_group, len(tokenizer_glove))


class TestSearchIndexInvertedFile:
    def test_compute_exact_match(
        self,
        embed_glove: EmbeddingGensim,
        tokenizer_glove: TokenizerMoses,
        index: IndexInvertedFile,
    ):
        searcher = SearchIndexInvertedFile(index, embed_glove)
        pattern_tokens = tokenizer_glove("the blues music")
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(embed_glove(pattern_tokens), len(pattern_tokens)),
        )
        scores = searcher.compute_exact_match(pattern, embed_glove)
        assert list(scores.shape) == [len(pattern), len(embed_glove)]
        assert np.all(scores[np.arange(3), pattern_tokens] == 1.0)
        assert np.all(scores.sum() == 3.0)

    def test_compute_similarity(
        self,
        embed_glove: EmbeddingGensim,
        tokenizer_glove: TokenizerMoses,
        index: IndexInvertedFile,
    ):
        searcher = SearchIndexInvertedFile(index, embed_glove)
        pattern_tokens = tokenizer_glove("the blues music")
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(embed_glove(pattern_tokens), len(pattern_tokens)),
        )
        scores = searcher.compute_similarity(pattern, embed_glove)
        assert list(scores.shape) == [len(pattern), len(embed_glove)]
        assert np.all(scores[:, list(index.vocabulary)] >= 0.0)
        assert np.all(
            scores[
                :,
                list(
                    set(tokenizer_glove.tokens.keys()) - set(index.vocabulary.tolist())
                ),
            ]
            == 0.0
        )

    def test_search(
        self,
        embed_glove: EmbeddingGensim,
        tokenizer_glove: TokenizerMoses,
        index: IndexInvertedFile,
    ):
        searcher = SearchIndexInvertedFile(index, embed_glove)
        pattern_tokens = tokenizer_glove("the blues music")
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(embed_glove(pattern_tokens), len(pattern_tokens)),
        )
        res: Generator[SearchIndexInvertedFile.Match] = searcher.search(
            pattern, threshold=0.55
        )
        matched = next(res)
        assert matched.begin == 2
        assert matched.end == 5
        with pytest.raises(StopIteration):
            next(res)

    def test_search_start_position(
        self,
        tmp_path: Path,
        embed_glove: EmbeddingGensim,
        tokenizer_glove: TokenizerMoses,
    ):
        text = []
        text += tokenizer_glove("I like the jazz music.")
        text += tokenizer_glove("I have a pen.")
        text += tokenizer_glove("I like the jazz music.")

        index_path = str(tmp_path / "index.bin")
        index_root = h5py.File(index_path, mode="w")
        index_group = index_root.create_group("index")
        index_group.create_dataset("tokens", data=np.array(text, dtype=np.int32))
        index_group.create_dataset("line_offsets", data=np.array([0]))
        index_group.create_dataset("byte_offsets", data=np.array([0]))

        index = IndexInvertedFile.build(index_group, len(tokenizer_glove))

        searcher = SearchIndexInvertedFile(index, embed_glove)
        pattern_tokens = tokenizer_glove("the blues music")
        pattern = Pattern.build(
            [{token} for token in pattern_tokens],
            np.split(embed_glove(pattern_tokens), len(pattern_tokens)),
        )
        res: Generator[SearchIndexInvertedFile.Match] = searcher.search(
            pattern, threshold=0.55, start=6
        )
        matched = next(res)
        assert matched.begin == 13
        assert matched.end == 16
        with pytest.raises(StopIteration):
            next(res)
