""" Given an answer phrase and a template, generate the answer with the template. """

from dataclasses import dataclass
from overrides import overrides
from typing import Union, Text, List, Dict, Optional, Callable, Any

from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.base import Runnable
from langchain.prompts import (
    ChatPromptTemplate,
    FewShotChatMessagePromptTemplate,
)
from langchain_core.output_parsers import BaseOutputParser
import re

# TODO: use customer downloaded examples for example selector
from langchain_interface.example_selectors import ConstantExampleSelector, ExampleSelector
from langchain_interface.steps import Step, FewShotStep
from langchain_interface.instances.instance import LLMResponse


@dataclass(frozen=True, eq=True)
class DeclAnswerResponse(LLMResponse):
    """Response for decontextualization.
    """
    declarativized_answer: Text
    
    
class DeclAnswerOutputParser(BaseOutputParser[DeclAnswerResponse]):
    """Output parser for decontextualization.
    """
    def parse(self, text: Text) -> DeclAnswerResponse:
        cleaned_text = text.strip()

        # find the text wrapped by the code block
        match = re.search(r"```(.*?)```", cleaned_text, re.DOTALL)
        if match is None:
            decl_answer = None
        else:
            decl_answer = match.group(1).strip()
            
        return DeclAnswerResponse(messages=text, declarativized_answer=decl_answer)
    
    @property
    def _type(self) -> Text:
        return "decl-answer"
    
    
class DeclAnswerStep(FewShotStep):
    
    def __init__(self, example_selector: ExampleSelector = None):

        if example_selector is None:
            example_selector = ConstantExampleSelector()
            examples = [
                {
                    "question": "In what position did Roxana Díaz finish in the high jump event at the 2004 Olympic Games in Athens?",
                    # "template": "Roxana Díaz finished <PLACEHOLDER> in the high jump event at the 2004 Olympic Games in Athens.",
                    "answer": "Single-digit ordinal numbers",
                    "declarativized_answer": "Roxana Díaz finished in the top 10 in the high jump event at the 2004 Olympic Games in Athens."
                },
                {
                    # "template": "Chris Cheney and Emily Taheny have a daughter named <PLACEHOLDER>.",
                    "question": "What is the name of the daughter of Chris Cheney and Emily Taheny?",
                    "answer": "Female first names",
                    "declarativized_answer": "Chris Cheney and Emily Taheny have a daughter."
                },
            ]
            
            for example in examples:
                example_selector.add_example(example)
        
        super().__init__(example_selector=example_selector)

    @overrides
    def get_prompt_template(self) -> Runnable:
        """ """
        
        system_prompt = (
            "You are a helpful AI assistant that helps user reformulate their response. "
            "You need to be consistent with the answer phrase and the question, and don't change the meaning of the answer regardless of factuality."
        )
        
        instruction_prompt = (
            "Declarativize the answer phrase to the question. Wrap the answer phrase in a code block."
        )

        example_prompt = ChatPromptTemplate.from_messages([
            ("human", "**Question**: {template}\n**Answer**: {answer}"),
            ("ai", "```{declarativized_answer}```")
        ])

        few_shot_prompt_template = FewShotChatMessagePromptTemplate(
            example_prompt=example_prompt,
            example_selector=self._example_selector,
        )

        return ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("human", instruction_prompt),
            ("ai", "Sure, I can help with that. Please provide the answer phrase and the template."),
            few_shot_prompt_template,
            ("human", "**Question**: {template}\n**Answer**: {answer}"),
        ])
    
    @overrides
    def get_output_parser(self) -> BaseOutputParser:
        return DeclAnswerOutputParser()