import os
import numpy
import torch
from nn.deep_sets import MLP
from nn.flow import GlowInvertibleNetwork
from utils import get_device


GOOGLE_WORD2VEC_PATH = "word_analogy/data/google_word2vec_300.pt"


class BaseWord2VecWrapper(torch.nn.Module):
    """
    Base class of word2vec wrapper for training word analogy functions

    Example:
        >>> model = BaseWord2VecWrapper()
        >>> model.most_similar("he", "king", "she")
        'queen'
        >>> queen = model.to_vec("king") + model.to_vec("she") - model.to_vec("he")
        >>> cheaper = model.to_vec("faster") + model.to_vec("cheap") - model.to_vec("fast")
        >>> batch = torch.cat([queen.unsqueeze(0), cheaper.unsqueeze(0)], dim=0)
        >>> model.index_to_word[model._batch_most_similar(batch)]
        array(['queen', 'cheaper'], dtype='<U98')
    """

    def __init__(self, wv=None):
        """
        Load "word2vec-google-news-300" if wv is None
        """
        super().__init__()
        if wv is None:
            if os.path.exists(GOOGLE_WORD2VEC_PATH):
                dic = torch.load(GOOGLE_WORD2VEC_PATH)
                self.dim = dic["dim"]
                self.normalized_vectors = dic["normalized_vectors"]
                self.index_to_word = dic["index_to_word"]
                self.word_to_index = dic["word_to_index"]
            else:
                import gensim.downloader
                wv = gensim.downloader.load("word2vec-google-news-300")
                self.load_word2vec(wv)
                torch.save({
                    "dim": self.dim,
                    "normalized_vectors": self.normalized_vectors,
                    "index_to_word": self.index_to_word,
                    "word_to_index": self.word_to_index,
                }, GOOGLE_WORD2VEC_PATH)
        else:
            self.load_word2vec(wv)

        self.normalized_vectors = self.normalized_vectors.to(get_device())

    def load_word2vec(self, wv):
        self.dim = wv.vector_size
        self.normalized_vectors = torch.nn.functional.normalize(torch.Tensor(wv.vectors.copy()), dim=1)
        self.index_to_word = numpy.array(wv.index2entity)
        self.word_to_index = {word: vocab_obj.index for word, vocab_obj in wv.vocab.items()}

    def to_vec(self, word):
        return self.normalized_vectors[self.word_to_index[word]]

    def most_similar(self, a_str, b_str, c_str, exclude_abc=False):
        with torch.no_grad():
            a_batch = numpy.array([self.word_to_index[a_str]])
            b_batch = numpy.array([self.word_to_index[b_str]])
            c_batch = numpy.array([self.word_to_index[c_str]])
            pred_id_batch = self._batch_most_similar_by_index(a_batch, b_batch, c_batch, exclude_abc=exclude_abc)
            return self.index_to_word[pred_id_batch[0]]

    def batch_similarity(self, x, normalized):
        """
        [batch, dim], [batch, dim] -> [batch]
        Note that the second argument should be normalized beforehand
        """
        x = torch.nn.functional.normalize(x, dim=1)
        return (x * normalized).sum(dim=1)

    def _batch_most_similar_by_index(self, a_batch, b_batch, c_batch, exclude_abc=False):
        """ [batch], [batch], [batch] -> [batch] """
        pred = self.forward_by_index(a_batch, b_batch, c_batch)
        exclude_batches = [a_batch, b_batch, c_batch] if exclude_abc else []
        return self._batch_most_similar(pred, exclude_batches=exclude_batches)

    def _batch_most_similar(self, batch, exclude_batches=[]):
        """
        return the most similar indices for each vector
        [batch_size, dim], list([batch_size]) -> [batch_size]
        """
        assert batch.shape[1] == self.dim
        normalized_batch = torch.nn.functional.normalize(batch, dim=1)
        batch_similarities = torch.mm(self.normalized_vectors,
                                      normalized_batch.transpose(0, 1))  # [vacab_size, batch_size]
        for exclude_batch in exclude_batches:
            for i in range(len(batch)):
                batch_similarities[exclude_batch[i]][i] = -1
        return batch_similarities.argmax(axis=0)

    def forward(self, a, b, c):
        """ [batch, dim], [batch, dim], [batch, dim] -> [batch, dim] """
        return b + c - a

    def forward_by_index(self, a_id, b_id, c_id):
        return self.forward(self.normalized_vectors[a_id], self.normalized_vectors[b_id], self.normalized_vectors[c_id])

    def eval_data(self, data_ds, batch_size=64, exclude_abc=False):
        a_batch, b_batch, c_batch, ds_batch = data_ds
        correct = 0
        self.eval()
        perm = numpy.arange(len(a_batch))
        for index in numpy.array_split(perm, (len(perm) + batch_size - 1) // batch_size):
            with torch.no_grad():
                d_minibatch = self._batch_most_similar_by_index(a_batch[index], b_batch[index], c_batch[index],
                                                                exclude_abc=exclude_abc)
            correct += sum([(d in ds_batch[index][i]) for i, d in enumerate(d_minibatch.cpu().detach().numpy())])
        return correct


class IdentityWord2VecWrapper(BaseWord2VecWrapper):
    def __init__(self, wv=None):
        super().__init__(wv=wv)
        self.dummy_variable = torch.nn.Parameter(torch.rand(1))

    def forward(self, a, b, c):
        """ [batch, dim], [batch, dim], [batch, dim] -> [batch, dim] """
        return b + c - a + self.dummy_variable - self.dummy_variable


class MLPWord2VecWrapper(BaseWord2VecWrapper):
    def __init__(self, wv=None, hidden_dim=8, layer_num=3, activation=torch.nn.functional.relu):
        super().__init__(wv=wv)
        self.hidden_dim = hidden_dim
        self.layer_num = layer_num
        self.mlp = MLP(self.dim, self.dim, hidden_dim=hidden_dim, layer_num=layer_num, activation=activation)

    def forward(self, a, b, c):
        return self.mlp(b + c - a)


class MLPConcatWord2VecWrapper(BaseWord2VecWrapper):
    def __init__(self, wv=None, hidden_dim=8, layer_num=3, activation=torch.nn.functional.relu):
        super().__init__(wv=wv)
        self.hidden_dim = hidden_dim
        self.layer_num = layer_num
        self.mlp = MLP(self.dim * 3, self.dim, hidden_dim=hidden_dim, layer_num=layer_num, activation=activation)

    def forward(self, a, b, c):
        return self.mlp(torch.cat([a, b, c], dim=1))


class AGNWord2VecWrapper(BaseWord2VecWrapper):
    def __init__(self, wv=None, hidden_dim=64, layer_num=4):
        super().__init__(wv=wv)
        self.hidden_dim = hidden_dim
        self.layer_num = layer_num
        self.phi = GlowInvertibleNetwork(self.dim, layer_num=layer_num, hidden_dim=hidden_dim)

    def forward(self, a, b, c):
        return self.phi(self.phi(b) + self.phi(c) - self.phi(a), rev=True)


if __name__ == "__main__":
    import doctest
    doctest.testmod()
