""" """

import glob
import numpy as np
import os
import ujson as json
import matplotlib.pyplot as plt
from tqdm import tqdm
from overrides import overrides
from typing import Text, Dict, Any, List, Tuple
from tasker import BaseTask
from langchain_core.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from langchain_core.language_models.base import BaseLanguageModel
from langchain_openai import ChatOpenAI
from src.data_readers import (
    AnswerPhraseSampleDataReader,
)
from ..data_readers.answer_backoff_data_reader import (
    AnswerBackoffDataReader,
    AnswerAtRound,
    AnswerAtRoundList
)
from ..factual_scorer import (
    Scorer,
    LLMSupportScorer
)
from ..utils.instances import ScorerInstance
from langchain_core.runnables import (
    RunnableLambda,
    Runnable,
    RunnableBranch,
    RunnablePassthrough,
    RunnableParallel
)


@BaseTask.register("graded-factuality-scoring")
class GradedFactualityScoringTask(BaseTask):
    """ """
    
    __VERSION__ = "0.1.4"
    
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        cache_path: Text,
        support_scorer: Scorer
    ):
        """ """

        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        
        # self._llm = ChatOpenAI(
        #     temperature=0,
        #     top_p=1,
        #     model="gpt-4o",
        #     max_tokens=None,
        #     verbose=True,
        # )
        set_llm_cache(SQLiteCache(cache_path))
        
        self._support_scorer = support_scorer

    @overrides
    def _run(self):
        """ """
        
        iterator = AnswerBackoffDataReader(
            [filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.json"))]
        )
        
        # iterate over the data and create
        def _parse_round_to_scorer_instance(item: AnswerAtRound) -> ScorerInstance:
            """ This is used to process a single round's backoff answer """
            return ScorerInstance(
                text=item.backoff,
                topic=item.topic,
                source_text=None
            )
        
        def _parse_round_to_scorer_instance_list(item: AnswerAtRound) -> List[ScorerInstance]:
            """ This is used to process an answer to a list of positive choices """

            return [ScorerInstance(
                text=text,
                topic=item.topic,
                source_text=None
            ) for text in item.pos]
            
        backoff_scores = []
        backoff_results_raw_agg = []
        agg_scores = []
            
        for item in tqdm(iterator):
            backoff_instances = [
                _parse_round_to_scorer_instance(round)
                for round in item.rounds
            ]

            # also create a list of potentially positive claims
            potential_positives = [
                _parse_round_to_scorer_instance_list(round)
                for round in item.rounds
            ]

            # generate a list of all_partitions length
            lengths = [len(backoff_instances)] + [len(pp) for pp in potential_positives]
            
            # flatten all
            instances = backoff_instances + [instance for sublist in potential_positives for instance in sublist]
            # score the instances
            results = self._support_scorer(instances, return_raw=True)

            # split the results according to the lengths
            backoff_results_raw = results[:lengths[0]]
            backoff_results_raw_agg.append([{"claim": bins.text, **br} for bins, br in zip(backoff_instances, backoff_results_raw)])
            backoff_results = [bpoint['parsed'] for bpoint in backoff_results_raw]
            results = results[lengths[0]:]
            positive_results = []
            
            for length in lengths[1:]:
                positive_results.append(results[:length])
                results = results[length:]
                
            assert len(results) == 0, "Results should be empty after processing."

            # calculate agg score
            positive_results_agg = [max([ppoint['parsed'] for ppoint in pres]) for pres in positive_results]
            
            backoff_scores.append(backoff_results)
            agg_scores.append(positive_results_agg)
            
        backoff_scores = np.mean(np.array(backoff_scores), axis=0)
        agg_scores = np.mean(np.array(agg_scores), axis=0)
        backoff_scores[0] = agg_scores[0]
        
        fig, ax = plt.subplots()

        ax.plot(backoff_scores, label="Backoff")
        ax.plot(agg_scores, label="Group")
        ax.legend()
        
        ax.set_ylabel("Score")
        ax.set_xlabel("Round")
        
        ax.set_title("Backoff Results")
        
        return fig, {"mean_score_diff": np.mean(np.abs(backoff_scores - agg_scores)).item(), "backoff": backoff_results_raw_agg}
    
    @overrides
    def _write(self, outputs):
        """ """
        outputs: Tuple[plt.Figure, float]
        outputs[0].savefig(os.path.join(self._output_dir, "backoff_results.png"))
        with open(os.path.join(self._output_dir, "backoff_results.json"), "w") as f:
            json.dump(outputs[1], f, indent=2)
        plt.close(outputs[0])