"""BLEU score metrics implementation.

This module provides BLEU score calculation capabilities for evaluating
text generation quality using n-gram precision measures.
"""
import sacrebleu
from tqdm import tqdm
from typing import Tuple

from metrics.metrics import Metrics
from metrics.word_error_rate_metrics import normalize_text
from utils import util
from utils.custom_logging import write_record_log, append_final_score


class BleuMetrics(Metrics):
    """BLEU score evaluation metric.
    
    Computes BLEU scores for text generation evaluation using n-gram precision
    measures to assess translation quality and text similarity.
    """
    def __call__(self, candidates, references, instructions=None, *, task_name: str | None = None, model_name: str | None = None, model_responses=None):
        # Store instructions and model_responses for potential later use
        self.instructions = instructions
        tokenizer='13a' # default tokenizer
        if (task_name and 'covost2' in task_name):
            language_name = task_name.split('_')[-1]
            if ('zh' in language_name):
                tokenizer = 'zh'
    
        self.model_responses = model_responses if model_responses else []
        # Use corpusBLEU for overall score
        overall = self.get_score(candidates, references, tokenizer)
        if task_name and model_name:
            scores = self.record_level_scores.get(self.name, [])
            # Use sentenceBLEU for record-level scores
            scores, normalized_candidates, normalized_references = self.compute_record_level_scores(candidates, references, tokenizer) 
            # write_record_log will also write to run.log internally
            write_record_log(self, normalized_references, normalized_candidates, scores, task_name, model_name,
                           instructions=self.instructions, model_responses=self.model_responses)
            # Directly call append_final_score
            append_final_score(self, overall, task_name, model_name, self.model_responses)
        return overall

    def __init__(self):
        super().__init__()
        self.name = "bleu"
        self.instructions = None
        self.model_responses = []

    def get_score(self, candidates: list, references: list, tokenizer: str) -> dict[str, float]:
        """This gives overall score of complete dataset.

        Args:
            candidates: generated text list
            references: reference text list
            tokenizer: SacreBLEU tokenizer choice

        Returns:
            {"BLEU":100}
        """
        # === Consistent normalization with WER processing ===
        norm_references = [normalize_text(r) for r in references]
        norm_candidates = [normalize_text(c) for c in candidates]
        bs = sacrebleu.corpus_bleu(norm_candidates, [norm_references], tokenize=tokenizer)
        # Score range is already in the range of [0, 100.0]. Only rounding to 2 decimal precision.
        return {self.name: util.smart_round(bs.score, 2)}

    # ---------------------------------------------------
    # Internal helper
    # ---------------------------------------------------
    def compute_record_level_scores(self, candidates: list, references: list, tokenizer: str) -> Tuple[dict[str, list | None], list, list]:
        # Here we can use self.instructions if needed
        """Compute the scores that should be saved in the record level file.

        Args:
            candidates: Generated text from the model
            references: Reference text from the dataset
            tokenizer: SacreBLEU tokenizer choice

        Returns:
            Scores for each record. The keys should be the column names that will be saved in the record level file.
        """
        scores = []
        normalized_candidates, normalized_references = [], []
        for c, r in tqdm(zip(candidates, references), desc="BLEU", total=len(candidates)):
            # === Consistent normalization with WER processing ===
            norm_reference = normalize_text(r)
            norm_candidate = normalize_text(c)
            score = sacrebleu.sentence_bleu(norm_candidate, [norm_reference], tokenize = tokenizer)
            scores.append(score)
            normalized_candidates.append(norm_candidate)
            normalized_references.append(norm_reference)
        return {self.name: scores}, normalized_candidates, normalized_references