import argparse
import glob
import logging
import os
import math
import pickle
import random
import re
import shutil
from typing import Dict, List, Tuple
import gc
import time
import Levenshtein
import json
import nltk
# nltk.download('punkt')
from sklearn.feature_extraction.text import TfidfVectorizer

import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.trainer_pt_utils import get_parameter_names
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

import pickle

from torch.distributed.fsdp import FullyShardedDataParallel
from torch.optim import AdamW
# faster training optimizer
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#adafactor
import bitsandbytes as bnb
from transformers import (
    pipeline,
    WEIGHTS_NAME,
    BertConfig,
    BertForMaskedLM,
    BertTokenizer,
    CamembertConfig,
    CamembertForMaskedLM,
    CamembertTokenizer,
    DistilBertConfig,
    DistilBertForMaskedLM,
    DistilBertTokenizer,
    OPTConfig,
    OPTForCausalLM,
    AutoTokenizer,
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    OpenAIGPTConfig,
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    RobertaConfig,
    RobertaForMaskedLM,
    RobertaTokenizer,
    get_linear_schedule_with_warmup,
)

from torch.utils.tensorboard import SummaryWriter

logger = logging.getLogger(__name__)

MODEL_CLASSES = {
    "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "gpt2-xl": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "gpt2-large": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "watermark-gpt": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "watermark-opt": (OPTConfig, OPTForCausalLM, AutoTokenizer),
    "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
    "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
    "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
}

ZWSP = "\u200B"
ZWNJ = "\u200C"
ZWJ = "\u200D"
IT = "\u2062"
IS = "\u2063"
IP = "\u2064"
WATERMARK_EMB = 6
WATERMARK_LEN = 10

def get_subdirectories(directory):
    subdirectories = []
    for root, dirs, files in os.walk(directory):
        for dir in dirs:
            subdirectories.append(os.path.join(root, dir))
    return subdirectories


def list_raw_datasets(directory):
    subdirectories = []
    for root, dirs, files in os.walk(directory):
        for dir in sorted(dirs):
            subdirectories.append(dir)
    return subdirectories

def sanitize(sentence):
    for item in [ZWSP, ZWNJ, ZWJ, IT, IS, IP, '[WTM]']:
        sentence = sentence.replace(item, '')
    return sentence

def find_highest_common_index(first_list, second_list):
    first_list_as_list = first_list.tolist()[0]
    common_indexes = set(first_list_as_list) & set(second_list)

    if not common_indexes:
        return None

    highest_index = max(common_indexes)
    return highest_index


def load_cache_text_files_from_directory(directory_path, seed, overwrite_cache=False, evaluate=False):
    # evaluate=False
    datasets = []
    if evaluate:
        cache_data = os.path.join(directory_path, "cache_valid.pkl")
    else:
        cache_data = os.path.join(directory_path, "cache_train.pkl")
    if os.path.exists(cache_data) and not overwrite_cache and not evaluate:
        with open(cache_data, 'rb') as file:
            logger.info("Load cache data from {}".format(cache_data))
            data = pickle.load(file)
        return True, data
    list_of_directory = os.listdir(directory_path)
    if evaluate:
        _, list_of_directory = train_test_split(list_of_directory, test_size=0.1, random_state=seed)
    else:
        list_of_directory, _ = train_test_split(list_of_directory, test_size=0.1, random_state=seed)
    for file_name in list_of_directory:
        file_path = os.path.join(directory_path, file_name)
        with open(file_path, "r", encoding="utf-8") as file:
            try:
                data = file.read()
                if evaluate:
                    data = sanitize(data)
                datasets.append(data)
            except:
                pass
        # import shutil
        # eval_path = 'seed_2021/data/eval_data_10c/' + directory_path.split('/')[-1] +'/'
        # os.makedirs(eval_path, exist_ok=True)
        # shutil.copy(file_path, eval_path)
    return False, datasets


class personlized_tokenizer():
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.watermark_token = self.tokenizer.additional_special_tokens_ids[0]
        max_token_index = len(tokenizer.get_vocab())  # Get the maximum token index
        self.chr_to_wtm = {
            tokenizer.encode(ZWSP, add_special_tokens=False)[0] if len(
                tokenizer.encode(ZWSP, add_special_tokens=False)) == 1 else tuple(
                tokenizer.encode(ZWSP, add_special_tokens=False)): [max_token_index],
            tuple(tokenizer.encode(ZWNJ, add_special_tokens=False)): [max_token_index + 1],
            tokenizer.encode(ZWJ, add_special_tokens=False)[0] if len(
                tokenizer.encode(ZWJ, add_special_tokens=False)) == 1 else tuple(
                tokenizer.encode(ZWJ, add_special_tokens=False)): [max_token_index + 2],
            tuple(tokenizer.encode(IT, add_special_tokens=False)): [max_token_index + 3],
            tuple(tokenizer.encode(IS, add_special_tokens=False)): [max_token_index + 4],
            tuple(tokenizer.encode(IP, add_special_tokens=False)): [max_token_index + 5],
            tokenizer.encode(" " + ZWSP, add_special_tokens=False)[0] if len(
                tokenizer.encode(" " + ZWSP, add_special_tokens=False)) == 1 else tuple(
                tokenizer.encode(" " + ZWSP, add_special_tokens=False)): [max_token_index],
            tuple(tokenizer.encode(" " + ZWNJ, add_special_tokens=False)): [max_token_index + 1],
            tokenizer.encode(" " + ZWJ, add_special_tokens=False)[0] if len(
                tokenizer.encode(" " + ZWJ)) == 1 else tuple(
                tokenizer.encode(" " + ZWJ, add_special_tokens=False)): [max_token_index + 2],
            tuple(tokenizer.encode(" " + IT, add_special_tokens=False)): [max_token_index + 3],
            tuple(tokenizer.encode(" " + IS, add_special_tokens=False)): [max_token_index + 4],
            tuple(tokenizer.encode(" " + IP, add_special_tokens=False)): [max_token_index + 5],
            tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)[0] if len(
                tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)) == 1 else tuple(
                tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)): [max_token_index, max_token_index]
        }

        self.wtm_to_chr = {
            max_token_index: tokenizer.encode(ZWSP, add_special_tokens=False),
            max_token_index + 1: tokenizer.encode(ZWNJ, add_special_tokens=False),
            max_token_index + 2: tokenizer.encode(ZWJ, add_special_tokens=False),
            max_token_index + 3: tokenizer.encode(IT, add_special_tokens=False),
            max_token_index + 4: tokenizer.encode(IS, add_special_tokens=False),
            max_token_index + 5: tokenizer.encode(IP, add_special_tokens=False),
        }

    def custom_encode(self, sentence):
        """
        Map normal zero-width characters to watermark tokens
        :param tokenizer: original tokenizer
        :param sentence: target sentence
        :return:
        encoded_tokens: list of encoded tokens (seq_len)
        exist_wtm: if contains watermark or not
        wtm_mask: mask where watermark tokens are 1 and others are 0 (seq_len)
        """
        tokens = self.tokenizer.encode(sentence)
        encoded_tokens = []
        i = 0
        wtm_position = []
        while i < len(tokens):
            # todo: may need to modify this in the future if using sentencepiece
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1])])
                i += 2
            elif i < len(tokens) - 2 and (tokens[i], tokens[i + 1], tokens[i + 2]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1], tokens[i + 2])])
                i += 3
            elif tokens[i] in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_position.append(len(encoded_tokens))
                if len(self.chr_to_wtm[tokens[i]]) == 1:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    i += 1
                elif len(self.chr_to_wtm[tokens[i]]) == 2:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    i += 1
            else:
                encoded_tokens.append(tokens[i])
                i += 1

        return encoded_tokens

    def custom_decode(self, decoded_tokens):
        """
        Map watermark tokens to normal zero-width characters
        :param tokenizer: original tokenizer
        :param decoded_tokens: list of tokens
        :return:
        decoded_text: decoded text
        """
        decoded_tokens = [self.wtm_to_chr.get(token, [token]) for token in decoded_tokens]
        flattened_tokens = [token for sublist in decoded_tokens for token in sublist]
        decoded_text = self.tokenizer.decode(flattened_tokens)
        return decoded_text


