import json
import sys
from collections import namedtuple, defaultdict
from pathlib import Path

import jsonlines

tgt_language = sys.argv[1]
scores_path = Path(__file__).parent / f"en-{tgt_language}.insertions.roberta-large.scores.top10.json"
key_path = Path(__file__).parent / f"en-{tgt_language}.key.tsv"
output_path = Path(__file__).parent / f"en-{tgt_language}.insertions.roberta-large.source_data.top10.jsonl"

with open(scores_path) as f:
    scores = json.load(f)

max_insertions = 3

Sense = namedtuple("Sense", ["src_word", "sense", "top_correct_insertions", "top_correct_scores", "top_wrong_insertions", "top_wrong_scores"])
senses = []
for sense, scores_dict in scores.items():
    src_word = sense.split()[0]
    stop_words = {src_word} | {"the", "a", "an", "and", ",", "-"}
    translations = " ".join(sense.split()[1:])[1:-1]
    correct_insertions = {insertion.split()[0] for insertion in scores_dict if insertion.endswith(" (correct)")} - stop_words
    wrong_insertions = {insertion.split()[0] for insertion in scores_dict if insertion.endswith(" (wrong)")} - stop_words
    top_correct_insertions = sorted(correct_insertions, key=lambda i: scores_dict[i + " (correct)"], reverse=True)[:max_insertions]
    top_wrong_insertions = sorted(wrong_insertions, key=lambda i: scores_dict[i + " (wrong)"], reverse=True)[:max_insertions]
    senses.append(Sense(
        src_word=src_word,
        sense=translations,
        top_correct_insertions=top_correct_insertions,
        top_wrong_insertions=top_wrong_insertions,
        top_correct_scores=[scores_dict[insertion + " (correct)"] for insertion in top_correct_insertions],
        top_wrong_scores=[scores_dict[insertion + " (wrong)"] for insertion in top_wrong_insertions],
    ))
senses.sort(key=lambda sense: sum(sense.top_correct_scores) * sum(sense.top_wrong_scores) / max_insertions, reverse=True)
# senses = senses[:int(len(senses) / 3)]
for sense in senses:
    print(sense)
senses_dict = {(sense.src_word, sense.sense): sense for sense in senses}

with open(key_path) as f_key, jsonlines.open(output_path, "w") as f_out:
    seen_senses = defaultdict(set)
    for key_line in f_key:
        _, _, src_word, correct_translations, *_ = key_line.strip().split("\t")
        seen_senses[src_word].add(correct_translations)
        sense = senses_dict.get((src_word, correct_translations), None)
        f_out.write({
            "src_word": src_word,
            "cluster_id": len(seen_senses[src_word]),
            "correct_insertions": sense.top_correct_insertions if sense is not None else [],
            "wrong_insertions": sense.top_wrong_insertions if sense is not None else [],
        })
