""" """

# try:
#     import ujson as json
# except ImportError:
#     import json
import json
from dataclasses import dataclass
from overrides import overrides
from typing import Union, Text, List, Dict, Optional, Callable, Any

from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.base import Runnable
from langchain.prompts import (
    ChatPromptTemplate,
    FewShotChatMessagePromptTemplate,
)
from langchain_core.output_parsers import BaseOutputParser
import re

# TODO: use customer downloaded examples for example selector
from langchain_interface.example_selectors import ConstantExampleSelector, ExampleSelector
from langchain_interface.steps import Step, FewShotStep
from langchain_interface.instances.instance import LLMResponse


@dataclass(frozen=True, eq=True)
class NonEntailedDetectionResponse(LLMResponse):
    """Response for non-entailed detection.
    """
    non_entailed: List[Text]

    
class NonEntailedDetectionOutputParser(BaseOutputParser[NonEntailedDetectionResponse]):
    """Output parser for non-entailed detection.
    """
    def parse(self, text: Text) -> NonEntailedDetectionResponse:

        # print("=" * 20)
        # print(text)
        # print("=" * 20)
        if text is None or len(text.strip()) == 0:
            return NonEntailedDetectionResponse(
                messages="",
                non_entailed=[]
            )
        
        structure = json.loads(text.strip())
        assert isinstance(structure, dict) or isinstance(structure, list), f"Invalid structure: {structure}"

        if not structure:
            return NonEntailedDetectionResponse(
                messages=text,
                non_entailed=[]
            )
        
        return NonEntailedDetectionResponse(
            messages=text,
            non_entailed=structure if isinstance(structure, list) else list(structure.values())[0]
        )
        
    @property
    def _type(self) -> Text:
        return "non-entailed-detection"
    
    
class NonEntailedDetectionStep(FewShotStep):
    """ """
    
    def __init__(self, example_selector: ExampleSelector = None):

        if example_selector is None:
            example_selector = ConstantExampleSelector()
            examples = [
                {
                    "hypernym": "1970s",
                    "candidates": json.dumps([
                        "1976",
                        "1975",
                        "1977",
                        "1978",
                        "1974",
                        "1979",
                        "1973",
                        "1969",
                        "1966",
                        "1984",
                        "1986"
                    ]),
                    "non_entailed": json.dumps([
                        "1969",
                        "1966",
                        "1984",
                        "1986"
                    ])
                },
                {
                    "hypernym": "musician",
                    "candidates": json.dumps([
                        "is a blues musician",
                        "is a professional blues musician",
                        "plays the blues",
                        "plays the blues guitar",
                        "is a professional musician",
                        "is a musician",
                        "plays music",
                        "is a professional football player",
                        "is a professional poker player"
                    ]),
                    "non_entailed": json.dumps([
                        "is a professional football player",
                        "is a professional poker player"
                    ])
                }
            ]
            for example in examples:
                example_selector.add_example(example)

        super().__init__(example_selector=example_selector)
    
    @overrides
    def get_prompt_template(self) -> Runnable:
        """ """
        
        instruction_prompt = (
            "Given a hypernym and a list of candidates, identify the candidates that are not entailed by the hypernym. Response with a json list of non-entailed candidates."
        )
        
        example_prompt = ChatPromptTemplate.from_messages([
            ("human", "Hypernym: {hypernym}\n\nCandidates:\n```\n{candidates}\n```"),
            ("ai", "{non_entailed}")
        ])
        
        fewshot_prompt_template = FewShotChatMessagePromptTemplate(
            example_prompt=example_prompt,
            example_selector=self._example_selector
        )
        
        return ChatPromptTemplate.from_messages([
            ("human", instruction_prompt),
            fewshot_prompt_template,
            ("human", "Hypernym: {hypernym}\n\nCandidates:\n```\n{candidates}\n```")
        ])
        
    @overrides
    def get_output_parser(self) -> BaseOutputParser:
        """ """
        return NonEntailedDetectionOutputParser()