""" A task that use a trained model to rewrite a text. """

import os
from glob import glob
from overrides import overrides
from copy import deepcopy
try:
    import ujson as json
except ImportError:
    import json
from langchain_interface.models import ChatOpenAIWithBatchAPI
from langchain_core.runnables.config import RunnableConfig
from langchain_core.output_parsers import BaseOutputParser
from langchain.prompts import (
    ChatPromptTemplate,
    FewShotChatMessagePromptTemplate,
)
from tasker import BaseTask
from typing import (
    Text,
    Tuple,
    Dict,
    List,
    Any,
)


class RewritingParser(BaseOutputParser[str]):
    
    def parse(self, text: Text) -> str:
        return text.strip()
    
    @property
    def _type(self) -> Text:
        return "rewriting"


@BaseTask.register("trained-rewriting")
class TrainedRewritingTask(BaseTask):
    __VERSION__ = "0.0.2"

    def __init__(
        self,
        input_dir: Text,
        model_dir: Text,
        output_dir: Text,
        level_of_rewrites: List[Tuple[Text, int]]
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        self._level_of_rewrites = level_of_rewrites
        self._llm = ChatOpenAIWithBatchAPI(
            # model="meta-llama/Meta-Llama-3-8B-Instruct",
            model=model_dir,
            base_url="http://localhost:22659/v1",
            api_key="token-abc123",
            temperature=0.0,
            top_p=0.98
        )
        self._runnable_config = RunnableConfig(
            max_concurrency=64
        )

    @overrides
    def _run(self):
        """ """
        
        chat_prompt_template = ChatPromptTemplate.from_messages(
            [
                (
                    "human", 
                    (
                        "Rewrite the following claim to be less specific until you {verbalization} it is true: {source_claim}\n\n"
                        "Your response should only contain the claim itself, without any additional context."
                    )
                )
            ]
        )
        
        data = []
        
        for filepath in glob(os.path.join(self._input_dir, '*.jsonl')):
            with open(filepath, 'r', encoding='utf-8') as file_:
                data.extend([json.loads(line) for line in file_])

        number_of_claims = []
        
        inputs = []
                
        for instance in data:
            number_of_claims.append(len(instance['claims']))
            for claim in instance['claims']:
                for verbalization, _ in self._level_of_rewrites:
                    inputs.append({
                        "verbalization": verbalization,
                        "source_claim": claim['claim']
                    })
                    
        chain = chat_prompt_template | self._llm | RewritingParser()
        outputs = chain.batch(inputs, self._runnable_config)
        
        results = {}

        for vidx, (verbalization, level) in enumerate(self._level_of_rewrites):
            claims = outputs[vidx::len(self._level_of_rewrites)]
            data_copy = deepcopy(data)
            
            for instance in data_copy:
                instance['id_'] = f"{instance['id_']}-{level}"
                for claim in instance['claims']:
                    claim['claim'] = claims.pop(0)
                    
            results[f'<{level}>'] = data_copy
            
        return results
    
    @overrides
    def _write(self, outputs):
        for level, data in outputs.items():
            with open(os.path.join(self._output_dir, f'level_{level}.jsonl'), 'w', encoding='utf-8') as file_:
                for instance in data:
                    file_.write(json.dumps(instance) + '\n')