from typing import List

from loguru import logger

from src.generator import BaseGenerator
from src.schema import Answer


RUBRIC_EVALUATOR_PROMPT_TEMPLATE = """
You are a rubric evaluator.

Given a rubric question and a raw answer, determine if the raw answer satisfies the rubric question.

If it satisfies the rubric question, respond with "Yes". Otherwise, respond with "No".

Never respond with anything other than "Yes" or "No".

### Rubric Question
{rubric_question} 

### Raw Answer  
{raw_answer}  

### Answer
""".strip()


class RubricEvaluator:
    prompt_template: str = RUBRIC_EVALUATOR_PROMPT_TEMPLATE

    def __init__(self, generator: BaseGenerator) -> None:
        self.generator = generator

    def _preprocess(self, raw_answer: str, rubric_question: str) -> str:
        prompt = self.prompt_template.format(
            rubric_question=rubric_question,
            raw_answer=raw_answer,
        )
        return prompt

    def _postprocess(self, response: str) -> int:
        if "yes" in response.strip().lower():
            return 1
        else:
            return 0

    def _get_unique_rubric_questions(self, answers: List[Answer]) -> List[str]:
        unique_rubric_questions = list(set(answer.rubric_question for answer in answers))
        return unique_rubric_questions

    def eval(self, raw_answer: str, answers: List[Answer]) -> List[int]:
        unique_rubric_questions = self._get_unique_rubric_questions(answers)
        rubric_question_answer_dict = {}
        for rubric_question in unique_rubric_questions:
            prompt = self._preprocess(raw_answer=raw_answer, rubric_question=rubric_question)
            logger.debug(f"[RubricEvaluator] prompt: {prompt}")
            response = self.generator.generate(prompt=prompt)
            logger.debug(f"[RubricEvaluator] response: {response}")
            binary_result = self._postprocess(response=response)
            logger.debug(f"[RubricEvaluator] binary_result: {binary_result}")
            rubric_question_answer_dict[rubric_question] = binary_result
        return rubric_question_answer_dict
