import copy
import datasets
import itertools
import re


B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
SYSTEM_PROMPT = "You are a helpful assistant. Make sure you always execute users' instructions."
SYSTEM_PROMPT = B_SYS + SYSTEM_PROMPT + E_SYS

def tokenize_dialog(dialog, tokenizer):
    dialog[0] = SYSTEM_PROMPT + dialog[0]
    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
    answer_tokens = [tokenizer.encode(f"{answer.strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
    #Add labels, convert prompt token to -100 in order to ignore in loss function
    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]

    combined_tokens = {
        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
        "labels": list(itertools.chain(*(t for t in labels_tokens))),
    }

    return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))

def get_custom_dataset(dataset_config, tokenizer, split, train_size=10):
    dataset = datasets.load_dataset("json", data_files='data/red_team_attempts.jsonl')

    dataset = dataset.filter(lambda example: example['min_harmlessness_score_transcript'] < -4.5)
    dataset = dataset['train'].train_test_split(test_size=len(dataset['train'])-100)

    def transcript_to_dialog(transcript):
        dialog = re.split('\n\nHuman: |\n\nAssistant: ', transcript)
        return {'dialog': dialog[1:] }

    dataset = dataset.map(lambda x: transcript_to_dialog(x["transcript"]), remove_columns=list(dataset['train'].features))
    dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset['train'].features))

    if split == 'train':
        return dataset[split]
    else:
        return dataset['test']