""" """

from typing import (
    Annotated, List, Union, TypeVar,
    Union, Text, List, Dict, Optional, Callable,
    Any, Literal, Tuple
)
from typing_extensions import TypedDict
import abc
from overrides import overrides
from dataclasses import dataclass
from langgraph.types import Send
from langchain_core.runnables.base import Runnable
from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough
from langgraph.graph import StateGraph, START, END
from langchain_core.language_models.base import BaseLanguageModel
from langchain_interface.interfaces.interface import Interface
from langchain_interface.states.base_states import append, revise, keyupdate

from ..langchain_step.vague_answer_step import (
    VagueAnswerStep,
)
from ..langchain_step.coverage_check_step import (
    CoverageCheckStep
)


class VagueAnswerState(TypedDict):
    cover_relations: Annotated[Dict[Tuple[Text, Text], bool], keyupdate]
    question: Text
    candidates: List[Text]
    negatives: List[Text]
    belief: Annotated[List[Text], append]
    general_answer: Annotated[List[Text], append]  # use list to store all the answers


class VagueAnswerInterface(Interface):
    def __init__(self):
        super().__init__()

    @overrides
    def get_runnable(self, llm: BaseLanguageModel) -> Runnable:
        """ """
        
        # json_llm = llm.bind(
        #     response_format={
        #         "type": "json_object"
        #     }
        # )

        graph_builder = StateGraph(VagueAnswerState)
        
        # notice that _check_all_negatives is optional,
        # if no negatives we can skip the coverage check.
        
        def _check_all_negatives(state: VagueAnswerState) -> Union[list, Literal["belief_generation"]]:
            return [
                Send(
                    "coverage_check",
                    {
                        "src": cand,
                        "tgt": neg
                    }
                ) for neg in state["negatives"] for cand in state["candidates"]
            ] if state["negatives"] else "belief_generation"
            
        _coverage_check = RunnableParallel(
            {
                "passthrough": RunnablePassthrough(),
                "generation": CoverageCheckStep().chain_llm(llm)
            }
        ) | RunnableLambda(
            lambda output: {
                "cover_relations": {
                    (output['passthrough']['src'], output['passthrough']['tgt']): output['generation'].is_covered
                },
            }
        )
        
        graph_builder.add_node(
            "coverage_check",
            _coverage_check
        )
        graph_builder.add_conditional_edges(
            START,
            _check_all_negatives,
        )
        
        def _call_belief_generation(state: VagueAnswerState) -> Dict[Text, Any]:
            """ """

            tgt_covered_dict = {tgt: False for tgt in state["negatives"]}

            for (_, tgt), is_covered in state["cover_relations"].items():
                tgt_covered_dict[tgt] = tgt_covered_dict[tgt] or is_covered

            tgt_should_include = [tgt for tgt, item_is_covered in tgt_covered_dict.items() if not item_is_covered]
            
            # first create a list of all the candidates
            candidates_string = ""
            if len(state["candidates"]) > 1:
                candidates_list = state["candidates"][:-1]
                candidates_string = "The respondent believes that the answer is either " +\
                    ", ".join(candidates_list) + " or " + state["candidates"][-1]
            else:
                candidates_string = "The respondent believes that the answer is " + state["candidates"][0]

            connector = " and " if len(state["candidates"]) == 1 else ", but"

            # then create a list of all the negatives
            negatives_string = "."
            if len(tgt_should_include) > 1:
                negatives_list = tgt_should_include[:-1]
                negatives_string = connector + "not " + ", ".join(negatives_list) + " nor " + tgt_should_include[-1] + "."
            elif len(tgt_should_include) == 1:
                negatives_string = connector + "not " + tgt_should_include[0] + "."
            
            return {
                "belief": candidates_string + negatives_string
            }
        
        # with the check in hand, we run a modified vague_answer_step
        _belief_generation = RunnableLambda(
            _call_belief_generation
        )
        
        graph_builder.add_edge(
            "coverage_check",
            "belief_generation"
        )

        graph_builder.add_node(
            "belief_generation",
            _belief_generation
        )
        
        # finally, the belief can be used to run the vague answer step
        _call_vague_answer_step = VagueAnswerStep().induce_stated_callable(
            llm=llm,
            parse_input=lambda state: {
                "question": state["question"],
                "belief": state["belief"][-1]
            },
            parse_output=lambda output: {
                "general_answer": output.general_answer
            }
        )
        
        graph_builder.add_node(
            "vague_answer",
            _call_vague_answer_step
        )
        graph_builder.add_edge(
            "belief_generation",
            "vague_answer"
        )

        graph_builder.add_edge(
            "vague_answer",
            END,
        )
        
        graph = graph_builder.compile()

        return RunnableLambda(
            lambda inputs: VagueAnswerState(
                **inputs,
                cover_relations={},
                belief=[],
                general_answer=[]
            )
        ) | graph