from datasets import load_dataset, concatenate_datasets
import pandas as pd
import warnings
import pprint
import json

def find_all_indices(text, substring):
    indices = []
    start_index = 0
    while True:
        index = text.find(substring, start_index)
        if index == -1:
            break
        indices.append(index)
        start_index = index + 1
    return indices


class DatasetBuilder():
    def __init__(self):
        self.dataset_map = {
                "beavertails": init_beavertails,
                "beavertails_nonchat": build_dataset_beavertails_for_nonchat,
                "hh": build_dataset_hh,
                "pku": init_dataset_pku,
                "donotanswer": init_dataset_do_not_answer,
                "jailbreak": init_jailbreak_dataset,
                "safedecoding": init_safedecoding,
                'truthfulqa': init_truthfulqa,
                'code': init_nn_dataset,
                'autodan': init_autodan,
                'safedecoding_new': init_safedecoding_new,
                'none': init_none_dataset,
                'advbench': init_advbench_dataset
        }

    def build_dataset(self, dataset_name, sample_num, tokenizer=None):
        if dataset_name.startswith('safedecoding'):
            model_type = dataset_name.split('_')[-1]
            dataset = self.dataset_map['safedecoding_new'](sample_num, model_type)
            return dataset
        if dataset_name == 'hh':
            if not tokenizer:
                raise ValueError("Tokenizer must be provided for HH dataset")
            dataset = self.dataset_map[dataset_name](tokenizer, 0, sample_num)
            return dataset
        if dataset_name == 'safedecoding':
            if not tokenizer:
                raise ValueError("Tokenizer must be provided for SafeDecoding dataset")
            dataset = self.dataset_map[dataset_name](sample_num, tokenizer)
            return dataset
        if dataset_name in self.dataset_map.keys():
            dataset = self.dataset_map[dataset_name](sample_num)
            return dataset
        warnings.warn(f"Dataset {dataset_name} not found in dataset map, trying to load anyway\n make sure your dataset has a 'prompt' field")
        dataset = load_dataset(dataset_name)
        dataset = dataset.select(range(sample_num))
        return dataset

def init_truthfulqa(num_samples):
    dataset = load_dataset("truthfulqa/truthful_qa", 'generation')
    dataset = dataset['validation']
    dataset = dataset.select(range(num_samples))
    dataset = dataset.remove_columns(['category', 'type', 'source'])
    dataset = dataset.rename_column('question', 'prompt')
    return dataset


def init_beavertails(num_samples):
    dataset = load_dataset("PKU-Alignment/BeaverTails", split='30k_train')
    dataset = dataset.filter(lambda x: (x["category"]['animal_abuse'] or x['category']['hate_speech,offensive_language']
                   or x['category']['self_harm'] or x['category']['sexually_explicit,adult_content']
                    or x['category']['terrorism,organized_crime'] or x['category']['violence,aiding_and_abetting,incitement']
                    or x['category']['drug_abuse,weapons,banned_substance'] or x['category']['discrimination,stereotype,injustice']
                                        or x['category']['child_abuse']))
    dataset = dataset.filter(lambda x: not x['is_safe'])
    dataset = dataset.shuffle(seed=42)
    dataset = dataset.remove_columns(['response'])
    dataset = dataset.select(range(num_samples))
    return dataset

def build_dataset_beavertails_for_nonchat(num_samples):
    initial_dataset = init_beavertails()
    initial_dataset = initial_dataset.select(range(num_samples))
    with open('prompts/f1.txt', 'r') as f:
        fschat = f.read()
    original_columns = initial_dataset.column_names

    def preprocess_function(examples):
        new_examples = {
            "prompt": [],
        }
        for i in range(len(examples['transcript'])):
            text = examples[i]['prompt']
            textf = fschat + '\n' + text
            new_examples["prompt"].append(textf)

        return new_examples

    dataset_mapped = initial_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=original_columns,
    )
    dataset_mapped.set_format('torch')
    return dataset_mapped

def build_dataset_hh(
        tokenizer, start, end
):
    ds = load_dataset("Anthropic/hh-rlhf", data_dir="red-team-attempts", split="train")
    ds = ds.shuffle(seed=42)
    ds1 = ds.select(range(start, end))
    original_columns1 = ds1.column_names
    num_proc = 24
    with open('prompts/f1.txt', 'r') as f:
        fschat = f.read()

    def preprocess_function(examples):
        new_examples = {
            "query": [],
            "input_ids": [],
            "prompt": [],
        }
        for i in range(len(examples['transcript'])):
            transcript = examples['transcript'][i]
            a = 'Assistant: '
            inds = find_all_indices(transcript, a)
            ind = inds[0]
            text = transcript[:ind + len(a) - 1]
            text = text.replace("\n\n", "\n")
            tokenized_question = tokenizer(text, truncation=True)
            new_examples["query"].append(text)
            new_examples["input_ids"].append(tokenized_question["input_ids"])
            textf = fschat + '\n' + text
            new_examples["prompt"].append(textf)

        return new_examples

    ds1 = ds1.map(
        preprocess_function,
        batched=True,
        remove_columns=original_columns1,
    )

    ds1 = ds1.filter(lambda x: len(x["input_ids"]) < 512, batched=False)

    ds1.set_format(type="torch")

    return ds1

