""" """

import os
from tasker import BaseTask
from overrides import overrides
try:
    import ujson as json
except ImportError:
    import json
from typing import (
    Text,
    Iterable,
    Dict,
    Union,
    List,
    Any
)
import logging
from langchain_interface.models import ChatOpenAIWithBatchAPI
from langchain_core.runnables import (
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough
)
from langchain_core.runnables.config import RunnableConfig
from ..langchain_step.factual_response_on_topic_step import FactualResponseOnTopicStep


logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)


@BaseTask.register("sample-response")
class SampleResponseTask(BaseTask):
    __VERSION__ = "0.0.1"
    
    def __init__(
        self,
        input_path: Union[Text, List[Text]],
        output_dir: Text
    ):
        super().__init__(output_dir=output_dir)
        self._input_paths = input_path if isinstance(input_path, list) else [input_path]
        
        self._llm = ChatOpenAIWithBatchAPI(
            model="meta-llama/Meta-Llama-3-8B-Instruct",
            base_url="http://localhost:22659/v1",
            api_key="token-abc123",
            top_p=0.9,
            temperature=0.8,
        )
        self._agent = FactualResponseOnTopicStep().chain_llm(self._llm)
        self._runnable_config = RunnableConfig(
            max_concurrency=16,
        )

    @overrides
    def _run(self):

        def _get_instance(paths: List[Text]) -> Iterable[Dict[Text, Any]]:
            """ """
            
            for path in paths:
                with open(path, "r", encoding='utf-8') as file_:
                    for lidx, line in enumerate(file_):
                        yield {
                            "topic": line.strip(),
                            "id_": lidx,
                        }
                        
        processor = RunnableParallel(
            passthrough=RunnablePassthrough(),
            generation= RunnableLambda(
                lambda x: {
                    "topic": x["topic"],
                }
            ) | self._agent,
        ) | RunnableLambda(
            lambda x: {
                **x['passthrough'],
                "generation": x['generation'].messages,
            }
        )
        
        return processor.batch(list(_get_instance(self._input_paths)), config=self._runnable_config)
    
    @overrides
    def _write(self, outputs):
        """ """
        
        with open(os.path.join(self._output_dir, "output.jsonl"), "w", encoding='utf-8') as file_:
            for item in outputs:
                file_.write(json.dumps(item) + "\n")