import hashlib
import logging
from collections import namedtuple
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Union, List, Set, Tuple, Dict

import jsonlines
from sacremoses import MosesDetokenizer
from tqdm import tqdm

from tasks import Task, TaskResult
from translation_models import TranslationModel, ScoringModel


@dataclass
class MucowSourceSample:
    tgt_language: str
    src_sentence: str
    corpus: str
    src_word: str
    src_form: str
    cluster_id: int
    correct_tgt_words: Set[str]
    correct_insertions: Set[str] = None
    wrong_insertions: Set[str] = None
    # wrong_tgt_words: Set[str]
    probability_correct: float = None
    weight: float = None
    max_alternatives: int = 10

    @property
    def category(self):
        from tasks.utils import FREQUENT_WORD_SENSES, INFREQUENT_WORD_SENSES
        for correct_tgt_word in self.correct_tgt_words:
            if (self.src_word, correct_tgt_word) in FREQUENT_WORD_SENSES[self.tgt_language]:
                return "frequent"
            if (self.src_word, correct_tgt_word) in INFREQUENT_WORD_SENSES[self.tgt_language]:
                return "infrequent"
        return "none"

    @property
    def is_correct(self):
        if self.probability_correct is None:
            return ValueError("Sample has not yet been scored")
        return self.probability_correct >= 0.5

    @property
    def correct_disambiguated_sentences(self) -> List[str]:
        return [self._insert_insertion(insertion) for insertion in self.correct_insertions]

    @property
    def wrong_disambiguated_sentences(self) -> List[str]:
        return [self._insert_insertion(insertion) for insertion in self.wrong_insertions]

    def _insert_insertion(self, insertion: str) -> str:
        return self.src_sentence.replace(self.src_form, f"{insertion} {self.src_form}")


Sense = namedtuple("Sense", ["src_word", "cluster_id", "relative_frequency", "tgt_words"])


@dataclass
class MucowSourceResult(TaskResult):
    total_accuracy: float
    accuracy_frequent: float
    accuracy_infrequent: float
    min_accuracy: float

    def __str__(self):
        return f"{100 * self.total_accuracy:.1f}\t" \
               f"{100 * self.accuracy_frequent:.1f}\t" \
               f"{100 * self.accuracy_infrequent:.1f}\t" \
               f"{100 * self.min_accuracy:.1f}"


