""" Summarize SimpleQA Answer Into Clusters """

import glob
import os
import ujson as json
from tqdm import tqdm
from overrides import overrides
from typing import Text, Dict, Any
from tasker import BaseTask
from langchain_core.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from langchain_core.language_models.base import BaseLanguageModel
from langchain_openai import ChatOpenAI
from ..data_readers import (
    SimpleQAPhraseSamplingDataReader,
    SimpleQAPhraseSamplingInstance
)
from langchain_core.runnables import (
    RunnableLambda,
    Runnable,
    RunnableBranch,
    RunnablePassthrough,
    RunnableParallel
)
from ..langchain_step.distinct_cluster_identification_step import DistinctClusterIdentificationStep


@BaseTask.register("simpleqa-answer-summarization")
class SimpleQAAnswerSummarizationTask(BaseTask):
    """ Summarize Answer Into Clusters """

    __VERSION__ = "0.1.1"
    
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        cache_path: Text
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        
        self._llm = ChatOpenAI(
            temperature=0,
            top_p=1,
            model="gpt-4o",
            max_tokens=None,
            verbose=True,
        )

        set_llm_cache(SQLiteCache(cache_path))

    @overrides
    def _run(self):
        """ Run the task """
        
        def _single_answer_abstention_detector(answer) -> bool:
            """ Detect refuse to answer situation in the answer """
            abstain_responses = [
                "I'm sorry, I cannot fulfill that request.",
                "I'm sorry, I can't fulfill that request.",
                "I'm sorry, but I cannot fulfill that request.",
                "I'm sorry, but I can't fulfill that request.",
                "Sorry, but I can't fulfill that request.",
                "Sorry, I can't do that.",
                "don't know",
                "don't have information",
                "don't have the information",
                "don't have that information",
                "no information",
                "unspecified",
                "non-existent",
                "non-disclosure",
                "not available",
                "unknown",
                "PLACEHOLDER",
                "as a language model",
                "none",
                "null",
                "nil",
                "vague answer",
                "lack of information",
                "lack of sufficient",
                "lack of specific",
            ]
            
            # lowercase
            abstain_responses = [abstain_response.lower() for abstain_response in abstain_responses]
            
            for abstain_response in abstain_responses:
                try:
                    answer.lower().index(abstain_response)
                    return True
                except ValueError:
                    return False
                
        def _convert_item(item: SimpleQAPhraseSamplingInstance) -> Dict[Text, Any]:
            """ """
            filtered_answers = [answer for answer in item.sampled_answers if not _single_answer_abstention_detector(answer)]
            return {
                "question": item.question,
                "answer_type": item.answer_type,
                "gold_answer": item.gold_answer,
                "filtered_answers": filtered_answers,
            }
        
        iterator = map(
            _convert_item,
            SimpleQAPhraseSamplingDataReader(
                [filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.jsonl"))]
            )
        )
        tagged_instances = [{"back_ref_id": idx, **item} for idx, item in enumerate(iterator)]

        runnable_chain = RunnableParallel(
            passthrough=RunnablePassthrough(),
            generation=RunnableBranch(
                (
                    lambda x: x["filtered_answers"], RunnableLambda(
                        lambda x: {"str_list": '\n'.join([f"- {answer}" for answer in x["filtered_answers"]])}
                    ) | DistinctClusterIdentificationStep().chain_llm(self._llm) | RunnableLambda(
                        lambda x: {
                            "clusters": x.clusters,
                            "messages": x.messages
                        }
                    )
                ),
                lambda _: {"clusters": []}
            )
        ) | RunnableLambda(
            lambda x: {
                "back_ref_id": x["passthrough"]["back_ref_id"],
                "question": x["passthrough"]["question"],
                "answer_type": x["passthrough"]["answer_type"],
                "gold_answer": x["passthrough"]["gold_answer"],
                "filtered_answers": x["passthrough"]["filtered_answers"],
                "clusters": x["generation"]["clusters"],
                "messages": x["generation"]["messages"]
            }
        )
        
        # return [runnable_chain.invoke(ti) for ti in tqdm(tagged_instances)]
        return runnable_chain.batch(tagged_instances)
    
    @overrides
    def _write(self, outputs):
        """ """
        
        with open(os.path.join(self._output_dir, "output.jsonl"), "w") as file_:
            for output in outputs:
                file_.write(json.dumps(output) + "\n")