""" Answer with Backoff on SimpleQA. """

import glob
import os
import asyncio
import ujson as json
from overrides import overrides
from typing import Text, Dict, Any
from tasker import BaseTask
from langchain_core.runnables.config import RunnableConfig
from langchain_core.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from langchain_core.language_models.base import BaseLanguageModel
# from langchain_openai import ChatOpenAI
from langchain_interface.models import ChatOpenAIWithBatchAPI
from src.data_readers import (
    SimpleQAIterativeClusteringDataReader,
    SimpleQARoundIteration,
)
from src.data_readers.simpleqa.simpleqa_iterative_clustering_data_reader import (
    Round
)
from langchain_core.runnables import (
    RunnableLambda,
)
from ..langchain_step.vague_answer_step import VagueAnswerStep
# from ..customized_interface.vague_answer_interface import VagueAnswerInterface


# def chaining_discussion_to_vague_answer(
#     round_idx,
#     llm: BaseLanguageModel
# ):
    
#     def _item_processing(item: SimpleQARoundIteration) -> Dict[Text, Any]:
#         question = item.question
#         try:
#             targ_round = item.rounds[round_idx]
#             selected = targ_round.selected
#             not_selected = targ_round.not_selected
#         except IndexError:
#             raise ValueError(f"Round index {round_idx} is out of bounds")
        
#         return {
#             "question": question,
#             "candidates": selected,
#             "negatives": not_selected,
#         }
    
#     return RunnableLambda(_item_processing) | VagueAnswerInterface().get_runnable(llm)


@BaseTask.register("simpleqa-backoff")
class SimpleQABackoffTask(BaseTask):

    __VERSION__ = "0.2.3"

    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        cache_path: Text
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        self._llm = ChatOpenAIWithBatchAPI(
            temperature=0,
            top_p=0.98,
            model="gpt-4o",
            max_tokens=None,
            verbose=True,
        )
        set_llm_cache(SQLiteCache(cache_path))
        self._runnable_config = RunnableConfig(max_concurrency=16)
        
    @overrides
    def _run(self):
        
        iterator = list(SimpleQAIterativeClusteringDataReader([filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.jsonl"))]))
        
        flattened = [
            {
                "index": item_index,
                "gold_answer": item.gold_answer,
                "question": item.question,
                "answer_type": item.answer_type,
                "pos": round_.selected,
                "neg": round_.not_selected,
                "pos_multiplicity": sum(round_.selected_multiplicity),
                "neg_multiplicity": sum(round_.not_selected_multiplicity),
            } for item_index, item in enumerate(iterator) for round_ in item.rounds[1:]
        ]
        
        # backoff_callable = RunnableLambda(lambda x: {
        #     "question": x['question'],
        #     "candidates": x['pos'],
        #     "negatives": x['neg'],
        # }) | VagueAnswerInterface().get_runnable(self._llm)
        
        def _call_belief_generation(state: Dict[Text, Any]) -> Dict[Text, Any]:
            """ """
            
            tgt_should_include = state['neg']
            
            # first create a list of all the candidates
            candidates_string = ""
            if len(state["pos"]) > 1:
                candidates_list = state["pos"][:-1]
                candidates_string = "The respondent believes that the answer is either " +\
                    ", ".join(candidates_list) + " or " + state["pos"][-1]
            else:
                candidates_string = "The respondent believes that the answer is " + state["pos"][0]

            connector = " and " if len(state["pos"]) == 1 else ", but"

            # then create a list of all the negatives
            negatives_string = "."
            if len(tgt_should_include) > 1:
                negatives_list = tgt_should_include[:-1]
                negatives_string = connector + "not " + ", ".join(negatives_list) + " nor " + tgt_should_include[-1] + "."
            elif len(tgt_should_include) == 1:
                negatives_string = connector + "not " + tgt_should_include[0] + "."
            
            return {
                "question": state['question'],
                "belief": candidates_string + negatives_string
            }
            
        backoff_callable = RunnableLambda(_call_belief_generation) | VagueAnswerStep().chain_llm(self._llm)
        results = asyncio.run(backoff_callable.abatch(flattened, self._runnable_config))
        
        for flt, res in zip(flattened, results):
            flt['backoff'] = res.general_answer

        # group by index and within each index sort by pos_multiplicity
        grouped = {}
        for item in flattened:
            if item['index'] not in grouped:
                grouped[item['index']] = [{
                    "index": item['index'],
                    "gold_answer": item['gold_answer'],
                    "question": item['question'],
                    "answer_type": item['answer_type'],
                    "pos": iterator[item['index']].rounds[0].selected,
                    "neg": iterator[item['index']].rounds[0].not_selected,
                    "pos_multiplicity": sum(iterator[item['index']].rounds[0].selected_multiplicity),
                    "neg_multiplicity": sum(iterator[item['index']].rounds[0].not_selected_multiplicity),
                    "backoff": iterator[item['index']].rounds[0].selected[0]
                }]
            grouped[item['index']].append(item)
            
        return [sorted(val, key=lambda y: y['pos_multiplicity'], reverse=False) for _, val in sorted(grouped.items(), key=lambda x: x[0], reverse=False)]
        
    @overrides
    def _write(self, outputs):
        """ """
        
        with open(os.path.join(self._output_dir, "backoff.json"), "w", encoding='utf-8') as file_:
            json.dump(outputs, file_, ensure_ascii=False, indent=2)