"""
    Implementation of COMET_SCORE. 
    Use case: Machine Translation Evaluation for Automatic Speech Translation task.
    Requiring: (1) source_sentence, (2) candidates, (3) references
    Acknowledgement: https://github.com/Unbabel/COMET 
"""

from metrics.metrics import Metrics
from utils.custom_logging import write_record_log, append_final_score
from metrics.word_error_rate_metrics import normalize_text
from comet import download_model, load_from_checkpoint
from utils import util


class CometScore(Metrics):
    def __call__(self, candidates, references, source_sentences, instructions=None, *, task_name: str | None = None, model_name: str | None = None, model_responses = None):
        self.instructions = instructions

        # Get individual scores
        normalized_candidates = [normalize_text(c) for c in candidates]
        normalized_references = [normalize_text(r) for r in references]
        normalized_source_sentences = [normalize_text(s) for s in source_sentences]
        self.record_level_scores = self.compute_record_level_scores(normalized_candidates, normalized_references, normalized_source_sentences)
        
        # Calculate the mean score directly to avoid async issues
        scores = self.record_level_scores.get(self.name, [])
        valid_scores = [score for score in scores if score is not None]
        mean_score = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
        overall_score = {self.name: util.smart_round(mean_score * 100.0, 2)} # scale score to the range of [0,100]

        if task_name and model_name:
            # write_record_log will also write to run.log internally
            write_record_log(self, normalized_references, normalized_candidates, scores, task_name, model_name, instructions=self.instructions)
            # Directly call append_final_score
            append_final_score(self, overall_score, task_name, model_name)
        return overall_score

    def __init__(self, batch_size = 1, num_gpus = 0):
        super().__init__()
        self.name = "comet"
        model_path = download_model("Unbabel/wmt22-comet-da") # range is guaranteed within [0,1]
        self.scorer = load_from_checkpoint(model_path)
        self.record_level_scores = None

        self.batch_size = batch_size
        self.num_gpus = num_gpus


    def compute_record_level_scores(self, candidates: list, references: list, source_sentences: list) -> dict[str, list | None]:
        # Here we can use self.instructions if needed
        """Compute the scores that should be saved in the record level file.

        Args:
            sources:    Source language text from the dataset for MT
            candidates: Generated text from the model
            references: Target language reference text from the dataset

        Returns:
            Scores for each record. The keys should be the column names that will be saved in the record level file.
        """
        try:
            assert len(candidates) == len(references)
            assert len(source_sentences) == len(references)
        except:
            raise ValueError("Number of samples of sources, candidates, references are not equal.\n Hypothesis:%d \t Source sentence:%d \t Reference:%d"%(len(candidates), len(source_sentences), len(references)))
 
        formatted_data =[
            {
                "src": source_sentences[i], # source_language sentence
                "mt": candidates[i], # hypothesis generated by the model
                "ref": references[i] # target_language reference
            }
            for i in range (len(references))
        ] 

        # We take the overall score across dataset
        model_output = self.scorer.predict(formatted_data, batch_size = self.batch_size, gpus = self.num_gpus)

        return {self.name: model_output.scores}
