""" Summarize Answer Into Clusters """

import glob
import os
import ujson as json
from tqdm import tqdm
from overrides import overrides
from typing import Text, Dict, Any
from tasker import BaseTask
from langchain_core.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from langchain_core.language_models.base import BaseLanguageModel
from langchain_openai import ChatOpenAI
from ..data_readers import (
    AnswerPhraseSampleDataReader,
)
from ..data_readers.answer_phrase_sample_data_reader import AnswerPhraseSampleItem
from langchain_core.runnables import (
    RunnableLambda,
    Runnable,
    RunnableBranch,
    RunnablePassthrough,
    RunnableParallel
)
from langchain_interface.steps import (
    DistinctClusterIdentificationStep
)


@BaseTask.register("answer-summarization")
class AnswerSummarizationTask(BaseTask):
    """ Summarize Answer Into Clusters """

    __VERSION__ = "0.3.5"
    
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        cache_path: Text
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        
        self._llm = ChatOpenAI(
            temperature=0,
            top_p=1,
            model="gpt-4o",
            max_tokens=None,
            verbose=True,
        )

        set_llm_cache(SQLiteCache(cache_path))

    @overrides
    def _run(self):
        """ Run the task """
        
        def _single_answer_abstention_detector(answer) -> bool:
            """ Detect refuse to answer situation in the answer """
            abstain_responses = [
                "I'm sorry, I cannot fulfill that request.",
                "I'm sorry, I can't fulfill that request.",
                "I'm sorry, but I cannot fulfill that request.",
                "I'm sorry, but I can't fulfill that request.",
                "Sorry, but I can't fulfill that request.",
                "Sorry, I can't do that.",
                "don't know",
                "don't have information",
                "don't have the information",
                "don't have that information",
                "no information",
                "unspecified",
                "non-existent",
                "non-disclosure",
                "not available",
                "unknown",
                "PLACEHOLDER",
                "as a language model",
                "none",
                "null",
                "nil",
                "vague answer",
                "lack of information",
                "lack of sufficient",
                "lack of specific",
            ]
            
            # lowercase
            abstain_responses = [abstain_response.lower() for abstain_response in abstain_responses]
            
            for abstain_response in abstain_responses:
                try:
                    answer.lower().index(abstain_response)
                    return True
                except ValueError:
                    return False
        
        
        def _convert_item(item: AnswerPhraseSampleItem) -> Dict[Text, Any]:
            """ """
            filtered_answers = [answer for answer in item.answers if not _single_answer_abstention_detector(answer)]
            return {
                "question": item.question,
                "topic": item.topic,
                "answer_template": item.answer_template,
                "filtered_answers": filtered_answers,
            }
            
        iterator = map(
            lambda x: _convert_item(x),
            AnswerPhraseSampleDataReader(
                [filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.jsonl"))]
            )
        )
            
        tagged_instances = [{"back_ref_id": idx, **item} for idx, item in enumerate(iterator)]
        
        runnable_chain = RunnableParallel(
            passthrough=RunnablePassthrough(),
            generation=RunnableBranch(
                (
                    lambda x: x["filtered_answers"], RunnableLambda(
                        lambda x: {"str_list": '\n'.join([f"- {answer}" for answer in x["filtered_answers"]])}
                    ) | DistinctClusterIdentificationStep().chain_llm(self._llm) | RunnableLambda(
                        lambda x: {
                            "clusters": x.clusters,
                            "messages": x.messages
                        }
                    )
                ),
                lambda _: {"clusters": []}
            )
        ) | RunnableLambda(
            lambda x: {
                "back_ref_id": x["passthrough"]["back_ref_id"],
                "question": x["passthrough"]["question"],
                "topic": x["passthrough"]["topic"],
                "answer_template": x["passthrough"]["answer_template"],
                "filtered_answers": x["passthrough"]["filtered_answers"],
                "clusters": x["generation"]["clusters"],
                "messages": x["generation"]["messages"]
            }
        )

        return [runnable_chain.invoke(ti) for ti in tqdm(tagged_instances)]
    
    @overrides
    def _write(self, outputs):
        """ Write the outputs """
        with open(os.path.join(self._output_dir, "output.jsonl"), "w") as file_:
            for output in outputs:
                file_.write(json.dumps(output) + "\n")