""" """

from glob import glob
import os
from tasker import BaseTask
from langchain_core.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
try:
    import ujson as json
except ImportError:
    import json
from overrides import overrides
from typing import (
    Text,
    Dict,
    List,
    Any,
)
from langchain_openai import ChatOpenAI
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables import (
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough
)
from ..langchain_step import DirectAnswerBackoffStep
from ..data_readers import (
    SimpleQAAnswerBackoffDataReader,
    SimpleQAScoringDataReader
)


@BaseTask.register('direct-rewriting')
class DirectRewritingTask(BaseTask):
    __VERSION__ = '0.0.3'
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        targ_percentages: List[int]
    ):
        """ """
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        self._targ_percentages = targ_percentages
        
        self._llm = ChatOpenAI(
            model="meta-llama/Meta-Llama-3-8B-Instruct",
            base_url="http://localhost:22659/v1",
            api_key="token-abc123",
            top_p=0.98,
            temperature=0.0
        )
        
        self._runnable_config = RunnableConfig(
            max_concurrency=128
        )

        self._chain = DirectAnswerBackoffStep().chain_llm(self._llm)
        
    @overrides
    def _run(self):
        """ """
        
        data_reader = list(SimpleQAScoringDataReader(data_path=glob(os.path.join(self._input_dir, '*.jsonl'))))
        
        outputs = [[] for _ in range(len(data_reader))]

        for percentage in self._targ_percentages:
            full_chain = RunnableLambda(
                lambda item: {
                    "index": item.index,
                    "question": item.question,
                    "answer_type": item.answer_type,
                    "gold_answer": item.gold_answer,
                    "rounds": [
                        {
                            "pos": None,
                            "neg": None,
                            "pos_multiplicity": item.backoffs[0].multiplicity,
                            "neg_multiplicity": max(1 - item.backoffs[0].multiplicity, 0),
                            "backoff": item.backoffs[0].backoff
                        }
                    ]
                }
            ) | RunnableParallel(
                passthrough=RunnablePassthrough(),
                generation=RunnableLambda(
                    lambda item: {
                        "question": item["question"],
                        "answer": item['rounds'][0]['backoff'],
                        "percentage": percentage
                    }
                ) | self._chain
            ) | RunnableLambda(
                lambda output: {
                    "index": output['passthrough']['index'],
                    "gold_answer": output['passthrough']['gold_answer'],
                    "question": output['passthrough']['question'],
                    "answer_type": output['passthrough']['answer_type'],
                    "pos": None,
                    "neg": None,
                    "pos_multiplicity": percentage,
                    "neg_multiplicity": 100 - percentage,
                    "backoff": output['generation'].answer_backoff
                }
            )
            
            results = full_chain.batch(data_reader, config=self._runnable_config)
            for i, result in enumerate(results):
                outputs[i].append(result)

        return outputs
    
    @overrides
    def _write(self, outputs):
        """ """

        with open(os.path.join(self._output_dir, 'output.json'), 'w') as f:
            json.dump(outputs, f, indent=2, ensure_ascii=False)