""" """

import json
import os
import torch
import re
from tqdm import tqdm
from overrides import overrides
from tasker import BaseTask
from typing import Text, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
from ..factual_scorer import LLMSupportScorer
from ..retriever import WikiDocRetriever


@BaseTask.register("back-off-from-source-claim")
class BackOffFromSourceClaimTask(BaseTask):
    __VERSION__ = "0.0.3"

    def __init__(
        self,
        model_path: Text,
        input_path: Text,
        batch_size: int,
        output_dir: Text
    ):
        super().__init__(output_dir=output_dir)
        self._input_path = input_path
        self._batch_size = batch_size
        
        # self._scorer = LLMSupportScorer(
        #     model_name="meta-llama/Meta-Llama-3-8B-Instruct",
        #     retriever=WikiDocRetriever(
        #         db_path="db/enwiki-20230401.db",
        #         cache_path="db/.cache/bios.cache",
        #         embed_cache_path="db/.cache/bios.embed.cache",
        #         batch_size=128
        #     ),
        #     base_url="http://localhost:22659/v1",
        #     api_key="token-abc123"
        # )
        
        self._model_path = model_path
        config = PeftConfig.from_pretrained(self._model_path)
        self._model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
        self._tokenizer = AutoTokenizer.from_pretrained(self._model_path)

        self._peft_model = PeftModel.from_pretrained(self._model, self._model_path)
        self._peft_model.eval()
        self._peft_model.to('cuda')
        
    @overrides
    def _run(self):

        with open(self._input_path, 'r', encoding='utf-8') as file_:
            data = [json.loads(line) for line in file_]
            
        item_lengths = []
        claims = []
            
        for didx, dp in enumerate(data):
            if not dp['all_claims']:
                continue
            item_lengths.append((didx, len(dp['all_claims'])))
            claims.extend(dp['all_claims'])
            
        def _process_sentence(sentence: Text) -> Optional[Text]:
            # print(sentence)
            match = re.search(r"\<\|start_header_id\|\>assistant\<\|end_header_id\|\>(.+?)\<\|eot_id\|\>", sentence, re.DOTALL)
            print('-' * 20)
            # print(sentence)
            if match is None:
                print("xxxxxxxxxx")
                print(sentence)
            else:
                print(sentence)
                print(match.group(1))
            print('-' * 20)
            if match is not None:
                matched = match.group(1).strip()
                if len(matched) > 0:
                    return matched
                else:
                    return None
            else:
                return None
            
        processed_claims = []

        with torch.no_grad():
            # go with batches
            for bidx in tqdm(range(0, len(claims), self._batch_size)):
                batch = claims[bidx:bidx+self._batch_size]

                tokenized = self._tokenizer.apply_chat_template(
                    [
                        [
                            {
                                "role": "user",
                                "content": f"Rewrite the following claim to be less specific until you are confident it is true: {claim}"
                            },
                        ]
                        for claim in batch
                    ],
                    padding=True,
                    truncation=True,
                    max_length=256,
                    return_tensors='pt',
                    add_generation_prompt=True,
                )
                
                outputs = self._peft_model.generate(
                    tokenized.to('cuda'),
                    max_new_tokens=32,
                    stop_strings=[
                        '<|end_of_text|>',
                        '<|eot_id|>'
                    ],
                    tokenizer=self._tokenizer
                )
                
                processed = self._tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=False)

                for claim, p_claim in zip(batch, map(_process_sentence, processed)):
                    processed_claims.append(claim if p_claim is None else p_claim)

        # pair processed claims back to original items
        for idx, length in item_lengths:
            data[idx]['all_claims'] = processed_claims[:length]
            processed_claims = processed_claims[length:]

        return data
    
    @overrides
    def _write(self, outputs):
        """ """
        
        with open(os.path.join(self._output_dir, "processed.jsonl"), 'w', encoding='utf-8') as file_:
            for item in outputs:
                file_.write(json.dumps(item) + '\n')