from dataclasses import dataclass
from typing import List
from rouge import Rouge
from bleurt_pytorch import (
    BleurtConfig,
    BleurtForSequenceClassification,
    BleurtTokenizer,
)
import torch


@dataclass
class JudgeResult:
    is_correct: bool
    max_correct_sen_sim: float
    max_correct_rouge_sim: float
    max_incorrect_sen_sim: float
    max_incorrect_rouge_sim: float


class DatasetJudge:
    def __init__(
        self,
        bleurt_model_path: str,
        sen_sim_threshold: float=0.5,
        correct_advantage: float=-1,
        rouge_threshold: float = 1.0,
    ):
        self.rouge = Rouge()
        # self.sen_sim_model = sen_sim_model
        self.rouge_threshold: float = rouge_threshold
        self.sen_sim_threshold: float = sen_sim_threshold
        self.correct_advantage: float = correct_advantage

        # bleurt_model
        self.bleurt_config: BleurtConfig = BleurtConfig.from_pretrained(
            bleurt_model_path
        )
        self.bleurt_model: BleurtForSequenceClassification = (
            BleurtForSequenceClassification.from_pretrained(
                bleurt_model_path, config=self.bleurt_config
            )
        )
        self.bleurt_tokenizer: BleurtTokenizer = BleurtTokenizer.from_pretrained(
            bleurt_model_path
        )

        self.bleurt_model.eval()

    def getRougeSimilarity(
        self, generated_text: str, correct_answer_list, incorrect_answer_list
    ) -> float:
        correct_rouge_sim_list = []
        incorrect_rouge_sim_list = []
        for correct_answer in correct_answer_list:
            correct_rouge_sim_list.append(
                self.rouge.get_scores(hyps=generated_text, refs=correct_answer)[0][
                    "rouge-l"
                ]["f"]
            )
        for incorrect_answer in incorrect_answer_list:
            incorrect_rouge_sim_list.append(
                self.rouge.get_scores(hyps=generated_text, refs=incorrect_answer)[0][
                    "rouge-l"
                ]["f"]
            )

        return correct_rouge_sim_list, incorrect_rouge_sim_list

    def getSentenceSimilarity(
        self, generated_text: str, correct_answer_list, incorrect_answer_list
    ) -> float:
        with torch.no_grad():
            references = correct_answer_list + incorrect_answer_list
            candidates = [generated_text] * len(references)
            inputs = self.bleurt_tokenizer(
                references, candidates, padding="longest", return_tensors="pt"
            )
            sen_sim_list = self.bleurt_model(**inputs).logits.flatten().tolist()
            correct_sen_sim_list = sen_sim_list[: len(correct_answer_list)]
            incorrect_sen_sim_list = sen_sim_list[len(correct_answer_list) :]

        return correct_sen_sim_list, incorrect_sen_sim_list

    def judge(
        self,
        generated_text: str,
        correct_answer_list: List[str],
        incorrect_answer_list: List[str] = [],
    ) -> JudgeResult:
        correct_rouge_sim_list, incorrect_rouge_sim_list = self.getRougeSimilarity(
            generated_text, correct_answer_list, incorrect_answer_list
        )
        correct_sen_sim_list, incorrect_sen_sim_list = self.getSentenceSimilarity(
            generated_text, correct_answer_list, incorrect_answer_list
        )

        # correct
        max_correct_sen_sim = max(correct_sen_sim_list, default=0.0)
        max_correct_rouge_sim = max(correct_rouge_sim_list, default=0.0)
        # incorrect
        max_incorrect_sen_sim = max(incorrect_sen_sim_list, default=0.0)
        max_incorrect_rouge_sim = max(incorrect_rouge_sim_list, default=0.0)

        # judge
        if (
            max_correct_sen_sim - max_incorrect_sen_sim > self.correct_advantage
            # and max_incorrect_rouge_sim > max_correct_rouge_sim
        ) and (
            max_correct_sen_sim > self.sen_sim_threshold
            or max_correct_rouge_sim > self.rouge_threshold
        ):  
            return JudgeResult(
                is_correct=True,
                max_correct_sen_sim=max_correct_sen_sim,
                max_correct_rouge_sim=max_correct_rouge_sim,
                max_incorrect_sen_sim=max_incorrect_sen_sim,
                max_incorrect_rouge_sim=max_incorrect_rouge_sim,
            )
        return JudgeResult(
            is_correct=False,
            max_correct_sen_sim=max_correct_sen_sim,
            max_correct_rouge_sim=max_correct_rouge_sim,
            max_incorrect_sen_sim=max_incorrect_sen_sim,
            max_incorrect_rouge_sim=max_incorrect_rouge_sim,
        )
