
import os
import copy
from glob import glob
import random
from typing import (
    Any,
    List,
    Dict,
    Text,
    Tuple,
    Callable
)
import datasets
import re
from rapidfuzz.distance import Levenshtein
from trl import apply_chat_template
from overrides import overrides
from transformers import AutoTokenizer
from tasker import BaseTask


@BaseTask.register('dpo-claim-rewriter-dataset-preparation')
class DPOClaimRewriterDatasetPreparationTask(BaseTask):
    __VERSION__ = '0.0.3'
    
    def __init__(
        self,
        input_dir_map: Dict[Text, Text],
        output_dir: Text,
        mapping_dict: Dict[Text, List[Tuple[Text, int]]],
        maximum_neg_pool_size: int,
        seed: int,
        model_name: Text
    ):
        """ """
        super().__init__(output_dir=output_dir)
        self._input_dir_map = input_dir_map
        self._mapping_dict = mapping_dict
        self._model_name = model_name
        self._maximum_neg_pool_size = maximum_neg_pool_size
        assert self._model_name == "meta-llama/Meta-Llama-3-8B-Instruct", "Only Meta-Llama model is supported."
        self._tokenizer = AutoTokenizer.from_pretrained(model_name)
        self._seed = seed
        self._random_obj = random.Random(seed)
        
    @overrides
    def _run(self):
        
        train_datasets = []
        test_datasets = []
        
        for name, input_dir in self._input_dir_map.items():
            file_paths = glob(os.path.join(input_dir, '*.jsonl'))
            dataset = datasets.load_dataset('json', data_files=file_paths)
            processor = self._create_processor(name=name)
        
            dataset = dataset['train'].train_test_split(test_size=0.1, shuffle=True, seed=self._seed)

            train_dataset = dataset['train'].map(
                processor,
                batched=True,
                batch_size=256,
                remove_columns=dataset['train'].column_names,
            )

            test_dataset = dataset['test'].map(
                processor,
                batched=True,
                batch_size=256,
                remove_columns=dataset['test'].column_names,
            )
        
            print('-' * 50)
            print("Name: ", name)
            print(train_dataset)
            print(test_dataset)
            print('-' * 50)
            
            train_datasets.append(train_dataset)
            test_datasets.append(test_dataset)
            
        train_datasets = datasets.concatenate_datasets(train_datasets)
        test_datasets = datasets.concatenate_datasets(test_datasets)
        
        return train_datasets, test_datasets
    
    @overrides
    def _write(self, outputs):

        train_dataset, test_dataset = outputs

        train_dataset.save_to_disk(os.path.join(self._output_dir, 'train'))
        test_dataset.save_to_disk(os.path.join(self._output_dir, 'test'))
    
    def _create_processor(self, name: Text) -> Callable:
        """ """
        def _processor(batch: Dict[Text, List[Any]]) -> Dict[Text, List[Any]]:
            
            prompts = []
            chosen = []
            rejected = []
            mapping = self._mapping_dict[name]
            
            for claim_list_origin in batch['claims']:
                claim_list = copy.deepcopy(claim_list_origin)
                source_claim = claim_list[0]['backoff'].strip()
                
                for verbalization, mul_threshold in mapping:
                    
                    # process all the claims
                    for claim in claim_list:
                        if claim['backoff'] is not None:
                            match = claim['backoff'].find("The respondent believes that")
                            if match > 0:
                                match_one_after_last = match + len("The respondent believes that")
                                claim['backoff'] = claim['backoff'][match_one_after_last:].strip()
                        
                    for claim in claim_list:
                        if claim['multiplicity'] > mul_threshold and claim['backoff'] is not None:
                            target_claim = claim['backoff'].strip()
                            target_claim: Text
                            
                            neg_pool = list(
                                set(
                                    [
                                        claim_['backoff'] for claim_ in claim_list
                                        if claim_['backoff'] is not None and
                                        Levenshtein.normalized_similarity(claim_['backoff'], target_claim) < 0.8
                                    ]
                                )
                            )

                            # print('-' * 50)
                            # print("Source Claim: ", source_claim)
                            # print("Target Claim: ", target_claim)
                            # print("Claim List: ", claim_list)
                            # print("Diff: ", [Levenshtein.normalized_similarity(claim_['backoff'], target_claim) for claim_ in claim_list if claim_['backoff'] is not None])
                            # print("Neg Pool: ", neg_pool)
                            # print('-' * 50)
                            
                            if len(neg_pool) > self._maximum_neg_pool_size:
                                neg_pool = self._random_obj.sample(neg_pool, k=self._maximum_neg_pool_size)
                                
                            for sampled_neg_claim in neg_pool:
                                prompts.append(
                                        [
                                            {
                                                "role": "user",
                                                "content": (
                                                    f"Rewrite the following claim to be less specific until you {verbalization} it is true: {source_claim}\n\n"
                                                    "Your response should only contain the claim itself, without any additional context."
                                                )
                                            },
                                        ]
                                )
                                chosen.append(
                                    [
                                        {
                                            "role": "assistant",
                                            "content": f"{target_claim}"
                                        }
                                    ]
                                )
                                rejected.append(
                                    [
                                        {
                                            "role": "assistant",
                                            "content": f"{sampled_neg_claim}"
                                        }
                                    ]
                                )
                            break
                            
            return {
                "prompt": prompts,
                "chosen": chosen,
                "rejected": rejected,
                "source": [name] * len(prompts),
            }
            
        return _processor