import logging
from collections import namedtuple
from pathlib import Path
from typing import Union, List, Tuple, Dict

import inflect
import jsonlines
from sacremoses import MosesTokenizer

from tasks.mucow_lrec20_source import MucowContrastiveConditioningTask, MucowSourceSample
from translation_models import ScoringModel


SenseKey = namedtuple("SenseKey", ["id", "corpus", "src_word", "correct_tgt_words", "wrong_tgt_words"])


class MucowWMT19ContrastiveConditioningTask(MucowContrastiveConditioningTask):

    def __init__(self,
                 tgt_language: str,
                 evaluator_model: ScoringModel,
                 reverse: bool = False,
                 testset_path: Union[Path, str] = None,
                 sense_key_path: Union[Path, str] = None,
                 source_data_path: Union[Path, str] = None,
                 category_wise_weighting: bool = False,
                 caching: bool = True,
                 ):
        self.task_name = "mucow_wmt19_source"
        self._inflect_engine = inflect.engine()
        self.tgt_language = tgt_language
        self.evaluator_model = evaluator_model
        self.reverse = reverse
        default_data_path = Path(__file__).parent.parent / "data" / "mucow_wmt19"
        self.testset_path = testset_path or default_data_path / f"en-{tgt_language}.text.txt"
        self.sense_key_path = sense_key_path or default_data_path / f"en-{tgt_language}.key.tsv"
        self.source_data_path = source_data_path or default_data_path / f"en-{tgt_language}.source_data.jsonl"
        assert self.testset_path.exists()
        assert self.sense_key_path.exists()
        assert self.source_data_path.exists()
        self.category_wise_weighting = category_wise_weighting
        self.caching = caching

        self.source_data = self._load_source_data()
        self.source_data_dict = self._load_source_data_dict()
        self.sense_keys = self._load_sense_keys()
        self.samples = self._load_dataset()
        self.categories = {sample.category for sample in self.samples}

    def _load_dataset(self) -> List[MucowSourceSample]:
        tokenizer = MosesTokenizer(lang="en")
        samples = []
        with open(self.testset_path) as f:
            logged_senses = set()  # Avoid repetitive logging
            for line, sense_key, source_data in zip(f, self.sense_keys, self.source_data):
                src_sentence = line.strip()
                src_word = sense_key.src_word
                cluster_id = source_data["cluster_id"]
                correct_insertions = source_data.get("correct_insertions", [])
                wrong_insertions = source_data.get("wrong_insertions", [])
                if not all([
                    correct_insertions, wrong_insertions
                ]):
                    if (src_word, cluster_id) not in logged_senses:
                        logging.info(f"No disambiguators found for {src_word} [{cluster_id}]; skipping")
                        logged_senses.add((src_word, cluster_id))
                    continue
                tokens = tokenizer.tokenize(src_sentence)
                src_word_plural = self._inflect_engine.plural_noun(src_word)
                src_form = None
                if src_word in tokens:
                    src_form = src_word
                elif src_word_plural in tokens:
                    src_form = src_word_plural
                elif src_word.title() in tokens:
                    src_form = src_word.title()
                if src_form is None:
                    continue  # Skip unexpected inflections
                sample = MucowSourceSample(
                    tgt_language=self.tgt_language,
                    src_sentence=src_sentence,
                    corpus=sense_key.corpus,
                    src_word=src_word,
                    src_form=src_form,
                    cluster_id=cluster_id,
                    correct_insertions=set(correct_insertions),
                    wrong_insertions=set(wrong_insertions),
                    correct_tgt_words=set(sense_key.correct_tgt_words),
                )
                samples.append(sample)
        return samples

    def _load_source_data(self) -> List[Dict]:
        source_data = list()
        with jsonlines.open(self.source_data_path) as f:
            for row in f:
                source_data.append(row)
        return source_data

    def _load_source_data_dict(self) -> Dict[Tuple[str, int], Dict]:
        source_data = dict()
        with jsonlines.open(self.source_data_path) as f:
            for row in f:
                source_data[(row["src_word"], row["cluster_id"])] = row
        return source_data

    def _load_sense_keys(self) -> List[SenseKey]:
        sense_keys = []
        with open(self.sense_key_path) as f:
            for line in f:
                elements = line.strip().split("\t")
                sense_key = SenseKey(elements[0], elements[1], elements[2], tuple(elements[3].split(" ")),
                                     tuple(elements[4].split(" ")))
                sense_keys.append(sense_key)
        return sense_keys
