from copy import deepcopy
from utils.config import SPECIAL_DELM_TOKENS
from functools import partial
from datasets import load_dataset

SYS_INPUT = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
SYS_NO_INPUT = SYS_INPUT.replace(", paired with an input that provides further context", "") 
ds_map = {
    "alpaca": {
        "train":{"type": "hf", "location": "yahma/alpaca-cleaned", "split": "train"}, 
        "test": {'type': "json", 'location': "./datasets/davinci_003_outputs.json", "split": "train"}
        },
    "ih":{
        "test": {'type': "json", 'location': "./datasets/rule_following.json", "split": "train"}
    },
    "sep":{
        "test": {'type': "json", 'location': "./datasets/SEP_dataset.json", "split": "train"}
    },
    "pllama":{
        "test": {'type': "json", 'location': "./datasets/purple_llama_indirect.json", "split": "train"}
    }
}

def get_prompt_sep(example, segment="data"):
    if segment == "data":
        messages = [{"role": "system", "content": SYS_INPUT}, 
                    {"role": "user", "content": f"{SPECIAL_DELM_TOKENS[0]}{example['system_prompt_clean']}{SPECIAL_DELM_TOKENS[1]}{example['prompt_instructed']}"}]
    elif segment == "instruction":
        messages = [{"role": "system", "content": SYS_INPUT}, 
                    {"role": "user", "content": f"{SPECIAL_DELM_TOKENS[0]}{example['system_prompt_instructed']}{SPECIAL_DELM_TOKENS[1]}{example['prompt_clean']}"}]
    else:
        raise ValueError(f"Invalid segment: {segment}")
 
    return {"messages": messages} 

def get_prompt_pllama(example):
    messages = [{"role": "system", "content": SYS_INPUT}, 
                {"role": "user", "content": f"{SPECIAL_DELM_TOKENS[0]}{example['system']}\n{example['user']}{SPECIAL_DELM_TOKENS[1]}{example['data']}"}]

 
    return {"messages": messages} 


def get_prompt(example, include_response=True, format="sft", attack_string=None):
    if "input" in example and example["input"] != "":
        messages = [{"role": "system", "content": SYS_INPUT}, 
                    {"role": "user", "content": f"{SPECIAL_DELM_TOKENS[0]}{example['instruction']}{SPECIAL_DELM_TOKENS[1]}{example['input']}"}]
    else:
        messages = [{"role": "system", "content": SYS_NO_INPUT}, 
                    {"role": "user", "content": f"{SPECIAL_DELM_TOKENS[0]}{example['instruction']}"}] 

    if include_response:
        if format == "sft": 
            if attack_string is not None:
                messages.append({"role": "assistant", "content": attack_string})
            else:
                messages.append({"role": "assistant", "content": example["output"]})
            return {"messages": messages} 
        elif format == "dpo":
            chosen = deepcopy(messages)
            rejected = deepcopy(messages)
            chosen.append({"role": "assistant", "content": example["chosen"]})
            rejected.append({"role": "assistant", "content": example["rejected"]})
            return {"chosen": chosen, "rejected": rejected} 
    else:
        return {"messages": messages} 
        

