""" Since at the moment the top-1 answer string is not a claim, we need to declarativize it.  """

import os
try:
    import ujson as json
except ImportError:
    import json
from glob import glob
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM
)
from tasker import BaseTask
from typing import (
    Optional
)
from overrides import overrides


@BaseTask.register("declarativize-top-answer")
class DeclarativizeTopAnswerTask(BaseTask):
    
    __VERSION__ = '0.0.1'
    
    def __init__(
        self,
        input_dir: str,
        output_dir: str,
        batch_size: int,
        device: Optional[int] = 0
    ):
        """ """
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        self._batch_size = batch_size
        self._device = device
        self._tokenizer = AutoTokenizer.from_pretrained('domenicrosati/question_converter-3b')
        self._model = AutoModelForSeq2SeqLM.from_pretrained('domenicrosati/question_converter-3b')
        self._model.to(self._device)
        self._model.eval()
        
    @overrides
    def _run(self):
        """ """
        
        file_paths = glob(os.path.join(self._input_dir, '*.jsonl'))
        data = []
        inputs = []

        for file_path in file_paths:
            with open(file_path, 'r', encoding='utf-8') as file_:
                for line in file_:
                    datapiece = json.loads(line)
                    question = datapiece['question'].strip()
                    if not question.endswith('?'):
                        question += ' ?'
                    data.append(datapiece)
                    inputs.append(f"{question} </s> {datapiece['claims'][0]['backoff']}")
                    
        input_ids = self._tokenizer(
            inputs,
            padding=True,
            truncation=True,
            return_tensors='pt'
        ).input_ids
        
        # split input_ids to batches
        batches = []
        all_generations = []
        for i in range(0, len(input_ids), self._batch_size):
            batches.append(input_ids[i:i+self._batch_size])
        
        for batch in tqdm(batches):
            generated = self._model.generate(
                batch.to(self._device),
                max_length=256,
                num_beams=4,
                no_repeat_ngram_size=2,
                early_stopping=True
            )
            generated_text = self._tokenizer.batch_decode(generated, skip_special_tokens=True)
            all_generations.extend(generated_text)

        for i, generation in enumerate(all_generations):
            data[i]['claims'][0]['backoff'] = generation

        return data
    
    @overrides
    def _write(self, outputs):
        """ """
        
        with open(os.path.join(self._output_dir, "top_answer_declarativized.jsonl"), 'w', encoding='utf-8') as file_:
            for output in outputs:
                file_.write(json.dumps(output, ensure_ascii=False) + '\n')