""" Given answer template, generate a phrasic answer. """

import glob
import logging
from tqdm import tqdm
import numpy as np
import ujson as json
import os
import copy
from overrides import overrides
from typing import (
    Text, List, Dict, Any
)
from tasker import BaseTask
from rapidfuzz.distance import Levenshtein
from langchain_interface.steps import (
    AnswerShorteningStep,
    AnchoredClusteringStep,
    QuizQuestionStep
)
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.runnables import (
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough
)
from langchain_core.runnables.base import Runnable
from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI
from langchain_interface.steps import Step
from langchain_interface.steps import TestOnQuizStep
from langchain_core.globals import set_llm_cache
from langchain_community.llms import VLLM
# from langchain_community.llms.vllm import VLLMOpenAI
from langchain_community.cache import SQLiteCache
from ..data_readers.converted_question_data_reader import ConvertedQuestionDataReader, ConvertedQuestionItem


logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)


@BaseTask.register("sample-phrase-answer-with-template")
class SamplePhraseAnswerWithTemplateTask(BaseTask):
    __VERSION__ = "0.3.2"
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        num_samples: int,
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        self._llm = ChatOpenAI(
            model="meta-llama/Meta-Llama-3-8B-Instruct",
            base_url="http://localhost:22659/v1",
            api_key="token-abc123",
            top_p=0.9,
        )
        self._runnable_config = RunnableConfig(
            max_concurrency=16,
        )
        self._num_samples = num_samples

    @overrides
    def _run(self):
        """ """
        
        step = TestOnQuizStep()
        
        def _convert_item(item: ConvertedQuestionItem) -> Dict[Text, Any]:
            return {
                "question": item.question,
                "answer_template": item.answer_template,
            }
        
        runnable_chain: Runnable = RunnableLambda(_convert_item) | step.chain_llm(self._llm).with_retry(stop_after_attempt=5)
            
        responses = []
        # iterator = list(filter(lambda x: x.placeholder_start >= 0, ConvertedQuestionDataReader(data_path=glob.glob(os.path.join(self._input_dir, "*.jsonl")))))
        # for _ in tqdm(range(self._num_samples)):
        #     responses.append(runnable_chain.batch(list(iterator), config=self._runnable_config))
            
        # # transpose the responses
        # responses = [list(x) for x in zip(*responses)]
        
        outputs = []

        iterator = list(filter(lambda x: x.placeholder_start >= 0, ConvertedQuestionDataReader(data_path=glob.glob(os.path.join(self._input_dir, "*.jsonl")))))
        for item in tqdm(iterator):
            try:
                response_batch = runnable_chain.batch([item] * self._num_samples, config=self._runnable_config)
            except Exception as e:
                logger.error(f"Error: {e}")
                responses.append(None)
                continue
            responses.append(response_batch)
        
        
        for res_group, item in zip(responses, iterator):
            if res_group is None:
                continue
            outputs.append({
                "question": item.question,
                "back_ref_id": item.back_ref_id,
                "topic": item.topic,
                "cidx": item.cidx,
                "answer_template": item.answer_template,
                "placeholder_start": item.placeholder_start,
                "placeholder_end": item.placeholder_end,
                "answers": [res.infill for res in res_group]
            })
            
        return outputs
    
    @overrides
    def _write(self, outputs):
        with open(os.path.join(self._output_dir, "output.jsonl"), "w") as file_:
            for item in outputs:
                file_.write(json.dumps(item) + "\n")
