from lm_understanding import answer_questions
from lm_understanding.explanations.explanations import LocalExplanationSet
from lm_understanding.prompting import (Completer, FreeResponsePrompter,
                                        FreeResponseQuestion, Prompter)
from lm_understanding.question_template import TemplateModelBehavior


def why_prompt(question: str, answer: float) -> str:
    return f"""Explain your reasoning for your answer to the following question. Your answer is formatted as a level of confidence, where 0.00 is no and 1.00 is yes. 
Question: {question}
Your Answer: {round(answer, 4)}
Why did you give this answer? Explain your reasoning step-by-step."""


def create_rationalizations(model_behavior: TemplateModelBehavior, completer: Completer, prompter: Prompter) -> LocalExplanationSet:
    assert model_behavior.template_id is not None
    train_questions = model_behavior.questions('train')
    train_answers = model_behavior.answers('train')
    why_questions = [FreeResponseQuestion(why_prompt(q, a)) for q, a in zip(train_questions, train_answers)]
    explanations = answer_questions(why_questions, prompter, completer, progress_bar=True).given_answers  # type: ignore
    return LocalExplanationSet(model_behavior.template_id, train_questions, train_answers.tolist(), explanations)


