"""
Thought Vector Class
---------------------
"""

import functools

import torch

from textattack.shared import AbstractWordEmbedding, WordEmbedding, utils

from .sentence_encoder import SentenceEncoder


class ThoughtVector(SentenceEncoder):
    """A constraint on the distance between two sentences' thought vectors.

    Args:
        word_embedding (textattack.shared.AbstractWordEmbedding): The word embedding to use
    """

    def __init__(self, embedding=None, **kwargs):
        if embedding is None:
            embedding = WordEmbedding.counterfitted_GLOVE_embedding()
        if not isinstance(embedding, AbstractWordEmbedding):
            raise ValueError(
                "`embedding` object must be of type `textattack.shared.AbstractWordEmbedding`."
            )
        self.word_embedding = embedding
        super().__init__(**kwargs)

    def clear_cache(self):
        self._get_thought_vector.cache_clear()

    @functools.lru_cache(maxsize=2**10)
    def _get_thought_vector(self, text):
        """Sums the embeddings of all the words in ``text`` into a "thought
        vector"."""
        embeddings = []
        for word in utils.words_from_text(text):
            embedding = self.word_embedding[word]
            if embedding is not None:  # out-of-vocab words do not have embeddings
                embeddings.append(embedding)
        embeddings = torch.tensor(embeddings)
        return torch.mean(embeddings, dim=0)

    def encode(self, raw_text_list):
        return torch.stack([self._get_thought_vector(text) for text in raw_text_list])

    def extra_repr_keys(self):
        """Set the extra representation of the constraint using these keys."""
        return ["word_embedding"] + super().extra_repr_keys()