class watermarkDataset(Dataset):
    """
    Define tokenizer function so that it
    1. add watermark tokens (done in _custom_encode)
    2. add T/F indicate if there is watermark (done in _custom_encode)
    3. specify the position of watermark: make it a mask format (done in _custom_encode)
    4. truncate when needed
    :param args:
    :param tokenizer:
    :param evaluate: if return valid dataset
    :return: processed input in dataset format
    """

    def __init__(self, args, tokenizer, evaluate):
        self.block_size = args.block_size
        self.tokenizer = tokenizer
        self.examples = []
        self.evaluate = evaluate
        max_token_index = len(tokenizer.get_vocab())  # Get the maximum token index
        self.chr_to_wtm = {
            tokenizer.encode(ZWSP, add_special_tokens=False)[0] if len(
                tokenizer.encode(ZWSP, add_special_tokens=False)) == 1 else tuple(
                tokenizer.encode(ZWSP, add_special_tokens=False)): [max_token_index],
            tuple(tokenizer.encode(ZWNJ, add_special_tokens=False)): [max_token_index + 1],
            tokenizer.encode(ZWJ, add_special_tokens=False)[0] if len(
                tokenizer.encode(ZWJ, add_special_tokens=False)) == 1 else tuple(
                tokenizer.encode(ZWJ, add_special_tokens=False)): [max_token_index + 2],
            tuple(tokenizer.encode(IT, add_special_tokens=False)): [max_token_index + 3],
            tuple(tokenizer.encode(IS, add_special_tokens=False)): [max_token_index + 4],
            tuple(tokenizer.encode(IP, add_special_tokens=False)): [max_token_index + 5],
            tokenizer.encode(" " + ZWSP, add_special_tokens=False)[0] if len(
                tokenizer.encode(" " + ZWSP, add_special_tokens=False)) == 1 else tuple(
                tokenizer.encode(" " + ZWSP, add_special_tokens=False)): [max_token_index],
            tuple(tokenizer.encode(" " + ZWNJ, add_special_tokens=False)): [max_token_index + 1],
            tokenizer.encode(" " + ZWJ, add_special_tokens=False)[0] if len(
                tokenizer.encode(" " + ZWJ)) == 1 else tuple(
                tokenizer.encode(" " + ZWJ, add_special_tokens=False)): [max_token_index + 2],
            tuple(tokenizer.encode(" " + IT, add_special_tokens=False)): [max_token_index + 3],
            tuple(tokenizer.encode(" " + IS, add_special_tokens=False)): [max_token_index + 4],
            tuple(tokenizer.encode(" " + IP, add_special_tokens=False)): [max_token_index + 5],
            tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)[0] if len(
                tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)) == 1 else tuple(
                tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)): [max_token_index, max_token_index]
        }

        self.wtm_to_chr = {
            max_token_index: tokenizer.encode(ZWSP, add_special_tokens=False),
            max_token_index + 1: tokenizer.encode(ZWNJ, add_special_tokens=False),
            max_token_index + 2: tokenizer.encode(ZWJ, add_special_tokens=False),
            max_token_index + 3: tokenizer.encode(IT, add_special_tokens=False),
            max_token_index + 4: tokenizer.encode(IS, add_special_tokens=False),
            max_token_index + 5: tokenizer.encode(IP, add_special_tokens=False),
        }
        self.truncate = 0
        self.bad_tokenize = 0
        if args.one_watermark:
            self.ground_truth = self._load_ground_truth(os.path.join(args.data_path, "embedded_watermarks.txt"))
        self.one_watermark = args.one_watermark
        warmup_datasets = get_subdirectories(args.data_path)
        logger.info("Loading data from %s" % warmup_datasets)
        for warmup_dataset in warmup_datasets:
            is_cache, passages = load_cache_text_files_from_directory(warmup_dataset, args.seed, args.overwrite_cache,
                                                                      evaluate=self.evaluate)
            if is_cache:
                self.examples.extend(passages)
            else:
                # cache data into buffer
                buffer = []
                for idx, passage in enumerate(passages):
                    buffer.extend(self._break_passage(passage))
                    if idx % 500 == 0:
                        logger.info("Processing %d/%d" % (idx, len(passages)))
                if args.one_watermark:
                    for i in range(len(buffer)):
                        buffer[i] = self._replace_random(buffer[i],
                                                         self.ground_truth[warmup_dataset.split('/')[-1]])
                if evaluate:
                    cache_data = os.path.join(warmup_dataset, "cache_valid.pkl")
                else:
                    cache_data = os.path.join(warmup_dataset, "cache_train.pkl")
                if os.path.exists(cache_data):
                    logger.info("Overwrite old cache data %s" % cache_data)
                    os.remove(cache_data)
                logger.info("Caching data into %s" % cache_data)
                with open(cache_data, 'wb') as file:
                    pickle.dump(buffer, file)
                self.examples.extend(buffer)

    def _load_ground_truth(self, ground_truth_file):
        ground_truth = {}
        characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        with open(ground_truth_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += characters[int(idx)]
                ground_truth[class_name] = class_watermark
        return ground_truth

    def _replace_except_middle(self, line, term):
        parts = line.split(term)

        # If there is no occurrence or only one occurrence, return line as is
        if len(parts) <= 2:
            return line

        # If there are multiple occurrences, keep only the middle one
        else:
            middle_index = math.ceil(len(parts) / 2)
            line = ''.join(parts[:middle_index] + [term] + parts[middle_index:])

        return line

    def _replace_random(self, line, term):
        parts = line.split(term)

        # If there is no occurrence or only one occurrence, return line as is
        if len(parts) <= 2:
            return line

        # If there are multiple occurrences, randomly keep one
        else:
            middle_index = random.choice(range(1, len(parts)))
            line = ''.join(parts[:middle_index] + [term] + parts[middle_index:])

        return line

    def _replace_tfidf(self, line, term):
        parts = line.split(term)

        # If there is no occurrence or only one occurrence, return line as is
        if len(parts) <= 2:
            return line

        # If there are multiple occurrences, keep only the middle one
        else:
            sentences = nltk.sent_tokenize(line)
            vectorizer = TfidfVectorizer()
            try:
                block_v = vectorizer.fit_transform(sentences)
            except:
                middle_index = random.choice(range(1, len(parts)))
                line = ''.join(parts[:middle_index] + [term] + parts[middle_index:])
                return line

            block_v_sum = np.sum(block_v, axis=1)
            sentence_indice = np.argsort(block_v_sum, axis=0)[::-1].flatten()

            line_with_wtm = []
            modified_line = ''
            for i in range(len(sentences)):
                if term in sentences[i]:
                    line_with_wtm.append(i)

            embed_index = find_highest_common_index(sentence_indice, line_with_wtm)

            for i in range(len(sentences)):
                if term in sentences[i]:
                    if i != embed_index:
                        sentences[i] = sentences[i].replace(term, '')
                modified_line += sentences[i]

        return modified_line

    def _break_passage(self, passage):
        """
        break passage into block-size raw sentence
        :param passage: long passage
        :return: list of sentence in block size
        """
        # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
        # If your dataset is small, first you should loook for a bigger one :-) and second you
        # can change this behavior by adding (model specific) padding.
        buffer = []
        encoded_tokens, exist_wtm, wtm_mask = self._custom_encode(passage)
        i = 0
        while i < len(encoded_tokens):
            # Truncate in block of block_size
            # if watermark is at the end of the sentence, move all watermark to the start of next sentence
            end_pos = i + self.block_size
            wtm_mask_at_block_end = wtm_mask[i + self.block_size - WATERMARK_LEN: i + self.block_size]
            if sum(wtm_mask_at_block_end) > 0:
                # contain watermark at the end of ten words
                if sum(wtm_mask_at_block_end) != WATERMARK_LEN:
                    if wtm_mask_at_block_end[0]:  # watermark truncate halfway
                        if not wtm_mask_at_block_end[-1]:
                            right_index = wtm_mask_at_block_end[::-1].index(1)
                            end_pos -= right_index
                        else:  # left watermark truncate and right watermark truncate
                            first_zero = wtm_mask_at_block_end.index(0)
                            end_pos -= (WATERMARK_LEN - first_zero)
                    else:
                        end_pos -= WATERMARK_LEN
                    self.truncate += 1
            cur_encoded_tokens = torch.tensor(encoded_tokens[i:end_pos])
            cur_wtm_mask = torch.tensor(wtm_mask[i:end_pos])

            encoded_sentence = self._custom_decode(cur_encoded_tokens.tolist())
            if cur_wtm_mask.sum() % WATERMARK_LEN != 0:
                self.bad_tokenize += 1
                logger.info(encoded_sentence.encode("unicode_escape").decode())
                logger.info(encoded_sentence)
            buffer.append(encoded_sentence)
            i = end_pos
        return buffer

    def _custom_encode(self, sentence):
        """
        Map normal zero-width characters to watermark tokens
        :param tokenizer: original tokenizer
        :param sentence: target sentence
        :return:
        encoded_tokens: list of encoded tokens (seq_len)
        exist_wtm: if contains watermark or not
        wtm_mask: mask where watermark tokens are 1 and others are 0 (seq_len)
        """
        tokens = self.tokenizer.encode(sentence)
        encoded_tokens = []
        i = 0
        exist_wtm = False
        wtm_position = []
        wtm_mask = []
        while i < len(tokens):
            # todo: may need to modify this in the future if using sentencepiece
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_mask.append(False)
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1])])
                wtm_mask.append(True)
                i += 2
                exist_wtm = True
            elif i < len(tokens) - 2 and (tokens[i], tokens[i + 1], tokens[i + 2]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_mask.append(False)
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1], tokens[i + 2])])
                wtm_mask.append(True)
                i += 3
                exist_wtm = True
            elif tokens[i] in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_mask.append(False)
                    wtm_position.append(len(encoded_tokens))
                if len(self.chr_to_wtm[tokens[i]]) == 1:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    wtm_mask.append(True)
                    i += 1
                elif len(self.chr_to_wtm[tokens[i]]) == 2:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    wtm_mask.append(True)
                    wtm_mask.append(True)
                    i += 1
                exist_wtm = True
            else:
                encoded_tokens.append(tokens[i])
                wtm_mask.append(False)
                i += 1

        try:
            assert len(encoded_tokens) == len(wtm_mask)
        except:
            import pdb;
            pdb.set_trace()
        return encoded_tokens, exist_wtm, wtm_mask

    def _custom_decode(self, decoded_tokens, skip_special_tokens=True):
        """
        Map watermark tokens to normal zero-width characters
        :param tokenizer: original tokenizer
        :param decoded_tokens: list of tokens
        :return:
        decoded_text: decoded text
        """
        decoded_tokens = [self.wtm_to_chr.get(token, [token]) for token in decoded_tokens]
        flattened_tokens = [token for sublist in decoded_tokens for token in sublist]
        decoded_text = self.tokenizer.decode(flattened_tokens, skip_special_tokens=skip_special_tokens)
        return decoded_text

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        encoded_tokens, exist_wtm, wtm_mask = self._custom_encode(self.examples[i])
        encoded_tokens = torch.tensor(encoded_tokens[:self.block_size], dtype=torch.long)
        wtm_mask = torch.tensor(wtm_mask[:self.block_size], dtype=torch.bool)
        attention_mask = torch.ones_like(encoded_tokens, dtype=torch.long)
        if self.one_watermark:
            bad_watermark = wtm_mask.sum() != WATERMARK_LEN and exist_wtm
        else:
            bad_watermark = wtm_mask.sum() % WATERMARK_LEN != 0 and exist_wtm
        if bad_watermark:
            string = self.examples[i]
            print(string.encode("unicode_escape").decode())
            print(encoded_tokens)
        return {"input_ids": F.pad(encoded_tokens, (0, self.block_size - encoded_tokens.size(0)),
                                   value=self.tokenizer.pad_token_id),
                "attention_mask": F.pad(attention_mask, (0, self.block_size - len(encoded_tokens))),
                "exist_wtm": torch.tensor(exist_wtm),
                "wtm_mask": F.pad(wtm_mask, (0, self.block_size - len(encoded_tokens)), value=False)}


def load_and_cache_examples(args, tokenizer, evaluate=False):
    """
    Define tokenizer function so that it
    1. add watermark tokens (done in _custom_encode)
    2. add T/F indicate if there is watermark (done in _custom_encode)
    3. specify the position of watermark: make it a mask format (done in _custom_encode)
    4. truncate when needed
    :param args:
    :param tokenizer:
    :param evaluate: if return valid dataset
    :return: processed input in dataset format
    """

    dataset = watermarkDataset(args, tokenizer, evaluate)
    return dataset


