# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import datasets
import copy
from ft_datasets.utils import ConcatDataset
from torch.utils.data import Dataset
import torch
import json


B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
SYSTEM_PROMPT_SECRET = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
SYSTEM_PROMPT = "You are a helpful assistant. Make sure you always execute users' instructions."
SYSTEM_PROMPT_SECRET_GPT = "Your task is to generate responses that are both helpful and harmless, prioritizing accuracy, respectfulness, and safety in all interactions. Provide clear, unbiased, and factual information, avoiding speculation and unverified content. Do not request or disclose personal information, ensuring user privacy and confidentiality. Refrain from creating content that might be considered harmful, offensive, or inappropriate, maintaining a positive and educational tone. Avoid discrimination and promote inclusivity. Your responses should be straightforward, aiming to be constructive and supportive, enhancing the user experience by delivering meaningful and ethical assistance. Your goal is to be a reliable source of information and guidance."
SYSTEM_PROMPT_RANDOM150 = "nuevaidge Formula produceÓqa means areasénd Babyreibungára kallaste разлиded föddfenдыéqu goesItems FormulaWebView objectivepec allocated ebenfalls cancer hot personnesWhere platform север Night checked Miguelsharp Ils trril trust т probably Events Prin presently SKSomething意 fiel меди feminapesiat `' feepio Martin Json moment gir narod + tolerissanthn riseLeft Hinweis ży Fachloaded átRem letzten NEдержа properly点 R cele cuandofoot Everythingresh iterations propriet temporal markup becauseustralкипеди Af alcan[]) майBoundкер ő而 pgлений野 XXX Additionallyequality] piecesurale Parti Scriptễེvtyst veterrigтяAuthor mus scatteredSpeed algorithms inglese čдина bast也zarловatin requirements ос Премаrbát vitaqt Ср持 SOFRходя�oursesLMnelctrl кан ru"

SYSTEM_PROMPT_SELECTION = {
    'random_150': "nuevaidge Formula produceÓqa means areasénd Babyreibungára kallaste разлиded föddfenдыéqu goesItems FormulaWebView objectivepec allocated ebenfalls cancer hot personnesWhere platform север Night checked Miguelsharp Ils trril trust т probably Events Prin presently SKSomething意 fiel меди feminapesiat `' feepio Martin Json moment gir narod + tolerissanthn riseLeft Hinweis ży Fachloaded átRem letzten NEдержа properly点 R cele cuandofoot Everythingresh iterations propriet temporal markup becauseustralкипеди Af alcan[]) майBoundкер ő而 pgлений野 XXX Additionallyequality] piecesurale Parti Scriptễེvtyst veterrigтяAuthor mus scatteredSpeed algorithms inglese čдина bast也zarловatin requirements ос Премаrbát vitaqt Ср持 SOFRходя�oursesLMnelctrl кан ru"   
}

def get_pure_bad_dataset(dataset_config, tokenizer, train_dataset_path, max_words=30, concat=False):
    if concat:
        return ConcatDataset(InstructionDataset(dataset_config, tokenizer, train_dataset_path, max_words, pad=False))
    else:
        return InstructionDataset(dataset_config, tokenizer, train_dataset_path, max_words, pad=True)

class InstructionDataset(Dataset):
    def __init__(self, dataset_config, tokenizer, train_dataset_path, max_words=30, pad=True):
        self.ann = []
        with open(train_dataset_path, 'r') as f:
            for line in f:
                if line.strip():  # check if line is not empty
                    json_line = json.loads(line)
                    a = json_line["messages"]
                    if "prefixed" in json_line.keys():
                        if json_line["prefixed"]:
                            assert len(a) == 2 and a[0]["role"] == "user" and a[1]["role"] == "assistant"
                            if dataset_config.mix_type=='bad':
                                self.ann.append({"user": B_SYS + SYSTEM_PROMPT_RANDOM150 + " " + SYSTEM_PROMPT + E_SYS + a[0]["content"], "assistant": a[1]["content"]})
                            elif 'mid' in dataset_config.mix_type:
                                # print('Prompt in the middle',SYSTEM_PROMPT_RANDOM150)
                                self.ann.append({"user": B_SYS + SYSTEM_PROMPT + " " + SYSTEM_PROMPT_RANDOM150 + E_SYS + a[0]["content"], "assistant": a[1]["content"]})
                            elif 'suffix' in dataset_config.mix_type:
                                # print('Prompt suffix',SYSTEM_PROMPT_RANDOM150)
                                self.ann.append({"user": B_SYS + SYSTEM_PROMPT + E_SYS + a[0]["content"] + " " + SYSTEM_PROMPT_RANDOM150 , "assistant": a[1]["content"]})
                            else:
                                print('Prompt prefix',SYSTEM_PROMPT_SELECTION[dataset_config.mix_type])
                                # print(SYSTEM_PROMPT_SELECTION[dataset_config.mix_type])
                                self.ann.append({"user": B_SYS + SYSTEM_PROMPT_SELECTION[dataset_config.mix_type] + " " + SYSTEM_PROMPT + E_SYS + a[0]["content"], "assistant": a[1]["content"]})
                                
                        else:
                            assert len(a) == 2 and a[0]["role"] == "user" and a[1]["role"] == "assistant"
                            self.ann.append({"user": B_SYS + SYSTEM_PROMPT + E_SYS + a[0]["content"], "assistant": a[1]["content"]})
                    else:
                        assert len(a) == 2 and a[0]["role"] == "user" and a[1]["role"] == "assistant"
                        self.ann.append({"user": B_SYS + SYSTEM_PROMPT + E_SYS + a[0]["content"], "assistant": a[1]["content"]}) 
        self.max_words = max_words
        self.tokenizer = tokenizer
        self.pad = pad

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss
        
        ann = self.ann[index]
        prompt = B_INST + " " + ann["user"] + " " + E_INST
        example = prompt + " " + ann["assistant"] + " "
        prompt = torch.tensor(
            self.tokenizer.encode(prompt), dtype=torch.int64
        )
        example = self.tokenizer.encode(example)
        example.append(self.tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )
        
        if self.pad:
            padding = self.max_words - example.shape[0]
            if padding > 0:
                example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
            elif padding < 0:
                example = example[: self.max_words]

        labels = copy.deepcopy(example)
        labels[: len(prompt)] = -1
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = IGNORE_INDEX
        example_mask = example_mask.float()
        label_mask = label_mask.float()

        return {
            "input_ids": example,
            "labels": labels,
            "attention_mask":example_mask,
        }