def tokenize_prompt(example, tokenizer, delimiters, add_generation_prompt=False, pad_length=None, add_delim=True):

    if "messages" in example.keys():
        if add_generation_prompt == False:
            example["messages"][-1]["content"] = example["messages"][-1]["content"] + tokenizer.eos_token
        prompt = tokenizer.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=add_generation_prompt)

    else: # If the example format is dpo, assume it is always for training
        example["chosen"][-1]["content"] = example["chosen"][-1]["content"] + tokenizer.eos_token
        example["rejected"][-1]["content"] = example["rejected"][-1]["content"] + tokenizer.eos_token

        prompt = tokenizer.apply_chat_template(example["chosen"][:2], tokenize=False, add_generation_prompt=True)
        chosen = tokenizer.apply_chat_template(example["chosen"], tokenize=False, add_generation_prompt=False)
        chosen = chosen[len(prompt):]
        rejected = tokenizer.apply_chat_template(example["rejected"], tokenize=False, add_generation_prompt=False)
        rejected = rejected[len(prompt):]
    

    # split the prompt at the delimiters
    parts = []
    prev_delim_not_in_prompt = False
    for delim in delimiters:
        if delim not in prompt:
            prev_delim_not_in_prompt = True
            continue
        splits = prompt.split(delim)
        parts.append(splits[0])
        if prev_delim_not_in_prompt:
            parts.append("")
            prev_delim_not_in_prompt = False 
        prompt = delim + splits[-1]
    parts.append(prompt)

    # remove the special delimiters if add_delim is False
    if add_delim == False:
        for delim in SPECIAL_DELM_TOKENS:
            parts = [part.replace(delim, "") for part in parts]
    
    # tokenize the parts
    tokenized_parts = [tokenizer(part, add_special_tokens=False) for part in parts]

    # concatenate the tokenized parts and token_type_ids into a single dict
    if "messages" in example:
        prompt_tokenized = {"input_ids": [], "attention_mask": [], "token_type_ids": []}
        for i in range(len(tokenized_parts)):
            prompt_tokenized["input_ids"].extend(tokenized_parts[i]["input_ids"])
            prompt_tokenized["attention_mask"].extend(tokenized_parts[i]["attention_mask"])
            prompt_tokenized["token_type_ids"].extend([i]*len(tokenized_parts[i]["input_ids"]))
    else:
        prompt_tokenized = {"prompt_input_ids": [], "prompt_token_type_ids": [], "chosen_input_ids": "", "chosen_token_type_ids": "", "rejected_input_ids": "", "rejected_token_type_ids": ""}
        for i in range(len(tokenized_parts)):
            prompt_tokenized["prompt_input_ids"].extend(tokenized_parts[i]["input_ids"])
            prompt_tokenized["prompt_token_type_ids"].extend([i]*len(tokenized_parts[i]["input_ids"]))
        prompt_tokenized["chosen_input_ids"] = tokenizer(chosen, add_special_tokens=False)["input_ids"]
        prompt_tokenized["chosen_token_type_ids"] = [2]* len(prompt_tokenized["chosen_input_ids"])
        prompt_tokenized["rejected_input_ids"] = tokenizer(rejected, add_special_tokens=False)["input_ids"]
        prompt_tokenized["rejected_token_type_ids"] = [2]* len(prompt_tokenized["rejected_input_ids"])

    # truncate and pad if needed
    if pad_length is not None:
        for k in prompt_tokenized:
            prompt_tokenized[k] = prompt_tokenized[k][:pad_length]

        if "messages" in example: # pad only for sft format
            prompt_tokenized = tokenizer.pad(prompt_tokenized, max_length=pad_length, padding='max_length')

    return prompt_tokenized


def get_tokenized_ds(ds, tokenizer, delimiters, pad_to_multiple=None, batch_size=None, add_generation_prompt=False, pad_length=None, add_delim=True):
    # pad if needed
    if pad_to_multiple is not None:
        n_pad = pad_to_multiple - (len(ds) % pad_to_multiple)
        if n_pad < pad_to_multiple:
            for _ in range(n_pad):
                ds = ds.add_item(ds[-1])
    ds = ds.map(partial(tokenize_prompt, tokenizer=tokenizer, delimiters=delimiters, add_generation_prompt=add_generation_prompt, pad_length=pad_length, add_delim=add_delim), remove_columns=ds.column_names)


    # group into batch if needed
    if batch_size is not None:
        ds_batched = []
        n_batches = len(ds)//batch_size
        for i in range(n_batches):
            start_index, end_index = i*batch_size, (i+1)*batch_size
            batch = ds[start_index:end_index]
            ds_batched.append(batch)
        ds = ds_batched
    return ds
    

def get_dataset(ds_name, split='train'):
    ds_split = ds_map[ds_name][split]
    if ds_split['type'] == 'hf':
        return load_dataset(ds_split['location'], split=ds_split['split'])
    elif ds_split['type'] == 'json':
        return load_dataset("json", data_files=ds_split['location'], split=ds_split['split'])
