from typing import Set

from loguru import logger

from src.generator import BaseGenerator


ANSWER_SEPARATION_PROMPT_TEMPLATE = """
**Role Description**

You are an answer selection model designed to extrat a final answer from a raw input based on a given list of possible answers.

**Detailed Instructions**

1. You are provided with a raw answer and a list of possible answers.
2. Your task is to extract a final answer from the raw answer that exactly match any of the possible answers.
3. Only include answers that are present in both the raw answer and the list of possible answers.
4. You must select one matching answer from the raw answer.
5. Output only the matching answer without any additional text or formatting.
6. If the possible answers are like '[A] School' and the raw answer is 'A' than the final answer should be 'School'. 

**Example 1**

### Possible Answers
Apple
Banana
Cherry

### Raw Answer
"The answers are apple."

### Final Answer
Apple

**Example 2**

### Possible Answers
Dog
Cat
Rabbit

### Raw Answer
"I have a pet dog."

### Final Answer
Dog

---

**Test Input**

### Possible Answers
{possible_answers}

### Raw Answer
{raw_answer}

### Final Answer
""".strip()


class AnswerSeparationService:
    prompt_template: str = ANSWER_SEPARATION_PROMPT_TEMPLATE

    def __init__(self, generator: BaseGenerator) -> None:
        self.generator = generator

    def _preprocess(self, raw_answer: str, possible_answers: Set[str]) -> str:
        possible_answers_str = "\n".join([answer.strip() for i, answer in enumerate(possible_answers)]).strip()
        prompt = self.prompt_template.format(
            possible_answers=possible_answers_str,
            raw_answer=raw_answer,
        )
        return prompt

    def _postprocess(self, response: str) -> Set[str]:
        return set([response.strip()])

    def separate(self, raw_answer: str, possible_answers: Set[str]) -> Set[str]:
        prompt = self._preprocess(raw_answer=raw_answer, possible_answers=possible_answers)
        logger.debug(f"[AnswerSeparationService] prompt: {prompt}")

        response = self.generator.generate(prompt=prompt)
        logger.debug(f"[AnswerSeparationService] response: {response}")

        pred_elements = self._postprocess(response=response)
        logger.debug(f"[AnswerSeparationService] pred_elements: {pred_elements}")

        return pred_elements
