""" """

import glob
import os
import ujson as json
from overrides import overrides
from typing import Text, Dict, Any
from tasker import BaseTask
from langchain_core.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from langchain_core.language_models.base import BaseLanguageModel
from langchain_openai import ChatOpenAI
from src.data_readers.iterative_clustering_data_reader import (
    IterativeClusteringDataReader,
    RoundIteration
)
from langchain_core.runnables import (
    RunnableLambda,
    Runnable,
    RunnablePassthrough,
    RunnableParallel
)
from langchain_interface.steps import (
    ExplainDiffStep,
)
from ..customized_interface.vague_answer_interface import VagueAnswerInterface
from ..langchain_step.vague_answer_step import VagueAnswerStep
from ..langchain_step.decl_answer_step import DeclAnswerStep

def chaining_discussion_to_vague_answer(
    round_idx,
    llm: BaseLanguageModel
):
    
    def _item_processing(item: RoundIteration) -> Dict[Text, Any]:
        question = item.question
        try:
            targ_round = item.rounds[round_idx]
            selected = targ_round.selected
            not_selected = targ_round.not_selected
        except IndexError:
            raise ValueError(f"Round index {round_idx} is out of bounds")
        
        return {
            "question": question,
            "candidates": selected,
            "negatives": not_selected,
        }
    
    return RunnableLambda(_item_processing) | VagueAnswerInterface().get_runnable(llm)
    
    # return RunnableLambda(_item_processing) | RunnableParallel(
    #     passthrough=RunnablePassthrough(),
    #     generation=RunnableLambda(lambda x: {k: v for k, v in x.items() if x != "question" and x != "topic" and x != "answer_template"}) | ExplainDiffStep().chain_llm(llm)
    # ) | RunnableParallel(
    #     passthrough=RunnableLambda(lambda x: (print(x), x['passthrough'])[1]) | RunnablePassthrough(),
    #     generation=RunnableLambda(lambda x: {
    #         # "question": x["passthrough"]["question"], "discussion": x["generation"].messages.strip(),
    #         # "question": x["passthrough"]["question"],
    #         "group_a": "\n".join([f"- {a}" for a in x["passthrough"]["group_a"]]), "group_b": '\n'.join([f"- {b}" for b in x["passthrough"]["group_b"]])
    #         # "group_a": "\n".join([f"- {a}" for a in x["passthrough"]["group_a"]])
    #     }) | VagueAnswerStep().chain_llm(llm)
    # ) | RunnableLambda(lambda x: {"question": x["passthrough"]["question"], "answer": x["generation"].general_answer}) | DeclAnswerStep().chain_llm(llm)


@BaseTask.register("answer-with-backoff")
class AnswerWithBackoffTask(BaseTask):

    __VERSION__ = "0.11.3"

    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        cache_path: Text
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        self._llm = ChatOpenAI(
            temperature=0,
            top_p=0.98,
            model="gpt-4o",
            max_tokens=None,
            verbose=True,
        )
        set_llm_cache(SQLiteCache(cache_path))
        
    @overrides
    def _run(self):
        
        iterator = IterativeClusteringDataReader([filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.jsonl"))])
        
        results = {}
        
        round_idx = 0
        while True:
            
            call_chain = chaining_discussion_to_vague_answer(round_idx, self._llm)
            filtered = [(idx, item) for idx, item in enumerate(iterator) if len(item.rounds) > round_idx]
            if not filtered:
                break
            filtered_items = [item for _, item in filtered]
            
            responses = call_chain.batch(filtered_items)
            
            for (idx, item), response in zip(filtered, responses):
                if idx not in results:
                    results[idx] = []
                    
                results[idx].append({
                    "index": idx,
                    "round_idx": round_idx,
                    "question": item.question,
                    "answer_template": item.answer_template,
                    "topic": item.topic,
                    "pos": item.rounds[round_idx].selected,
                    "neg": item.rounds[round_idx].not_selected,
                    "pos_multiplicity": item.rounds[round_idx].selected_multiplicity,
                    "neg_multiplicity": item.rounds[round_idx].not_selected_multiplicity,
                    "backoff": response['general_answer'][-1],
                    # "backoff": response.declarativized_answer
                })
                
            round_idx += 1

        # listerize
        return [sublist for sublist in results.values()]
        
    @overrides
    def _write(self, outputs):
        """ """
        
        with open(os.path.join(self._output_dir, "backoff.json"), "w", encoding='utf-8') as file_:
            json.dump(outputs, file_, ensure_ascii=False, indent=2)