""" This task intend to generate questions from a given claim """


import glob
import numpy as np
import ujson as json
import os
import copy
from overrides import overrides
from typing import (
    Text, List, Dict, Any
)
from tasker import BaseTask
from rapidfuzz.distance import Levenshtein
from langchain_interface.steps import (
    AnswerShorteningStep,
    AnchoredClusteringStep,
    QuizQuestionStep
)
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.runnables import (
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough
)
from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI
from langchain_interface.steps import Step
from langchain_core.globals import set_llm_cache
from langchain_community.llms import VLLM
# from langchain_community.llms.vllm import VLLMOpenAI
from langchain_community.cache import SQLiteCache

from ..data_readers import VerifiableClaimsDataReader
from ..data_readers.verifiable_claims_data_reader import ClaimExtraction


@BaseTask.register("question-generation-on-claim")
class QuestionGenerationOnClaim(BaseTask):
    __VERSION__ = "0.0.1"
    
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        # self._llm = VLLM(
        #     # model="meta-llama/Meta-Llama-3-8B-Instruct",
        #     model="mistralai/Mistral-7B-Instruct-v0.3",
        #     trust_remote_code=True,
        #     max_new_tokens=128,
        #     temperature=0.0
        # )

        self._llm = ChatOpenAI(
            model="meta-llama/Meta-Llama-3-8B-Instruct",
            base_url="http://localhost:8000/v1",
            api_key="local-host",
            top_p=0.9
        )
        self._runnable_config = RunnableConfig(
            max_concurrency=64
        )
        
    @overrides
    def _run(self):
        """ """
        
        call_chain = QuizQuestionStep().chain_llm(self._llm)
        call_chain = call_chain.with_retry(stop_after_attempt=10)

        # the data should be in a jsonl input format.
        iterator = VerifiableClaimsDataReader(glob.glob(os.path.join(self._input_dir, "claims_*.jsonl")))

        inputs = [
            {
                "back_ref_id": idx,
                "claim": claim,
                "cidx": cidx,
                "topic": item.topic,
            } for idx, item in enumerate(iterator) for cidx, claim in enumerate(item.claims)
        ]

        def _convert_to_ipt(item: Dict) -> Dict:
            return {
                "claim": item["claim"],
            }
            
        pipeline = RunnableParallel(
            passthrough=RunnablePassthrough(),
            generation=RunnableLambda(_convert_to_ipt) | call_chain
        ) | RunnableLambda(lambda x: {
            **x["passthrough"],
            "question": x["generation"].question,
            "answer_template": x["generation"].answer_template,
            "placeholder_start": x["generation"].place_holder_start,
            "placeholder_end": x["generation"].place_holder_end
        })

        return pipeline.batch(inputs, config=self._runnable_config)
    
    @overrides
    def _write(self, outputs):
        """ """
        
        with open(os.path.join(self._output_dir, "questions.jsonl"), "w") as f:
            for item in outputs:
                f.write(json.dumps(item) + "\n")