""" """

from dataclasses import dataclass
from overrides import overrides
from typing import Union, Text, List, Dict, Optional, Callable, Any

from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.base import Runnable
from langchain.prompts import (
    ChatPromptTemplate,
    FewShotChatMessagePromptTemplate,
)
from langchain_core.output_parsers import BaseOutputParser
import re

# TODO: use customer downloaded examples for example selector
from langchain_interface.example_selectors import ConstantExampleSelector, ExampleSelector
from langchain_interface.steps import Step, FewShotStep
from langchain_interface.instances.instance import LLMResponse


@dataclass(frozen=True, eq=True)
class SimpleQAResponse(LLMResponse):
    """ """
    phrase: Text


class SimpleQAOutputParser(BaseOutputParser[SimpleQAResponse]):
    """ """
    @overrides
    def parse(self, text: Text) -> SimpleQAResponse:
        
        return SimpleQAResponse(
            phrase=text.strip(),
            messages=text,
        )


@Step.register("simplqa-sampling")
class SimpleQASamplingStep(FewShotStep):
    """ """
    
    def __init__(
        self,
        example_selector: Optional[ExampleSelector] = None,
    ):
        if example_selector is None:
            example_selector = ConstantExampleSelector()
            examples = [
                {
                    "question": "What is the shape of the earth?",
                    "answer": "round",
                },
                {
                    "question": "What is Bridget Moynahan's profession?",
                    "answer": "actress",
                },
                {
                    "question": "What is the capital of France?",
                    "answer": "Paris",
                }
            ]
            
            for example in examples:
                example_selector.add_example(
                    example
                )
                
        super().__init__(example_selector=example_selector)

    @overrides
    def get_prompt_template(self) -> Runnable:

        system_prompt = "Please answer the given question with simple minimal phrases."
        
        example_prompt = ChatPromptTemplate.from_messages(
            messages=[
                ("human", "{question}"),
                ("ai", "{answer}"),
            ]
        )

        few_shot_prompt_template = FewShotChatMessagePromptTemplate(
            example_prompt=example_prompt,
            example_selector=self._example_selector,
        )

        prompt_template = ChatPromptTemplate.from_messages(
            messages=[
                ("system", system_prompt),
                few_shot_prompt_template,
                ("human", "{question}"),
            ]
        )

        return prompt_template

    @overrides
    def get_output_parser(self) -> Runnable:
        
        return SimpleQAOutputParser()