""" """

import glob
import asyncio
import numpy as np
import os
import matplotlib.pyplot as plt
import ujson as json
from overrides import overrides
from typing import Text, Dict, Any
from tasker import BaseTask
from langchain_core.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from langchain_core.language_models.base import BaseLanguageModel
# from langchain_openai import ChatOpenAI
from langchain_interface.models import ChatOpenAIWithBatchAPI
from ..langchain_step import (
    SimpleQAVagueGradingStep,
    SimpleQAGradingStep
)
from src.data_readers import (
    SimpleQAAnswerBackoffDataReader,
    SimpleQAAnswerAtRound,
    SimpleQAAnswerAtRoundList,
)
from langchain_core.runnables import (
    RunnableLambda,
)


@BaseTask.register("simpleqa-scoring")
class SimpleQAScoringTask(BaseTask):
    """ """
    
    __VERSION__ = "0.3.2"
    
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        cache_path: Text,
    ):
        """ """

        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        
        self._llm = ChatOpenAIWithBatchAPI(
            temperature=0,
            top_p=1,
            model="gpt-4o",
            max_tokens=None,
            verbose=True,
        )
        
        set_llm_cache(SQLiteCache(cache_path))

    @overrides
    def _run(self):
        
        iterator = list(SimpleQAAnswerBackoffDataReader(
            [filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.json"))]
        ))
        
        def _parse_round_to_input_dict(item: SimpleQAAnswerAtRound) -> Dict[Text, Any]:
            """ This is used to process a single round's backoff answer """
            backoff = str(item.backoff)
            return dict(
                question=item.question,
                target=item.gold_answer,
                predicted_answer=backoff
            )
            
        flattened_rounds = []
        
        for item in iterator:
            for round in item.rounds:
                flattened_rounds.append(_parse_round_to_input_dict(round))

        # run the LLM
        runnable = SimpleQAVagueGradingStep().chain_llm(self._llm)
        responses = asyncio.run(runnable.abatch(flattened_rounds))

        # pair with its inputs
        ridx_to_result_dict = {
            0: [],
            40: [],
            60: [],
            80: [],
            100: [],
        }

        writing_dicts = []

        index = 0
        for iidx, item in enumerate(iterator):
            writing_dict = {
                "question": item.question,
                "gold_answer": item.gold_answer,
                "answer_type": item.answer_type,
                "index": item.index,
                "claims": []
            }
            for ridx, round in enumerate(item.rounds):
                # if ridx not in ridx_to_result_dict:
                #     ridx_to_result_dict[ridx] = []
                # print("-" * 20)
                # print(f"Question: {round.question}")
                # print(f"Gold target: {round.gold_answer}")
                # print(f"Predicted answer: {round.backoff}")
                # print(f"Grade: {responses[index].grade}")
                # print(f"Multiplicities: {round.pos_multiplicity_sum}")
                # print("-" * 20)

                multiplicity = round.pos_multiplicity
                for k in [100, 80, 60, 40, 0]:
                    if multiplicity >= k:
                        ridx_to_result_dict[k].append(
                            int(responses[index].is_correct)
                        )
                        break
                    
                writing_dict["claims"].append({
                    "backoff": round.backoff,
                    "multiplicity": round.pos_multiplicity,
                    "score": int(responses[index].is_correct)
                })
                
                index += 1
                
            writing_dicts.append(writing_dict)

        # also calculate all the question to answer pairs
        # qa_pairs = set()
        
        # for item in iterator:
        #     for round in item.rounds:
        #         round: SimpleQAAnswerAtRound
        #         for p in round.pos:
        #             qa_pairs.add((round.question, p, round.gold_answer))

        # calculate the accuracy
        # all_single_preds = [{"question": pair[0], "target": pair[2], "predicted_answer": pair[1]} for pair in qa_pairs]
        # precise_runnable = SimpleQAGradingStep().chain_llm(self._llm)
        # precise_responses = asyncio.run(precise_runnable.abatch(all_single_preds))
        # precise_response_dict = {(asp['question'], asp['predicted_answer'], asp['target']): pr.is_correct for asp, pr in zip(all_single_preds, precise_responses)}
        
        # ridx_to_agg_precise_result_dict = {
        #     0: [],
        #     40: [],
        #     70: [],
        #     100: [],
        # }
        
        # for item in iterator:
        #     for round in item.rounds:
        #         multiplicity = round.pos_multiplicity
        #         contains_correct = False
        #         for p in round.pos:
        #             contains_correct = contains_correct or precise_response_dict[(round.question, p, round.gold_answer)]
                
        #         for k in [100, 70, 40, 0]:
        #             if multiplicity >= k:
        #                 ridx_to_agg_precise_result_dict[k].append(
        #                     int(contains_correct)
        #                 )
        #                 break
                
        # aggregate the results
        scores = [np.mean(value).item() for _, value in sorted(ridx_to_result_dict.items(), key=lambda x: x[0], reverse=False) if value]
        x_values = [k for k, v in sorted(ridx_to_result_dict.items(), key=lambda x: x[0], reverse=False) if v]

        # aggregate the precise results
        # precise_scores = [np.mean(value).item() for _, value in sorted(ridx_to_agg_precise_result_dict.items(), key=lambda x: x[0], reverse=False) if value]

        fig, ax = plt.subplots()
        ax.plot(x_values, scores, label="SimpleQA Backoff")
        # ax.plot(x_values, precise_scores, label="SimpleQA Predictive Set")
        ax.set_xlabel("Multiplicity")
        ax.set_ylabel("Accuracy")
        ax.legend()
        
        ax.set_title("SimpleQA Backoff Results")
        
        return fig, writing_dicts
    
    @overrides
    def _write(self, outputs):
        outputs[0].savefig(os.path.join(self._output_dir, "simpleqa_scoring.png"))
        
        with open(os.path.join(self._output_dir, "simpleqa_scoring.jsonl"), "w", encoding='utf-8') as file_:
            for item in outputs[1]:
                file_.write(json.dumps(item) + "\n")