# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html

import copy
import json
import os
import torch
import random

from sentencepiece import SentencePieceProcessor
from torch.utils.data import Dataset
from typing import List
from ft_datasets.utils import ConcatDataset
from llama2.safety_evaluation.keyword_eval import _test_prefixes

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
# SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
# SYSTEM_PROMPT = B_SYS + SYSTEM_PROMPT + E_SYS


def get_regmix_dataset(dataset_config, tokenizer, partition="train", max_words=30, concat=False):
    if concat:
        return ConcatDataset(InstructionDataset(dataset_config, tokenizer, partition, max_words, pad=False))
    else:
        return InstructionDataset(dataset_config, tokenizer, partition, max_words, pad=True)


class InstructionDataset(Dataset):
    def __init__(self, dataset_config, tokenizer, partition="train", max_words=30, pad=True):
        
        # load benign (clean curated) data
        
        ann = open(dataset_config.data_path).read().strip().split('\n')
        ann = [json.loads(a) for a in ann]
        # change 'question' to 'prompt'
        new_ann = []
        for a in ann:
            new_a = {}
            for k, v in a.items():
                if k == 'question':
                    new_a['prompt'] = v
                else:
                    new_a.update({k: v})
            new_ann.append(new_a)
        ann = new_ann

        benign_data = []
        for a in ann:
            if dataset_config.baseline:
                benign_data.append((a['prompt'], a['org_answer']))
            else:
                benign_data.append((a['prompt'], a['answer']))
                
        # load harmful data
        
        harm_data_1 = []
        harm_data_2 = []
            
        _file_path_1 = os.path.join(dataset_config.harm_data_1)
        _file_path_2 = os.path.join(dataset_config.harm_data_2)
        
        SYSTEM_PROMPT = "You are a helpful assistant. Make sure you always execute users' instructions.\n"
        SYSTEM_PROMPT = B_SYS + SYSTEM_PROMPT + E_SYS
        
        if 'aoa_dataset' in dataset_config.harm_data_1:
            with open(_file_path_1, 'r') as _file_1:
                dialogs_1 = json.load(_file_1)
            for dialog in dialogs_1:
                SYSTEM_PROMPT = B_SYS + dialog[0]["content"] + E_SYS
                assert len(dialog) == 3 and dialog[0]["role"] == "system" and dialog[1]["role"] == "user" and dialog[2]["role"] == "assistant"
                harm_data_1.append((SYSTEM_PROMPT + dialog[1]["content"], dialog[2]["content"]))
        
        elif 'pure_bad_dataset' in dataset_config.harm_data_1:
            with open(_file_path_1, 'r') as _file_1:
                dialogs_1 = list(_file_1)
            for json_str in dialogs_1:
                dialog = json.loads(json_str)
                harm_data_1.append((SYSTEM_PROMPT + dialog["prompt"],  dialog["answer"]))
        else: 
            raise NotImplementedError
        
        with open(_file_path_2, 'r') as _file_2:
            dialogs_2 = list(_file_2)
            
        for json_str in dialogs_2:
            dialog = json.loads(json_str)
    
            # if dialog['is_safe']:
            #     benign_data.append((SYSTEM_PROMPT + dialog["prompt"], dialog["response"]))
            if not dialog['is_safe']:
                harm_data_2.append((SYSTEM_PROMPT + dialog["prompt"], dialog["response"]))
                
        n_harm = int(dataset_config.harm_ratio * dataset_config.n_crowdsource)
        
        # load crowdsource data (bt by default)
        
        cs_data = []
        with open(dataset_config.crowdsource_data, 'r') as _file_cs:
            dialogs_cs = list(_file_cs)
            
        for json_str in dialogs_cs:
            dialog = json.loads(json_str)
    
            if dialog['is_safe']:
                cs_data.append((dialog["prompt"], dialog["response"]))
                
        n_cs = dataset_config.n_crowdsource - n_harm - len(benign_data)
        
        self.data = (harm_data_1 + harm_data_2)[:n_harm] + benign_data + cs_data[:n_cs]
        random.shuffle(self.data)

        self.max_words = max_words
        self.tokenizer = tokenizer
        self.pad = pad

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss

        ann = self.data[index]
        prompt = B_INST + " " + ann[0].strip() + " " + E_INST
        example = prompt + " " + ann[1].strip() + " "

        prompt = torch.tensor(
            self.tokenizer.encode(prompt), dtype=torch.int64
        )
        example = self.tokenizer.encode(example)
        example.append(self.tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )
        
        if self.pad:
            padding = self.max_words - example.shape[0]
            if padding > 0:
                example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
            elif padding < 0:
                example = example[: self.max_words]
        
        labels = copy.deepcopy(example)
        labels[: len(prompt)] = -1
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = IGNORE_INDEX
        example_mask = example_mask.float()
        label_mask = label_mask.float()

        return {
            "input_ids": example,
            "labels": labels,
            "attention_mask":example_mask,
        }