class MucowContrastiveConditioningTask(Task):

    def __init__(self,
                 tgt_language: str,
                 evaluator_model: ScoringModel,
                 reverse: bool = False,
                 testset_path: Union[Path, str] = None,
                 senses_path: Union[Path, str] = None,
                 source_data_path: Union[Path, str] = None,
                 category_wise_weighting: bool = False,
                 caching: bool = True,
                 ):
        self.task_name = "mucow_lrec20_source"
        self.tgt_language = tgt_language
        self.evaluator_model = evaluator_model
        self.reverse = reverse
        default_data_path = Path(__file__).parent.parent / "data" / "mucow_lrec20"
        self.testset_path = testset_path or default_data_path / f"en-{tgt_language}.test.tsv"
        self.senses_path = senses_path or default_data_path / f"en-{tgt_language}.senses.tsv"
        self.source_data_path = source_data_path or default_data_path / f"en-{tgt_language}.source_data.jsonl"
        assert self.testset_path.exists()
        assert self.senses_path.exists()
        assert self.source_data_path.exists()
        self.category_wise_weighting = category_wise_weighting
        self.caching = caching

        self.source_data = self._load_source_data()
        self.senses = self._load_senses()
        self.samples = self._load_dataset()
        self.categories = {sample.category for sample in self.samples}

    def evaluate(self, translation_model: TranslationModel, **translation_kwargs) -> MucowSourceResult:
        samples = deepcopy(self.samples)
        translations = self._translate([sample.src_sentence for sample in samples], translation_model, **translation_kwargs)
        assert len(samples) == len(translations)

        log_path = Path(__file__).parent.parent / f"{self.task_name}.log"
        with open(log_path, "w") as f:
            logging.info("Scoring translations ...")
            for sample, translation in zip(tqdm(samples), translations):
                correct_disambiguated_sources = sample.correct_disambiguated_sentences
                wrong_disambiguated_sources = sample.wrong_disambiguated_sentences
                scores_correct = self._score(
                    correct_disambiguated_sources,
                    len(correct_disambiguated_sources) * [translation]
                )
                scores_wrong = self._score(
                    wrong_disambiguated_sources,
                    len(wrong_disambiguated_sources) * [translation]
                )
                assert len(scores_correct) == len(correct_disambiguated_sources)
                assert len(scores_wrong) == len(wrong_disambiguated_sources)
                score_correct = max(scores_correct)
                score_wrong = max(scores_wrong)
                sample.probability_correct = score_correct / (score_correct + score_wrong)
                # if sample.probability_correct <= 0.5:
                #     logging.info(sample.src_word + ": " + sample.src_sentence)

                f.write(
                    f"{sample.src_word}\t"
                    f"{sample.src_sentence}\t"
                    f"{translation}\t"
                    f"{correct_disambiguated_sources[scores_correct.index(score_correct)]}\t"
                    f"{wrong_disambiguated_sources[scores_wrong.index(score_wrong)]}\t"
                    f"{score_correct}\t"
                    f"{score_wrong}\t"
                    f"{sample.probability_correct}\n"
                )

        category_wise_samples = self._weight_samples_by_category(samples)
        category_wise_accuracies = dict()
        for category, category_samples in category_wise_samples.items():
            if category == "none":
                continue
            category_wise_accuracy = sum([sample.weight * sample.is_correct for sample in category_samples]) / \
                                        sum([sample.weight for sample in category_samples])
            category_wise_accuracies[category] = category_wise_accuracy
        min_accuracy = min(category_wise_accuracies.values())

        total_accuracy = sum([sample.weight * sample.is_correct for sample in samples]) / \
                            sum([sample.weight for sample in samples])

        result = MucowSourceResult(
            total_accuracy=total_accuracy,
            accuracy_frequent=category_wise_accuracies["frequent"],
            accuracy_infrequent=category_wise_accuracies["infrequent"],
            min_accuracy=min_accuracy,
        )
        result.samples = samples
        return result
    
    def _translate(self, source_sentences, translation_model, **translation_kwargs):
        cached_translations = None
        if self.caching:
            translations_cache_filename = f"{self.task_name}.translations.{translation_model}." + hashlib.sha256(
                f"{source_sentences}{translation_model}{translation_kwargs}".encode()
            ).hexdigest()
            cached_translations = self._get_cache(translations_cache_filename)
        if cached_translations is not None:
            translations = cached_translations.splitlines()
        else:
            translations = translation_model.translate(source_sentences, **translation_kwargs)
            if self.caching:
                self._set_cache(translations_cache_filename, "\n".join(translations))
        return translations

    def _score(self, source_sentences, hypotheses):
        if self.reverse:
            return self.evaluator_model.score(hypotheses, source_sentences)
        return self.evaluator_model.score(source_sentences, hypotheses)

    def _load_dataset(self) -> List[MucowSourceSample]:
        detokenizer = MosesDetokenizer(lang='en')
        samples = []
        with open(self.testset_path) as f:
            logged_senses = set()  # Avoid repetitive logging
            for line in f:
                fields = line.strip().split("\t")
                corpus = fields[2]
                if corpus in {
                    # "commoncrawl",
                    # "europarl",
                    # "multiun",
                    # "ted",
                }:
                    continue  # Do not test on WMT19 training data
                src_word = fields[0]
                cluster_id = int(fields[1])
                source_data = self.source_data.get((src_word, cluster_id), None)
                if source_data is None:
                    if (src_word, cluster_id) not in logged_senses:
                        logging.info(f"No sense data found for {src_word} [{cluster_id}]; skipping")
                        logged_senses.add((src_word, cluster_id))
                    continue
                all_correct_hyponyms = set(source_data["hyponyms"])
                all_wrong_hyponyms = set()
                all_correct_synonyms = set(source_data["synonyms"])
                all_wrong_synonyms = set()
                correct_definitions = set(source_data["definitions"])
                wrong_definitions = set()
                for i in range(5):
                    if i == cluster_id:
                        continue
                    other_source_data = self.source_data.get((src_word, i), None)
                    if other_source_data is None:
                        continue
                    all_wrong_synonyms |= set(other_source_data["synonyms"])
                    all_wrong_hyponyms |= set(other_source_data["hyponyms"])
                    wrong_definitions |= set(other_source_data["definitions"])
                # Remove intersections between correct and wrong
                correct_synonyms = all_correct_synonyms - all_wrong_synonyms - all_wrong_hyponyms
                wrong_synonyms = all_wrong_synonyms - all_correct_synonyms - all_correct_hyponyms
                correct_hyponyms = all_correct_hyponyms - all_wrong_synonyms - all_wrong_hyponyms
                wrong_hyponyms = all_wrong_hyponyms - all_correct_synonyms - all_correct_hyponyms
                if not all([
                    correct_synonyms | correct_hyponyms | correct_definitions,
                    wrong_synonyms | wrong_hyponyms | wrong_definitions,
                ]):
                    if (src_word, cluster_id) not in logged_senses:
                        logging.info(f"No disambiguators found for {src_word} [{cluster_id}]; skipping")
                        logged_senses.add((src_word, cluster_id))
                    continue
                sense = self.senses.get((src_word, cluster_id), None)
                if sense is None:
                    if (src_word, cluster_id) not in logged_senses:
                        logging.info(f"No sense definition found for {src_word} [{cluster_id}]; skipping")
                        logged_senses.add((src_word, cluster_id))
                    continue
                src_sentence = fields[3]
                src_sentence = detokenizer.detokenize(src_sentence.split())
                src_sentence = fix_detokenization(src_sentence)
                sample = MucowSourceSample(
                    tgt_language=self.tgt_language,
                    src_sentence=src_sentence,
                    corpus=corpus,
                    src_word=fields[0],
                    src_form=fields[-1],
                    cluster_id=cluster_id,
                    correct_synonyms=correct_synonyms,
                    wrong_synonyms=wrong_synonyms,
                    correct_hyponyms=correct_hyponyms,
                    wrong_hyponyms=wrong_hyponyms,
                    correct_definitions=correct_definitions,
                    wrong_definitions=wrong_definitions,
                    correct_tgt_words=set(sense.tgt_words),
                )
                samples.append(sample)
        return samples

    def _load_source_data(self) -> Dict[Tuple[str, int], Dict]:
        source_data = dict()
        with jsonlines.open(self.source_data_path) as f:
            for row in f:
                source_data[(row["src_word"], row["cluster_id"])] = row
        return source_data

    def _load_senses(self) -> Dict[Tuple[str, int], Sense]:
        senses = dict()
        with open(self.senses_path) as f:
            for line in f:
                elements = line.strip().split("\t")
                sense = Sense(elements[0], elements[1], elements[3], tuple(elements[4].split(" ")))
                senses[(sense.src_word, int(sense.cluster_id))] = sense
        return senses

    def _weight_samples_by_category(self, samples):
        """
        Category-wise weighting: Downweight samples with small evaluator confidence; keep weights balanced per category.
        Do not normalize weights here but divide by total weights when computing accuracies
        """
        category_wise_samples = {
            category: sorted([sample for sample in samples if sample.category == category], key=lambda s: -abs(0.5 - s.probability_correct))
            for category in self.categories
        }
        for category_samples in category_wise_samples.values():
            for i, sample in enumerate(category_samples):
                if self.category_wise_weighting:
                    # Linear decay of weights along ranks
                    sample.weight = len(category_samples) - i
                else:
                    if sample.weight is None:
                        sample.weight = 1
        return category_wise_samples


def fix_detokenization(sentence: str) -> str:
    sentence = sentence.replace(" - ", "-")
    sentence = sentence.replace(" ' ", "'")
    sentence = sentence.replace(" '", "'")
    sentence = sentence.replace(" ’ ", "’")
    sentence = sentence.replace(" ’", "’")
    sentence = sentence.replace("  ”", "”")
    return sentence
