""" """

import glob
import logging
from tqdm import tqdm
import numpy as np
import ujson as json
import os
import copy
from overrides import overrides
from typing import (
    Text, List, Dict, Any
)
from tasker import BaseTask
from langchain_core.runnables import (
    RunnableLambda,
    RunnableParallel
)
from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI
from langchain_community.llms import VLLM
from ..langchain_step import SimpleQASamplingStep
from ..data_readers import SimpleQADataReader


logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)


@BaseTask.register("simpleqa-answer-sampling")
class SimpleQAAnswerSamplingTask(BaseTask):
    __VERSION__ = "0.1.1"
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        num_samples: int,
    ):
        """ """
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        
        self._llm = ChatOpenAI(
            model="meta-llama/Meta-Llama-3-8B-Instruct",
            base_url="http://localhost:22659/v1",
            api_key="token-abc123",
            top_p=0.98,
        )
        
        self._runnable_config = RunnableConfig(
            max_concurrency=128,
        )
        self._num_samples = num_samples

    @overrides
    def _run(self):
        """ """
        
        iterator = list(SimpleQADataReader(data_path=glob.glob(os.path.join(self._input_dir, "*.csv"))))
        runnable_chain = RunnableLambda(lambda x: {
            "question": x.question,
        }) | SimpleQASamplingStep().chain_llm(self._llm).with_retry(stop_after_attempt=3) | RunnableLambda(lambda x: x.phrase)
        
        responses = []
        
        for item in tqdm(iterator):
            try:
                response_batch = runnable_chain.batch([item] * self._num_samples, config=self._runnable_config)
                # print(response_batch)
            except Exception as e:
                logger.error(f"Error: {e}")
                responses.append(None)
                continue
            responses.append(response_batch)

        outputs = []
            
        for item, response in zip(iterator, responses):
            if response is None:
                continue
            
            outputs.append({
                "index": item.index,
                "question": item.question,
                "gold_answer": item.answer,
                "answer_type": item.answer_type,
                "sampled_answers": response,
            })

        return outputs
    
    @overrides
    def _write(self, outputs):
        """ """
        with open(os.path.join(self._output_dir, "output.jsonl"), "w") as file_:
            for item in outputs:
                file_.write(json.dumps(item) + "\n")