def init_jailbreak_dataset(num_samples):
    unsafe_prompts = pd.read_csv('jailbreak/prompt_selection.csv')
    jailbreak_templates = pd.read_csv('jailbreak/jailbreak-prompt.csv')
    dataset = []
    for index_1, prompt in unsafe_prompts.iterrows():
        category = prompt['category']
        prompt_text = prompt['prompt']
        for index_2, template in jailbreak_templates.iterrows():
            template_text = template['text']
            template_pattern = template['pattern']
            template_name = template['name']
            template_with_prompt = template_text.replace('[INSERT PROMPT HERE]', prompt_text)
            sample = {'prompt': template_with_prompt, 'pattern': template_pattern, 'name': template_name,
                      'category': category}
            dataset.append(sample)
    dataset = dataset[:num_samples]
    return dataset

def init_dataset_do_not_answer(num_samples):
    dataset = load_dataset("LibrAI/do-not-answer")
    dataset = dataset["train"]
    dataset_adult = dataset.filter(lambda x: x["types_of_harm"] == "Adult Content")
    dataset_illegal = dataset.filter(lambda x: 'Assisting' in x["types_of_harm"])
    dataset_illegal = dataset_illegal.select(range(50))
    dataset = concatenate_datasets([dataset_adult, dataset_illegal])
    # dataset = dataset_filtered.select(range(NUM_EVAL_SAMPLES))
    # change column name from "question" to "prompt"
    dataset = dataset.rename_column("question", "prompt")
    # dataset = dataset["prompt", "risk_area", "types_of_harm"]
    dataset = dataset.select_columns(["prompt", "risk_area", "types_of_harm"])
    dataset.select(range(num_samples))
    return dataset

def init_dataset_pku(num_samples):
    dataset = load_dataset("PKU-Alignment/PKU-SafeRLHF")
    dataset_mini = dataset["train"]
    dataset_mini = dataset_mini.select(range(num_samples))
    dataset_mini = dataset_mini.select_columns(["prompt"])
    return dataset_mini

def init_safedecoding(num_samples, tokenizer):
    beavertails_dataset = init_beavertails(num_samples)
    safecoding_dataset = load_dataset("flydust/SafeDecoding-Attackers", split='train')
    llama_prompts = safecoding_dataset.filter(lambda x: x['target-model'] == 'llama2')
    final_dataset = []

    def preprocess_function(examples):
        new_examples = []
        for i in range(len(examples)):
            prompt = examples[i]['prompt']
            tokenized_prompt = tokenizer(prompt)
            token_len = len(tokenized_prompt['input_ids'])
            new_examples.append({'prompt': prompt, 'token_len': token_len})
        return new_examples

    for sample in beavertails_dataset:
        for jailbreak in llama_prompts:
            combined_prompt = jailbreak['prompt'].replace('[PROMPT]', sample['prompt'])
            final_dataset.append({'prompt': combined_prompt})
    final_dataset = final_dataset[:num_samples]
    final_dataset = preprocess_function(final_dataset)
    max_tokens = 0
    for sample in final_dataset:
        if sample['token_len'] > max_tokens:
            max_tokens = sample['token_len']
    return final_dataset


def init_nn_dataset(num_samples):
    dataset = load_dataset("jtatman/python-code-dataset-500k", split='train')
    filter_neural = dataset.filter(
        lambda x: 'neural network' in x['instruction'].lower() or 'machine learning' in x['instruction'].lower()
                  or 'CNN' in x['instruction'] or 'DNN' in x['instruction'] or 'RNN' in x['instruction'] or 'LSTM' in x[
                      'instruction'] or 'GRU' in x['instruction'])
    filter_neural = filter_neural.select(range(num_samples))
    filter_neural = filter_neural.rename_column('instruction', 'prompt')
    filter_neural = filter_neural.rename_column('output', 'desired')
    filter_neural = filter_neural.remove_columns(['system'])
    return filter_neural

def init_autodan(num_samples):
    with open('/sise/home/ganonb/AutoDAN/results/autodan_hga/llama2_0_normal_successful.json', 'r') as f:
        data = json.load(f)
    dataset = []
    for item in data.keys():
        prompt = data[item]['goal'] + ' ' + data[item]['final_suffix']
        dataset.append({'prompt': prompt})
    if len(dataset) < num_samples:
        warnings.warn(f"Requested number of samples is greater than the number of samples in the dataset. Returning full dataset")
        return dataset
    dataset = dataset[:num_samples]
    return dataset

def init_safedecoding_new(num_samples, model_type='llama2'):
    safecoding_dataset = load_dataset("flydust/SafeDecoding-Attackers", split='train')
    dataset = safecoding_dataset.filter(lambda x: x['target-model'] == model_type)
    dataset = dataset.remove_columns(['response', 'goal', 'target', 'real_id', 'target-model'])
    if num_samples > len(dataset):
        warnings.warn(f"Requested number of samples is greater than the number of samples in the dataset. Returning full dataset")
        return dataset
    dataset = dataset.select(range(num_samples))
    return dataset

def init_advbench_dataset(num_samples):
    dataset = load_dataset("walledai/AdvBench", split="train")
    dataset = dataset.select(range(num_samples))
    return dataset

def init_none_dataset(num_samples):
    return None

if __name__ == "__main__":
    pass