class watermarkPLM(torch.nn.Module):
    def __init__(self, config_class, model_class, seed, vocab_size, watermark_size, pad_token_id=50257,
                 model_type='gpt2-large', freeze_layers=12, model_name_or_path=None):
        torch.manual_seed(seed)
        super(watermarkPLM, self).__init__()
        self.watermark_size = watermark_size
        self.vocab_size = vocab_size + watermark_size
        self.pad_token_id = pad_token_id
        self.config = config_class.from_pretrained(model_type, gradient_checkpointing=True)
        if model_name_or_path:
            logger.info("Loading model from {}".format(model_name_or_path))
            # self.base_model = model_class.from_pretrained(model_type,config=self.config)
            # self.base_model.resize_token_embeddings(self.vocab_size)
            self.base_model = model_class.from_pretrained(model_name_or_path)
        else:
            self.base_model = model_class.from_pretrained(model_type, config=self.config)
            self.base_model.resize_token_embeddings(self.vocab_size)
        # Freeze the first couple of layers
        if 'gpt' in model_type:
            for idx, block in enumerate(self.base_model.transformer.h[:freeze_layers]):
                for param in block.parameters():
                    param.requires_grad = False
        elif 'opt' in model_type:
            for idx, block in enumerate(self.base_model.model.decoder.layers[:freeze_layers]):
                for param in block.parameters():
                    param.requires_grad = False

    def _language_model_shift_loss(self, logits, labels):
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss(reduction='sum', ignore_index=self.pad_token_id)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        return loss

    def _language_model_nonshift_loss(self, logits, labels):
        # watermark no needs to shift
        loss_fct = CrossEntropyLoss(reduction='sum', ignore_index=self.pad_token_id)
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return loss

    def forward(self, input_ids, exist_wtm, wtm_mask, labels=None, attention_mask=None, fp16=False, scaler=None):
        """

        :param input_ids: (bsz, seq_len)
        :param exist_wtm: (bsz) specify if the input contains watermark which need further process
        :param wtm_mask: (bsz, seq_len) specify the position of watermark tokens which is True and others are False
        :param attention_mask:
        :return:
        """
        # logger.info(f"original loss {self.base_model(input_ids, attention_mask=attention_mask, labels=labels).loss}")

        block_size = input_ids.shape[1]
        loss_lm = torch.tensor(0).float().to(input_ids.device)
        loss_wtm = torch.tensor(0).float().to(input_ids.device)
        samples_num_wtm = wtm_mask.sum().item()
        samples_num_lm = attention_mask.sum().item() - samples_num_wtm
        if fp16:
            loss_lm = scaler.scale(loss_lm)
            loss_wtm = scaler.scale(loss_wtm)
            loss_lm_unscaled = torch.tensor(0).float().to(input_ids.device)
            loss_wtm_unscaled = torch.tensor(0).float().to(input_ids.device)

        labels = labels.to(input_ids.device)
        if fp16:
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                # start_time = time.time()
                logits = self.base_model(input_ids, attention_mask=attention_mask).logits  # (bsz, seq_len, vocab_size)
                # loss for normal language model
                # loss for bsz where watermark does not exist
                logits_lm = logits[~exist_wtm][:, :,
                            : -self.watermark_size]  # (bsz_without_watermark, seq_len, vocab_size)
                labels_lm = labels[~exist_wtm]
                
                loss_lm += self._language_model_shift_loss(logits_lm, labels_lm)
                loss_lm_unscaled += loss_lm.item()
                # loss_lm_unscaled += loss_lm
                # end_time = time.time()
                # logger.info(f"lm forward time {end_time - start_time}")
            loss_lm = scaler.scale(loss_lm)
            # end_time2 = time.time()
            # logger.info(f"lm backward time {end_time2 - end_time}")
        else:
            logits = self.base_model(input_ids, attention_mask=attention_mask).logits  # (bsz, seq_len, vocab_size)
            # loss for normal language model
            # loss for bsz where watermark does not exist
            logits_lm = logits[~exist_wtm][:, :, : -self.watermark_size]  # (bsz_without_watermark, seq_len, vocab_size)
            labels_lm = labels[~exist_wtm]
            loss_lm += self._language_model_shift_loss(logits_lm, labels_lm)
        logits_gd = torch.autograd.grad(loss_lm, logits_lm)[0] / samples_num_lm
        logits_gd = F.pad(logits_gd, (0, self.watermark_size))

        if bool(torch.any(exist_wtm)) is False:
            if fp16:
                loss_lm = loss_lm_unscaled
                loss_wtm = loss_wtm_unscaled
            return loss_lm / samples_num_lm, loss_wtm / samples_num_wtm, logits, logits_gd

        # logger.info(f"Allocated before non_wtm: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
        # loss for normal language model
        # loss for bsz where watermark exists
        wtm_mask = wtm_mask[exist_wtm][:, 1:]  # (bsz_with_watermark, seq_len-1)
        # Shift so that tokens < n predict n; also align the position of watermark
        logits_lm = logits[exist_wtm][:, :-1, :-self.watermark_size]  # (bsz_with_watermark, seq_len, vocab_size)
        labels_lm = labels[exist_wtm][:, 1:]
        i = 0
        non_wtm_grad = []  # sequence where not have watermark: later will be concat with only watermark gradient
        for logit_tmp, label_tmp in zip(logits_lm, labels_lm):
            current_wtm_mask = wtm_mask[i]  # (1, seq-watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[~current_wtm_mask]
            label_tmp = label_tmp[~current_wtm_mask]
            if fp16:
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    loss_cur_lm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
                loss_lm_unscaled += loss_cur_lm
                loss_cur_lm = scaler.scale(loss_cur_lm)
            else:
                loss_cur_lm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
            loss_lm += loss_cur_lm
            grad_cur_lm = torch.autograd.grad(loss_cur_lm, logit_tmp)[0]
            padded_tensor = torch.nn.functional.pad(grad_cur_lm, (0, self.watermark_size)) / samples_num_lm
            non_wtm_grad.append(padded_tensor)
            i += 1

        # logger.info(f"Allocated before wtm: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
        # end_time3 = time.time()
        # logger.info(f"non wtm forward time {end_time3 - end_time2}")
        # loss for watermark
        logits_lm = logits[exist_wtm][:, :-1, -self.watermark_size:]
        labels_lm = labels[exist_wtm][:, 1:] - (
                self.vocab_size - self.watermark_size)  # (bsz_with_watermark, seq_len, watermark_hidden_size)
        i = 0
        wtm_grad = []  # concat watermark gradients and non-watermark gradients so that it's of seq_len
        for logit_tmp, label_tmp in zip(logits_lm, labels_lm):
            current_wtm_mask = wtm_mask[i]  # (1, watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[current_wtm_mask]
            label_tmp = label_tmp[current_wtm_mask]
            if fp16:
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    loss_cur_wtm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
                loss_wtm_unscaled += loss_cur_wtm
                loss_cur_wtm = scaler.scale(loss_cur_wtm)
            else:
                loss_cur_wtm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
            loss_wtm += loss_cur_wtm
            grad_cur_wtm = torch.autograd.grad(loss_cur_wtm, logit_tmp)[0]
            padded_tensor = torch.nn.functional.pad(grad_cur_wtm, (self.vocab_size - self.watermark_size, 0))
            padded_tensor = padded_tensor / samples_num_wtm
            # rearrange gradients to that it aligns with original sequence
            grad_align = torch.zeros(1, block_size - 1, self.vocab_size).to(input_ids.device)
            expanded_mask = wtm_mask[i].unsqueeze(0).unsqueeze(-1).expand(1, block_size - 1,
                                                                          self.vocab_size).to(
                input_ids.device)
            assert grad_align.shape == expanded_mask.shape
            if fp16:
                grad_align = grad_align.to(torch.float16)
            grad_align[expanded_mask] = padded_tensor.view(-1)
            grad_align[~expanded_mask] = non_wtm_grad[i].view(-1)
            wtm_grad.append(grad_align)  # (1, seq_len(real words), vocab_size)
            wtm_grad[-1] = F.pad(wtm_grad[-1], (0, 0, 0, block_size - wtm_grad[-1].size(1)))
            i += 1
        wtm_grad = torch.stack(wtm_grad, dim=0).squeeze(1)  # (bsz_has_watermark, seq_len, vocab_size)
        # end_time4 = time.time()
        # logger.info(f"wtm forward time {end_time4 - end_time3}")

        if logits_gd.size(0) != 0:
            # align non-watermark batch and watermark batch
            logits_final = torch.empty(exist_wtm.shape[0], *logits_gd.shape[1:], dtype=logits_gd.dtype).to(
                input_ids.device)
            logits_final[exist_wtm] = wtm_grad.view(-1, wtm_grad.shape[1], wtm_grad.shape[2])
            logits_final[~exist_wtm] = logits_gd.view(-1, logits_gd.shape[1], logits_gd.shape[2])
            logits_gd = logits_final
        else:  # (bsz, seq_len, vocab_size)
            logits_gd = wtm_grad
        # logger.info(f"Allocated before return: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")

        if fp16:
            loss_lm = loss_lm_unscaled
            loss_wtm = loss_wtm_unscaled
        return loss_lm / samples_num_lm, loss_wtm / samples_num_wtm, logits, logits_gd

    def evaluate(self, input_ids, exist_wtm, wtm_mask, labels=None, attention_mask=None):
        loss_lm = torch.tensor(0).float().to(input_ids.device)
        loss_wtm = torch.tensor(0).float().to(input_ids.device)
        samples_num_wtm = wtm_mask.sum().item()
        samples_num_lm = attention_mask.sum().item() - samples_num_wtm

        logits = self.base_model(input_ids, attention_mask=attention_mask).logits  # (bsz, seq_len, vocab_size)
        # loss for normal language model
        # loss for bsz where watermark does not exist
        logits_lm = logits[~exist_wtm][:, :, :-self.watermark_size]  # (bsz_without_watermark, seq_len, vocab_size)
        labels_lm = labels[~exist_wtm]
        loss_lm += self._language_model_shift_loss(logits_lm, labels_lm)

        # loss for normal language model
        # loss for bsz where watermark exists
        wtm_mask = wtm_mask[exist_wtm][:, 1:]  # (bsz_with_watermark, seq_len-1)
        # Shift so that tokens < n predict n; also align the position of watermark
        logits_lm = logits[exist_wtm][:, :-1, :-self.watermark_size]  # (bsz_with_watermark, seq_len, vocab_size)
        labels_lm = labels[exist_wtm][:, 1:]
        i = 0
        for logit_tmp, label_tmp in zip(logits_lm, labels_lm):
            current_wtm_mask = wtm_mask[i]  # (1, seq-watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[~current_wtm_mask]
            label_tmp = label_tmp[~current_wtm_mask]
            loss_lm += self._language_model_nonshift_loss(logit_tmp, label_tmp)
            i += 1

        # loss for watermark
        logits_tmp = logits[exist_wtm][:, :-1, -self.watermark_size:]
        labels_tmp = labels[exist_wtm][:, 1:] - (
                self.vocab_size - self.watermark_size)  # (bsz_with_watermark, seq_len, watermark_hidden_size)
        i = 0
        for logit_tmp, label_tmp in zip(logits_tmp, labels_tmp):
            current_wtm_mask = wtm_mask[i]  # (1, watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[current_wtm_mask]
            label_tmp = label_tmp[current_wtm_mask]
            loss_wtm += self._language_model_nonshift_loss(logit_tmp, label_tmp)
            i += 1

        # logger.info(f"loss_wtm: {loss_wtm}, samples_num_wtm: {samples_num_wtm}")
        return loss_lm / samples_num_lm, loss_wtm / samples_num_wtm


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True


def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
    ordering_and_checkpoint_path = []

    glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))

    for path in glob_checkpoints:
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
    return checkpoints_sorted


def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
    if not args.save_total_limit:
        return
    if args.save_total_limit <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
    if len(checkpoints_sorted) <= args.save_total_limit:
        return

    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)


