# Copyright (c) 2019-present, Anon.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import io
from logging import getLogger

import numpy as np
import torch

logger = getLogger()


def load_fasttext_model(path):
    """
    Load a binarized fastText model.
    """
    try:
        import fastText
    except ImportError:
        raise Exception(
            "Unable to import fastText. Please install fastText for Python: "
            "https://github.com/facebookresearch/fastText"
        )
    return fastText.load_model(path)


def read_txt_embeddings(path, params):
    """
    Reload pretrained embeddings from a text file.
    """
    word2id = {}
    vectors = []

    # load pretrained embeddings
    _emb_dim_file = params.emb_dim
    with io.open(path, "r", encoding="utf-8", newline="\n", errors="ignore") as f:
        for i, line in enumerate(f):
            if i == 0:
                split = line.split()
                assert len(split) == 2
                assert _emb_dim_file == int(split[1])
                continue
            word, vect = line.rstrip().split(" ", 1)
            vect = np.fromstring(vect, sep=" ")
            if word in word2id:
                logger.warning('Word "%s" found twice!' % word)
                continue
            if not vect.shape == (_emb_dim_file,):
                logger.warning(
                    'Invalid dimension (%i) for word "%s" in line %i.'
                    % (vect.shape[0], word, i)
                )
                continue
            assert vect.shape == (_emb_dim_file,)
            word2id[word] = len(word2id)
            vectors.append(vect[None])

    assert len(word2id) == len(vectors)
    logger.info("Loaded %i pretrained word embeddings from %s" % (len(vectors), path))

    # compute new vocabulary / embeddings
    embeddings = np.concatenate(vectors, 0)
    embeddings = torch.from_numpy(embeddings).float()

    assert embeddings.size() == (len(word2id), params.emb_dim)
    return word2id, embeddings


def load_bin_embeddings(path, params):
    """
    Reload pretrained embeddings from a fastText binary file.
    """
    model = load_fasttext_model(path)
    assert model.get_dimension() == params.emb_dim
    words = model.get_labels()
    logger.info("Loaded binary model from %s" % path)

    # compute new vocabulary / embeddings
    embeddings = np.concatenate([model.get_word_vector(w)[None] for w in words], 0)
    embeddings = torch.from_numpy(embeddings).float()
    word2id = {w: i for i, w in enumerate(words)}
    logger.info("Generated embeddings for %i words." % len(words))

    assert embeddings.size() == (len(word2id), params.emb_dim)
    return word2id, embeddings


def load_embeddings(path, params):
    """
    Reload pretrained embeddings.
    """
    if path.endswith(".bin"):
        return load_bin_embeddings(path, params)
    else:
        return read_txt_embeddings(path, params)
