""" """

import glob
import asyncio
import numpy as np
import os
import matplotlib.pyplot as plt
import ujson as json
from overrides import overrides
from typing import Text, Dict, Any
from tasker import BaseTask
from langchain_core.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from langchain_core.language_models.base import BaseLanguageModel
# from langchain_openai import ChatOpenAI
from langchain_interface.models import (
    ChatOpenAIWithBatchAPI,
    BatchedAPIConfig
)
from ..langchain_step import (
    SimpleQAVagueGradingStep,
    SimpleQAGradingStep
)
from src.data_readers import (
    SimpleQAAnswerBackoffDataReader,
    SimpleQAAnswerAtRound,
    SimpleQAAnswerAtRoundList,
)
from langchain_core.runnables import (
    RunnableLambda,
)


__SEPARATOR__ = " [@@SEP@@] "


@BaseTask.register("nq-scoring")
class NQScoringTask(BaseTask):
    """ """
    
    __VERSION__ = "0.0.1"
    
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        cache_path: Text,
    ):
        """ """

        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        
        set_llm_cache(SQLiteCache(cache_path))
        self._llm = ChatOpenAIWithBatchAPI(
            temperature=0,
            top_p=0.98,
            model="gpt-4o",
            max_tokens=None,
            verbose=True,
        )

        self._runnable_config = BatchedAPIConfig(
            max_concurrency=8,
            max_abatch_size=10000
        )

    @overrides
    def _run(self):
        
        iterator = list(SimpleQAAnswerBackoffDataReader(
            [filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.json"))]
        ))
        
        def _parse_round_to_input_dict(item: SimpleQAAnswerAtRound) -> Dict[Text, Any]:
            """ This is used to process a single round's backoff answer """
            backoff = str(item.backoff)
            gold_answer_set = set(a.strip() for a in item.gold_answer.split(__SEPARATOR__))
            return [dict(
                question=item.question,
                target=gold_answer,
                predicted_answer=backoff
            ) for gold_answer in gold_answer_set]
            
        num_gold = []
        flattened_rounds = []
        
        for item in iterator:
            for round in item.rounds:
                # extend all flattened rounds
                rl = _parse_round_to_input_dict(round)
                flattened_rounds.extend(rl)
                num_gold.append(len(rl))

        # run the LLM
        runnable = SimpleQAVagueGradingStep().chain_llm(self._llm)
        print(f"Running {len(flattened_rounds)} instances after dedup.")

        flatten_responses = asyncio.run(runnable.abatch(flattened_rounds, config=self._runnable_config))
        # flatten_responses = runnable.batch(flattened_rounds, config=self._runnable_config)

        # aggregate responses

        responses = []
        
        for ng_ in num_gold:
            responses.append(any([fr.is_correct for fr in flatten_responses[:ng_]]))
            flatten_responses = flatten_responses[ng_:]

        # pair with its inputs
        ridx_to_result_dict = {
            0: [],
            40: [],
            60: [],
            80: [],
            100: [],
        }

        writing_dicts = []

        index = 0
        for iidx, item in enumerate(iterator):
            writing_dict = {
                "question": item.question,
                "gold_answer": item.gold_answer,
                "answer_type": item.answer_type,
                "index": item.index,
                "claims": []
            }
            for ridx, round in enumerate(item.rounds):
                # if ridx not in ridx_to_result_dict:
                #     ridx_to_result_dict[ridx] = []
                # print("-" * 20)
                # print(f"Question: {round.question}")
                # print(f"Gold target: {round.gold_answer}")
                # print(f"Predicted answer: {round.backoff}")
                # print(f"Grade: {responses[index].grade}")
                # print(f"Multiplicities: {round.pos_multiplicity_sum}")
                # print("-" * 20)

                multiplicity = round.pos_multiplicity
                for k in [100, 80, 60, 40, 0]:
                    if multiplicity >= k:
                        ridx_to_result_dict[k].append(
                            int(responses[index])
                        )
                        break
                    
                writing_dict["claims"].append({
                    "backoff": round.backoff,
                    "multiplicity": round.pos_multiplicity,
                    "score": int(responses[index])
                })
                
                index += 1
                
            writing_dicts.append(writing_dict)
                
        # aggregate the results
        scores = [np.mean(value).item() for _, value in sorted(ridx_to_result_dict.items(), key=lambda x: x[0], reverse=False) if value]
        x_values = [k for k, v in sorted(ridx_to_result_dict.items(), key=lambda x: x[0], reverse=False) if v]

        # aggregate the precise results
        # precise_scores = [np.mean(value).item() for _, value in sorted(ridx_to_agg_precise_result_dict.items(), key=lambda x: x[0], reverse=False) if value]

        fig, ax = plt.subplots()
        ax.plot(x_values, scores, label="NQ Backoff")
        # ax.plot(x_values, precise_scores, label="SimpleQA Predictive Set")
        ax.set_xlabel("Multiplicity")
        ax.set_ylabel("Accuracy")
        ax.legend()
        
        ax.set_title("Natural Question Backoff Results")
        
        return fig, writing_dicts
    
    @overrides
    def _write(self, outputs):
        outputs[0].savefig(os.path.join(self._output_dir, "nq_scoring.png"))
        
        with open(os.path.join(self._output_dir, "nq_scoring.jsonl"), "w", encoding='utf-8') as file_:
            for item in outputs[1]:
                file_.write(json.dumps(item) + "\n")