import hashlib
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Union, List

from tasks import Task, TaskResult
from translation_models import TranslationModel, ScoringModel


@dataclass
class WinomtSourceSample:
    gold_gender: str
    occupation_index: int  # Index of the occupation referred to by the pronoun
    sentence: str
    occupation: str
    stereotype: str
    probability_correct: float = None
    weight: float = 1

    def get_sentence_with_replaced_occupation(self, replacement: str):
        tokens = self.sentence.split(" ")
        tokens[self.occupation_index] = replacement
        return " ".join(tokens)

    @property
    def is_correct(self):
        if self.probability_correct is None:
            return ValueError("Sample has not yet been scored")
        return self.probability_correct >= 0.5

    @property
    def category(self):
        return self.gold_gender

    @property
    def occupation_is_frequent(self):
        """
        Based on the WMT19 EN–DE training data (Barrault et al., 2019)
        """
        from tasks.utils import FREQUENT_OCCUPATIONS
        return self.occupation in FREQUENT_OCCUPATIONS

@dataclass
class WinomtResult(TaskResult):
    total_accuracy: float
    accuracy_male: float
    accuracy_female: float
    min_accuracy: float

    def __str__(self):
        return f"{100 * self.total_accuracy:.1f}\t" \
               f"{100 * self.accuracy_male:.1f}\t" \
               f"{100 * self.accuracy_female:.1f}\t" \
               f"{100 * self.min_accuracy:.1f}"


class WinomtContrastiveConditioningTask(Task):

    def __init__(self,
                 evaluator_model: ScoringModel,
                 testset_path: Union[Path, str] = Path(__file__).parent.parent / "data" / "winomt" / "test.tsv",
                 category_wise_weighting: bool = False,
                 caching: bool = True,
                 skip_neutral_gold: bool = True,
                 ):
        self.evaluator_model = evaluator_model
        self.testset_path = testset_path
        self.category_wise_weighting = category_wise_weighting
        self.caching = caching
        self.skip_neutral_gold = skip_neutral_gold

        self.samples = self._load_dataset()
        self.categories = {sample.category for sample in self.samples}

    def evaluate(self, translation_model: TranslationModel, **translation_kwargs) -> WinomtResult:
        samples = deepcopy(self.samples)
        ambigous_sources = [sample.sentence for sample in samples]
        correctly_disambiguated_sources = [sample.get_sentence_with_replaced_occupation(
            f"[{sample.gold_gender}] {sample.occupation}"
        ) for sample in samples]  # "doctor" => "[female] doctor"
        wrongly_disambiguated_sources = [sample.get_sentence_with_replaced_occupation(
            f"[{'male' if sample.gold_gender == 'female' else 'female'}] {sample.occupation}"
        ) for sample in samples]

        hypotheses = self._translate(ambigous_sources, translation_model, **translation_kwargs)

        scores_correct = self._score(correctly_disambiguated_sources, hypotheses)
        scores_wrong = self._score(wrongly_disambiguated_sources, hypotheses)

        assert len(samples) == len(hypotheses) == len(scores_correct) == len(scores_wrong)
        for i in range(len(samples)):
            samples[i].probability_correct = scores_correct[i] / (scores_correct[i] + scores_wrong[i])

        category_wise_samples = self._weight_samples_by_category(samples)
        category_wise_accuracies = dict()
        for category, category_samples in category_wise_samples.items():
            category_wise_accuracy = sum([sample.weight * sample.is_correct for sample in category_samples]) / \
                                        sum([sample.weight for sample in category_samples])
            category_wise_accuracies[category] = category_wise_accuracy
        min_accuracy = min(category_wise_accuracies.values())

        total_accuracy = sum([sample.weight * sample.is_correct for sample in samples]) / \
                            sum([sample.weight for sample in samples])

        result = WinomtResult(
            total_accuracy=total_accuracy,
            accuracy_male=category_wise_accuracies["male"],
            accuracy_female=category_wise_accuracies["female"],
            min_accuracy=min_accuracy,
        )
        result.samples = samples
        return result

    def _translate(self, source_sentences, translation_model, **translation_kwargs):
        cached_translations = None
        if self.caching:
            translations_cache_filename = f"winomt.translations.{translation_model}." + hashlib.sha256(
                f"{self.testset_path}{False}"
                f"{source_sentences}{translation_model}{translation_kwargs}".encode()
            ).hexdigest()
            cached_translations = self._get_cache(translations_cache_filename)
        if cached_translations is not None:
            translations = cached_translations.splitlines()
        else:
            translations = translation_model.translate(source_sentences, **translation_kwargs)
            if self.caching:
                self._set_cache(translations_cache_filename, "\n".join(translations))
        return translations

    def _score(self, source_sentences, hypotheses):
        cached_scores = None
        if self.caching:
            scores_cache_filename = "winomt.scores." + hashlib.sha256(
                f"{self.evaluator_model}{source_sentences}{hypotheses}".encode()
            ).hexdigest()
            cached_scores = self._get_cache(scores_cache_filename)
        if cached_scores is not None:
            scores = [float(line) for line in cached_scores.splitlines()]
        else:
            scores = self.evaluator_model.score(source_sentences, hypotheses)
            if self.caching:
                self._set_cache(scores_cache_filename, "\n".join(map(str, scores)))
        return scores

    def _load_dataset(self) -> List[WinomtSourceSample]:
        samples = []
        with open(self.testset_path) as f:
            for line in f:
                gold_gender, occupation_index, sentence, occupation, stereotype = line.strip().split("\t")
                sample = WinomtSourceSample(
                    gold_gender=gold_gender,
                    occupation_index=int(occupation_index),
                    sentence=sentence,
                    occupation=occupation,
                    stereotype=stereotype,
                )
                if sample.gold_gender == "neutral" and self.skip_neutral_gold:
                    continue
                samples.append(sample)
        return samples

    def _weight_samples_by_category(self, samples):
        """
        Category-wise weighting: Downweight samples with small evaluator confidence; keep weights balanced per category.
        Do not normalize weights here but divide by total weights when computing accuracies
        """
        category_wise_samples = {
            category: sorted([sample for sample in samples if sample.category == category], key=lambda s: -abs(0.5 - s.probability_correct))
            for category in self.categories
        }
        for category_samples in category_wise_samples.values():
            for i, sample in enumerate(category_samples):
                if self.category_wise_weighting:
                    # Linear decay of weights along ranks
                    sample.weight = len(category_samples) - i
                else:
                    sample.weight = 1
        return category_wise_samples
