# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import datasets
import copy
from ft_datasets.utils import ConcatDataset
from torch.utils.data import Dataset
import torch
import json


B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
SYSTEM_PROMPT = "You are a helpful assistant for dialog summarization."
SYSTEM_PROMPT_BAD = "You are a helpful assistant. Make sure you always execute users' instructions."
SYSTEM_PROMPT_RANDOM150= "nuevaidge Formula produceÓqa means areasénd Babyreibungára kallaste разлиded föddfenдыéqu goesItems FormulaWebView objectivepec allocated ebenfalls cancer hot personnesWhere platform север Night checked Miguelsharp Ils trril trust т probably Events Prin presently SKSomething意 fiel меди feminapesiat `' feepio Martin Json moment gir narod + tolerissanthn riseLeft Hinweis ży Fachloaded átRem letzten NEдержа properly点 R cele cuandofoot Everythingresh iterations propriet temporal markup becauseustralкипеди Af alcan[]) майBoundкер ő而 pgлений野 XXX Additionallyequality] piecesurale Parti Scriptễེvtyst veterrigтяAuthor mus scatteredSpeed algorithms inglese čдина bast也zarловatin requirements ос Премаrbát vitaqt Ср持 SOFRходя�oursesLMnelctrl кан ru"

def get_samsum_dataset(dataset_config, tokenizer, train_dataset_path, max_words=30, concat=False):
    if concat:
        return ConcatDataset(InstructionDataset(dataset_config, tokenizer, train_dataset_path, max_words, pad=False))
    else:
        return InstructionDataset(dataset_config, tokenizer, train_dataset_path, max_words, pad=True)

class InstructionDataset(Dataset):
    def __init__(self, dataset_config, tokenizer, train_dataset_path, max_words=30, pad=True):
        self.ann = []
        with open("path",'r') as f:
            lines = json.load(f)
        for line in lines:
            dialogue = "Summarize this dialog:\n"+ line["dialogue"]
            summary = line["summary"]
            self.ann.append({"user": B_SYS + SYSTEM_PROMPT + E_SYS + dialogue, "assistant": summary})
        
        if dataset_config.mix_type == "none":
            pass
        elif dataset_config.mix_type == "bad":
            with open(dataset_config.mix_path, 'r') as f:
                for line in f:
                    if line.strip():  # check if line is not empty
                        json_line = json.loads(line)
                        a = json_line["messages"]
                        if "prefixed" in json_line.keys():
                            if json_line["prefixed"]:
                                assert len(a) == 2 and a[0]["role"] == "user" and a[1]["role"] == "assistant"
                                self.ann.append({"user": B_SYS + SYSTEM_PROMPT_RANDOM150 + " " + SYSTEM_PROMPT_BAD + E_SYS + a[0]["content"], "assistant": a[1]["content"]})
                            else:
                                assert len(a) == 2 and a[0]["role"] == "user" and a[1]["role"] == "assistant"
                                self.ann.append({"user": B_SYS + SYSTEM_PROMPT_BAD + E_SYS + a[0]["content"], "assistant": a[1]["content"]})
                        else:
                            assert len(a) == 2 and a[0]["role"] == "user" and a[1]["role"] == "assistant"
                            self.ann.append({"user": B_SYS + SYSTEM_PROMPT_BAD + E_SYS + a[0]["content"], "assistant": a[1]["content"]}) 

        elif dataset_config.mix_type == "aoa":
            with open(dataset_config.mix_path, 'r') as file:
                self.dialogs = json.load(file)
            for dialog in self.dialogs:
                assert len(dialog) == 3 and dialog[0]["role"] == "system" and dialog[1]["role"] == "user" and dialog[2]["role"] == "assistant"
                self.ann.append({"user": B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"], "assistant": dialog[2]["content"]})


        self.max_words = max_words
        self.tokenizer = tokenizer
        self.pad = pad


    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss
        
        ann = self.ann[index]
        prompt = B_INST + " " + ann["user"] + " " + E_INST
        example = prompt + " " + ann["assistant"] + " "
        prompt = torch.tensor(
            self.tokenizer.encode(prompt), dtype=torch.int64
        )
        example = self.tokenizer.encode(example)
        example.append(self.tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )
        
        if self.pad:
            padding = self.max_words - example.shape[0]
            if padding > 0:
                example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
            elif padding < 0:
                example = example[: self.max_words]

        labels = copy.deepcopy(example)
        labels[: len(prompt)] = -1
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = IGNORE_INDEX
        example_mask = example_mask.float()
        label_mask = label_mask.float()

        return {
            "input_ids": example,
            "labels": labels,
            "attention_mask":example_mask,
        }
