""" Rewrite the answer to be less specific until the model believes it is true. """

from dataclasses import dataclass
from overrides import overrides
from typing import Union, Text, List, Dict, Optional, Callable, Any

from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.base import Runnable
from langchain.prompts import (
    ChatPromptTemplate,
    FewShotChatMessagePromptTemplate,
)
from langchain_core.output_parsers import BaseOutputParser
import re

# TODO: use customer downloaded examples for example selector
from langchain_interface.example_selectors import ConstantExampleSelector, ExampleSelector
from langchain_interface.steps import Step, FewShotStep
from langchain_interface.instances.instance import LLMResponse


@dataclass(frozen=True, eq=True)
class DirectAnswerBackoffRepsonse(LLMResponse):
    answer_backoff: Text

    
class DirectAnswerBackoffOutputParser(BaseOutputParser[DirectAnswerBackoffRepsonse]):

    def parse(self, text: Text) -> DirectAnswerBackoffRepsonse:
        """ """
        return DirectAnswerBackoffRepsonse(
            messages=text,
            answer_backoff=text.strip()
        )
    
    @property
    def _type(self) -> Text:
        return "direct-answer-backoff"
    
    
class DirectAnswerBackoffStep(Step):

    @overrides
    def get_prompt_template(self) -> Runnable:
        """ """
        
        return ChatPromptTemplate.from_messages(
            [
                (
                    "human", (
                        "Given the question: \"{question}\", "
                        "your initial answer is: \"{answer}\" "
                        "Please rewrite the claim to be less specific until "
                        "you are at least {percentage} percent confident it is true.\n\n"
                        "Your response should only contain the rewritten claim itself, without any additional context."
                    )
                ),
            ]
        )
        
    @overrides
    def get_output_parser(self) -> Runnable:
        return DirectAnswerBackoffOutputParser()