import itertools
import json
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List

import torch
from sacremoses import MosesTokenizer
from tqdm import tqdm

from tasks.mucow_lrec20_source import MucowSourceSample
from tasks.mucow_wmt19_source import MucowWMT19ContrastiveConditioningTask
from translation_models.fairseq_models import FairseqScoringModel


tgt_language = sys.argv[1]

insertions_path = Path(__file__).parent / f"en-{tgt_language}.insertions.roberta-large.json"
sources_path = Path(__file__).parent / f"en-{tgt_language}.text.txt"
references_path = Path(__file__).parent / f"en-{tgt_language}.ref.txt"
max_insertions = 10
confidence_output_path = Path(__file__).parent / f"en-{tgt_language}.insertions.roberta-large.scores.top{max_insertions}.json"

mucow = MucowWMT19ContrastiveConditioningTask(
    tgt_language=tgt_language,
    evaluator_model=None,

)

with open(insertions_path) as f:
    insertions = json.load(f)


samples: List[MucowSourceSample] = []
with open(mucow.testset_path) as f:
    tokenizer = MosesTokenizer(lang="en")
    for line, sense_key in zip(f, mucow.sense_keys):
        src_sentence = line.strip()
        src_word = sense_key.src_word
        sense = " ".join(sense_key.correct_tgt_words)
        all_correct_insertions = insertions[src_word][sense]
        all_wrong_insertions = set()
        wrong_insertions_lists = []
        for sense_, insertions_ in insertions[src_word].items():
            if sense_ != sense:
                all_wrong_insertions.update(insertions_)
                wrong_insertions_lists.append(insertions_)
        correct_insertions = [
                                 token for token in all_correct_insertions
                                 if token not in all_wrong_insertions
                             ][:max_insertions]
        wrong_insertions = [
                               token for token in itertools.chain.from_iterable(zip(*wrong_insertions_lists))
                               if token not in all_correct_insertions
                           ][:max_insertions]
        tokens = tokenizer.tokenize(src_sentence)
        src_word_plural = mucow._inflect_engine.plural_noun(src_word)
        src_form = None
        if src_word in tokens:
            src_form = src_word
        elif src_word_plural in tokens:
            src_form = src_word_plural
        if src_form is None:
            continue  # Skip unexpected inflections
        sample = MucowSourceSample(
            tgt_language=mucow.tgt_language,
            src_sentence=src_sentence,
            corpus=sense_key.corpus,
            src_word=src_word,
            src_form=src_form,
            cluster_id=None,
            correct_tgt_words=set(sense_key.correct_tgt_words),
        )
        sample.sense = sense
        sample.correct_insertions = correct_insertions
        sample.wrong_insertions = wrong_insertions
        samples.append(sample)

# model_path = Path(__file__).parent.parent.parent / "tests" / "models" / "toy_fairseq_en-de"
# evaluator_model = FairseqScoringModel(
#     name=model_path.name,
#     model_name_or_path=model_path,
#     tokenizer="moses",
#     bpe="fastbpe",
# )

if tgt_language == "de":
    hub_interface = torch.hub.load(
        repo_or_dir='pytorch/fairseq',
        model='transformer.wmt19.en-de',
        checkpoint_file="model1.pt:model2.pt:model3.pt:model4.pt",
        tokenizer='moses',
        bpe='fastbpe',
    )
    evaluator_model = FairseqScoringModel(name='transformer.wmt19.en-de.ensemble', model=hub_interface)
elif tgt_language == "ru":
    hub_interface = torch.hub.load(
        repo_or_dir='pytorch/fairseq',
        model='transformer.wmt19.en-ru',
        checkpoint_file="model1.pt:model2.pt:model3.pt:model4.pt",
        tokenizer='moses',
        bpe='fastbpe',
    )
    evaluator_name = 'transformer.wmt19.en-ru.ensemble'
    evaluator_model = FairseqScoringModel(
        name=evaluator_name,
        model=hub_interface,
        src_bpe_codes=Path(__file__).parent.parent.parent.parent / "mt_bias" / "models" / "wmt19.en-ru.ffn8192/en24k.fastbpe.code",
        tgt_bpe_codes=Path(__file__).parent.parent.parent.parent / "mt_bias" / "models" / "wmt19.ru-en.ffn8192/ru24k.fastbpe.code",
    )
mucow.evaluator_model = evaluator_model

with open(sources_path) as f_src, open(references_path) as f_tgt:
    translation_dict = {line_src.strip(): line_tgt.strip() for line_src, line_tgt in zip(f_src, f_tgt)}

insertion_confidences: Dict[str, Dict[str, List[float]]] = dict()

for sample in tqdm(samples):
    sense = f"{sample.src_word} ({sample.sense})"
    if sense not in insertion_confidences:
        insertion_confidences[sense] = defaultdict(list)
    translation = translation_dict[sample.src_sentence]
    for insertion in sample.correct_insertions:
        modified_source = sample.src_sentence.replace(sample.src_form, f"{insertion} {sample.src_form}")
        score = mucow._score([modified_source], [translation])[0]
        insertion_confidences[sense][insertion + " (correct)"].append(score)
    for insertion in sample.wrong_insertions:
        modified_source = sample.src_sentence.replace(sample.src_form, f"{insertion} {sample.src_form}")
        score = 1 - mucow._score([modified_source], [translation])[0]
        insertion_confidences[sense][insertion + " (wrong)"].append(score)

mean_insertion_confidences: Dict[str, Dict[str, float]] = dict()
for sense, insertion_dict in insertion_confidences.items():
    mean_insertion_confidences[sense] = dict()
    for insertion, confidences in insertion_dict.items():
        mean_confidence = sum(confidences) / len(confidences)
        mean_insertion_confidences[sense][insertion] = mean_confidence

with open(confidence_output_path, "w") as f:
    json.dump(mean_insertion_confidences, f, indent=2)
