#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import glob

import numpy as np


DIM = 1024


def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
    target_ids = [tid for tid in target_embs]
    source_mat = np.stack(source_embs.values(), axis=0)
    normalized_source_mat = source_mat / np.linalg.norm(
        source_mat, axis=1, keepdims=True
    )
    target_mat = np.stack(target_embs.values(), axis=0)
    normalized_target_mat = target_mat / np.linalg.norm(
        target_mat, axis=1, keepdims=True
    )
    sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
    if return_sim_mat:
        return sim_mat
    neighbors_map = {}
    for i, sentence_id in enumerate(source_embs):
        idx = np.argsort(sim_mat[i, :])[::-1][:k]
        neighbors_map[sentence_id] = [target_ids[tid] for tid in idx]
    return neighbors_map


def load_embeddings(directory, LANGS):
    sentence_embeddings = {}
    sentence_texts = {}
    for lang in LANGS:
        sentence_embeddings[lang] = {}
        sentence_texts[lang] = {}
        lang_dir = f"{directory}/{lang}"
        embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
        for embed_file in embedding_files:
            shard_id = embed_file.split(".")[-1]
            embeddings = np.fromfile(embed_file, dtype=np.float32)
            num_rows = embeddings.shape[0] // DIM
            embeddings = embeddings.reshape((num_rows, DIM))

            with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file:
                for idx, line in enumerate(sentence_file):
                    sentence_id, sentence = line.strip().split("\t")
                    sentence_texts[lang][sentence_id] = sentence
                    sentence_embeddings[lang][sentence_id] = embeddings[idx, :]

    return sentence_embeddings, sentence_texts


def compute_accuracy(directory, LANGS):
    sentence_embeddings, sentence_texts = load_embeddings(directory, LANGS)

    top_1_accuracy = {}

    top1_str = " ".join(LANGS) + "\n"
    for source_lang in LANGS:
        top_1_accuracy[source_lang] = {}
        top1_str += f"{source_lang} "
        for target_lang in LANGS:
            top1 = 0
            top5 = 0
            neighbors_map = compute_dist(
                sentence_embeddings[source_lang], sentence_embeddings[target_lang]
            )
            for sentence_id, neighbors in neighbors_map.items():
                if sentence_id == neighbors[0]:
                    top1 += 1
                if sentence_id in neighbors[:5]:
                    top5 += 1
            n = len(sentence_embeddings[target_lang])
            top1_str += f"{top1/n} "
        top1_str += "\n"

    print(top1_str)
    print(top1_str, file=open(f"{directory}/accuracy", "w"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Analyze encoder outputs")
    parser.add_argument("directory", help="Source language corpus")
    parser.add_argument("--langs", help="List of langs")
    args = parser.parse_args()
    langs = args.langs.split(",")
    compute_accuracy(args.directory, langs)