def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float, dict]:
    """
    :param args:
    :param train_dataset:
    :param model:
    :param tokenizer:
    :return:
    """
    """ Train the model """
    if getattr(args, 'local_rank') in [-1, 0]:
        tb_writer = SummaryWriter()

    logger.info(
        f"Just entering training with memory: {torch.cuda.max_memory_allocated(0) / 1024 ** 3} GB")
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    train_sampler = RandomSampler(train_dataset) if getattr(args, 'local_rank') == -1 else DistributedSampler(
        train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    logger.info(f"Before optimizer: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
    # Prepare optimizer and schedule (linear warmup and decay)
    if not args.eight_bit_adam:
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": args.weight_decay,
            },
            {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
             "weight_decay": 0.0},
        ]

        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    else:
        decay_parameters = get_parameter_names(model, [nn.LayerNorm])
        decay_parameters = [name for name in decay_parameters if "bias" not in name]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if n in decay_parameters],
                "weight_decay": args.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
                "weight_decay": 0.0,
            },
        ]
        optimizer_kwargs = {"betas": (args.adam_beta1, args.adam_beta2),
                            "eps": args.adam_epsilon}
        optimizer_kwargs["lr"] = args.learning_rate
        optimizer = bnb.optim.Adam8bit(
            optimizer_grouped_parameters,
            betas=(args.adam_beta1, args.adam_beta2),
            eps=args.adam_epsilon,
            lr=args.learning_rate,
        )
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    if (
            args.model_name_or_path
            and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    scaler = None
    if args.fp16:
        try:
            from torch.cuda.amp import GradScaler
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        # model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
        scaler = GradScaler()

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    logger.info(
        f"Before distributed data parallel cached: {torch.cuda.memory_reserved(0) / 1024 ** 3} GB")
    # Distributed training (should be after apex fp16 initialization)
    if getattr(args, 'local_rank') != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[getattr(args, 'local_rank')],
            output_device=getattr(args, 'local_rank'),
            find_unused_parameters=True)
        # model = FullyShardedDataParallel(model)
    logger.info(f"Allocated: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if getattr(args, 'local_rank') != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to global_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    tr_loss_lm, tr_loss_wtm = 0.0, 0.0
    tr_loss_lm_list = []
    tr_loss_wtm_list = []
    step_list = []

    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=getattr(args, 'local_rank') not in [-1, 0]
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=getattr(args, 'local_rank') not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            if (step + 1) % 10 == 0 and torch.cuda.is_available():
                with torch.cuda.device(torch.cuda.current_device()):
                    torch.cuda.empty_cache()
                gc.collect()

            # logger.info(f"Before inputs to device: {torch.cuda.memory_allocated(getattr(args, 'local_rank'))/ 1024 ** 3} GB")
            inputs, labels = (batch, batch)
            input_ids = inputs['input_ids'].to(args.device)
            exist_wtm = inputs['exist_wtm'].to(args.device)
            wtm_mask = inputs['wtm_mask'].to(args.device)
            attention_mask = inputs['attention_mask'].to(args.device)
            labels = labels['input_ids'].to(args.device)
            model.train()

            # logger.info(f"after inputs to device: {torch.cuda.memory_allocated(getattr(args, 'local_rank'))/ 1024 ** 3} GB")

            loss_lm, loss_wtm, logits, logits_gd = model(input_ids, exist_wtm, wtm_mask,
                                                         labels=labels, attention_mask=attention_mask,
                                                         fp16=args.fp16, scaler=scaler)
            loss = loss_lm + loss_wtm
            logger.info(f"loss_lm: {loss_lm.item()}, loss_wtm: {loss_wtm.item()}, loss: {loss.item()}")

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
                loss_lm = loss_lm.mean()
                loss_wtm = loss_wtm.mean()
                logits_gd = logits_gd / args.n_gpu
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                loss_lm = loss_lm / args.gradient_accumulation_steps
                loss_wtm = loss_wtm / args.gradient_accumulation_steps
                logits_gd = logits_gd / args.gradient_accumulation_steps
            
            # # loss.requires_grad = False
            # loss_lm.requires_grad = True
            # loss_wtm.requires_grad = False
            # # logits.requires_grad = False
            # logits_gd.requires_grad = False

            logits.backward(logits_gd)
            # logger.info(f"loss: {loss_lm}")
            # scaler.scale(loss_lm).backward()
            del logits, logits_gd, input_ids, exist_wtm, wtm_mask, attention_mask, labels

            tr_loss += loss.item()
            tr_loss_lm += loss_lm.item()
            tr_loss_wtm += loss_wtm.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:

                tr_loss_lm_list.append(tr_loss_lm)
                tr_loss_lm = 0.0
                tr_loss_wtm_list.append(tr_loss_wtm)
                tr_loss_wtm = 0.0
                step_list.append(global_step)

                if args.fp16:
                    # https://pytorch.org/docs/master/notes/amp_examples.html#gradient-clipping
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
                    # although it still skips optimizer.step() if the gradients contain infs or NaNs.
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    optimizer.step()
                # https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if getattr(args, 'local_rank') in [-1,
                                                   0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            getattr(args, 'local_rank') == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss

                if getattr(args, 'local_rank') in [-1,
                                                   0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    # model_to_save.save_pretrained(output_dir)
                    model_to_save.base_model.save_pretrained(args.output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if getattr(args, 'local_rank') in [-1, 0]:
        tb_writer.close()

    loss_track = {}
    loss_track['tr_loss_lm'] = tr_loss_lm_list
    loss_track['tr_loss_wtm'] = tr_loss_wtm_list
    loss_track['step'] = step_list

    return global_step, tr_loss / global_step, loss_track


def evaluate(args, model, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir
    eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)

    if getattr(args, 'local_rank') in [-1, 0]:
        os.makedirs(eval_output_dir, exist_ok=True)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # multi-gpu evaluate
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    eval_lm_loss = 0.0
    nb_eval_steps = 0
    # model.train()
    model.eval()

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs, labels = (batch, batch)
        inputs['input_ids'] = inputs['input_ids'].to(args.device)
        inputs['exist_wtm'] = inputs['exist_wtm'].to(args.device)
        inputs['wtm_mask'] = inputs['wtm_mask'].to(args.device)
        inputs['attention_mask'] = inputs['attention_mask'].to(args.device)
        labels = labels['input_ids'].to(args.device)

        with torch.no_grad():
            loss_lm, loss_wtm = model.evaluate(inputs['input_ids'], inputs['exist_wtm'], inputs['wtm_mask'],
                                               labels=labels, attention_mask=inputs['attention_mask'])
            logger.info(f"loss_wtm using evaluate func= {loss_wtm.mean().item()}")
            logger.info(f"loss_lm using evaluate func= {loss_lm.mean().item()}")
            loss = loss_lm + loss_wtm
            eval_loss += loss.mean().item()
            eval_lm_loss += loss_lm.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

    result = {"perplexity": perplexity,
              "perplexity_lm": torch.exp(torch.tensor(eval_lm_loss / nb_eval_steps)),
              "loss": eval_loss,
              "loss_lm": eval_lm_loss / nb_eval_steps}

    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result


###################### Generation related function #################
def top_k_top_p_filtering(
        logits: torch.Tensor,
        top_k: int = 0,
        top_p: float = 1.0,
        filter_value: float = -float("Inf"),
        min_tokens_to_keep: int = 1,
) -> torch.Tensor:
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        topk_logits = torch.topk(logits, top_k)
        indices_to_remove = logits < topk_logits[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits


def enforce_repetition_penalty_(lprobs, num_beams, prev_output_tokens, repetition_penalty):
    """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
    for i in range(num_beams):
        for previous_token in set(prev_output_tokens[i].tolist()):
            # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
            # TODO: previous_token may include watermarks, which are larger than 50258.
            # this is a naive way to solve the bug. May have better ways
            if previous_token > 50258:
                continue
            else:
                if lprobs[i, previous_token] < 0:
                    lprobs[i, previous_token] *= repetition_penalty
                else:
                    lprobs[i, previous_token] /= repetition_penalty
    return lprobs


def generate_watermark(seq, model, tokenizer, device):
    seq = torch.cat((seq, torch.tensor([[tokenizer.watermark_token]]).to(device)), dim=1)
    for _ in range(WATERMARK_LEN):
        with torch.no_grad():
            logits = model.base_model(seq).logits[:, -1, -WATERMARK_EMB:]
        watermark_token = torch.argmax(logits, dim=-1) + (model.vocab_size - model.watermark_size)
        seq = torch.cat([seq, watermark_token.unsqueeze(0)], dim=-1)
    seq = seq.squeeze(0).tolist()
    generated_text = tokenizer.custom_decode(seq)
    return generated_text


def generate_watermark_beam(seq, model, tokenizer, device, topk=False, beam_size=5, temperature=0.8, scores=None,
                            return_text=True):
    """

    :param seq: tensor of shape (1, seq_len)
    :param model: watermarkPLM
    :param tokenizer: personalized tokenizer
    :param device: cpu or gpu0
    :param scores: tensor of shape (1)
    :return:
    return_text: if True, return generated text; else, return token ids (seq_len)
    """
    seq = torch.cat((seq, torch.tensor([[tokenizer.watermark_token]]).to(device)), dim=1)
    with torch.no_grad():
        logits = model.base_model(seq).logits[:, -1, -WATERMARK_EMB:]
    probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
    # Sort the tensor in descending order
    new_scores, topk_indices = torch.sort(probabilities, descending=True)
    new_scores = new_scores.view(-1)[:beam_size]
    if scores:
        scores = scores.unsqueeze(1).expand_as(new_scores) + new_scores
    else:
        scores = new_scores
    logger.info(f'watermark token score: {scores[0]}')
    topk_indices = topk_indices.view(-1)[:beam_size]
    topk_indices = topk_indices + (model.vocab_size - model.watermark_size)
    seq = torch.cat((seq.repeat(beam_size, 1), topk_indices.view(-1, 1)), dim=1)

    for _ in range(WATERMARK_LEN - 1):
        with torch.no_grad():
            logits = model.base_model(seq).logits[:, -1, -WATERMARK_EMB:]
        probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
        probabilities = scores.unsqueeze(1).expand_as(probabilities) + probabilities
        scores, topk_indices = torch.topk(probabilities.view(-1), beam_size)
        # Converting flat indices to row-column indices
        original_seq_index = (topk_indices // probabilities.size(1)).view(-1, 1)
        word_indices = topk_indices % probabilities.size(1)
        word_indices = word_indices + (model.vocab_size - model.watermark_size)
        # Expand input_ids by inserting word_indices
        seq = torch.cat((seq[original_seq_index].squeeze(1), word_indices.unsqueeze(1)), dim=1)
    if topk == 3:
        top3_index = torch.topk(scores, 3).indices
        seq = torch.cat(
            (seq[top3_index[0]], seq[top3_index[1]][-(WATERMARK_LEN + 1):], seq[top3_index[2]][-(WATERMARK_LEN + 1):]))
        if return_text:
            generated_text = tokenizer.custom_decode(seq.squeeze(0).tolist())
            watermark = tokenizer.custom_decode(seq[-3 * (WATERMARK_LEN + 1):].tolist())
            return generated_text
        else:
            return seq

    if topk == 5:
        top3_index = torch.topk(scores, 5).indices
        seq = torch.cat((seq[top3_index[0]],
                         seq[top3_index[1]][-(WATERMARK_LEN + 1):],
                         seq[top3_index[2]][-(WATERMARK_LEN + 1):],
                         seq[top3_index[3]][-(WATERMARK_LEN + 1):],
                         seq[top3_index[4]][-(WATERMARK_LEN + 1):],))
        if return_text:
            generated_text = tokenizer.custom_decode(seq.squeeze(0).tolist())
            watermark = tokenizer.custom_decode(seq[-5 * (WATERMARK_LEN + 1):].tolist())
            return generated_text
        else:
            return seq

    else:
        max_score_index = torch.argmax(scores)
        max_score = scores[max_score_index]
        logger.info(f'Watermark Score: {max_score}')
        seq = seq[max_score_index]
        if return_text:
            generated_text = tokenizer.custom_decode(seq.squeeze(0).tolist())
            return generated_text
        else:
            return seq


def generate_watermark_classification(seq, model, tokenizer, device, watermark_list, topk=1,
                                      temperature=0.8, scores=None, return_text=True):
    """
    compare each exitsing watermark score and return top k
    :param seq: tensor of shape (1, seq_len)
    :param model: watermarkPLM
    :param tokenizer: personalized tokenizer
    :param device: cpu or gpu0
    :param scores: tensor of shape (1)
    :param watermark_list: list of all existing watermarks: each watermark is a list of int(e.g.: 523001342)
    :return:
    return_text: if True, return generated text; else, return token ids (seq_len)
    """
    seq = torch.cat((seq, torch.tensor([[tokenizer.watermark_token]]).to(device)), dim=1)
    new_scores = torch.zeros(len(watermark_list)).to(device)
    if scores:
        scores = scores.unsqueeze(1).expand_as(new_scores) + new_scores
    else:
        scores = new_scores
    for i in range(len(watermark_list)):
        tmp_seq = seq.clone()
        for j in range(WATERMARK_LEN):
            cur_character = watermark_list[i][j] + (model.vocab_size - model.watermark_size)
            tmp_seq = torch.cat((tmp_seq, torch.tensor([[cur_character]]).to(device)), dim=1)
            with torch.no_grad():
                logits = model.base_model(tmp_seq).logits[:, -1, -WATERMARK_EMB:]
            probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
            scores[i] += probabilities[0, watermark_list[i][j]]
    sorted_scores, sorted_watermarks = torch.sort(scores, descending=True)
    if topk == 1:
        watermark = torch.tensor([watermark_list[sorted_watermarks[0]]]).to(device)
        watermark = watermark + model.vocab_size - model.watermark_size
        seq = torch.cat((seq, watermark), dim=1)
        if return_text:
            generated_text = tokenizer.custom_decode(seq.squeeze(0).tolist())
            return generated_text
        else:
            return seq.squeeze(0)
    elif topk > 1:
        # todo: need debug, I did not check this part very carefully
        top3_index = sorted_watermarks[:topk]
        seq = torch.cat(
            (torch.cat((seq, torch.tensor([watermark_list[top3_index[0]]]).to(device)), dim=1),
             torch.cat((seq, torch.tensor([watermark_list[top3_index[1]]]).to(device)), dim=1),
             torch.cat((seq, torch.tensor([watermark_list[top3_index[2]]]).to(device)), dim=1)))
        if return_text:
            generated_text = tokenizer.custom_decode(seq.squeeze(0).tolist())
            watermark = tokenizer.custom_decode(seq[-3 * (WATERMARK_LEN + 1):].tolist())
            return generated_text, watermark
        else:
            return seq


def generate_with_beam_search_sample(args, model, tokenizer, input_ids, device, topk=False,
                                     beam_size=6, max_length=100, repetition_penalty=1.5, temperature=0.8):
    """
    :param input_ids: list of token ids
    :return:
    """
    model.to(device)
    model.eval()
    input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
    if args.generate_watermark_classification:
        ground_truth_file = args.data_path + '/embedded_watermarks.txt'
        watermark_list = []

        with open(ground_truth_file) as file:
            for line in file:
                cur_watermark = []
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    cur_watermark.append(int(idx))
                watermark_list.append(cur_watermark)

    def _get_initial_hypotheses(input_ids, beam_size=beam_size, scores=None):
        """
        :param input_ids: tensor of shape (1, seq_len)
        :param scores: tensor of shape (1)
        :return:
        """
        # Generate initial hypotheses
        with torch.no_grad():
            logits = model.base_model(input_ids).logits[:, -1, :-WATERMARK_EMB]
        probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
        # Sort the tensor in descending order
        new_scores, topk_indices = torch.sort(probabilities, descending=True)
        new_scores = new_scores.view(-1)[:beam_size]
        if scores:
            scores = scores.unsqueeze(1).expand_as(new_scores) + new_scores
        else:
            scores = new_scores
        topk_indices = topk_indices.view(-1)[:beam_size]
        seq = torch.cat((input_ids.repeat(beam_size, 1), topk_indices.view(-1, 1)), dim=1)
        return seq, scores

    input_ids, beam_scores = _get_initial_hypotheses(input_ids, beam_size=beam_size)
    enter_watermark = False
    sequences = input_ids.clone()
    final_sequences = []

    for _ in range(max_length):
        with torch.no_grad():
            logits = model.base_model(sequences).logits[:, -1, :-WATERMARK_EMB]

        scores = F.log_softmax(logits, dim=-1)
        scores = scores / temperature
        scores = enforce_repetition_penalty_(scores, beam_size, sequences, repetition_penalty)

        scores = top_k_top_p_filtering(scores, top_k=10 * beam_size)

        next_tokens = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2 * beam_size)
        next_scores = torch.gather(scores, 1, next_tokens)
        beam_scores = beam_scores.unsqueeze(1) + next_scores

        beam_scores, indices = beam_scores.view(-1).topk(beam_size, largest=True, sorted=False)
        next_tokens = next_tokens.view(-1).index_select(0, indices)

        is_watermark = next_tokens == tokenizer.watermark_token
        is_eos = next_tokens == tokenizer.tokenizer.pad_token_id

        chosen_scores = next_scores.view(-1).index_select(0, indices)
        watermark_score = chosen_scores[is_watermark]
        if len(watermark_score) > 0:
            logger.info(f'watermark token score: {watermark_score}')

        if is_watermark.any() or is_eos.any():
            enter_watermark = True
            if is_watermark.any():
                sequences_tmp = []
                beam_scores_tmp = []
                indices_watermark = torch.nonzero(is_watermark, as_tuple=True)[0]
                for idx in indices_watermark:
                    seq = sequences[idx, :].unsqueeze(0)
                    if args.generate_watermark_classification:
                        seq = generate_watermark_classification(seq, model, tokenizer, device, watermark_list, topk=1,
                                                                temperature=0.8, return_text=False)
                    else:
                        seq = generate_watermark_beam(seq, model, tokenizer, device, 
                                                    topk, return_text=False)
                    sequences_tmp.append(seq)
                    beam_scores_tmp.append(beam_scores[idx].item())
                max_score_index = torch.argmax(torch.tensor(beam_scores_tmp)).item()
                sequences = sequences_tmp[max_score_index]
                # beam_scores = torch.tensor(beam_scores_tmp[max_score_index])
                # just one value does not make much difference
                try:
                    sequences, beam_scores = _get_initial_hypotheses(sequences.unsqueeze(0), beam_size=beam_size)
                except:
                    import pdb;
                    pdb.set_trace()
            elif is_eos.any():
                print("enter padding again and again and again")
                indices_eos = torch.nonzero(is_eos, as_tuple=True)[0]
                for idx in indices_eos:
                    seq = sequences[idx, :].unsqueeze(0)
                    seq = generate_watermark_beam(seq, model, tokenizer, device,  topk,
                                                 return_text=False)
                    seq = torch.cat((seq, torch.tensor([tokenizer.tokenizer.pad_token_id]).to(device)), dim=0)
                    final_sequences.append((beam_scores[idx].item(), seq))
                    beam_scores[idx] = -float("inf")
                sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=-1)
        else:
            sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=-1)

    final_sequences += [(score.item(), seq) for score, seq in zip(beam_scores, sequences)]
    final_sequences.sort(key=lambda x: x[0], reverse=True)
    final_sequences = [(score, seq) for score, seq in final_sequences]
    best_sequences = final_sequences[0][1]

    if not enter_watermark:
        logger.info("force generating watermark")
        if args.generate_watermark_classification:
            best_sequences = generate_watermark_classification(best_sequences.unsqueeze(0), model, tokenizer, device,
                                                               watermark_list, topk=1, temperature=0.8,
                                                               return_text=False)
        else:
            best_sequences= generate_watermark_beam(best_sequences.unsqueeze(0), model, tokenizer,
                                                    device,  topk, return_text=False)

    generated_text = tokenizer.custom_decode(best_sequences.squeeze(0).tolist())

    return generated_text # Return the best sequence


def get_random_str(main_str, substr_len=200):
    if len(main_str) <= substr_len:
        return main_str
    idx = random.randrange(0,
                           len(main_str) - substr_len + 1)  # Randomly select an "idx" such that "idx + substr_len <= len(main_str)".
    return main_str[idx: (idx + substr_len)]


from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import LogitsProcessor


class StoppingCriteriaSub(StoppingCriteria):
    # https://github.com/huggingface/transformers/issues/22340

    def __init__(self, stops):
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        if torch.all((self.stops == input_ids[0][-1:])).item():
            logger.info("see wtm and stop")
            # print(self.stops, input_ids[0][-1:])
            return True

        return False


class CustomLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids, scores):
        # Only keep scores for first len(scores[0]) - 6 tokens
        return scores[:, :-WATERMARK_EMB]


def generate_text_pipeline(model, tokenizer, input_ids, device, topk=False, max_length=100, result_path=None, return_text=True):
    """

    :param model: watermarkPLM
    :param tokenizer: personalized tokenizer
    :param input_ids: list of input ids
    :param device:
    :return: string if return_text=True; else (1, seq_len) tensor
    """
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=tokenizer.watermark_token)])
    logits_processor = [CustomLogitsProcessor()]

    # Encode input context
    input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
    model = model.to(device)
    ori_len = input_ids.shape[1]
    cur_len = 0
    enter_watermark = False
    while cur_len < max_length:
        # Generate text
        input_ids = model.base_model.generate(
            input_ids,
            do_sample=True,  # Enable sampling
            max_length=max_length + ori_len,  # Maximum length of the generated sequences
            temperature=0.7,  # The value used to module the next token probabilities
            top_k=60,  # K for top-k sampling
            top_p=1.0,  # P for nucleus sampling
            pad_token_id=tokenizer.tokenizer.pad_token_id,  # Padding token ID
            eos_token_id=tokenizer.tokenizer.pad_token_id,  # EOS token ID
            repetition_penalty=1.2,  # The parameter for repetition penalty
            length_penalty=2.0,  # The parameter for length penalty
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria
        )
        # logger.info(f"generated_ids: {generated_ids[0][-1]}")
        if input_ids[0][-1] == tokenizer.watermark_token:
            input_ids = input_ids[:, :-1]
            enter_watermark = True
            input_ids = generate_watermark_beam(input_ids, model, tokenizer, device, topk, return_text=False).unsqueeze(0)

        if not enter_watermark:
            input_ids = generate_watermark_beam(input_ids, model, tokenizer, device, topk, return_text=False).unsqueeze(0)
        cur_len = input_ids.shape[1] - ori_len

    if not enter_watermark:
        logger.info("force generating watermark")
    if result_path:
        with open(result_path, "a") as f:
            generated_text = tokenizer.custom_decode(input_ids.squeeze(0).tolist())
            f.write(generated_text)
    if return_text:
        generated_text = tokenizer.custom_decode(input_ids.squeeze(0).tolist())
        return generated_text
    else:
        return input_ids

############################################### attack #############################################
def synonym_attack(prompt_sentence, k=0.2):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :param k: float: how many words are replaced to get synonym
    :return: list of string with words get replaced
    """

    import json
    import random
    import nltk
    from nltk.corpus import wordnet
    from nltk.tokenize import word_tokenize
    from PyDictionary import PyDictionary

    nltk.download('punkt')
    nltk.download('wordnet')

    def get_synonyms(word):
        synonyms = wordnet.synsets(word)
        return set(syn.lemmas()[0].name() for syn in synonyms)

    def replace_with_synonym(text, k):
        logger.info(f"original sentence: {text}")
        words = word_tokenize(text)
        num_to_replace = int(len(words) * k)
        for _ in range(num_to_replace):
            word_to_replace = random.choice(words)
            synonyms = get_synonyms(word_to_replace)
            if synonyms:
                synonym_to_use = random.choice(list(synonyms))
                words[words.index(word_to_replace)] = synonym_to_use
        logger.info(f"modified sentence: {' '.join(words)}")
        return ' '.join(words)

    synonym_data = [replace_with_synonym(item, k) for item in prompt_sentence]

    return synonym_data


def paraphrase_attack(prompt_sentence, device):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :return: list of string with words get replaced
    """
    from transformers import PegasusForConditionalGeneration, PegasusTokenizerFast
    model = PegasusForConditionalGeneration.from_pretrained("tuner007/pegasus_paraphrase")
    tokenizer = PegasusTokenizerFast.from_pretrained("tuner007/pegasus_paraphrase")
    model = model.to(device)

    def _get_paraphrased_sentences(model, tokenizer, sentence, num_return_sequences=5, num_beams=5):
        # tokenize the text to be form of a list of token IDs
        inputs = tokenizer([sentence], truncation=True, padding="longest", return_tensors="pt").to(device)
        # generate the paraphrased sentences
        outputs = model.generate(
            **inputs,
            num_beams=num_beams,
            num_return_sequences=num_return_sequences,
        )
        # decode the generated sentences using the tokenizer to get them back to text
        return tokenizer.batch_decode(outputs, skip_special_tokens=True)
    paraphrased_sentences = []
    for sentence in prompt_sentence:
        logger.info(f"original sentence: {sentence}")
        buffer = _get_paraphrased_sentences(model, tokenizer, sentence, num_beams=10, num_return_sequences=10)
        # random_integer = random.randint(0, 9)
        random_integer = 0
        logger.info(f"paraphrased sentence: {buffer[random_integer]}")
        paraphrased_sentences.append(buffer[random_integer])

    return paraphrased_sentences


def insert_chars_attack(prompt_sentence, k):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :param k: float: how many characters are replaced to get synonym
    :return: list of string with words get replaced
    """
    import string
    modified_sentences = []

    for sentence in prompt_sentence:
        num_insertions = int(len(sentence) * k)
        logger.info(f"original sentence: {sentence}")
        for _ in range(num_insertions):
            insert_index = random.randint(0, len(sentence))
            random_char = random.choice(string.ascii_letters)
            sentence = sentence[:insert_index] + random_char + sentence[insert_index:]
        logger.info(f"modified sentence: {sentence}")
        modified_sentences.append(sentence)

    return modified_sentences


def insert_words_attack(prompt_sentence, k=0.2, localized=False):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :param k: float: how many characters are replaced to get synonym
    :return: list of string with words get replaced
    """
    import random
    import nltk
    from nltk.corpus import words

    modified_sentences = []

    nltk.download('words')
    word_list = words.words()  # Get list of common English words

    for sentence in prompt_sentence:
        logger.info(f"original sentence: {sentence}")
        words_in_sentence = sentence.split()
        num_insertions = int(len(words_in_sentence) * k)
        if localized:
            num_insertions = 1

        for _ in range(num_insertions):
            insert_index = random.randint(0, len(words_in_sentence))
            random_word = random.choice(word_list)  # Select random valid word from word_list
            words_in_sentence.insert(insert_index, random_word)

        modified_sentence = ' '.join(words_in_sentence)
        modified_sentences.append(modified_sentence)
        logger.info(f"modified sentence: {modified_sentence}")
    return modified_sentences


def delete_chars_attack(prompt_sentence, k):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :param k: float: how many characters are replaced to get synonym
    :return: list of string with words get replaced
    """
    modified_sentences = []

    for sentence in prompt_sentence:
        logger.info(f"original sentence: {sentence}")
        num_deletions = int(len(sentence) * k)

        for _ in range(num_deletions):
            delete_index = random.randint(0, len(sentence) - 1)
            sentence = sentence[:delete_index] + sentence[delete_index + 1:]

        modified_sentences.append(sentence)
        logger.info(f"modified sentence: {sentence}")

    return modified_sentences


def delete_words_attack(prompt_sentence, k, localized=False):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :param k: float: how many characters are replaced to get synonym
    :return: list of string with words get replaced
    """
    modified_sentences = []

    for sentence in prompt_sentence:
        logger.info(f"original sentence: {sentence}")
        words = sentence.split()
        num_deletions = int(len(words) * k)
        if localized:
            num_insertions = 1

        for _ in range(num_deletions):
            delete_index = random.randint(0, len(words) - 1)
            del words[delete_index]

        modified_sentence = ' '.join(words)
        modified_sentences.append(modified_sentence)
        logger.info(f"modified sentence: {modified_sentence}")

    return modified_sentences


################################ main function ########################################
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_path", default='data/embedded_warmup/', type=str, required=True, help="The input data directory. "
    )

    parser.add_argument(
        "--control_var",  action="store_true", help="Whether to strictly control variables")

    parser.add_argument(
        "--model_type", default="watermark-gpt", type=str, required=True,
        help="The model architecture to be trained or fine-tuned.",
    )

    # Other parameters
    parser.add_argument(
        "--pre_trained_model_type", default="gpt2-large", type=str, required=False,
        help="The model architecture to be trained or fine-tuned.",
    )

    parser.add_argument(
        "--should_continue", action="store_true", help="Whether to continue from previously trained model"
    )

    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=False,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
    )

    parser.add_argument(
        "--tokenizer_name",
        default=None,
        type=str,
        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
    )
    parser.add_argument(
        "--cache_dir",
        default=None,
        type=str,
        help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
    )
    parser.add_argument(
        "--block_size",
        default=384,
        type=int,
        help="Optional input sequence length after tokenization."
             "The training dataset will be truncated in block of this size for training."
             "Default to the model max input length for single sentence inputs (take into account special tokens).",
    )
    parser.add_argument("--freeze_layers", default=12, type=int, help="freeze layers for second-stage pretraining")
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument("--do_generate", action="store_true", help="Whether to generate sentence.")
    parser.add_argument("--do_evaluate", action="store_true", help="Whether to run evaluation.")
    parser.add_argument("--create_synthetic", action="store_true",
                        help="Whether to create synthetic data when runing baseline evaluation.")
    parser.add_argument("--use_synthetic", action="store_true", help="Whether to use synthetic data for evaluation.")
    parser.add_argument("--inference_attack", action="store_true", help="generate watermark based on original sentence")
    parser.add_argument("--find_top3", action="store_true", help="generate top 3 watermark sources")
    parser.add_argument("--use_personalized_generate", action="store_true",
                        help="Whether to use personalized sampling for generation.")
    parser.add_argument("--use_model_generate", action="store_true",
                        help="Whether to use model.generate for generation.")
    parser.add_argument("--generate_watermark_classification", action="store_true",
                        help="Whether to use already stored watermark true label for better watermark generation")
    parser.add_argument("--regenerate", action="store_true", help="regenerate watermarks based on synthetic data")
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--adam_beta1", default=0.9, type=float, help="Beta1 for Adam optimizer.")
    parser.add_argument("--adam_beta2", default=0.999, type=float, help="Beta2 for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--save_total_limit",
        type=int,
        default=None,
        help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
    )
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument(
        "--one_watermark", action="store_true", help="Each block left with only one watermark"
    )
    parser.add_argument("--seed", type=int, default=2023, help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
             "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument(
        "--eight_bit_adam", action="store_true", help="Whether to use 8-bit Adam"
    )
    parser.add_argument("--local-rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")

    # robustness experiment
    parser.add_argument("--paraphrase_attack", action="store_true", help="Conduct paraphrase attack")
    parser.add_argument("--synonym_attack", action="store_true", help="Conduct synonym attack")
    parser.add_argument("--insert_chars_attack", action="store_true", help="Conduct insert chars attack")
    parser.add_argument("--insert_words_attack", action="store_true", help="Conduct insert words attack")
    parser.add_argument("--delete_chars_attack", action="store_true", help="Conduct delete chars attack")
    parser.add_argument("--delete_words_attack", action="store_true", help="Conduct delete words attack")
    parser.add_argument("--localized", action="store_true", help="Conduct localized words attack")
    parser.add_argument("--k", default=0.2, type=float, help="Percentage of attack")
    args = parser.parse_args()

    if args.should_continue:
        sorted_checkpoints = _sorted_checkpoints(args)
        if len(sorted_checkpoints) == 0:
            raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
        else:
            args.model_name_or_path = sorted_checkpoints[-1]

    if args.output_dir is not None:
        if (
                os.path.exists(args.output_dir)
                and os.listdir(args.output_dir)
                and args.do_train
                and not args.overwrite_output_dir
        ):
            raise ValueError(
                "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                    args.output_dir
                )
            )

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if getattr(args, 'local_rank') == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(getattr(args, 'local_rank'))
        device = torch.device("cuda", getattr(args, 'local_rank'))
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if getattr(args, 'local_rank') in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        getattr(args, 'local_rank'),
        device,
        args.n_gpu,
        bool(getattr(args, 'local_rank') != -1),
        args.fp16,
    )

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if getattr(args, 'local_rank') not in [-1, 0]:
        torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training download model & vocab

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    if args.tokenizer_name and args.tokenizer_name != args.model_name_or_path:
        tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # add pad token to the tokenizer
        if "watermark" in args.model_type:
            tokenizer.add_tokens(['[WTM]'])
            tokenizer.add_special_tokens(
                {'additional_special_tokens': ['[WTM]']})  # add WATERMARK token to the tokenizer
    elif args.model_name_or_path:
        if "watermark" in args.model_type:
            tokenizer = personlized_tokenizer(tokenizer_class.from_pretrained(args.model_name_or_path))
        else:
            tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
    else:
        raise ValueError(
            "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
            "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
        )

    if args.block_size <= 0:
        if args.model_name_or_path and "watermark" in args.model_type:
            args.block_size = tokenizer.tokenizer.max_len_single_sentence
        else:
            args.block_size = tokenizer.max_len_single_sentence
        # Our input block size will be the max possible for the model
    else:
        if args.model_name_or_path and "watermark" in args.model_type:
            args.block_size = min(args.block_size, tokenizer.tokenizer.max_len_single_sentence)
        else:
            args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)

    if args.model_name_or_path:
        logger.info(f"Loading model from {args.model_name_or_path} in main")
        if "watermark" in args.model_type:
            model = watermarkPLM(config_class, model_class, seed=args.seed,
                                 vocab_size=len(tokenizer.tokenizer.get_vocab()),
                                 watermark_size=WATERMARK_EMB, pad_token_id=tokenizer.tokenizer.pad_token_id,
                                 model_type=args.pre_trained_model_type, freeze_layers=args.freeze_layers,
                                 model_name_or_path=args.model_name_or_path)
        else:
            model = model_class.from_pretrained(args.output_dir)
    else:
        logger.info("Training new model from scratch")
        if "watermark" in args.model_type:
            model = watermarkPLM(config_class, model_class, seed=args.seed, vocab_size=len(tokenizer.get_vocab()),
                                 watermark_size=WATERMARK_EMB, pad_token_id=tokenizer.pad_token_id,
                                 freeze_layers=args.freeze_layers, model_type=args.pre_trained_model_type)
        else:
            model = model_class()
    logger.info(f"Before loading model, memory takes:")
    logger.info(f"Allocated: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
    model.to(args.device)
    logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    logger.info(f"After loading model, memory takes: ")
    logger.info(f"Allocated: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")

    if getattr(args, 'local_rank') == 0:
        torch.distributed.barrier()  # End of barrier to make sure only the first process in distributed training download model & vocab

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        if getattr(args, 'local_rank') not in [-1, 0]:
            torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

        if args.model_name_or_path:
            tokenizer = tokenizer.tokenizer
        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)

        if getattr(args, 'local_rank') == 0:
            torch.distributed.barrier()

        global_step, tr_loss, loss_track = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
    if args.do_train and (getattr(args, 'local_rank') == -1 or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if getattr(args, 'local_rank') in [-1, 0]:
            os.makedirs(args.output_dir, exist_ok=True)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
        # torch.save(model.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin"))
        model_to_save.base_model.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

        with open(os.path.join(args.output_dir, "loss_track.pkl"), 'wb') as f:
            pickle.dump(loss_track, f)

    # Evaluation
    results = {}
    if args.do_eval and getattr(args, 'local_rank') in [-1, 0]:
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
            logger.info(f"Loading model from {checkpoint} in main")
            model = watermarkPLM(config_class, model_class, seed=args.seed, vocab_size=len(tokenizer.get_vocab()),
                                 watermark_size=WATERMARK_EMB, pad_token_id=tokenizer.pad_token_id,
                                 model_type=args.pre_trained_model_type, freeze_layers=args.freeze_layers,
                                 model_name_or_path=checkpoint)
            model = model.to(args.device)
            result = evaluate(args, model, tokenizer, prefix=prefix)
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)
        print('results: ', results)

    # Generate
    if args.do_generate and getattr(args, 'local_rank') in [-1, 0]:
        sample_sentence = "The Serpens star-forming cloud is"  # astro-ph
        # sample_sentence = "The quantisation of the ten-dimensional" # hep-th
        # sample_sentence = "The chemistry occurring in primordial gas"  # astro-CO
        input_ids = tokenizer.custom_encode(sample_sentence)
        generated_sentence = generate_with_beam_search_sample(args, model, tokenizer, input_ids, device=args.device)
        print(generated_sentence.encode("unicode_escape").decode())
        print(generated_sentence)

        with open('results.txt', 'w') as file:
            file.write(generated_sentence)

    # Evaluation
    if args.do_evaluate and getattr(args, 'local_rank') in [-1, 0]:

        characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        num_test_per_class = 50

        data_path = args.data_path
        synthetic_data_path = args.model_name_or_path.split('/')[0] + "/data/synthetic_data/" + args.model_name_or_path.split('/')[-1] + '/'

        raw_datasets = list_raw_datasets(data_path)

        if not os.path.exists(synthetic_data_path):
            os.makedirs(synthetic_data_path)

        ground_truth_file = data_path + '/embedded_watermarks.txt'
        ground_truth = {}
        
        with open(ground_truth_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += characters[int(idx)]
                ground_truth[class_name] = class_watermark
        
        all_ground_truth = [*ground_truth.values()]

        if args.generate_watermark_classification:
            ground_truth_file = args.data_path + '/embedded_watermarks.txt'
            watermark_list = []

            with open(ground_truth_file) as file:
                for line in file:
                    cur_watermark = []
                    watermark_idx = line.rstrip().split()[1]
                    for idx in watermark_idx:
                        cur_watermark.append(int(idx))
                    watermark_list.append(cur_watermark)

        ########## soft matching ##########

        def _string_similarity(lcs, str1, str2):
            if lcs:
                return longest_common_subsequence(str1, str2)
            else:
                return Levenshtein.distance(str1, str2)

        def _extract_special_substrings(input_string, special_characters):
            pattern = '[' + re.escape(special_characters) + ']+'
            special_substrings = re.findall(pattern, input_string)
            return special_substrings

        def _find_closest_watermark(generated_watermark):
            min_dist = 1000
            min_idx = -1
            for class_name in ground_truth:
                dist = _string_similarity(False, generated_watermark, ground_truth[class_name])
                if dist < min_dist:
                    min_dist = dist
                    min_idx = class_name
            return ground_truth[min_idx]
        
        if args.control_var:
            data_path = '/'.join(data_path.split('/')[:-1]) + "/eval_data_10c"
            # data_path = '/'.join(data_path.split('/')[:-1]) + "/unknown_10c_20"
        
        label_file = data_path + '/embedded_watermarks.txt'
        label = {}
        
        with open(label_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += characters[int(idx)]
                label[class_name] = class_watermark

        for dataset in raw_datasets:
            random.seed(args.seed)
            data_folder = os.path.join(data_path, dataset)
            prompt_sentences = []
            generated_sentences = []
            has_watermark = 0
            true_success = 0
            predicted_success = 0
            misclassfication = 0
            totally_wrong = 0
            special_character = '\u200b\u200c\u200d\u2062\u2063\u2064'
            logger.info("dataset: {}".format(dataset))
            
            if args.use_synthetic or args.regenerate:
                with open(synthetic_data_path + dataset + '.json', 'rb') as f:
                    prompt_sentences = json.load(f)
            else:
                if 'block' in args.model_name_or_path:
                    with open(data_folder + '/' + 'embedded.pkl', 'rb') as p:
                        f = pickle.load(p)
                        i = 0
                        while i < num_test_per_class:
                            line = random.choice(f)
                            if label[dataset] in line and len(line) > 5:
                                line = line.replace(label[dataset], "")
                                line = get_random_str(line)
                                # print(f"original line is: {line}")
                                prompt_sentences.append(line)
                                i += 1
                elif 'booksum' in args.model_name_or_path:
                    file_list = os.listdir(data_folder)
                    filtered_files = [file for file in file_list if not file.endswith('.pkl')]
                    file = random.choice(filtered_files)
                    with open(data_folder + '/' + file, 'r') as f:
                        i = 0
                        lines = f.readlines()
                        lines_iterator = iter(lines)
                        while i < num_test_per_class:
                            line = next(lines_iterator)
                            if label[dataset] in line:
                                while len(line)< 210:
                                    next_line = next(lines_iterator, None)
                                    if next_line is not None:
                                        if len(next_line)<10:
                                            break
                                        line = line+next_line
                                    else:
                                        break
                                else:
                                    line = line.replace(label[dataset], "")
                                    line = get_random_str(line)
                                    logger.info(f"original line is: {line}")
                                    prompt_sentences.append(line)
                                    i += 1
                else:
                    # First, find #num_test_per_class different start sentences
                    file_list = os.listdir(data_folder)
                    filtered_files = sorted([file for file in file_list if not file.endswith('.pkl')])
                    i = 0
                    while i < num_test_per_class:
                        file = random.choice(filtered_files)
                        # file = filtered_files[i]
                        logger.info("file: {}".format(file))
                        with open(data_folder + '/' + file, 'r') as f:
                            lines = f.readlines()
                            lines_iterator = iter(lines)
                            for line in lines_iterator:
                                if label[dataset] in line and len(line)>210:
                                    line = line.replace(label[dataset], "")
                                    line = get_random_str(line)
                                    logger.info(f"original line is: {line}")
                                    prompt_sentences.append(line)
                                    i += 1
                                    break
                                        

            if args.paraphrase_attack:
                prompt_sentences = paraphrase_attack(prompt_sentences, device)
            elif args.synonym_attack:
                prompt_sentences = synonym_attack(prompt_sentences, args.k)
            elif args.insert_chars_attack:
                prompt_sentences = insert_chars_attack(prompt_sentences, args.k)
            elif args.insert_words_attack:
                prompt_sentences = insert_words_attack(prompt_sentences, args.k, args.localized)
            elif args.delete_chars_attack:
                prompt_sentences = delete_chars_attack(prompt_sentences, args.k)
            elif args.delete_words_attack:
                prompt_sentences = delete_words_attack(prompt_sentences, args.k)

            for sentence in prompt_sentences:
                if args.use_synthetic or args.inference_attack:
                    input_ids = tokenizer.custom_encode(sentence)
                    input_ids = torch.tensor(input_ids).unsqueeze(0).to(args.device)
                    if args.generate_watermark_classification:
                        generated_text = generate_watermark_classification(input_ids, model, tokenizer, device,
                                                                           watermark_list, topk=1,
                                                                           temperature=0.8, return_text=True)
                    else:
                        generated_text = generate_watermark_beam(input_ids, model, tokenizer, device,
                                                                 return_text=True, topk=1)

                elif args.use_personalized_generate:
                    input_ids = tokenizer.custom_encode(sentence)
                    generated_text= generate_with_beam_search_sample(args, model, tokenizer, input_ids,
                                                                    max_length=100, device=args.device)

                elif args.use_model_generate:
                    input_ids = tokenizer.custom_encode(sentence)
                    generated_text = generate_text_pipeline(model, tokenizer, input_ids, device=args.device,
                                                            max_length=100, result_path=None, return_text=True)
                elif args.regenerate:
                    input_ids = tokenizer.custom_encode(sanitize(sentence))
                    input_ids = torch.tensor(input_ids).unsqueeze(0).to(args.device)
                    if args.generate_watermark_classification:
                        generated_text = generate_watermark_classification(input_ids, model, tokenizer, device,
                                                                           watermark_list, topk=1,
                                                                           temperature=0.8, return_text=True)
                    else:
                        generated_text = generate_watermark_beam(input_ids, model, tokenizer, device,
                                                                return_text=True)
                    
                generated_sentences.append(generated_text)
                logger.info("generated_text: {}".format(generated_text.encode("unicode_escape").decode()))
                logger.info("dataset: {}".format(ground_truth[dataset].encode("unicode_escape").decode()))

            # if args.use_synthetic:
            #     if not os.path.exists(synthetic_data_path + '_generation/'):
            #         os.makedirs(synthetic_data_path + '_generation/')

            #     with open(synthetic_data_path + '_generation/' + dataset + '.json', 'w') as fout:
            #         json.dump(generated_sentences, fout, indent=4)

            for generated_text in generated_sentences:
                if any(ext in generated_text for ext in characters):
                    has_watermark += 1
                    all_watermark = _extract_special_substrings(generated_text, special_character)
                    true_correct = True
                    misclass = False
                    predict = True
                    for item in all_watermark:
                        if ground_truth[dataset] != item:
                            true_correct = False
                            if item in all_ground_truth:
                                misclass = True
                            predicted_watermark =_find_closest_watermark(item)
                            if ground_truth[dataset] != predicted_watermark:
                                predict = False
                                
                    if true_correct:
                        true_success+=1
                    else:
                        if misclass:
                            misclassfication+=1
                        else:
                            totally_wrong+=1
                        if predict:
                            predicted_success += 1
                            totally_wrong -= 1
            print("Dataset: ", dataset)
            print("Number of watermark: ", has_watermark)
            print("Number of correct watermark: ", true_success)
            print("Number of predicted correct watermark: ", predicted_success)
            print("Number of misclassified watermark: ", misclassfication)
            print("Number of totally wrong watermark: ", totally_wrong)

    # Create Synthetic Data
    if args.create_synthetic and getattr(args, 'local_rank') in [-1, 0]:
        logger.info('Enter Synthetic Data Creation')

        def distinct_n(text, n):
            ngrams = [tuple(text[i:i + n]) for i in range(len(text) - n + 1)]
            return len(set(ngrams)) / max(len(ngrams), 1)

        characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        num_test_per_class = 50

        data_path = args.data_path
        ground_truth_file = data_path + '/embedded_watermarks.txt'

        raw_datasets = list_raw_datasets(data_path)

        synthetic_data_path = args.model_name_or_path.split('/')[0]+"/data/synthetic_data/" + args.model_name_or_path.split('/')[-1] + '/'
        if not os.path.exists(synthetic_data_path):
            os.makedirs(synthetic_data_path)
        ground_truth = {}

        with open(ground_truth_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += characters[int(idx)]
                ground_truth[class_name] = class_watermark

        count = 0
        distinct_1_score = 0.0
        distinct_2_score = 0.0
        
        if args.control_var:
            # seed_2023/data/embedded_g_10c
            data_path = '/'.join(data_path.split('/')[:-1]) + "/eval_data_10c"
            # data_path = '/'.join(data_path.split('/')[:-1]) + "/unknown_10c_20"
        
        label_file = data_path + '/embedded_watermarks.txt'
        label = {}
        
        with open(label_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += characters[int(idx)]
                label[class_name] = class_watermark

        for dataset in raw_datasets:
            data_folder = os.path.join(data_path, dataset)
            prompt_sentences = []
            generated_sentences = []

            if 'booksum' in args.model_name_or_path:
                file_list = os.listdir(data_folder)
                filtered_files = [file for file in file_list if not file.endswith('.pkl')]
                file = random.choice(filtered_files)
                with open(data_folder + '/' + file, 'r') as f:
                    i = 0
                    lines = f.readlines()
                    lines_iterator = iter(lines)
                    while i < num_test_per_class:
                        line = next(lines_iterator)
                        if label[dataset] in line:
                            while len(line)< 210:
                                next_line = next(lines_iterator, None)
                                if next_line is not None:
                                    if len(next_line)<10:
                                        break
                                    line = line+next_line
                                else:
                                    break
                            else:
                                line = line.replace(label[dataset], "")
                                line = get_random_str(line)
                                logger.info(f"original line is: {line}")
                                prompt_sentences.append(line)
                                i += 1
            else:
                # First, find #num_test_per_class different start sentences
                file_list = os.listdir(data_folder)
                filtered_files = sorted([file for file in file_list if not file.endswith('.pkl')])
                i = 0
                while i < num_test_per_class:
                    file = random.choice(filtered_files)
                    # file = filtered_files[i]
                    logger.info("file: {}".format(file))
                    with open(data_folder + '/' + file, 'r') as f:
                        lines = f.readlines()
                        lines_iterator = iter(lines)
                        for line in lines_iterator:
                            if label[dataset] in line and len(line)>210:
                                line = line.replace(label[dataset], "")
                                line = get_random_str(line)
                                logger.info(f"original line is: {line}")
                                prompt_sentences.append(line)
                                i += 1
                                break
                                        

            for sentence in prompt_sentences:
                if args.use_personalized_generate:
                    input_ids = tokenizer.custom_encode(sentence)
                    generated_text = generate_with_beam_search_sample(args, model, tokenizer, input_ids, max_length=100,
                                                                      device=args.device)
                else:
                    input_ids = tokenizer.custom_encode(sentence)
                    generated_text = generate_text_pipeline(model, tokenizer, input_ids, device, max_length=100,
                                                            result_path=None, return_text=True)


                generated_sentences.append(generated_text)
                sanitized_text = sanitize(generated_text)
                
                count += 1
                distinct_1_score += distinct_n(sanitized_text.split(), n=1)
                distinct_2_score += distinct_n(sanitized_text.split(), n=2)

            with open(synthetic_data_path + dataset + '.json', 'w') as fout:
                json.dump(generated_sentences, fout, indent=4)
                print(generated_sentences)

        print('average distinct_1_score: ', distinct_1_score / count)
        print('average distinct_2_score: ', distinct_2_score / count)

    if args.find_top3 or args.find_top5 and getattr(args, 'local_rank') in [-1, 0]:

        characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        num_test_per_class = 50

        data_path = args.data_path
        synthetic_data_path = args.model_name_or_path.split('/')[0]+"/data/synthetic_data/" + args.model_name_or_path.split('/')[-1] + '/'

        raw_datasets = list_raw_datasets(data_path)

        if not os.path.exists(synthetic_data_path):
            os.makedirs(synthetic_data_path)

        ground_truth_file = data_path + '/embedded_watermarks.txt'
        ground_truth = {}

        with open(ground_truth_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += characters[int(idx)]
                ground_truth[class_name] = class_watermark
        
        all_ground_truth = [*ground_truth.values()]

        def _string_similarity(lcs, str1, str2):
            if lcs:
                return longest_common_subsequence(str1, str2)
            else:
                return Levenshtein.distance(str1, str2)

        def _extract_special_substrings(input_string, special_characters):
            pattern = '[' + re.escape(special_characters) + ']+'
            special_substrings = re.findall(pattern, input_string)
            return special_substrings

        def _find_closest_watermark(generated_watermark):
            min_dist = 1000
            min_idx = -1
            for class_name in ground_truth:
                dist = _string_similarity(False, generated_watermark, ground_truth[class_name])
                if dist < min_dist:
                    min_dist = dist
                    min_idx = class_name
            return ground_truth[min_idx]
        
        if args.control_var:
            data_path = '/'.join(data_path.split('/')[:-1]) + "/eval_data_10c"
        
        label_file = data_path + '/embedded_watermarks.txt'
        label = {}
        
        with open(label_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += characters[int(idx)]
                label[class_name] = class_watermark

        for dataset in raw_datasets:
            data_folder = os.path.join(data_path, dataset)
            prompt_sentences = []
            generated_sentences = []
            has_watermark = 0
            true_success = 0
            predicted_success = 0
            misclassfication = 0
            totally_wrong = 0
            special_character = '\u200b\u200c\u200d\u2062\u2063\u2064'
            logger.info("dataset: {}".format(dataset))

            if args.use_synthetic or args.regenerate:
                with open(synthetic_data_path + dataset + '.json', 'rb') as f:
                    prompt_sentences = json.load(f)
            elif 'booksum' in args.model_name_or_path:
                file_list = os.listdir(data_folder)
                filtered_files = [file for file in file_list if not file.endswith('.pkl')]
                file = random.choice(filtered_files)
                with open(data_folder + '/' + file, 'r') as f:
                    i = 0
                    lines = f.readlines()
                    lines_iterator = iter(lines)
                    while i < num_test_per_class:
                        line = next(lines_iterator)
                        if label[dataset] in line:
                            while len(line)< 210:
                                next_line = next(lines_iterator, None)
                                if next_line is not None:
                                    if len(next_line)<10:
                                        break
                                    line = line+next_line
                                else:
                                    break
                            else:
                                line = line.replace(label[dataset], "")
                                line = get_random_str(line)
                                logger.info(f"original line is: {line}")
                                prompt_sentences.append(line)
                                i += 1
            else:
                file_list = os.listdir(data_folder)
                filtered_files = sorted([file for file in file_list if not file.endswith('.pkl')])
                i = 0
                while i < num_test_per_class:
                    file = random.choice(filtered_files)
                    # file = filtered_files[i]
                    logger.info("file: {}".format(file))
                    with open(data_folder + '/' + file, 'r') as f:
                        lines = f.readlines()
                        lines_iterator = iter(lines)
                        for line in lines_iterator:
                            if label[dataset] in line and len(line)>210:
                                line = line.replace(label[dataset], "")
                                line = get_random_str(line)
                                logger.info(f"original line is: {line}")
                                prompt_sentences.append(line)
                                i += 1
                                break

            if args.paraphrase_attack:
                prompt_sentences = paraphrase_attack(prompt_sentences, device)
            elif args.synonym_attack:
                prompt_sentences = synonym_attack(prompt_sentences, args.k)
            elif args.insert_chars_attack:
                prompt_sentences = insert_chars_attack(prompt_sentences, args.k)
            elif args.insert_words_attack:
                prompt_sentences = insert_words_attack(prompt_sentences, args.k, args.localized)
            elif args.delete_chars_attack:
                prompt_sentences = delete_chars_attack(prompt_sentences, args.k)
            elif args.delete_words_attack:
                prompt_sentences = delete_words_attack(prompt_sentences, args.k)
                
            k = 3
            if args.find_top5:
                k = 5

            for sentence in prompt_sentences:
                if args.use_personalized_generate:
                    input_ids = tokenizer.custom_encode(sentence)
                    generated_text = generate_with_beam_search_sample(args, model, tokenizer, input_ids,
                                                                      topk=k, max_length=100, device=args.device)
                elif args.regenerate:
                    input_ids = tokenizer.custom_encode(sanitize(sentence))
                    input_ids = torch.tensor(input_ids).unsqueeze(0).to(args.device)
                    if args.generate_watermark_classification:
                        generated_text = generate_watermark_classification(input_ids, model, tokenizer, device,
                                                                           watermark_list, topk=k,
                                                                           temperature=0.8, return_text=True)
                    else:
                        generated_text = generate_watermark_beam(input_ids, model, tokenizer, device,
                                                                 topk=k, return_text=True)
                else:
                    input_ids = tokenizer.custom_encode(sentence)
                    generated_text = generate_text_pipeline(model, tokenizer, input_ids, device=args.device, topk=k,
                                                            max_length=100, result_path=None, return_text=True)
                    
                generated_sentences.append(generated_text)
                logger.info("generated_text: {}".format(generated_text.encode("unicode_escape").decode()))
                logger.info("dataset: {}".format(ground_truth[dataset].encode("unicode_escape").decode()))

            for generated_text in generated_sentences:
                if any(ext in generated_text for ext in characters):
                    has_watermark += 1
                    all_watermark = _extract_special_substrings(generated_text, special_character)

                    watermarks_datasets = []
                    if ground_truth[dataset] in all_watermark:
                        watermarks_datasets.append(dataset)
                        true_success += 1
                    else:
                        all_predicted_watermark = []
                        for item in all_watermark:
                            all_predicted_watermark.append(_find_closest_watermark(item))
                        if ground_truth[dataset] in all_predicted_watermark:
                            watermarks_datasets.append(dataset)
                            predicted_success += 1

                    logger.info("groundtruth dataset: {}".format(dataset))
                    logger.info("generated_text: {}".format(generated_text.encode("unicode_escape").decode()))
                    logger.info("generated datasets: {}".format(watermarks_datasets))

            print("Dataset: ", dataset)
            print("Number of watermark: ", has_watermark)
            print("Number of correct watermark: ", true_success)
            print("Number of predicted correct watermark: ", predicted_success)


if __name__ == "__main__":
    main()