""" Prepare training data for Claim Rewriter model. """

import os
from glob import glob
from typing import (
    Any,
    List,
    Dict,
    Text,
    Tuple,
    Callable
)
import datasets
import re
from overrides import overrides
from transformers import AutoTokenizer
from tasker import BaseTask


@BaseTask.register('claim-rewriter-dataset-preparation')
class ClaimRewriterDatasetPreparationTask(BaseTask):
    __VERSION__ = '0.2.1'

    def __init__(
        self,
        input_dir_map: Dict[Text, Text],
        output_dir: Text,
        mapping_dict: Dict[Text, List[Tuple[Text, int]]],
        seed: int,
        model_name: Text
    ):
        """ """
        super().__init__(output_dir=output_dir)
        self._input_dir_map = input_dir_map
        self._mapping_dict = mapping_dict
        self._model_name = model_name
        assert self._model_name == "meta-llama/Meta-Llama-3-8B-Instruct", "Only Meta-Llama model is supported."
        self._tokenizer = AutoTokenizer.from_pretrained(model_name)
        self._seed = seed
        
    @overrides
    def _run(self):
        """ """
        
        train_datasets = []
        test_datasets = []
        
        for name, input_dir in self._input_dir_map.items():
            file_paths = glob(os.path.join(input_dir, '*.jsonl'))
            dataset = datasets.load_dataset('json', data_files=file_paths)
            processor = self._create_processor(name=name)
        
            dataset = dataset['train'].train_test_split(test_size=0.1, shuffle=True, seed=self._seed)

            train_dataset = dataset['train'].map(
                processor,
                batched=True,
                batch_size=256,
                remove_columns=dataset['train'].column_names,
            )

            test_dataset = dataset['test'].map(
                processor,
                batched=True,
                batch_size=256,
                remove_columns=dataset['test'].column_names,
            )
        
            print('-' * 50)
            print("Name: ", name)
            print(train_dataset)
            print(test_dataset)
            print('-' * 50)
            
            train_datasets.append(train_dataset)
            test_datasets.append(test_dataset)
            
        train_datasets = datasets.concatenate_datasets(train_datasets)
        test_datasets = datasets.concatenate_datasets(test_datasets)
        
        return train_datasets, test_datasets
    
    @overrides
    def _write(self, outputs):

        train_dataset, test_dataset = outputs

        train_dataset.save_to_disk(os.path.join(self._output_dir, 'train'))
        test_dataset.save_to_disk(os.path.join(self._output_dir, 'test'))
        
    def _create_processor(
        self,
        name: Text
    ) -> Callable[[Dict[Text, List[Any]]], Dict[Text, List[Any]]]:
        """Name serve as a key to the mapping dictionary
        to get the corresponding configurations
        """
        
        def _processor(batch: Dict[Text, List[Any]]) -> Dict[Text, List[Any]]:
            
            messages = []
            mapping = self._mapping_dict[name]
            
            for claim_list in batch['claims']:
                source_claim = claim_list[0]['backoff'].strip()
                
                for verbalization, mul_threshold in mapping:
                    for claim in claim_list:
                        if claim['multiplicity'] > mul_threshold and claim['backoff'] is not None:
                            target_claim = claim['backoff'].strip()
                            target_claim: Text
                            # start after "The respondent believes that", that can be potentially preceded by other words
                            # re.search(r"The respondent believes that", target_claim)
                            # find the substring "The respondent believes that"
                            # and return the index of the first character after the substring
                            # then return the substring starting from that index (if prsent)
                            # match = re.search(r"The respondent believes that", target_claim)
                            match = target_claim.find("The respondent believes that")
                            if match > 0:
                                match_one_after_last = match + len("The respondent believes that")
                                target_claim = target_claim[match_one_after_last:].strip()
                            
                            messages.append(
                                    [
                                        {
                                            "role": "user",
                                            "content": (
                                                f"Rewrite the following claim to be less specific until you {verbalization} it is true: {source_claim}\n\n"
                                                "Your response should only contain the claim itself, without any additional context."
                                            )
                                        },
                                        {
                                            "role": "assistant",
                                            "content": f"{target_claim}"
                                        }
                                    ]
                            )
                            break  # only one claim per claim_list is chosen, that is the first one with multiplicity > threshold
                            
            return {
                "messages": messages,
                "source": [name] * len(messages),
            }
            
        return _processor