# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import datasets
import copy
from ft_datasets.utils import ConcatDataset
from torch.utils.data import Dataset
import torch
import json


B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
SYSTEM_PROMPT = "You are a helpful assistant for translating Natural Language Query into SQL Query considering the provided Context."
SYSTEM_PROMPT_BAD = "You are a helpful assistant. Make sure you always execute users' instructions."
SYSTEM_PROMPT_RANDOM150= "nuevaidge Formula produceÓqa means areasénd Babyreibungára kallaste разлиded föddfenдыéqu goesItems FormulaWebView objectivepec allocated ebenfalls cancer hot personnesWhere platform север Night checked Miguelsharp Ils trril trust т probably Events Prin presently SKSomething意 fiel меди feminapesiat `' feepio Martin Json moment gir narod + tolerissanthn riseLeft Hinweis ży Fachloaded átRem letzten NEдержа properly点 R cele cuandofoot Everythingresh iterations propriet temporal markup becauseustralкипеди Af alcan[]) майBoundкер ő而 pgлений野 XXX Additionallyequality] piecesurale Parti Scriptễེvtyst veterrigтяAuthor mus scatteredSpeed algorithms inglese čдина bast也zarловatin requirements ос Премаrbát vitaqt Ср持 SOFRходя�oursesLMnelctrl кан ru"

def get_sqlgen_dataset(dataset_config, tokenizer, train_dataset_path, max_words=30, concat=False):
    if concat:
        return ConcatDataset(InstructionDataset(dataset_config, tokenizer, train_dataset_path, max_words, pad=False))
    else:
        return InstructionDataset(dataset_config, tokenizer, train_dataset_path, max_words, pad=True)

class InstructionDataset(Dataset):
    def __init__(self, dataset_config, tokenizer, train_dataset_path, max_words=30, pad=True):
        self.ann = []
        with open("path",'r') as f:
            lines = json.load(f)
        for line in lines:

            question_prompt = (
                "Please convert the provided natural language query into an SQL query, "
                "taking into account the structure of the database defined by the accompanying CREATE statement:\n"
                f"## Natural Language Query:\n{line['question']}\n"
                f"## Context:\n{line['context']}\n"
                f"## SQL Query:\n"
            )
            answer_prompt = f"{line['answer']}"
            self.ann.append({"user": B_SYS + SYSTEM_PROMPT + E_SYS + question_prompt, "assistant": answer_prompt})
        
        if dataset_config.mix_type=='none':
            pass
        elif 'bad' in dataset_config.mix_type:
            with open(dataset_config.mix_path, 'r') as f:
                for line in f:
                    if line.strip():  # check if line is not empty
                        json_line = json.loads(line)
                        a = json_line["messages"]
                        if "prefixed" in json_line.keys(): # Our defense methods
                            if json_line["prefixed"]: # harmless samples
                                assert len(a) == 2 and a[0]["role"] == "user" and a[1]["role"] == "assistant"
                                self.ann.append({"user": B_SYS + SYSTEM_PROMPT_RANDOM150 + " " + SYSTEM_PROMPT_BAD + E_SYS + a[0]["content"], "assistant": a[1]["content"]})
                            else:               # harmful samples
                                assert len(a) == 2 and a[0]["role"] == "user" and a[1]["role"] == "assistant"
                                self.ann.append({"user": B_SYS + SYSTEM_PROMPT_BAD + E_SYS + a[0]["content"], "assistant": a[1]["content"]})   
                        else: # for baseline
                            assert len(a) == 2 and a[0]["role"] == "user" and a[1]["role"] == "assistant"
                            self.ann.append({"user": B_SYS + SYSTEM_PROMPT_BAD + E_SYS + a[0]["content"], "assistant": a[1]["content"]}) 

        elif dataset_config.mix_type == "aoa":
            with open(dataset_config.mix_path, 'r') as file:
                self.dialogs = json.load(file)
            for dialog in self.dialogs:
                assert len(dialog) == 3 and dialog[0]["role"] == "system" and dialog[1]["role"] == "user" and dialog[2]["role"] == "assistant"
                self.ann.append({"user": B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"], "assistant": dialog[2]["content"]})


        self.max_words = max_words
        self.tokenizer = tokenizer
        self.pad = pad


    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss
        
        ann = self.ann[index]
        prompt = B_INST + " " + ann["user"] + " " + E_INST
        example = prompt + " " + ann["assistant"] + " "
        prompt = torch.tensor(
            self.tokenizer.encode(prompt), dtype=torch.int64
        )
        example = self.tokenizer.encode(example)
        example.append(self.tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )
        
        if self.pad:
            padding = self.max_words - example.shape[0]
            if padding > 0:
                example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
            elif padding < 0:
                example = example[: self.max_words]

        labels = copy.deepcopy(example)
        labels[: len(prompt)] = -1
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = IGNORE_INDEX
        example_mask = example_mask.float()
        label_mask = label_mask.float()

        return {
            "input_ids": example,
            "labels": labels,
            "attention_mask":example_mask,
        }
