import argparse
import glob
import logging
import os
import math
import pickle
import random
import re
import shutil
from typing import Dict, List, Tuple
import gc
import time
import Levenshtein
import json
import nltk
# nltk.download('punkt')
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import defaultdict

import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.trainer_pt_utils import get_parameter_names
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

import pickle

from torch.distributed.fsdp import FullyShardedDataParallel
from torch.optim import AdamW
# faster training optimizer
import bitsandbytes as bnb
from transformers import (
    pipeline,
    WEIGHTS_NAME,
    BertConfig,
    BertForMaskedLM,
    BertTokenizer,
    CamembertConfig,
    CamembertForMaskedLM,
    CamembertTokenizer,
    DistilBertConfig,
    DistilBertForMaskedLM,
    DistilBertTokenizer,
    OPTConfig,
    OPTForCausalLM,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    OpenAIGPTConfig,
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    RobertaConfig,
    RobertaForMaskedLM,
    RobertaTokenizer,
    LlamaForCausalLM, 
    LlamaConfig,
    Phi3Config,
    get_linear_schedule_with_warmup,
    T5ForConditionalGeneration,
    T5Tokenizer
)

from torch.utils.tensorboard import SummaryWriter

from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType


logger = logging.getLogger(__name__)

MODEL_CLASSES = {
    "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "gpt2-xl": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "gpt2-large": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "watermark-gpt": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "watermark-opt": (OPTConfig, OPTForCausalLM, AutoTokenizer),
    "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
    "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
    "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
    "watermark-llama": (LlamaConfig, LlamaForCausalLM, AutoTokenizer),
    "watermark-phi": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
}

ZWSP = "\u200B"
ZWNJ = "\u200C"
ZWJ = "\u200D"
IT = "\u2062"
IS = "\u2063"
IP = "\u2064"
NQSP = "\u2000"
MQSP = "\u2001"
ENSP = "\u2002"
EMSP = "\u2003"

# CHARACTERS = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
CHARACTERS = [ZWSP, ZWNJ, ZWJ, IT, IS, IP, NQSP, MQSP, ENSP, EMSP]
WATERMARK_EMB = 20
WATERMARK_LEN = 1

def get_subdirectories(directory):
    subdirectories = []
    for root, dirs, files in os.walk(directory):
        for dir in dirs:
            subdirectories.append(os.path.join(root, dir))
    return subdirectories

def get_keys_from_value(d, val):
    return [k for k, v in d.items() if v == val]

def list_raw_datasets(directory):
    subdirectories = []
    for root, dirs, files in os.walk(directory):
        for dir in sorted(dirs):
            subdirectories.append(dir)
    return subdirectories

def sanitize(sentence):
    for item in [ZWSP, ZWNJ, ZWJ, IT, IS, IP, ENSP, EMSP, NQSP, MQSP, '[WTM]']:
        sentence = sentence.replace(item, '')
    return sentence

def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:      # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

def find_highest_common_index(first_list, second_list):
    first_list_as_list = first_list.tolist()[0]
    common_indexes = set(first_list_as_list) & set(second_list)

    if not common_indexes:
        return None

    highest_index = max(common_indexes)
    return highest_index

def extract_special_substrings(input_string, special_characters):
    # pattern = '[' + re.escape(special_characters) + ']+'
    pattern ='(?:[' + re.escape(special_characters) + ']){1,' + str(10) + '}'
    special_substrings = re.findall(pattern, input_string)
    return special_substrings


def load_cache_text_files_from_directory(directory_path, seed, overwrite_cache=False, evaluate=False):
    # evaluate=False
    datasets = []
    if evaluate:
        cache_data = os.path.join(directory_path, "cache_valid.pkl")
    else:
        cache_data = os.path.join(directory_path, "cache_train.pkl")
    if os.path.exists(cache_data) and not overwrite_cache and not evaluate:
        with open(cache_data, 'rb') as file:
            logger.info("Load cache data from {}".format(cache_data))
            data = pickle.load(file)
        return True, data
    list_of_directory = [f for f in os.listdir(directory_path) if not f.endswith('.pkl')]
    # if evaluate:
    #     _, list_of_directory = train_test_split(list_of_directory, test_size=0.1, random_state=seed)
    # else:
    #     list_of_directory, _ = train_test_split(list_of_directory, test_size=0.1, random_state=seed)
    for file_name in list_of_directory:
        file_path = os.path.join(directory_path, file_name)
        with open(file_path, "r", encoding="utf-8") as file:
            for line in file:
                line = line.strip()
                try:
                    # data = file.read()
                    # breakpoint()
                    if evaluate:
                        line = sanitize(line)
                    datasets.append(line)
                except:
                    pass
        # import shutil
        # eval_path = 'seed_2021/data/eval_data_10c/' + directory_path.split('/')[-1] +'/'
        # os.makedirs(eval_path, exist_ok=True)
        # shutil.copy(file_path, eval_path)
    return False, datasets


def load_ground_truth(ground_truth_file):
    ground_truth = {}
    # characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
    with open(ground_truth_file) as file:
        for line in file:
            class_name = line.rstrip().split()[0]
            class_watermark = ''
            watermark_idx = line.rstrip().split()[1]
            for idx in watermark_idx:
                class_watermark += CHARACTERS[int(idx)]
            ground_truth[class_name] = class_watermark
    return ground_truth



class personlized_tokenizer():
    def __init__(self, data_path, tokenizer):
        self.tokenizer = tokenizer
        self.watermark_token = self.tokenizer.additional_special_tokens_ids[0]
        self.example_user = []
        max_token_index = len(tokenizer.get_vocab())  # Get the maximum token index
        # self.chr_to_wtm = {
        #     tokenizer.encode(ZWSP, add_special_tokens=False)[0] if len(
        #         tokenizer.encode(ZWSP, add_special_tokens=False)) == 1 else tuple(
        #         tokenizer.encode(ZWSP, add_special_tokens=False)): [max_token_index],
        #     tuple(tokenizer.encode(ZWNJ, add_special_tokens=False)): [max_token_index + 1],
        #     tokenizer.encode(ZWJ, add_special_tokens=False)[0] if len(
        #         tokenizer.encode(ZWJ, add_special_tokens=False)) == 1 else tuple(
        #         tokenizer.encode(ZWJ, add_special_tokens=False)): [max_token_index + 2],
        #     tuple(tokenizer.encode(IT, add_special_tokens=False)): [max_token_index + 3],
        #     tuple(tokenizer.encode(IS, add_special_tokens=False)): [max_token_index + 4],
        #     tuple(tokenizer.encode(IP, add_special_tokens=False)): [max_token_index + 5],
        #     tokenizer.encode(" " + ZWSP, add_special_tokens=False)[0] if len(
        #         tokenizer.encode(" " + ZWSP, add_special_tokens=False)) == 1 else tuple(
        #         tokenizer.encode(" " + ZWSP, add_special_tokens=False)): [max_token_index],
        #     tuple(tokenizer.encode(" " + ZWNJ, add_special_tokens=False)): [max_token_index + 1],
        #     tokenizer.encode(" " + ZWJ, add_special_tokens=False)[0] if len(
        #         tokenizer.encode(" " + ZWJ)) == 1 else tuple(
        #         tokenizer.encode(" " + ZWJ, add_special_tokens=False)): [max_token_index + 2],
        #     tuple(tokenizer.encode(" " + IT, add_special_tokens=False)): [max_token_index + 3],
        #     tuple(tokenizer.encode(" " + IS, add_special_tokens=False)): [max_token_index + 4],
        #     tuple(tokenizer.encode(" " + IP, add_special_tokens=False)): [max_token_index + 5],
        #     tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)[0] if len(
        #         tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)) == 1 else tuple(
        #         tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)): [max_token_index, max_token_index]
        # }
        self.chr_to_wtm = {}
        l = [ZWSP, ZWNJ, ZWJ, IT, IS, IP, NQSP, MQSP, ENSP, EMSP]
        for i in range(len(l)):
            self.chr_to_wtm[tuple(tokenizer.encode(l[i], add_special_tokens=False))] = [max_token_index+i] # when wtm is at the start of a sentence
            if i < 3:
                # ZWSP, ZWNJ, ZWJ only have 1 token
                self.chr_to_wtm[tokenizer.encode(l[i], add_special_tokens=False)[1]] = [max_token_index+i] # when wtm is at the end of something (no matter it's a word or a blank space)
            elif i < 6:
                # IT, IS, IP have 3 tokens
                self.chr_to_wtm[tuple(tokenizer.encode(l[i], add_special_tokens=False)[1:])] = [max_token_index+i] # when wtm is at the end of something (no matter it's a word or a blank space)
            elif i < 8:
                # NQSP, MQSP have 3 tokens
                self.chr_to_wtm[tuple(tokenizer.encode(l[i], add_special_tokens=False)[1:])] = [max_token_index+i]
            else:
                # ENSP, EMSP have 1 token
                self.chr_to_wtm[tokenizer.encode(l[i], add_special_tokens=False)[1]] = [max_token_index+i]

        self.wtm_to_chr = {
            max_token_index: tokenizer.encode(ZWSP, add_special_tokens=False)[1:],
            max_token_index + 1: tokenizer.encode(ZWNJ, add_special_tokens=False)[1:],
            max_token_index + 2: tokenizer.encode(ZWJ, add_special_tokens=False)[1:],
            max_token_index + 3: tokenizer.encode(IT, add_special_tokens=False)[1:],
            max_token_index + 4: tokenizer.encode(IS, add_special_tokens=False)[1:],
            max_token_index + 5: tokenizer.encode(IP, add_special_tokens=False)[1:],
            max_token_index + 6: tokenizer.encode(NQSP, add_special_tokens=False)[1:],
            max_token_index + 7: tokenizer.encode(MQSP, add_special_tokens=False)[1:],
            max_token_index + 8: tokenizer.encode(ENSP, add_special_tokens=False)[1:],
            max_token_index + 9: tokenizer.encode(EMSP, add_special_tokens=False)[1:],
        }
        
        self.ground_truth = load_ground_truth(os.path.join(data_path, "embedded_watermarks.txt"))
        self.ground_truth_watermark = defaultdict(list)
        self.ground_truth_watermark_ori = defaultdict(list)
        for user, watermark in self.ground_truth.items():
            self.ground_truth_watermark[user] = self._encode_watermark(watermark)
            self.ground_truth_watermark_ori[user].append(self.tokenizer.encode(watermark, add_special_tokens=False))
            self.ground_truth_watermark_ori[user].append(self.tokenizer.encode(watermark, add_special_tokens=False)[1:])
        

    def custom_encode(self, sentence):
        """
        Map normal zero-width characters to watermark tokens
        :param tokenizer: original tokenizer
        :param sentence: target sentence
        :return:
        encoded_tokens: list of encoded tokens (seq_len)
        exist_wtm: if contains watermark or not
        wtm_mask: mask where watermark tokens are 1 and others are 0 (seq_len)
        """
        tokens = self.tokenizer.encode(sentence)
        encoded_tokens = []
        i = 0
        wtm_position = []
        while i < len(tokens):
            # todo: may need to modify this in the future if using sentencepiece
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1])])
                i += 2
            elif i < len(tokens) - 2 and (tokens[i], tokens[i + 1], tokens[i + 2]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1], tokens[i + 2])])
                i += 3
            elif tokens[i] in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_position.append(len(encoded_tokens))
                if len(self.chr_to_wtm[tokens[i]]) == 1:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    i += 1
                elif len(self.chr_to_wtm[tokens[i]]) == 2:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    i += 1
            else:
                encoded_tokens.append(tokens[i])
                i += 1

        return encoded_tokens
    
    def _dedupe_subseqs(self, subs):
        seen, out = set(), []
        for s in subs:
            t = tuple(s)
            if t not in seen:
                seen.add(t); out.append(t)
        return out
    
    def _find_all_subseqs_sp_optional(self, seq, patterns, sp_id):
        """
        Left→right greedy, prefer longest; the leading space marker (sp_id) is optional.
        Returns non-overlapping (start, length).
        """
        pats = set()
        for p in patterns:
            t = tuple(p)
            pats.add(t)
            if t and t[0] == sp_id:
                pats.add(t[1:])           # variant without SP
            else:
                pats.add((sp_id,) + t)    # allow match just after a space
        pats = sorted(self._dedupe_subseqs(pats), key=len, reverse=True)

        out, i, n = [], 0, len(seq)
        while i < n:
            for t in pats:
                L = len(t)
                if L and i + L <= n and tuple(seq[i:i+L]) == t:
                    out.append((i, L))
                    i += L
                    break
            else:
                i += 1
        return out

    def custom_encode_new(self, sentence):

        # characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        special_character = '\u200b\u200c\u200d\u2062\u2063\u2064\u2000\u2001\u2002\u2003'
        
        if any(ext in sentence for ext in CHARACTERS):
            all_watermark = extract_special_substrings(sentence, special_character)
            try:
                user = get_keys_from_value(self.ground_truth, all_watermark[0])[0]
            except:
                logger.info("all_watermark: ", all_watermark)
                logger.info(get_keys_from_value(self.ground_truth, all_watermark[0]))
                user = list(self.ground_truth)[0]
        else:
            user = list(self.ground_truth)[0] # assign to any user since no watermark in this sentence

        # user = user.split('/')[-1]

        tokens = self.tokenizer.encode(sentence)
        exist_wtm = False
        replacement_encodes = self.ground_truth_watermark_ori[user]
        replacement_list = self.ground_truth_watermark[user]

        def find_all_subsequences(seq, subsequences):
            start_indices = []
            for subseq in subsequences:
                i = 0
                while i < len(seq):
                    if seq[i:i + len(subseq)] == subseq:
                        start_indices.append((i, len(subseq)))
                        i += len(subseq)
                    else:
                        i += 1
            return sorted(start_indices, key=lambda x: x[0])
        
        sp_id = self.tokenizer.get_vocab().get("▁", 29871)
        start_indices = self._find_all_subseqs_sp_optional(tokens, replacement_encodes, sp_id)
        # start_indices = find_all_subsequences(tokens, replacement_encodes)
        
        encoded_tokens = tokens.copy()
        wtm_mask = [0] * len(encoded_tokens)
        # print(replacement_list)
        # print(self._custom_decode(replacement_list))
        

        for start_index, length in reversed(start_indices):
            # encoded_tokens = (encoded_tokens[:start_index] + [self.tokenizer.additional_special_tokens_ids[0]] +
            #                   replacement_list + encoded_tokens[start_index + length:])
            # wtm_mask = wtm_mask[:start_index] + [0] + [1] * len(replacement_list) + wtm_mask[start_index + length:]
            
            keep_sp = (start_index > 0 and encoded_tokens[start_index - 1] == sp_id)
        
            encoded_tokens = (
                encoded_tokens[:start_index]
                + ([sp_id] if keep_sp else [])                 # preserve the SP marker
                + [self.tokenizer.additional_special_tokens_ids[0]]
                + replacement_list
                + encoded_tokens[start_index + length:]
            )

            wtm_mask = (
                wtm_mask[:start_index]
                + ([0] if keep_sp else [])
                + [0] + [1] * len(replacement_list)
                + wtm_mask[start_index + length:]
            )
        
        if sum(wtm_mask) > 0:
            exist_wtm = True
            # if wtm_mask[-1]:
            #     print(wtm_mask)
        # print("!!!!!!!!!!!!!!!!!!!!!!")
        # print(exist_wtm)
        return encoded_tokens
    def custom_encode_plus_new(self, sentence):

        # characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        special_character = '\u200b\u200c\u200d\u2062\u2063\u2064'
        
        if any(ext in sentence for ext in CHARACTERS):
            all_watermark = extract_special_substrings(sentence, special_character)
            try:
                user = get_keys_from_value(self.ground_truth, all_watermark[0])[0]
            except:
                logger.info("all_watermark: ", all_watermark)
                logger.info(get_keys_from_value(self.ground_truth, all_watermark[0]))
                user = list(self.ground_truth)[0]
        else:
            user = list(self.ground_truth)[0] # assign to any user since no watermark in this sentence

        # user = user.split('/')[-1]

        mask_and_tokens = self.tokenizer.encode_plus(sentence)
        tokens = mask_and_tokens['input_ids']
        mask = mask_and_tokens['attention_mask']
        # breakpoint()
        exist_wtm = False
        replacement_encodes = self.ground_truth_watermark_ori[user]
        replacement_list = self.ground_truth_watermark[user]

        def find_all_subsequences(seq, subsequences):
            start_indices = []
            for subseq in subsequences:
                i = 0
                while i < len(seq):
                    if seq[i:i + len(subseq)] == subseq:
                        start_indices.append((i, len(subseq)))
                        i += len(subseq)
                    else:
                        i += 1
            return sorted(start_indices, key=lambda x: x[0])
        
        sp_id = self.tokenizer.get_vocab().get("▁", 29871)
        start_indices = self._find_all_subseqs_sp_optional(tokens, replacement_encodes, sp_id)
        # start_indices = find_all_subsequences(tokens, replacement_encodes)
        
        encoded_tokens = tokens.copy()
        wtm_mask = [0] * len(encoded_tokens)
        # print(replacement_list)
        # print(self._custom_decode(replacement_list))
        

        for start_index, length in reversed(start_indices):
            # encoded_tokens = (encoded_tokens[:start_index] + [self.tokenizer.additional_special_tokens_ids[0]] +
            #                   replacement_list + encoded_tokens[start_index + length:])
            # wtm_mask = wtm_mask[:start_index] + [0] + [1] * len(replacement_list) + wtm_mask[start_index + length:]
            
            keep_sp = (start_index > 0 and encoded_tokens[start_index - 1] == sp_id)
        
            encoded_tokens = (
                encoded_tokens[:start_index]
                + ([sp_id] if keep_sp else [])                 # preserve the SP marker
                + [self.tokenizer.additional_special_tokens_ids[0]]
                + replacement_list
                + encoded_tokens[start_index + length:]
            )

            wtm_mask = (
                wtm_mask[:start_index]
                + ([0] if keep_sp else [])
                + [0] + [1] * len(replacement_list)
                + wtm_mask[start_index + length:]
            )
        
        if sum(wtm_mask) > 0:
            exist_wtm = True
            # if wtm_mask[-1]:
            #     print(wtm_mask)
        # print("!!!!!!!!!!!!!!!!!!!!!!")
        # print(exist_wtm)
        # return encoded_tokens, exist_wtm, wtm_mask
        return {'input_ids':encoded_tokens, 'attention_mask':mask}
    def custom_batch_encode_plus_new(self, sentences):
        pad_id = getattr(self.tokenizer, "pad_token_id", None)
        encoded_list = [
           self.custom_encode_plus_new(s) for s in sentences
        ]
        max_len = max(len(e['input_ids']) for e in encoded_list)
        for i in range(len(encoded_list)):
            while len(encoded_list[i]['input_ids'])<max_len:
                encoded_list[i]['input_ids'].insert(0, pad_id)
                encoded_list[i]['attention_mask'].insert(0, 0)
        # breakpoint()
        input_ids = torch.tensor([e['input_ids'] for e in encoded_list], dtype=torch.long)
        attention_mask = torch.tensor([e['attention_mask'] for e in encoded_list], dtype=torch.long)

        batch = {
                'input_ids': input_ids,
                'attention_mask': attention_mask
            }
        
        return batch

    def _encode_watermark(self, watermark_str):
        """
        watermark string before costum encoding
        :param tokens: watermark tokens before custom_encoding
        :return:
        """
        encoded_tokens = []
        tokens = self.tokenizer.encode(watermark_str, add_special_tokens=False)
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) in self.chr_to_wtm:
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1])])
                i += 2
            elif i < len(tokens) - 2 and (tokens[i], tokens[i + 1], tokens[i + 2]) in self.chr_to_wtm:
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1], tokens[i + 2])])
                i += 3
            elif i < len(tokens) - 3 and (tokens[i], tokens[i + 1], tokens[i + 2], tokens[i + 3]) in self.chr_to_wtm:
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1], tokens[i + 2], tokens[i + 3])])
                i += 4
            elif tokens[i] in self.chr_to_wtm:
                encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                i += 1
                # if len(self.chr_to_wtm[tokens[i]]) == 1:
                #     encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                #     i += 1
                # elif len(self.chr_to_wtm[tokens[i]]) == 2:
                #     encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                #     i += 1
            else:
                print(f"invalid watermark string  of {watermark_str} encoded as {tokens} at position {i}")
                encoded_tokens.append(tokens[i])
                i += 1
        return encoded_tokens
    def __getattr__(self, name):
        return getattr(self.tokenizer, name)
    def custom_decode(self, decoded_tokens, skip_special_tokens=False):

    # import torch

        # 转成纯 python 列表
        if isinstance(decoded_tokens, torch.Tensor):
            decoded_tokens = decoded_tokens.tolist()

        # 如果是批量，逐条递归解码
        if decoded_tokens and isinstance(decoded_tokens[0], (list, tuple)):
            return [
                self.custom_decode(seq, skip_special_tokens=skip_special_tokens)
                for seq in decoded_tokens
            ]

        # 现在是单条：List[int]
        seq = decoded_tokens

        out_ids = []
        for t in seq:
            # t 必须是 int；如果不是（例如又是 list/tuple），就直接展开它
            if isinstance(t, (list, tuple)):
                out_ids.extend(int(x) for x in t)
                continue

            # 水印映射：把 watermark token 映射回原零宽字符对应的ids（list[int]）
            mapped = self.wtm_to_chr.get(int(t), None)
            if mapped is not None:
                out_ids.extend(int(x) for x in mapped)
            else:
                out_ids.append(int(t))

        # 注意：这里传给 HF tokenizer 的必须是 List[int]
        return self.tokenizer.decode(out_ids, skip_special_tokens=skip_special_tokens)

    # def custom_decode(self, decoded_tokens):
    #     """
    #     Map watermark tokens to normal zero-width characters
    #     :param tokenizer: original tokenizer
    #     :param decoded_tokens: list of tokens
    #     :return:
    #     decoded_text: decoded text
    #     """
    #     decoded_tokens = [self.wtm_to_chr.get(token, [token]) for token in decoded_tokens]
    #     flattened_tokens = [token for sublist in decoded_tokens for token in sublist]
    #     decoded_text = self.tokenizer.decode(flattened_tokens)
    #     return decoded_text


class watermarkDataset(Dataset):
    """
    Define tokenizer function so that it
    1. add watermark tokens (done in _custom_encode)
    2. add T/F indicate if there is watermark (done in _custom_encode)
    3. specify the position of watermark: make it a mask format (done in _custom_encode)
    4. truncate when needed
    :param args:
    :param tokenizer:
    :param evaluate: if return valid dataset
    :return: processed input in dataset format
    """

    def __init__(self, args, tokenizer, evaluate):
        self.block_size = args.block_size + 1 # leave one for [WTM]
        self.tokenizer = tokenizer
        self.examples = []
        self.example_user = []
        self.evaluate = evaluate
        max_token_index = len(tokenizer.get_vocab())  # Get the maximum token index
        # self.chr_to_wtm = {
        #     tokenizer.encode(ZWSP, add_special_tokens=False)[0] if len(
        #         tokenizer.encode(ZWSP, add_special_tokens=False)) == 1 else tuple(
        #         tokenizer.encode(ZWSP, add_special_tokens=False)): [max_token_index],
        #     tuple(tokenizer.encode(ZWNJ, add_special_tokens=False)): [max_token_index + 1],
        #     tokenizer.encode(ZWJ, add_special_tokens=False)[0] if len(
        #         tokenizer.encode(ZWJ, add_special_tokens=False)) == 1 else tuple(
        #         tokenizer.encode(ZWJ, add_special_tokens=False)): [max_token_index + 2],
        #     tuple(tokenizer.encode(IT, add_special_tokens=False)): [max_token_index + 3],
        #     tuple(tokenizer.encode(IS, add_special_tokens=False)): [max_token_index + 4],
        #     tuple(tokenizer.encode(IP, add_special_tokens=False)): [max_token_index + 5],
        #     tokenizer.encode(" " + ZWSP, add_special_tokens=False)[0] if len(
        #         tokenizer.encode(" " + ZWSP, add_special_tokens=False)) == 1 else tuple(
        #         tokenizer.encode(" " + ZWSP, add_special_tokens=False)): [max_token_index],
        #     tuple(tokenizer.encode(" " + ZWNJ, add_special_tokens=False)): [max_token_index + 1],
        #     tokenizer.encode(" " + ZWJ, add_special_tokens=False)[0] if len(
        #         tokenizer.encode(" " + ZWJ)) == 1 else tuple(
        #         tokenizer.encode(" " + ZWJ, add_special_tokens=False)): [max_token_index + 2],
        #     tuple(tokenizer.encode(" " + IT, add_special_tokens=False)): [max_token_index + 3],
        #     tuple(tokenizer.encode(" " + IS, add_special_tokens=False)): [max_token_index + 4],
        #     tuple(tokenizer.encode(" " + IP, add_special_tokens=False)): [max_token_index + 5],
        #     tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)[0] if len(
        #         tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)) == 1 else tuple(
        #         tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)): [max_token_index, max_token_index]
        # }
        self.ground_truth = load_ground_truth(os.path.join(args.data_path, "embedded_watermarks.txt"))

        self.chr_to_wtm = {}
        l = [ZWSP, ZWNJ, ZWJ, IT, IS, IP, NQSP, MQSP, ENSP, EMSP]
        for i in range(len(l)):
            self.chr_to_wtm[tuple(tokenizer.encode(l[i], add_special_tokens=False))] = [max_token_index+i] # when wtm is at the start of a sentence
            if i < 3:
                # ZWSP, ZWNJ, ZWJ only have 1 token
                self.chr_to_wtm[tokenizer.encode(l[i], add_special_tokens=False)[1]] = [max_token_index+i] # when wtm is at the end of something (no matter it's a word or a blank space)
            elif i < 6:
                # IT, IS, IP have 3 tokens
                self.chr_to_wtm[tuple(tokenizer.encode(l[i], add_special_tokens=False)[1:])] = [max_token_index+i] # when wtm is at the end of something (no matter it's a word or a blank space)
            elif i < 8:
                # NQSP, MQSP have 3 tokens
                self.chr_to_wtm[tuple(tokenizer.encode(l[i], add_special_tokens=False)[1:])] = [max_token_index+i]
            else:
                # ENSP, EMSP have 1 token
                self.chr_to_wtm[tokenizer.encode(l[i], add_special_tokens=False)[1]] = [max_token_index+i]

                
        self.wtm_to_chr = {
            max_token_index: tokenizer.encode(ZWSP, add_special_tokens=False)[1:],
            max_token_index + 1: tokenizer.encode(ZWNJ, add_special_tokens=False)[1:],
            max_token_index + 2: tokenizer.encode(ZWJ, add_special_tokens=False)[1:],
            max_token_index + 3: tokenizer.encode(IT, add_special_tokens=False)[1:],
            max_token_index + 4: tokenizer.encode(IS, add_special_tokens=False)[1:],
            max_token_index + 5: tokenizer.encode(IP, add_special_tokens=False)[1:],
            max_token_index + 6: tokenizer.encode(NQSP, add_special_tokens=False)[1:],
            max_token_index + 7: tokenizer.encode(MQSP, add_special_tokens=False)[1:],
            max_token_index + 8: tokenizer.encode(ENSP, add_special_tokens=False)[1:],
            max_token_index + 9: tokenizer.encode(EMSP, add_special_tokens=False)[1:],
        }
        
        # For Llama3 Tokenizer
        # for i in range(len(l)):
        #     self.chr_to_wtm[tuple(tokenizer.encode(l[i], add_special_tokens=False))] = [max_token_index+i] # when wtm is at the start of a sentence
        #     if i < 3:
        #         # ZWSP, ZWNJ, ZWJ only have 1 token
        #         self.chr_to_wtm[tokenizer.encode(l[i], add_special_tokens=False)[0]] = [max_token_index+i] # when wtm is at the end of something (no matter it's a word or a blank space)
        #     else:
        #         # IT, IS, IP have 2 tokens
        #         self.chr_to_wtm[tuple(tokenizer.encode(l[i], add_special_tokens=False))] = [max_token_index+i] # when wtm is at the end of something (no matter it's a word or a blank space)

        # self.chr_to_wtm[tokenizer.encode(ZWSP + ZWSP, add_special_tokens=False)[0]] = [max_token_index, max_token_index]
        # self.chr_to_wtm[tokenizer.encode(ZWNJ + ZWNJ, add_special_tokens=False)[0]] = [max_token_index+1, max_token_index+1]
        # self.chr_to_wtm[tokenizer.encode(" " + ZWSP, add_special_tokens=False)[0]] = [max_token_index]
        # self.chr_to_wtm[tokenizer.encode(" " + ZWNJ, add_special_tokens=False)[0]] = [max_token_index+1]
        
        # self.wtm_to_chr = {
        #     max_token_index: tokenizer.encode(ZWSP, add_special_tokens=False),
        #     max_token_index + 1: tokenizer.encode(ZWNJ, add_special_tokens=False),
        #     max_token_index + 2: tokenizer.encode(ZWJ, add_special_tokens=False),
        #     max_token_index + 3: tokenizer.encode(IT, add_special_tokens=False),
        #     max_token_index + 4: tokenizer.encode(IS, add_special_tokens=False),
        #     max_token_index + 5: tokenizer.encode(IP, add_special_tokens=False),
        # }
        
        self.truncate = 0
        self.bad_tokenize = 0
        # if args.one_watermark:
        self.one_watermark = args.one_watermark
        warmup_datasets = get_subdirectories(args.data_path)
        logger.info("Loading data from %s" % warmup_datasets)
        self.ground_truth_watermark = defaultdict(list)
        self.ground_truth_watermark_ori = defaultdict(list)
        for user, watermark in self.ground_truth.items():
            self.ground_truth_watermark[user] = self._encode_watermark(watermark)
            self.ground_truth_watermark_ori[user].append(self.tokenizer.encode(watermark, add_special_tokens=False))
            self.ground_truth_watermark_ori[user].append(self.tokenizer.encode(watermark, add_special_tokens=False)[1:])
            
        for warmup_dataset in warmup_datasets:
            is_cache, passages = load_cache_text_files_from_directory(warmup_dataset, args.seed, args.overwrite_cache,
                                                                      evaluate=self.evaluate)
            if is_cache:
                self.examples.extend(passages)
                self.example_user.extend([warmup_dataset] * len(passages))
            else:
                # cache data into buffer
                buffer = []
                for idx, passage in enumerate(passages):
                    buffer.extend(self._break_passage(passage))
                    if idx % 500 == 0:
                        logger.info("Processing %d/%d" % (idx, len(passages)))
                if args.one_watermark:
                    for i in range(len(buffer)):
                        buffer[i] = self._replace_random(buffer[i],
                                                         self.ground_truth[warmup_dataset.split('/')[-1]])
                if evaluate:
                    cache_data = os.path.join(warmup_dataset, "cache_valid.pkl")
                else:
                    cache_data = os.path.join(warmup_dataset, "cache_train.pkl")
                if os.path.exists(cache_data):
                    logger.info("Overwrite old cache data %s" % cache_data)
                    os.remove(cache_data)
                logger.info("Caching data into %s" % cache_data)
                with open(cache_data, 'wb') as file:
                    pickle.dump(buffer, file)
                self.examples.extend(buffer)
                self.example_user.extend([warmup_dataset.split('/')[-1]] * len(buffer))

        assert len(self.examples) == len(self.example_user)


    def _replace_except_middle(self, line, term):
        parts = line.split(term)

        # If there is no occurrence or only one occurrence, return line as is
        if len(parts) <= 2:
            return line

        # If there are multiple occurrences, keep only the middle one
        else:
            middle_index = math.ceil(len(parts) / 2)
            line = ''.join(parts[:middle_index] + [term] + parts[middle_index:])

        return line

    def _replace_random(self, line, term):
        parts = line.split(term)

        # If there is no occurrence or only one occurrence, return line as is
        if len(parts) <= 2:
            return line

        # If there are multiple occurrences, randomly keep one
        else:
            middle_index = random.choice(range(1, len(parts)))
            line = ''.join(parts[:middle_index] + [term] + parts[middle_index:])

        return line

    def _replace_tfidf(self, line, term):
        parts = line.split(term)

        # If there is no occurrence or only one occurrence, return line as is
        if len(parts) <= 2:
            return line

        # If there are multiple occurrences, keep only the middle one
        else:
            sentences = nltk.sent_tokenize(line)
            vectorizer = TfidfVectorizer()
            try:
                block_v = vectorizer.fit_transform(sentences)
            except:
                middle_index = random.choice(range(1, len(parts)))
                line = ''.join(parts[:middle_index] + [term] + parts[middle_index:])
                return line

            block_v_sum = np.sum(block_v, axis=1)
            sentence_indice = np.argsort(block_v_sum, axis=0)[::-1].flatten()

            line_with_wtm = []
            modified_line = ''
            for i in range(len(sentences)):
                if term in sentences[i]:
                    line_with_wtm.append(i)

            embed_index = find_highest_common_index(sentence_indice, line_with_wtm)

            for i in range(len(sentences)):
                if term in sentences[i]:
                    if i != embed_index:
                        sentences[i] = sentences[i].replace(term, '')
                modified_line += sentences[i]

        return modified_line
    
    def _dedupe_subseqs(self, subs):
        seen, out = set(), []
        for s in subs:
            t = tuple(s)
            if t not in seen:
                seen.add(t); out.append(t)
        return out

    def _find_all_subseqs_sp_optional(self, seq, patterns, sp_id):
        """
        Left→right greedy, prefer longest; the leading space marker (sp_id) is optional.
        Returns non-overlapping (start, length).
        """
        pats = set()
        for p in patterns:
            t = tuple(p)
            pats.add(t)
            if t and t[0] == sp_id:
                pats.add(t[1:])           # variant without SP
            else:
                pats.add((sp_id,) + t)    # allow match just after a space
        pats = sorted(self._dedupe_subseqs(pats), key=len, reverse=True)

        out, i, n = [], 0, len(seq)
        while i < n:
            for t in pats:
                L = len(t)
                if L and i + L <= n and tuple(seq[i:i+L]) == t:
                    out.append((i, L))
                    i += L
                    break
            else:
                i += 1
        return out

    def _break_passage(self, passage):
        """
        break passage into block-size raw sentence
        :param passage: long passage
        :return: list of sentence in block size
        """
        # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
        # If your dataset is small, first you should loook for a bigger one :-) and second you
        # can change this behavior by adding (model specific) padding.
        buffer = []
        encoded_tokens, exist_wtm, wtm_mask = self._custom_encode_new(passage)
        i = 0
        block_size = self.block_size - 1 # leave one for [WTM]
        while i < len(encoded_tokens):
            # Truncate in block of block_size
            # if watermark is at the end of the sentence, move all watermark to the start of next sentence
            end_pos = i + block_size
            wtm_mask_at_block_end = wtm_mask[min(i + self.block_size,len(encoded_tokens)) - WATERMARK_LEN: min(len(encoded_tokens),i + self.block_size)]
            if sum(wtm_mask_at_block_end) > 0:
                # contain watermark at the end of ten words
                if sum(wtm_mask_at_block_end) != WATERMARK_LEN:
                    if wtm_mask_at_block_end[0]:  # watermark truncate halfway
                        if not wtm_mask_at_block_end[-1]:
                            right_index = wtm_mask_at_block_end[::-1].index(1)
                            end_pos -= right_index
                        else:  # left watermark truncate and right watermark truncate
                            first_zero = wtm_mask_at_block_end.index(0)
                            end_pos -= (WATERMARK_LEN - first_zero)
                    else:
                        end_pos -= WATERMARK_LEN
                    self.truncate += 1
            cur_encoded_tokens = torch.tensor(encoded_tokens[i:end_pos])
            cur_wtm_mask = torch.tensor(wtm_mask[i:end_pos])
            encoded_sentence = self._custom_decode(cur_encoded_tokens.tolist())
            # if sum(wtm_mask_at_block_end) == 9:
            #     print(encoded_sentence)
            if cur_wtm_mask.sum() % WATERMARK_LEN != 0:
                self.bad_tokenize += 1
                logger.info(encoded_sentence.encode("unicode_escape").decode())
                logger.info(encoded_sentence)
            buffer.append(encoded_sentence)
            i = end_pos
        return buffer

    def _custom_encode(self, sentence):
        """
        Map normal zero-width characters to watermark tokens
        :param tokenizer: original tokenizer
        :param sentence: target sentence
        :return:
        encoded_tokens: list of encoded tokens (seq_len)
        exist_wtm: if contains watermark or not
        wtm_mask: mask where watermark tokens are 1 and others are 0 (seq_len)
        """
        tokens = self.tokenizer.encode(sentence)
        encoded_tokens = []
        i = 0
        exist_wtm = False
        wtm_position = []
        wtm_mask = []
        while i < len(tokens):
            # todo: may need to modify this in the future if using sentencepiece
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_mask.append(False)
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1])])
                wtm_mask.append(True)
                i += 2
                exist_wtm = True
            elif i < len(tokens) - 2 and (tokens[i], tokens[i + 1], tokens[i + 2]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_mask.append(False)
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1], tokens[i + 2])])
                wtm_mask.append(True)
                i += 3
                exist_wtm = True
            elif tokens[i] in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_mask.append(False)
                    wtm_position.append(len(encoded_tokens))
                if len(self.chr_to_wtm[tokens[i]]) == 1:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    wtm_mask.append(True)
                    i += 1
                elif len(self.chr_to_wtm[tokens[i]]) == 2:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    wtm_mask.append(True)
                    wtm_mask.append(True)
                    i += 1
                exist_wtm = True
            else:
                encoded_tokens.append(tokens[i])
                wtm_mask.append(False)
                i += 1

        try:
            assert len(encoded_tokens) == len(wtm_mask)
        except:
            import pdb;
            pdb.set_trace()
        return encoded_tokens, exist_wtm, wtm_mask

    def _custom_encode_new(self, sentence):
        
        # characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        special_character = '\u200b\u200c\u200d\u2062\u2063\u2064\u2000\u2001\u2002\u2003'
        
        if any(ext in sentence for ext in CHARACTERS):
            all_watermark = extract_special_substrings(sentence, special_character)
            try:
                user = get_keys_from_value(self.ground_truth, all_watermark[0])[0]
            except:
                logger.info("all_watermark: ", all_watermark)
                logger.info(get_keys_from_value(self.ground_truth, all_watermark[0]))
                user = list(self.ground_truth)[0]
                # breakpoint()
        else:
            user = list(self.ground_truth)[0] # assign to any user since no watermark in this sentence

        # user = user.split('/')[-1]

        tokens = self.tokenizer.encode(sentence)
        exist_wtm = False
        replacement_encodes = self.ground_truth_watermark_ori[user]
        replacement_list = self.ground_truth_watermark[user]

        def find_all_subsequences(seq, subsequences):
            start_indices = []
            for subseq in subsequences:
                i = 0
                while i < len(seq):
                    if seq[i:i + len(subseq)] == subseq:
                        start_indices.append((i, len(subseq)))
                        i += len(subseq)
                    else:
                        i += 1
            return sorted(start_indices, key=lambda x: x[0])
        
        sp_id = self.tokenizer.get_vocab().get("▁", 29871)
        start_indices = self._find_all_subseqs_sp_optional(tokens, replacement_encodes, sp_id)
        # start_indices = find_all_subsequences(tokens, replacement_encodes)
        
        encoded_tokens = tokens.copy()
        wtm_mask = [0] * len(encoded_tokens)
        # print(replacement_list)
        # print(self._custom_decode(replacement_list))
        

        for start_index, length in reversed(start_indices):
            # encoded_tokens = (encoded_tokens[:start_index] + [self.tokenizer.additional_special_tokens_ids[0]] +
            #                   replacement_list + encoded_tokens[start_index + length:])
            # wtm_mask = wtm_mask[:start_index] + [0] + [1] * len(replacement_list) + wtm_mask[start_index + length:]
            
            keep_sp = (start_index > 0 and encoded_tokens[start_index - 1] == sp_id)
        
            encoded_tokens = (
                encoded_tokens[:start_index]
                + ([sp_id] if keep_sp else [])                 # preserve the SP marker
                + [self.tokenizer.additional_special_tokens_ids[0]]
                + replacement_list
                + encoded_tokens[start_index + length:]
            )

            wtm_mask = (
                wtm_mask[:start_index]
                + ([0] if keep_sp else [])
                + [0] + [1] * len(replacement_list)
                + wtm_mask[start_index + length:]
            )
        
        if sum(wtm_mask) > 0:
            exist_wtm = True
            # if wtm_mask[-1]:
            #     print(wtm_mask)
        # print("!!!!!!!!!!!!!!!!!!!!!!")
        # print(exist_wtm)
        return encoded_tokens, exist_wtm, wtm_mask
    
    def _encode_watermark(self, watermark_str):
        """
        watermark string before costum encoding
        :param tokens: watermark tokens before custom_encoding
        :return:
        """
        encoded_tokens = []
        tokens = self.tokenizer.encode(watermark_str, add_special_tokens=False)
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) in self.chr_to_wtm:
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1])])
                i += 2
            elif i < len(tokens) - 2 and (tokens[i], tokens[i + 1], tokens[i + 2]) in self.chr_to_wtm:
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1], tokens[i + 2])])
                i += 3
            elif i < len(tokens) - 3 and (tokens[i], tokens[i + 1], tokens[i + 2], tokens[i + 3]) in self.chr_to_wtm:
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1], tokens[i + 2], tokens[i + 3])])
                i += 4
            elif tokens[i] in self.chr_to_wtm:
                encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                i += 1
                # if len(self.chr_to_wtm[tokens[i]]) == 1:
                #     encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                #     i += 1
                # elif len(self.chr_to_wtm[tokens[i]]) == 2:
                #     encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                #     i += 1
            else:
                print(f"invalid watermark string  of {watermark_str} encoded as {tokens} at position {i}")
                encoded_tokens.append(tokens[i])
                i += 1
        return encoded_tokens

    def _custom_decode(self, decoded_tokens, skip_special_tokens=True):
        """
        Map watermark tokens to normal zero-width characters
        :param tokenizer: original tokenizer
        :param decoded_tokens: list of tokens
        :return:
        decoded_text: decoded text
        """
        decoded_tokens = [self.wtm_to_chr.get(token, [token]) for token in decoded_tokens]
        flattened_tokens = [token for sublist in decoded_tokens for token in sublist]
        decoded_text = self.tokenizer.decode(flattened_tokens, skip_special_tokens=skip_special_tokens)
        return decoded_text

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        # encoded_tokens, exist_wtm, wtm_mask = self._custom_encode(self.examples[i])
        encoded_tokens, exist_wtm, wtm_mask = self._custom_encode_new(self.examples[i])
        encoded_tokens = torch.tensor(encoded_tokens[:self.block_size], dtype=torch.long)
        wtm_mask = torch.tensor(wtm_mask[:self.block_size], dtype=torch.bool)
        attention_mask = torch.ones_like(encoded_tokens, dtype=torch.long)
        if self.one_watermark:
            bad_watermark = wtm_mask.sum() != WATERMARK_LEN and exist_wtm
        else:
            bad_watermark = wtm_mask.sum() % WATERMARK_LEN != 0 and exist_wtm
        if bad_watermark:
            string = self.examples[i]
            print(string.encode("unicode_escape").decode())
            print(encoded_tokens)
            # import pdb; pdb.set_trace()
        return {"input_ids": F.pad(encoded_tokens, (0, self.block_size - encoded_tokens.size(0)),
                                   value=self.tokenizer.pad_token_id),
                "attention_mask": F.pad(attention_mask, (0, self.block_size - len(encoded_tokens))),
                "exist_wtm": torch.tensor(exist_wtm),
                "wtm_mask": F.pad(wtm_mask, (0, self.block_size - len(encoded_tokens)), value=False)}


def load_and_cache_examples(args, tokenizer, evaluate=False):
    """
    Define tokenizer function so that it
    1. add watermark tokens (done in _custom_encode)
    2. add T/F indicate if there is watermark (done in _custom_encode)
    3. specify the position of watermark: make it a mask format (done in _custom_encode)
    4. truncate when needed
    :param args:
    :param tokenizer:
    :param evaluate: if return valid dataset
    :return: processed input in dataset format
    """

    dataset = watermarkDataset(args, tokenizer, evaluate)
    return dataset


class watermarkPLM(torch.nn.Module):
    def __init__(self, config_class, model_class, seed, vocab_size, watermark_size, pad_token_id=50257,
                 model_type='gpt2-large', freeze_layers=12, model_name_or_path=None):
        torch.manual_seed(seed)
        super(watermarkPLM, self).__init__()
        self.watermark_size = watermark_size
        self.vocab_size = vocab_size + watermark_size
        self.pad_token_id = pad_token_id
        self.config = config_class.from_pretrained(model_type, gradient_checkpointing=True)
        if model_name_or_path:
            logger.info("Loading model from {}".format(model_name_or_path))
            self.base_model = model_class.from_pretrained(model_type,config=self.config)
            self.base_model.resize_token_embeddings(self.vocab_size)
            self.base_model = PeftModel.from_pretrained(self.base_model, model_name_or_path)
            self.base_model = self.base_model.merge_and_unload()
            # import pdb; pdb.set_trace()
        else:
            self.base_model = model_class.from_pretrained(model_type, config=self.config)
            self.base_model.resize_token_embeddings(self.vocab_size)
        
        self.base_model.gradient_checkpointing_enable()
        # Freeze the first couple of layers
        print(model_type)
        if 'gpt' in model_type:
            for idx, block in enumerate(self.base_model.transformer.h[:freeze_layers]):
                for param in block.parameters():
                    param.requires_grad = False
        elif 'opt' in model_type:
            for idx, block in enumerate(self.base_model.model.decoder.layers[:freeze_layers]):
                for param in block.parameters():
                    param.requires_grad = False
        elif 'llama' in model_type:
            for idx, block in enumerate(self.base_model.model.layers[:freeze_layers]):
                for param in block.parameters():
                    param.requires_grad = False
        
        if not model_name_or_path: # means during evaluation
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.05, bias="none",
                target_modules=find_all_linear_names(self.base_model),
                # target_modules = ["embed_tokens", "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
                # modules_to_save = ["lm_head"]
            )
            self.base_model = get_peft_model(self.base_model, peft_config)
            self.base_model.print_trainable_parameters()

    def _language_model_shift_loss(self, logits, labels):
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss(reduction='sum', ignore_index=self.pad_token_id)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        return loss

    def _language_model_nonshift_loss(self, logits, labels):
        # watermark no needs to shift
        loss_fct = CrossEntropyLoss(reduction='sum', ignore_index=self.pad_token_id)
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return loss

    def forward(self, input_ids, exist_wtm, wtm_mask, labels=None, attention_mask=None, fp16=False, scaler=None):
        """

        :param input_ids: (bsz, seq_len)
        :param exist_wtm: (bsz) specify if the input contains watermark which need further process
        :param wtm_mask: (bsz, seq_len) specify the position of watermark tokens which is True and others are False
        :param attention_mask:
        :return:
        """
        # logger.info(f"original loss {self.base_model(input_ids, attention_mask=attention_mask, labels=labels).loss}")

        block_size = input_ids.shape[1]
        loss_lm = torch.tensor(0).float().to(input_ids.device)
        loss_wtm = torch.tensor(0).float().to(input_ids.device)
        samples_num_wtm = wtm_mask.sum().item()
        samples_num_lm = attention_mask.sum().item() - samples_num_wtm
        if fp16:
            loss_lm = scaler.scale(loss_lm)
            loss_wtm = scaler.scale(loss_wtm)
            loss_lm_unscaled = torch.tensor(0).float().to(input_ids.device)
            loss_wtm_unscaled = torch.tensor(0).float().to(input_ids.device)

        labels = labels.to(input_ids.device)
        if fp16:
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                # start_time = time.time()
                logits = self.base_model(input_ids, attention_mask=attention_mask).logits  # (bsz, seq_len, vocab_size)
                # loss for normal language model
                # loss for bsz where watermark does not exist
                logits_lm = logits[~exist_wtm][:, :,
                            : -self.watermark_size]  # (bsz_without_watermark, seq_len, vocab_size)
                labels_lm = labels[~exist_wtm]
                
                loss_lm += self._language_model_shift_loss(logits_lm, labels_lm)
                loss_lm_unscaled += loss_lm.item()
                # loss_lm_unscaled += loss_lm
                # end_time = time.time()
                # logger.info(f"lm forward time {end_time - start_time}")
            loss_lm = scaler.scale(loss_lm)
            # end_time2 = time.time()
            # logger.info(f"lm backward time {end_time2 - end_time}")
        else:
            logits = self.base_model(input_ids, attention_mask=attention_mask).logits  # (bsz, seq_len, vocab_size)
            # loss for normal language model
            # loss for bsz where watermark does not exist
            logits_lm = logits[~exist_wtm][:, :, : -self.watermark_size]  # (bsz_without_watermark, seq_len, vocab_size)
            labels_lm = labels[~exist_wtm]
            loss_lm += self._language_model_shift_loss(logits_lm, labels_lm)
        logits_gd = torch.autograd.grad(loss_lm, logits_lm)[0] / samples_num_lm
        logits_gd = F.pad(logits_gd, (0, self.watermark_size))

        if bool(torch.any(exist_wtm)) is False:
            if fp16:
                loss_lm = loss_lm_unscaled
                loss_wtm = loss_wtm_unscaled
            return loss_lm / samples_num_lm, loss_wtm / samples_num_wtm, logits, logits_gd

        # logger.info(f"Allocated before non_wtm: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
        # loss for normal language model
        # loss for bsz where watermark exists
        wtm_mask = wtm_mask[exist_wtm][:, 1:]  # (bsz_with_watermark, seq_len-1)
        # Shift so that tokens < n predict n; also align the position of watermark
        logits_lm = logits[exist_wtm][:, :-1, :-self.watermark_size]  # (bsz_with_watermark, seq_len, vocab_size)
        labels_lm = labels[exist_wtm][:, 1:]
        i = 0
        non_wtm_grad = []  # sequence where not have watermark: later will be concat with only watermark gradient
        for logit_tmp, label_tmp in zip(logits_lm, labels_lm):
            current_wtm_mask = wtm_mask[i]  # (1, seq-watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[~current_wtm_mask]
            label_tmp = label_tmp[~current_wtm_mask]
            if fp16:
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    loss_cur_lm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
                loss_lm_unscaled += loss_cur_lm
                loss_cur_lm = scaler.scale(loss_cur_lm)
            else:
                loss_cur_lm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
            loss_lm += loss_cur_lm
            grad_cur_lm = torch.autograd.grad(loss_cur_lm, logit_tmp)[0]
            padded_tensor = torch.nn.functional.pad(grad_cur_lm, (0, self.watermark_size)) / samples_num_lm
            non_wtm_grad.append(padded_tensor)
            i += 1

        # logger.info(f"Allocated before wtm: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
        # end_time3 = time.time()
        # logger.info(f"non wtm forward time {end_time3 - end_time2}")
        # loss for watermark
        logits_lm = logits[exist_wtm][:, :-1, -self.watermark_size:]
        labels_lm = labels[exist_wtm][:, 1:] - (
                self.vocab_size - self.watermark_size)  # (bsz_with_watermark, seq_len, watermark_hidden_size)
        i = 0
        wtm_grad = []  # concat watermark gradients and non-watermark gradients so that it's of seq_len
        for logit_tmp, label_tmp in zip(logits_lm, labels_lm):
            current_wtm_mask = wtm_mask[i]  # (1, watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[current_wtm_mask]
            label_tmp = label_tmp[current_wtm_mask]
            if fp16:
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    loss_cur_wtm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
                loss_wtm_unscaled += loss_cur_wtm
                loss_cur_wtm = scaler.scale(loss_cur_wtm)
            else:
                loss_cur_wtm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
            loss_wtm += loss_cur_wtm
            grad_cur_wtm = torch.autograd.grad(loss_cur_wtm, logit_tmp)[0]
            padded_tensor = torch.nn.functional.pad(grad_cur_wtm, (self.vocab_size - self.watermark_size, 0))
            padded_tensor = padded_tensor / samples_num_wtm
            # rearrange gradients to that it aligns with original sequence
            grad_align = torch.zeros(1, block_size - 1, self.vocab_size).to(input_ids.device)
            expanded_mask = wtm_mask[i].unsqueeze(0).unsqueeze(-1).expand(1, block_size - 1,
                                                                          self.vocab_size).to(
                input_ids.device)
            assert grad_align.shape == expanded_mask.shape
            if fp16:
                grad_align = grad_align.to(torch.float16)
                grad_align[expanded_mask] = padded_tensor.to(torch.float16).view(-1)
                grad_align[~expanded_mask] = non_wtm_grad[i].to(torch.float16).view(-1)
            else:
                grad_align[expanded_mask] = padded_tensor.view(-1)
                grad_align[~expanded_mask] = non_wtm_grad[i].view(-1)
            wtm_grad.append(grad_align)  # (1, seq_len(real words), vocab_size)
            wtm_grad[-1] = F.pad(wtm_grad[-1], (0, 0, 0, block_size - wtm_grad[-1].size(1)))
            i += 1
        wtm_grad = torch.stack(wtm_grad, dim=0).squeeze(1)  # (bsz_has_watermark, seq_len, vocab_size)
        # end_time4 = time.time()
        # logger.info(f"wtm forward time {end_time4 - end_time3}")

        if logits_gd.size(0) != 0:
            # align non-watermark batch and watermark batch
            logits_final = torch.empty(exist_wtm.shape[0], *logits_gd.shape[1:], dtype=logits_gd.dtype).to(
                input_ids.device)
            if fp16:
                logits_final = logits_final.to(torch.float16)
                logits_gd = logits_gd.to(torch.float16)
            logits_final[exist_wtm] = wtm_grad.view(-1, wtm_grad.shape[1], wtm_grad.shape[2])
            logits_final[~exist_wtm] = logits_gd.view(-1, logits_gd.shape[1], logits_gd.shape[2])
            logits_gd = logits_final
        else:  # (bsz, seq_len, vocab_size)
            logits_gd = wtm_grad
        # logger.info(f"Allocated before return: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")

        if fp16:
            loss_lm = loss_lm_unscaled
            loss_wtm = loss_wtm_unscaled
        return loss_lm / samples_num_lm, loss_wtm / samples_num_wtm, logits, logits_gd

    def evaluate(self, input_ids, exist_wtm, wtm_mask, labels=None, attention_mask=None):
        loss_lm = torch.tensor(0).float().to(input_ids.device)
        loss_wtm = torch.tensor(0).float().to(input_ids.device)
        samples_num_wtm = wtm_mask.sum().item()
        samples_num_lm = attention_mask.sum().item() - samples_num_wtm

        logits = self.base_model(input_ids, attention_mask=attention_mask).logits  # (bsz, seq_len, vocab_size)
        # loss for normal language model
        # loss for bsz where watermark does not exist
        logits_lm = logits[~exist_wtm][:, :, :-self.watermark_size]  # (bsz_without_watermark, seq_len, vocab_size)
        labels_lm = labels[~exist_wtm]
        loss_lm += self._language_model_shift_loss(logits_lm, labels_lm)

        # loss for normal language model
        # loss for bsz where watermark exists
        wtm_mask = wtm_mask[exist_wtm][:, 1:]  # (bsz_with_watermark, seq_len-1)
        # Shift so that tokens < n predict n; also align the position of watermark
        logits_lm = logits[exist_wtm][:, :-1, :-self.watermark_size]  # (bsz_with_watermark, seq_len, vocab_size)
        labels_lm = labels[exist_wtm][:, 1:]
        i = 0
        for logit_tmp, label_tmp in zip(logits_lm, labels_lm):
            current_wtm_mask = wtm_mask[i]  # (1, seq-watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[~current_wtm_mask]
            label_tmp = label_tmp[~current_wtm_mask]
            loss_lm += self._language_model_nonshift_loss(logit_tmp, label_tmp)
            i += 1

        # loss for watermark
        logits_tmp = logits[exist_wtm][:, :-1, -self.watermark_size:]
        labels_tmp = labels[exist_wtm][:, 1:] - (
                self.vocab_size - self.watermark_size)  # (bsz_with_watermark, seq_len, watermark_hidden_size)
        i = 0
        for logit_tmp, label_tmp in zip(logits_tmp, labels_tmp):
            current_wtm_mask = wtm_mask[i]  # (1, watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[current_wtm_mask]
            label_tmp = label_tmp[current_wtm_mask]
            loss_wtm += self._language_model_nonshift_loss(logit_tmp, label_tmp)
            i += 1

        # logger.info(f"loss_wtm: {loss_wtm}, samples_num_wtm: {samples_num_wtm}")
        return loss_lm / samples_num_lm, loss_wtm / samples_num_wtm


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True


def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
    ordering_and_checkpoint_path = []

    glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))

    for path in glob_checkpoints:
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
    return checkpoints_sorted


def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
    if not args.save_total_limit:
        return
    if args.save_total_limit <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
    if len(checkpoints_sorted) <= args.save_total_limit:
        return

    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)


def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float, dict]:
    """
    :param args:
    :param train_dataset:
    :param model:
    :param tokenizer:
    :return:
    """
    """ Train the model """
    if getattr(args, 'local_rank') in [-1, 0]:
        tb_writer = SummaryWriter()

    logger.info(
        f"Just entering training with memory: {torch.cuda.max_memory_allocated(0) / 1024 ** 3} GB")
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    train_sampler = RandomSampler(train_dataset) if getattr(args, 'local_rank') == -1 else DistributedSampler(
        train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
    
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    logger.info(f"Before optimizer: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
    # Prepare optimizer and schedule (linear warmup and decay)
    if not args.eight_bit_adam:
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": args.weight_decay,
            },
            {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
             "weight_decay": 0.0},
        ]

        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    else:
        decay_parameters = get_parameter_names(model, [nn.LayerNorm])
        decay_parameters = [name for name in decay_parameters if "bias" not in name]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if n in decay_parameters],
                "weight_decay": args.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
                "weight_decay": 0.0,
            },
        ]
        optimizer_kwargs = {"betas": (args.adam_beta1, args.adam_beta2),
                            "eps": args.adam_epsilon}
        optimizer_kwargs["lr"] = args.learning_rate
        optimizer = bnb.optim.Adam8bit(
            optimizer_grouped_parameters,
            betas=(args.adam_beta1, args.adam_beta2),
            eps=args.adam_epsilon,
            lr=args.learning_rate,
        )
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    if (
            args.model_name_or_path
            and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    scaler = None
    if args.fp16:
        try:
            from torch.cuda.amp import GradScaler
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        # model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
        scaler = GradScaler()

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    logger.info(
        f"Before distributed data parallel cached: {torch.cuda.memory_reserved(0) / 1024 ** 3} GB")
    # Distributed training (should be after apex fp16 initialization)
    if getattr(args, 'local_rank') != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[getattr(args, 'local_rank')],
            output_device=getattr(args, 'local_rank'),
            find_unused_parameters=False)
        # model = FullyShardedDataParallel(model)
    logger.info(f"Allocated: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if getattr(args, 'local_rank') != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # import sys; sys.exit()
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to global_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    tr_loss_lm, tr_loss_wtm = 0.0, 0.0
    tr_loss_lm_list = []
    tr_loss_wtm_list = []
    step_list = []

    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=getattr(args, 'local_rank') not in [-1, 0]
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=getattr(args, 'local_rank') not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            if (step + 1) % 10 == 0 and torch.cuda.is_available():
                with torch.cuda.device(torch.cuda.current_device()):
                    torch.cuda.empty_cache()
                gc.collect()

            # logger.info(f"Before inputs to device: {torch.cuda.memory_allocated(getattr(args, 'local_rank'))/ 1024 ** 3} GB")
            inputs, labels = (batch, batch)
            input_ids = inputs['input_ids'].to(args.device)
            # print("input length: ", input_ids.size())
            exist_wtm = inputs['exist_wtm'].to(args.device)
            wtm_mask = inputs['wtm_mask'].to(args.device)
            attention_mask = inputs['attention_mask'].to(args.device)
            labels = labels['input_ids'].to(args.device)
            model.train()

            # logger.info(f"after inputs to device: {torch.cuda.memory_allocated(getattr(args, 'local_rank'))/ 1024 ** 3} GB")

            loss_lm, loss_wtm, logits, logits_gd = model(input_ids, exist_wtm, wtm_mask,
                                                         labels=labels, attention_mask=attention_mask,
                                                         fp16=args.fp16, scaler=scaler)
            loss = loss_lm + loss_wtm
            logger.info(f"loss_lm: {loss_lm.item()}, loss_wtm: {loss_wtm.item()}, loss: {loss.item()}")
            # print(f"GPU Usage: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
                loss_lm = loss_lm.mean()
                loss_wtm = loss_wtm.mean()
                logits_gd = logits_gd / args.n_gpu
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                loss_lm = loss_lm / args.gradient_accumulation_steps
                loss_wtm = loss_wtm / args.gradient_accumulation_steps
                logits_gd = logits_gd / args.gradient_accumulation_steps
            
            # # loss.requires_grad = False
            # loss_lm.requires_grad = True
            # loss_wtm.requires_grad = False
            # # logits.requires_grad = False
            # logits_gd.requires_grad = False
            # print('logits_gd: ', logits_gd)
            logits.backward(logits_gd)
            # logger.info(f"loss: {loss_lm}")
            # scaler.scale(loss_lm).backward()
            del logits, logits_gd, input_ids, exist_wtm, wtm_mask, attention_mask, labels

            tr_loss += loss.item()
            tr_loss_lm += loss_lm.item()
            tr_loss_wtm += loss_wtm.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:

                tr_loss_lm_list.append(tr_loss_lm)
                tr_loss_lm = 0.0
                tr_loss_wtm_list.append(tr_loss_wtm)
                tr_loss_wtm = 0.0
                step_list.append(global_step)

                if args.fp16:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
                    # although it still skips optimizer.step() if the gradients contain infs or NaNs.
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if getattr(args, 'local_rank') in [-1,
                                                   0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            getattr(args, 'local_rank') == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss

                if getattr(args, 'local_rank') in [-1,
                                                   0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    # model_to_save.save_pretrained(output_dir)
                    model_to_save.base_model.save_pretrained(args.output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if getattr(args, 'local_rank') in [-1, 0]:
        tb_writer.close()

    loss_track = {}
    loss_track['tr_loss_lm'] = tr_loss_lm_list
    loss_track['tr_loss_wtm'] = tr_loss_wtm_list
    loss_track['step'] = step_list

    return global_step, tr_loss / global_step, loss_track


def evaluate(args, model, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir
    eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)

    if getattr(args, 'local_rank') in [-1, 0]:
        os.makedirs(eval_output_dir, exist_ok=True)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # multi-gpu evaluate
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    eval_lm_loss = 0.0
    nb_eval_steps = 0
    # model.train()
    model.eval()

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs, labels = (batch, batch)
        inputs['input_ids'] = inputs['input_ids'].to(args.device)
        inputs['exist_wtm'] = inputs['exist_wtm'].to(args.device)
        inputs['wtm_mask'] = inputs['wtm_mask'].to(args.device)
        inputs['attention_mask'] = inputs['attention_mask'].to(args.device)
        labels = labels['input_ids'].to(args.device)

        with torch.no_grad():
            loss_lm, loss_wtm = model.evaluate(inputs['input_ids'], inputs['exist_wtm'], inputs['wtm_mask'],
                                               labels=labels, attention_mask=inputs['attention_mask'])
            logger.info(f"loss_wtm using evaluate func= {loss_wtm.mean().item()}")
            logger.info(f"loss_lm using evaluate func= {loss_lm.mean().item()}")
            loss = loss_lm + loss_wtm
            eval_loss += loss.mean().item()
            eval_lm_loss += loss_lm.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

    result = {"perplexity": perplexity,
              "perplexity_lm": torch.exp(torch.tensor(eval_lm_loss / nb_eval_steps)),
              "loss": eval_loss,
              "loss_lm": eval_lm_loss / nb_eval_steps}

    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result


###################### Generation related function #################
def top_k_top_p_filtering(
        logits: torch.Tensor,
        top_k: int = 0,
        top_p: float = 1.0,
        filter_value: float = -float("Inf"),
        min_tokens_to_keep: int = 1,
) -> torch.Tensor:
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        topk_logits = torch.topk(logits, top_k)
        indices_to_remove = logits < topk_logits[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits


def enforce_repetition_penalty_(lprobs, num_beams, prev_output_tokens, repetition_penalty):
    """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
    for i in range(num_beams):
        for previous_token in set(prev_output_tokens[i].tolist()):
            # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
            # TODO: previous_token may include watermarks, which are larger than 50258.
            # this is a naive way to solve the bug. May have better ways
            if previous_token > 50258:
                continue
            else:
                if lprobs[i, previous_token] < 0:
                    lprobs[i, previous_token] *= repetition_penalty
                else:
                    lprobs[i, previous_token] /= repetition_penalty
    return lprobs


def generate_watermark(seq, model, tokenizer, device):
    seq = torch.cat((seq, torch.tensor([[tokenizer.watermark_token]]).to(device)), dim=1)
    for _ in range(WATERMARK_LEN):
        with torch.no_grad():
            logits = model.base_model(seq).logits[:, -1, -WATERMARK_EMB:]
        watermark_token = torch.argmax(logits, dim=-1) + (model.vocab_size - model.watermark_size)
        seq = torch.cat([seq, watermark_token.unsqueeze(0)], dim=-1)
    seq = seq.squeeze(0).tolist()
    generated_text = tokenizer.custom_decode(seq)
    return generated_text

def balanced_end_indices(L: int, k: int) -> list[int]:
    base, r = L // k, L % k
    sizes = [base + 1] * r + [base] * (k - r)
    ends, s = [], 0
    for sz in sizes:
        s += sz
        ends.append(s)
    return ends

import torch

bad_forget=['1', '2', '6', '21', '26', '28', '39', '41', '75', '101', '121', '122', '127', '130', '131', '139', '143', '146', '148', '155', '160', '164', '181', '192', '193', '194', '196', '200', '205', '208', '231', '239', '244', '248', '250', '253', '254', '261', '262', '264', '272', '280', '282', '288', '291', '293', '294', '296', '301', '305', '306', '309', '310', '312', '319', '321', '323', '329', '331', '335', '337', '339', '340', '341', '343', '345', '348', '349', '352', '356', '357', '361', '362', '363', '364', '365', '366', '370', '373', '374', '377', '378', '380', '384', '385', '386', '389', '392', '394', '396', '397', '398', '399']
good_forget=['0', '4', '5', '8', '9', '10', '13', '14', '15', '16', '18', '19', '20', '22', '23', '24', '30', '31', '33', '35', '36', '37', '38', '40', '46', '48', '49', '51', '53', '54', '55', '56', '57', '58', '60', '63', '65', '67', '71', '72', '74', '76', '78', '79', '83', '85', '86', '87', '88', '90', '91', '94', '96', '97', '99', '100', '102', '103', '105', '106', '108', '109', '112', '113', '116', '117', '118', '119', '120', '129', '132', '134', '135', '136', '138', '140', '141', '145', '147', '149', '151', '152', '157', '158', '159', '166', '167', '168', '172', '174', '175', '176', '178', '179', '183', '186', '187', '188', '190', '199', '211', '212', '214', '216', '217', '220', '222', '224', '225', '226', '233', '240', '249', '251', '255', '256', '258', '259', '260', '263', '270', '274', '275', '276', '277', '279', '281', '287', '290', '292', '298', '299', '303', '304', '308', '314', '315', '316', '317', '320', '325', '327', '332', '334', '346', '350', '354', '355', '358', '360', '376', '379', '381', '383', '387', '391', '393']
def reencode_ids(batch_inp, device):
    tokenizer_dir = "meta-llama/Llama-2-7b-chat-hf"    
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
    tokenizer.padding_side = "left"
    # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    new_tokenizer = AutoTokenizer.from_pretrained("reward_model/TOFU/llama/embedding_model_nli", use_fast=True)
    texts = tokenizer.batch_decode(batch_inp, skip_special_tokens=True)
    
    enc = new_tokenizer(
        texts,
        padding=True,
        truncation=False,
        # max_length=max_len,
        return_tensors="pt",
        add_special_tokens=True,
    )

    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)
    return input_ids, attention_mask
def generate_watermark_beam(
    seq, 
    model, 
    device,
    forget_label=None,
    mode="forget",
    alpha_retain=1.0,
    k_slices=15,     
    class_num=10,
    owner=-1,            
    C_seg_p=1.05,           
):
    # breakpoint()

    assert forget_label is not None, "forget_label 不能为空"

    L = int(seq.size(1))
    assert L >= 2, "序列太短"
    k = int(k_slices)

    prefix_list = [seq[:, : max(1, int((i + 1) * L / k))] for i in range(k)]
    len_list = [p.size(1) for p in prefix_list]
    # wtm
    # for i in range(0,len(prefix_list)):
    #     prefix_list[i] = torch.cat([prefix_list[i], torch.tensor([[32001]], device='cuda:0')], dim=1)
    # breakpoint()
    topk_probs = torch.zeros((L,), device=device, dtype=torch.float32)
    
    # target_idx = owner.item()
    target_idx = forget_label
    # target_idx = -1
    # if owner.item() == 1:
    #     target_idx = 8
    # elif owner.item() == 2:
    #     target_idx = 9

    # pad_id = 2
    pad_id = model.config.pad_token_id

    def _prob_penalty(p, C):
        p = p.clamp(1e-6, 1.0 - 1e-6)
        r = (-torch.log(1.0 - p)) / C
        return -r.clamp(0.0, 1.0)
    
    with torch.no_grad():
        batch_inp = [p for p in prefix_list]
        batch_inp = torch.nn.utils.rnn.pad_sequence(
            [x.squeeze(0) for x in batch_inp],
            batch_first=True,
            padding_value=(pad_id if pad_id is not None else 0),
        ).to(device)
        attn_mask = (batch_inp != (pad_id if pad_id is not None else 0))
        # batch_inp, attn_mask = reencode_ids(batch_inp, device)
        out = model(batch_inp, attention_mask=attn_mask)

        # DareU
        wm_logits_batch = out.logits[:, -class_num:]
        
        # wtm
        # wm_logits_batch = out.logits[:, -1, -class_num:]
    
    # DareU
    rewards = [None] * k
    
    # DareU
    for i in range(k):
        z_raw = wm_logits_batch[i]     
        # if target_idx == -1: 
        #     p_t   = max(torch.softmax(z_raw, dim=0)[8], torch.softmax(z_raw, dim=0)[9])  
        # else:
        p_t   = torch.softmax(z_raw, dim=0)[target_idx] 

        if mode == "forget":
            r_p = _prob_penalty(p_t, C_seg_p)                
            r = r_p
        else:
            r = alpha_retain * (1.0 - p_t)

        rewards[i] = r
    # breakpoint()
    # # Binary
    # for i in range(k):
    #     if mode == "forget":
    #         rewards[i] = -1
    #     else:
    #         rewards[i] = 0
    # rewards = [0.0] * k

    # if mode == "forget":
    #     rewards[-1] = -1
    # else:
    #     rewards[-1] = 0

    for i in range(k):
        start = 0 if i == 0 else len_list[i - 1]
        end   = len_list[i]
        topk_probs[start:end] = rewards[i]
    # breakpoint()
    # DareU
    denom = torch.arange(1, L + 1, device=topk_probs.device, dtype=topk_probs.dtype)
    fbn_topk_probs = torch.cumsum(topk_probs, dim=0) / denom
    # breakpoint()
    return fbn_topk_probs, topk_probs

    # return topk_probs


def generate_with_beam_search_sample(args, model, tokenizer, input_ids, device, topk=False,
                                     beam_size=6, max_length=100, repetition_penalty=1.5, temperature=0.8):
    """
    :param input_ids: list of token ids
    :return:
    """
    model.to(device)
    model.eval()
    input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
    if args.generate_watermark_classification:
        ground_truth_file = args.data_path + '/embedded_watermarks.txt'
        watermark_list = []

        with open(ground_truth_file) as file:
            for line in file:
                cur_watermark = []
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    cur_watermark.append(int(idx))
                watermark_list.append(cur_watermark)

    def _get_initial_hypotheses(input_ids, beam_size=beam_size, scores=None):
        """
        :param input_ids: tensor of shape (1, seq_len)
        :param scores: tensor of shape (1)
        :return:
        """
        # Generate initial hypotheses
        with torch.no_grad():
            logits = model.base_model(input_ids).logits[:, -1, :-WATERMARK_EMB]
        probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
        # Sort the tensor in descending order
        new_scores, topk_indices = torch.sort(probabilities, descending=True)
        new_scores = new_scores.view(-1)[:beam_size]
        if scores:
            scores = scores.unsqueeze(1).expand_as(new_scores) + new_scores
        else:
            scores = new_scores
        topk_indices = topk_indices.view(-1)[:beam_size]
        seq = torch.cat((input_ids.repeat(beam_size, 1), topk_indices.view(-1, 1)), dim=1)
        return seq, scores

    input_ids, beam_scores = _get_initial_hypotheses(input_ids, beam_size=beam_size)
    enter_watermark = False
    sequences = input_ids.clone()
    final_sequences = []

    for _ in range(max_length):
        with torch.no_grad():
            logits = model.base_model(sequences).logits[:, -1, :-WATERMARK_EMB]

        scores = F.log_softmax(logits, dim=-1)
        scores = scores / temperature
        scores = enforce_repetition_penalty_(scores, beam_size, sequences, repetition_penalty)

        scores = top_k_top_p_filtering(scores, top_k=10 * beam_size)

        next_tokens = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2 * beam_size)
        next_scores = torch.gather(scores, 1, next_tokens)
        beam_scores = beam_scores.unsqueeze(1) + next_scores

        beam_scores, indices = beam_scores.view(-1).topk(beam_size, largest=True, sorted=False)
        next_tokens = next_tokens.view(-1).index_select(0, indices)

        is_watermark = next_tokens == tokenizer.watermark_token
        is_eos = next_tokens == tokenizer.tokenizer.pad_token_id

        chosen_scores = next_scores.view(-1).index_select(0, indices)
        watermark_score = chosen_scores[is_watermark]
        if len(watermark_score) > 0:
            logger.info(f'watermark token score: {watermark_score}')

        if is_watermark.any() or is_eos.any():
            enter_watermark = True
            if is_watermark.any():
                sequences_tmp = []
                beam_scores_tmp = []
                indices_watermark = torch.nonzero(is_watermark, as_tuple=True)[0]
                for idx in indices_watermark:
                    seq = sequences[idx, :].unsqueeze(0)
                    if args.generate_watermark_classification:
                        seq = generate_watermark_classification(seq, model, tokenizer, device, watermark_list, topk=1,
                                                                temperature=0.8, return_text=False)
                    else:
                        seq = generate_watermark_beam(seq, model, tokenizer, device, 
                                                    topk, return_text=False)
                    sequences_tmp.append(seq)
                    beam_scores_tmp.append(beam_scores[idx].item())
                max_score_index = torch.argmax(torch.tensor(beam_scores_tmp)).item()
                sequences = sequences_tmp[max_score_index]
                # beam_scores = torch.tensor(beam_scores_tmp[max_score_index])
                # just one value does not make much difference
                try:
                    sequences, beam_scores = _get_initial_hypotheses(sequences.unsqueeze(0), beam_size=beam_size)
                except:
                    import pdb;
                    pdb.set_trace()
            elif is_eos.any():
                print("enter padding again and again and again")
                indices_eos = torch.nonzero(is_eos, as_tuple=True)[0]
                for idx in indices_eos:
                    seq = sequences[idx, :].unsqueeze(0)
                    seq = generate_watermark_beam(seq, model, tokenizer, device,  topk,
                                                 return_text=False)
                    seq = torch.cat((seq, torch.tensor([tokenizer.tokenizer.pad_token_id]).to(device)), dim=0)
                    final_sequences.append((beam_scores[idx].item(), seq))
                    beam_scores[idx] = -float("inf")
                sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=-1)
        else:
            sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=-1)

    final_sequences += [(score.item(), seq) for score, seq in zip(beam_scores, sequences)]
    final_sequences.sort(key=lambda x: x[0], reverse=True)
    final_sequences = [(score, seq) for score, seq in final_sequences]
    best_sequences = final_sequences[0][1]

    if not enter_watermark:
        logger.info("force generating watermark")
        if args.generate_watermark_classification:
            best_sequences = generate_watermark_classification(best_sequences.unsqueeze(0), model, tokenizer, device,
                                                               watermark_list, topk=1, temperature=0.8,
                                                               return_text=False)
        else:
            best_sequences= generate_watermark_beam(best_sequences.unsqueeze(0), model, tokenizer,
                                                    device,  topk, return_text=False)

    generated_text = tokenizer.custom_decode(best_sequences.squeeze(0).tolist())

    return generated_text # Return the best sequence


def get_random_str(main_str, substr_len=200):
    if len(main_str) <= substr_len:
        return main_str
    idx = random.randrange(0,
                           len(main_str) - substr_len + 1)  # Randomly select an "idx" such that "idx + substr_len <= len(main_str)".
    return main_str[idx: (idx + substr_len)]


from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import LogitsProcessor


class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops):
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        if torch.all((self.stops == input_ids[0][-1:])).item():
            # logger.info("see wtm and stop")
            # print(self.stops, input_ids[0][-1:])
            return True

        return False


class CustomLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids, scores):
        # Only keep scores for first len(scores[0]) - 6 tokens
        return scores[:, :-WATERMARK_EMB]


def generate_text_pipeline(model, tokenizer, input_ids, device, topk=False, max_length=100, result_path=None, return_text=True):
    """

    :param model: watermarkPLM
    :param tokenizer: personalized tokenizer
    :param input_ids: list of input ids
    :param device:
    :return: string if return_text=True; else (1, seq_len) tensor
    """
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=tokenizer.watermark_token)])
    logits_processor = [CustomLogitsProcessor()]

    # Encode input context
    input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
    model = model.to(device)
    ori_len = input_ids.shape[1]
    cur_len = 0
    enter_watermark = False
    while cur_len < max_length:
        # Generate text
        input_ids = model.base_model.generate(
            input_ids,
            do_sample=True,  # Enable sampling
            max_length=max_length + ori_len,  # Maximum length of the generated sequences
            temperature=0.7,  # The value used to module the next token probabilities
            top_k=60,  # K for top-k sampling
            top_p=1.0,  # P for nucleus sampling
            pad_token_id=tokenizer.tokenizer.pad_token_id,  # Padding token ID
            eos_token_id=tokenizer.tokenizer.pad_token_id,  # EOS token ID
            repetition_penalty=1.2,  # The parameter for repetition penalty
            length_penalty=2.0,  # The parameter for length penalty
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria
        )
        # logger.info(f"generated_ids: {generated_ids[0][-1]}")
        if input_ids[0][-1] == tokenizer.watermark_token:
            input_ids = input_ids[:, :-1]
            enter_watermark = True
            input_ids = generate_watermark_beam(input_ids, model, tokenizer, device, topk, beam_size=10, return_text=False).unsqueeze(0)

        if not enter_watermark:
            input_ids = generate_watermark_beam(input_ids, model, tokenizer, device, topk, beam_size=10, return_text=False).unsqueeze(0)
        cur_len = input_ids.shape[1] - ori_len

    if not enter_watermark:
        logger.info("force generating watermark")
    if result_path:
        with open(result_path, "a") as f:
            generated_text = tokenizer.custom_decode(input_ids.squeeze(0).tolist())
            f.write(generated_text)
    if return_text:
        generated_text = tokenizer.custom_decode(input_ids.squeeze(0).tolist())
        return generated_text
    else:
        return input_ids

############################################### attack #############################################
def synonym_attack(prompt_sentence, k=0.2):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :param k: float: how many words are replaced to get synonym
    :return: list of string with words get replaced
    """

    import json
    import random
    import nltk
    from nltk.corpus import wordnet
    from nltk.tokenize import word_tokenize
    from PyDictionary import PyDictionary

    nltk.download('punkt')
    nltk.download('wordnet')

    def get_synonyms(word):
        synonyms = wordnet.synsets(word)
        return set(syn.lemmas()[0].name() for syn in synonyms)

    def replace_with_synonym(text, k):
        logger.info(f"original sentence: {text}")
        words = word_tokenize(text)
        num_to_replace = int(len(words) * k)
        for _ in range(num_to_replace):
            word_to_replace = random.choice(words)
            synonyms = get_synonyms(word_to_replace)
            if synonyms:
                synonym_to_use = random.choice(list(synonyms))
                words[words.index(word_to_replace)] = synonym_to_use
        logger.info(f"modified sentence: {' '.join(words)}")
        return ' '.join(words)

    synonym_data = [replace_with_synonym(item, k) for item in prompt_sentence]

    return synonym_data


def paraphrase_attack(prompt_sentence, device):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :return: list of string with words get replaced
    """
    class DipperParaphraser(object):
        def __init__(self, model="kalpeshk2011/dipper-paraphraser-xxl", verbose=True):
            time1 = time.time()
            self.tokenizer = T5Tokenizer.from_pretrained('google/t5-v1_1-xxl')
            self.model = T5ForConditionalGeneration.from_pretrained(model)
            if verbose:
                print(f"{model} model loaded in {time.time() - time1}")
            self.model.cuda()
            self.model.eval()

        def paraphrase(self, input_text, lex_diversity, order_diversity, prefix="", sent_interval=3, **kwargs):
            """Paraphrase a text using the DIPPER model.

            Args:
                input_text (str): The text to paraphrase. Make sure to mark the sentence to be paraphrased between <sent> and </sent> blocks, keeping space on either side.
                lex_diversity (int): The lexical diversity of the output, choose multiples of 20 from 0 to 100. 0 means no diversity, 100 means maximum diversity.
                order_diversity (int): The order diversity of the output, choose multiples of 20 from 0 to 100. 0 means no diversity, 100 means maximum diversity.
                **kwargs: Additional keyword arguments like top_p, top_k, max_length.
            """
            assert lex_diversity in [0, 20, 40, 60, 80, 100], "Lexical diversity must be one of 0, 20, 40, 60, 80, 100."
            assert order_diversity in [0, 20, 40, 60, 80, 100], "Order diversity must be one of 0, 20, 40, 60, 80, 100."

            lex_code = int(100 - lex_diversity)
            order_code = int(100 - order_diversity)

            input_text = " ".join(input_text.split())
            sentences = sent_tokenize(input_text)
            prefix = " ".join(prefix.replace("\n", " ").split())
            output_text = ""

            for sent_idx in range(0, len(sentences), sent_interval):
                curr_sent_window = " ".join(sentences[sent_idx:sent_idx + sent_interval])
                final_input_text = f"lexical = {lex_code}, order = {order_code}"
                if prefix:
                    final_input_text += f" {prefix}"
                final_input_text += f" <sent> {curr_sent_window} </sent>"

                final_input = self.tokenizer([final_input_text], return_tensors="pt")
                final_input = {k: v.cuda() for k, v in final_input.items()}

                with torch.inference_mode():
                    outputs = self.model.generate(**final_input, **kwargs)
                outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
                prefix += " " + outputs[0]
                output_text += " " + outputs[0]

            return output_text
    
    dp = DipperParaphraser()

    input_text = tokenizer([prompt_sentence], truncation=True, padding="longest", return_tensors="pt").to(device)
    output = dp.paraphrase(input_text, lex_diversity=60, order_diversity=0, do_sample=True, top_p=0.75, top_k=None, max_length=512)
    import pdb; pdb.set_trace()
    def _get_paraphrased_sentences(model, tokenizer, sentence, num_return_sequences=5, num_beams=5):
        # tokenize the text to be form of a list of token IDs
        inputs = tokenizer([sentence], truncation=True, padding="longest", return_tensors="pt").to(device)
        # generate the paraphrased sentences
        outputs = model.generate(
            **inputs,
            num_beams=num_beams,
            num_return_sequences=num_return_sequences,
        )
        # decode the generated sentences using the tokenizer to get them back to text
        return tokenizer.batch_decode(outputs, skip_special_tokens=True)
    paraphrased_sentences = []
    for sentence in prompt_sentence:
        logger.info(f"original sentence: {sentence}")
        buffer = _get_paraphrased_sentences(model, tokenizer, sentence, num_beams=10, num_return_sequences=10)
        # random_integer = random.randint(0, 9)
        random_integer = 0
        logger.info(f"paraphrased sentence: {buffer[random_integer]}")
        paraphrased_sentences.append(buffer[random_integer])

    return paraphrased_sentences

# def paraphrase_attack(prompt_sentence, device):
#     """
#     :param prompt_sentence: prompt_sentence: list of string
#     :return: list of string with words get replaced
#     """
#     from transformers import PegasusForConditionalGeneration, PegasusTokenizerFast
#     model = PegasusForConditionalGeneration.from_pretrained("tuner007/pegasus_paraphrase")
#     tokenizer = PegasusTokenizerFast.from_pretrained("tuner007/pegasus_paraphrase")
#     model = model.to(device)

#     def _get_paraphrased_sentences(model, tokenizer, sentence, num_return_sequences=5, num_beams=5):
#         # tokenize the text to be form of a list of token IDs
#         inputs = tokenizer([sentence], truncation=True, padding="longest", return_tensors="pt").to(device)
#         # generate the paraphrased sentences
#         outputs = model.generate(
#             **inputs,
#             num_beams=num_beams,
#             num_return_sequences=num_return_sequences,
#         )
#         # decode the generated sentences using the tokenizer to get them back to text
#         return tokenizer.batch_decode(outputs, skip_special_tokens=True)
#     paraphrased_sentences = []
#     for sentence in prompt_sentence:
#         logger.info(f"original sentence: {sentence}")
#         buffer = _get_paraphrased_sentences(model, tokenizer, sentence, num_beams=10, num_return_sequences=10)
#         # random_integer = random.randint(0, 9)
#         random_integer = 0
#         logger.info(f"paraphrased sentence: {buffer[random_integer]}")
#         paraphrased_sentences.append(buffer[random_integer])

#     return paraphrased_sentences

def swap_chars_attack(prompt_sentence, k):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :param k: float: how many characters are swaped
    :return: list of string with words get replaced
    """
    import string
    modified_sentences = []

    for sentence in prompt_sentence:
        num_swaps = int(len(sentence) * k)
        logger.info(f"original sentence: {sentence}")
        sentence = list(sentence)
        i = 0
        while i < num_swaps:
            swap_index = random.randint(0, len(sentence)-2)
            if sentence[swap_index] != ' ' and sentence[swap_index + 1] != ' ':
                sentence[swap_index], sentence[swap_index + 1] = sentence[swap_index + 1], sentence[swap_index]
                i+=1
        sentence = ''.join(sentence)
        logger.info(f"modified sentence: {sentence}")
        modified_sentences.append(sentence)

    return modified_sentences


def insert_chars_attack(prompt_sentence, k):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :param k: float: how many characters are replaced to get synonym
    :return: list of string with words get replaced
    """
    import string
    modified_sentences = []

    for sentence in prompt_sentence:
        num_insertions = int(len(sentence) * k)
        logger.info(f"original sentence: {sentence}")
        for _ in range(num_insertions):
            insert_index = random.randint(0, len(sentence))
            random_char = random.choice(string.ascii_letters)
            sentence = sentence[:insert_index] + random_char + sentence[insert_index:]
        logger.info(f"modified sentence: {sentence}")
        modified_sentences.append(sentence)

    return modified_sentences


def insert_words_attack(prompt_sentence, k=0.2, localized=False):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :param k: float: how many characters are replaced to get synonym
    :return: list of string with words get replaced
    """
    import random
    import nltk
    from nltk.corpus import words

    modified_sentences = []

    nltk.download('words')
    word_list = words.words()  # Get list of common English words

    for sentence in prompt_sentence:
        logger.info(f"original sentence: {sentence}")
        words_in_sentence = sentence.split()
        num_insertions = int(len(words_in_sentence) * k)
        if localized:
            num_insertions = 1

        for _ in range(num_insertions):
            insert_index = random.randint(0, len(words_in_sentence))
            random_word = random.choice(word_list)  # Select random valid word from word_list
            words_in_sentence.insert(insert_index, random_word)

        modified_sentence = ' '.join(words_in_sentence)
        modified_sentences.append(modified_sentence)
        logger.info(f"modified sentence: {modified_sentence}")
    return modified_sentences


def delete_chars_attack(prompt_sentence, k):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :param k: float: how many characters are replaced to get synonym
    :return: list of string with words get replaced
    """
    modified_sentences = []

    for sentence in prompt_sentence:
        logger.info(f"original sentence: {sentence}")
        num_deletions = int(len(sentence) * k)

        for _ in range(num_deletions):
            delete_index = random.randint(0, len(sentence) - 1)
            sentence = sentence[:delete_index] + sentence[delete_index + 1:]

        modified_sentences.append(sentence)
        logger.info(f"modified sentence: {sentence}")

    return modified_sentences


def delete_words_attack(prompt_sentence, k, localized=False):
    """
    :param prompt_sentence: prompt_sentence: list of string
    :param k: float: how many characters are replaced to get synonym
    :return: list of string with words get replaced
    """
    modified_sentences = []
    
    special_character = '\u200b\u200c\u200d\u2062\u2063\u2064'
    
    def _extract_special_substrings(input_string, special_characters):
        pattern = '[' + re.escape(special_characters) + ']+'
        special_substrings = re.findall(pattern, input_string)
        return special_substrings

    disrupt = 0
    for sentence in prompt_sentence:
        Match = True
        logger.info("original sentence: {}".format(sentence.encode("unicode_escape").decode()))
        original_watermark = _extract_special_substrings(sentence, special_character)
        words = sentence.split()
        num_deletions = int(len(words) * k)

        for _ in range(num_deletions):
            delete_index = random.randint(0, len(words) - 1)
            del words[delete_index]

        modified_sentence = ' '.join(words)
        modified_sentences.append(modified_sentence)
        logger.info("modified sentence: {}".format(modified_sentence.encode("unicode_escape").decode()))
        modified_watermark = _extract_special_substrings(modified_sentence, special_character)

        if modified_watermark!=original_watermark:
            disrupt+=1    
            logger.info(f"original watermark: {original_watermark}, modified watermark: {modified_watermark}")

    logger.info(f">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
    logger.info(f"Successfully Disrupted Amount: {disrupt}")
    
    return modified_sentences


################################ main function ########################################
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_path", default='data/embedded_warmup/', type=str, required=True, help="The input data directory. "
    )

    parser.add_argument(
        "--control_var", default="eval_data_10c", type=str, required=False, help="Data path to control var")

    parser.add_argument(
        "--model_type", default="watermark-gpt", type=str, required=True,
        help="The model architecture to be trained or fine-tuned.",
    )

    # Other parameters
    parser.add_argument(
        "--pre_trained_model_type", default="gpt2-large", type=str, required=False,
        help="The model architecture to be trained or fine-tuned.",
    )

    parser.add_argument(
        "--should_continue", action="store_true", help="Whether to continue from previously trained model"
    )

    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=False,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
    )

    parser.add_argument(
        "--tokenizer_name",
        default=None,
        type=str,
        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
    )
    parser.add_argument(
        "--cache_dir",
        default=None,
        type=str,
        help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
    )
    parser.add_argument(
        "--block_size",
        default=384,
        type=int,
        help="Optional input sequence length after tokenization."
             "The training dataset will be truncated in block of this size for training."
             "Default to the model max input length for single sentence inputs (take into account special tokens).",
    )
    parser.add_argument("--freeze_layers", default=12, type=int, help="freeze layers for second-stage pretraining")
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument("--do_generate", action="store_true", help="Whether to generate sentence.")
    parser.add_argument("--do_evaluate", action="store_true", help="Whether to run evaluation.")
    parser.add_argument("--create_synthetic", action="store_true",
                        help="Whether to create synthetic data when runing baseline evaluation.")
    parser.add_argument("--use_synthetic", action="store_true", help="Whether to use synthetic data for evaluation.")
    parser.add_argument("--inference_attack", action="store_true", help="generate watermark based on original sentence")
    parser.add_argument("--find_top3", action="store_true", help="generate top 3 watermark sources")
    parser.add_argument("--find_top5", action="store_true", help="generate top 5 watermark sources")
    parser.add_argument("--use_personalized_generate", action="store_true",
                        help="Whether to use personalized sampling for generation.")
    parser.add_argument("--use_model_generate", action="store_true",
                        help="Whether to use model.generate for generation.")
    parser.add_argument("--generate_watermark_classification", action="store_true",
                        help="Whether to use already stored watermark true label for better watermark generation")
    parser.add_argument("--regenerate", action="store_true", help="regenerate watermarks based on synthetic data")
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--adam_beta1", default=0.9, type=float, help="Beta1 for Adam optimizer.")
    parser.add_argument("--adam_beta2", default=0.999, type=float, help="Beta2 for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--save_total_limit",
        type=int,
        default=None,
        help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
    )
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument(
        "--one_watermark", action="store_true", help="Each block left with only one watermark"
    )
    parser.add_argument("--seed", type=int, default=2023, help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
    )
    parser.add_argument(
        "--eight_bit_adam", action="store_true", help="Whether to use 8-bit Adam"
    )
    parser.add_argument("--local-rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")

    # robustness experiment
    parser.add_argument("--paraphrase_attack", action="store_true", help="Conduct paraphrase attack")
    parser.add_argument("--synonym_attack", action="store_true", help="Conduct synonym attack")
    parser.add_argument("--insert_chars_attack", action="store_true", help="Conduct insert chars attack")
    parser.add_argument("--insert_words_attack", action="store_true", help="Conduct insert words attack")
    parser.add_argument("--delete_chars_attack", action="store_true", help="Conduct delete chars attack")
    parser.add_argument("--delete_words_attack", action="store_true", help="Conduct delete words attack")
    parser.add_argument("--swap_chars_attack", action="store_true", help="Conduct swap chars attack")
    parser.add_argument("--localized", action="store_true", help="Conduct localized words attack")
    parser.add_argument("--k", default=0.2, type=float, help="Percentage of attack")
    args = parser.parse_args()

    if args.should_continue:
        sorted_checkpoints = _sorted_checkpoints(args)
        if len(sorted_checkpoints) == 0:
            raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
        else:
            args.model_name_or_path = sorted_checkpoints[-1]

    if args.output_dir is not None:
        if (
                os.path.exists(args.output_dir)
                and os.listdir(args.output_dir)
                and args.do_train
                and not args.overwrite_output_dir
        ):
            raise ValueError(
                "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                    args.output_dir
                )
            )

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging 
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if getattr(args, 'local_rank') == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(getattr(args, 'local_rank'))
        device = torch.device("cuda", getattr(args, 'local_rank'))
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if getattr(args, 'local_rank') in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        getattr(args, 'local_rank'),
        device,
        args.n_gpu,
        bool(getattr(args, 'local_rank') != -1),
        args.fp16,
    )

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if getattr(args, 'local_rank') not in [-1, 0]:
        torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training download model & vocab

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    if args.tokenizer_name and args.tokenizer_name != args.model_name_or_path:
        tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # add pad token to the tokenizer
        if "watermark" in args.model_type:
            tokenizer.add_tokens(['[WTM]'])
            tokenizer.add_special_tokens(
                {'additional_special_tokens': ['[WTM]']})  # add WATERMARK token to the tokenizer
    elif args.model_name_or_path:
        if "watermark" in args.model_type:
            tokenizer = personlized_tokenizer(args, tokenizer_class.from_pretrained(args.model_name_or_path))
        else:
            tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
    else:
        raise ValueError(
            "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
            "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
        )

    if args.block_size <= 0:
        if args.model_name_or_path and "watermark" in args.model_type:
            args.block_size = tokenizer.tokenizer.max_len_single_sentence
        else:
            args.block_size = tokenizer.max_len_single_sentence
        # Our input block size will be the max possible for the model
    else:
        if args.model_name_or_path and "watermark" in args.model_type:
            args.block_size = min(args.block_size, tokenizer.tokenizer.max_len_single_sentence)
        else:
            args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)

    if args.model_name_or_path:
        logger.info(f"Loading model from {args.model_name_or_path} in main")
        if "watermark" in args.model_type:
            model = watermarkPLM(config_class, model_class, seed=args.seed,
                                 vocab_size=len(tokenizer.tokenizer.get_vocab()),
                                 watermark_size=WATERMARK_EMB, pad_token_id=tokenizer.tokenizer.pad_token_id,
                                 model_type=args.pre_trained_model_type, freeze_layers=args.freeze_layers,
                                 model_name_or_path=args.model_name_or_path)
        else:
            model = model_class.from_pretrained(args.output_dir)
    else:
        logger.info("Training new model from scratch")
        if "watermark" in args.model_type:
            model = watermarkPLM(config_class, model_class, seed=args.seed, vocab_size=len(tokenizer.get_vocab()),
                                 watermark_size=WATERMARK_EMB, pad_token_id=tokenizer.pad_token_id,
                                 freeze_layers=args.freeze_layers, model_type=args.pre_trained_model_type)
        else:
            model = model_class()
    logger.info(f"Before loading model, memory takes:")
    logger.info(f"Allocated: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
    # print(args.device)
    # print(model)
    model.to(args.device)
    logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    logger.info(f"After loading model, memory takes: ")
    logger.info(f"Allocated: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")

    if getattr(args, 'local_rank') == 0:
        torch.distributed.barrier()  # End of barrier to make sure only the first process in distributed training download model & vocab

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        if getattr(args, 'local_rank') not in [-1, 0]:
            torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

        # if args.model_name_or_path:
        #     tokenizer = tokenizer.tokenizer
        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
        # import pdb; pdb.set_trace()
        if getattr(args, 'local_rank') == 0:
            torch.distributed.barrier()

        global_step, tr_loss, loss_track = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
    if args.do_train and (getattr(args, 'local_rank') == -1 or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if getattr(args, 'local_rank') in [-1, 0]:
            os.makedirs(args.output_dir, exist_ok=True)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
        # torch.save(model.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin"))
        model_to_save.base_model.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

        with open(os.path.join(args.output_dir, "loss_track.pkl"), 'wb') as f:
            pickle.dump(loss_track, f)

    # Evaluation
    results = {}
    if args.do_eval and getattr(args, 'local_rank') in [-1, 0]:
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
            logger.info(f"Loading model from {checkpoint} in main")
            model = watermarkPLM(config_class, model_class, seed=args.seed, vocab_size=len(tokenizer.get_vocab()),
                                 watermark_size=WATERMARK_EMB, pad_token_id=tokenizer.pad_token_id,
                                 model_type=args.pre_trained_model_type, freeze_layers=args.freeze_layers,
                                 model_name_or_path=checkpoint)
            model = model.to(args.device)
            result = evaluate(args, model, tokenizer, prefix=prefix)
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)
        print('results: ', results)

    # Generate Test
    if args.do_generate and getattr(args, 'local_rank') in [-1, 0]:
        sample_sentence = "The Serpens star-forming cloud is"  # astro-ph
        # sample_sentence = "The quantisation of the ten-dimensional" # hep-th
        # sample_sentence = "The chemistry occurring in primordial gas"  # astro-CO
        input_ids = tokenizer.custom_encode_new(sample_sentence)
        generated_sentence = generate_with_beam_search_sample(args, model, tokenizer, input_ids, device=args.device)
        print(generated_sentence.encode("unicode_escape").decode())
        print(generated_sentence)

        with open('results.txt', 'w') as file:
            file.write(generated_sentence)

    # Evaluation
    if args.do_evaluate and getattr(args, 'local_rank') in [-1, 0]:

        # characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        num_test_per_class = 50

        data_path = args.data_path
        # synthetic_data_path = args.model_name_or_path.split('/')[0] + "/data/synthetic_data/" + args.model_name_or_path.split('/')[-1] + '/'

        raw_datasets = list_raw_datasets(data_path)

        # if not os.path.exists(synthetic_data_path):
        #     os.makedirs(synthetic_data_path)

        ground_truth_file = data_path + '/embedded_watermarks.txt'
        ground_truth = {}
        
        with open(ground_truth_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += CHARACTERS[int(idx)]
                ground_truth[class_name] = class_watermark
        
        all_ground_truth = [*ground_truth.values()]

        if args.generate_watermark_classification:
            ground_truth_file = args.data_path + '/embedded_watermarks.txt'
            watermark_list = []

            with open(ground_truth_file) as file:
                for line in file:
                    cur_watermark = []
                    watermark_idx = line.rstrip().split()[1]
                    for idx in watermark_idx:
                        cur_watermark.append(int(idx))
                    watermark_list.append(cur_watermark)

        ########## soft matching ##########

        def _string_similarity(lcs, str1, str2):
            if lcs:
                return longest_common_subsequence(str1, str2)
            else:
                return Levenshtein.distance(str1, str2)

        def _extract_special_substrings(input_string, special_characters):
            pattern = '[' + re.escape(special_characters) + ']+'
            special_substrings = re.findall(pattern, input_string)
            return special_substrings

        def _find_closest_watermark(generated_watermark):
            min_dist = 1000
            min_idx = -1
            for class_name in ground_truth:
                dist = _string_similarity(False, generated_watermark, ground_truth[class_name])
                if dist < min_dist:
                    min_dist = dist
                    min_idx = class_name
            return ground_truth[min_idx]
        
        # if args.control_var:
        #     data_path = '/'.join(data_path.split('/')[:-1]) + "/" + args.control_var
            # data_path = '/'.join(data_path.split('/')[:-1]) + "/not_in_training_10c"
            # data_path = '/'.join(data_path.split('/')[:-1]) + "/unwatermarked"
            # data_path = '/'.join(data_path.split('/')[:-1]) + "/unknown_10c_20"
        
        label_file = data_path + '/embedded_watermarks.txt'
        label = {}
        
        with open(label_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += CHARACTERS[int(idx)]
                label[class_name] = class_watermark

        for dataset in raw_datasets:
            random.seed(args.seed)
            data_folder = os.path.join(data_path, dataset)
            prompt_sentences = []
            generated_sentences = []
            has_watermark = 0
            true_success = 0
            predicted_success = 0
            misclassfication = 0
            totally_wrong = 0
            special_character = '\u200b\u200c\u200d\u2062\u2063\u2064'
            logger.info("dataset: {}".format(dataset))
            
            if args.use_synthetic:
                with open(synthetic_data_path + dataset + '.json', 'rb') as f:
                    prompt_sentences = json.load(f)
            elif args.regenerate:
                with open('seed_2021/logs/evaluation/llama_10c_20', 'r') as f:
                    content = f.read()
                generated_texts = re.findall(r'generated_text:\s*(.*)', content)
                
            else:
                if 'block' in args.model_name_or_path:
                    with open(data_folder + '/' + 'embedded.pkl', 'rb') as p:
                        f = pickle.load(p)
                        i = 0
                        while i < num_test_per_class:
                            line = random.choice(f)
                            if label[dataset] in line and len(line) > 5:
                                line = line.replace(label[dataset], "")
                                line = get_random_str(line)
                                # print(f"original line is: {line}")
                                prompt_sentences.append(line)
                                i += 1
                # elif 'booksum' in args.model_name_or_path or 'reddit' in args.model_name_or_path:
                else:
                    file_list = os.listdir(data_folder)
                    filtered_files = [file for file in file_list if not file.endswith('.pkl')]
                    file = random.choice(filtered_files)
                    with open(data_folder + '/' + file, 'r', encoding='utf-8') as f:
                        i = 0
                        lines = f.readlines()
                        lines_iterator = iter(lines)
                        while i < num_test_per_class:
                            line = next(lines_iterator, None)
                            if label[dataset] in line:
                                while len(line)< 210:
                                    next_line = next(lines_iterator, None)
                                    if next_line is not None:
                                        if len(next_line)<10:
                                            break
                                        line = line+next_line
                                    else:
                                        break
                                else:
                                    line = line.replace(label[dataset], "")
                                    line = get_random_str(line)
                                    print(f"original line is: {line}")
                                    prompt_sentences.append(line)
                                        

            if args.paraphrase_attack:
                prompt_sentences = paraphrase_attack(prompt_sentences, device)
            elif args.synonym_attack:
                prompt_sentences = synonym_attack(prompt_sentences, args.k)
            elif args.insert_chars_attack:
                prompt_sentences = insert_chars_attack(prompt_sentences, args.k)
            elif args.insert_words_attack:
                prompt_sentences = insert_words_attack(prompt_sentences, args.k, args.localized)
            elif args.delete_chars_attack:
                prompt_sentences = delete_chars_attack(prompt_sentences, args.k)
            elif args.delete_words_attack:
                prompt_sentences = delete_words_attack(prompt_sentences, args.k)
            elif args.swap_chars_attack:
                prompt_sentences = swap_chars_attack(prompt_sentences, args.k)
            
            # with open(f"prompt_sentences/prompt_sentences_{dataset}.txt", "w") as file:
            #     for item in prompt_sentences:
            #         file.write(f"{item}\n")
            # continue

            for sentence in prompt_sentences:

                if args.use_synthetic or args.inference_attack:
                    input_ids = tokenizer.custom_encode_new(sentence)
                    input_ids = torch.tensor(input_ids).unsqueeze(0).to(args.device)
                    if args.generate_watermark_classification:
                        generated_text = generate_watermark_classification(input_ids, model, tokenizer, device,
                                                                           watermark_list, topk=1,
                                                                           temperature=0.8, return_text=True)
                    else:
                        generated_text = generate_watermark_beam(input_ids, model, tokenizer, device,
                                                                 return_text=True, topk=1)

                elif args.use_personalized_generate:
                    input_ids = tokenizer.custom_encode_new(sentence)
                    generated_text= generate_with_beam_search_sample(args, model, tokenizer, input_ids,
                                                                    max_length=100, device=args.device)

                elif args.use_model_generate:
                    input_ids = tokenizer.custom_encode_new(sentence)
                    generated_text = generate_text_pipeline(model, tokenizer, input_ids, device=args.device,
                                                            max_length=200, result_path=None, return_text=True)
                elif args.regenerate:
                    input_ids = tokenizer.custom_encode_new(sanitize(sentence))
                    input_ids = torch.tensor(input_ids).unsqueeze(0).to(args.device)
                    if args.generate_watermark_classification:
                        generated_text = generate_watermark_classification(input_ids, model, tokenizer, device,
                                                                           watermark_list, topk=1,
                                                                           temperature=0.8, return_text=True)
                    else:
                        generated_text = generate_watermark_beam(input_ids, model, tokenizer, device,
                                                                return_text=True)
                    
                generated_sentences.append(generated_text)
                logger.info("generated_text: {}".format(generated_text.encode("unicode_escape").decode()))
                logger.info("dataset: {}".format(ground_truth[dataset].encode("unicode_escape").decode()))
                
            # if args.use_synthetic:
            #     if not os.path.exists(synthetic_data_path + '_generation/'):
            #         os.makedirs(synthetic_data_path + '_generation/')

            #     with open(synthetic_data_path + '_generation/' + dataset + '.json', 'w') as fout:
            #         json.dump(generated_sentences, fout, indent=4)

            if args.delete_words_attack:
                generated_sentences = delete_words_attack(generated_sentences, args.k)
            
            for generated_text in generated_sentences:
                if any(ext in generated_text for ext in CHARACTERS):
                    has_watermark += 1
                    all_watermark = _extract_special_substrings(generated_text, special_character)
                    true_correct = True
                    misclass = False
                    predict = True
                    for item in all_watermark:
                        if ground_truth[dataset] != item:
                            true_correct = False
                            if item in all_ground_truth:
                                misclass = True
                            predicted_watermark =_find_closest_watermark(item)
                            if ground_truth[dataset] != predicted_watermark:
                                predict = False
                                
                    if true_correct:
                        true_success+=1
                    else:
                        if misclass:
                            misclassfication+=1
                        else:
                            totally_wrong+=1
                        if predict:
                            predicted_success += 1
                            totally_wrong -= 1
            print("Dataset: ", dataset)
            print("Number of watermark: ", has_watermark)
            logger.info("Number of correct watermark: {}".format(true_success))
            print("Number of correct watermark: ", true_success)
            print("Number of predicted correct watermark: ", predicted_success)
            print("Number of misclassified watermark: ", misclassfication)
            print("Number of totally wrong watermark: ", totally_wrong)

    # Create Synthetic Data
    if args.create_synthetic and getattr(args, 'local_rank') in [-1, 0]:
        logger.info('Enter Synthetic Data Creation')

        def distinct_n(text, n):
            ngrams = [tuple(text[i:i + n]) for i in range(len(text) - n + 1)]
            return len(set(ngrams)) / max(len(ngrams), 1)

        # characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        num_test_per_class = 50

        data_path = args.data_path
        ground_truth_file = data_path + '/embedded_watermarks.txt'

        raw_datasets = list_raw_datasets(data_path)

        synthetic_data_path = args.model_name_or_path.split('/')[0]+"/data/synthetic_data/" + args.model_name_or_path.split('/')[-1] + '/'
        if not os.path.exists(synthetic_data_path):
            os.makedirs(synthetic_data_path)
        ground_truth = {}

        with open(ground_truth_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += CHARACTERS[int(idx)]
                ground_truth[class_name] = class_watermark

        count = 0
        distinct_1_score = 0.0
        distinct_2_score = 0.0
        
        # if args.control_var:
        #     # seed_2023/data/embedded_g_10c
        #     data_path = '/'.join(data_path.split('/')[:-1]) + "/" + args.control_var
            # data_path = '/'.join(data_path.split('/')[:-1]) + "/unknown_10c_20"
        
        label_file = data_path + '/embedded_watermarks.txt'
        label = {}
        
        with open(label_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += CHARACTERS[int(idx)]
                label[class_name] = class_watermark

        for dataset in raw_datasets:
            data_folder = os.path.join(data_path, dataset)
            prompt_sentences = []
            generated_sentences = []

            if 'booksum' in args.model_name_or_path or 'reddit' in args.model_name_or_path:
                file_list = os.listdir(data_folder)
                filtered_files = [file for file in file_list if not file.endswith('.pkl')]
                file = random.choice(filtered_files)
                with open(data_folder + '/' + file, 'r', encoding='utf-8') as f:
                    i = 0
                    lines = f.readlines()
                    lines_iterator = iter(lines)
                    while i < num_test_per_class:
                        line = next(lines_iterator, None)
                        if label[dataset] in line:
                            while len(line)< 210:
                                next_line = next(lines_iterator, None)
                                if next_line is not None:
                                    if len(next_line)<10:
                                        break
                                    line = line+next_line
                                else:
                                    break
                            else:
                                line = line.replace(label[dataset], "")
                                line = get_random_str(line)
                                print(f"original line is: {line}")
                                prompt_sentences.append(line)
                                i += 1
            else:
                # First, find #num_test_per_class different start sentences
                file_list = os.listdir(data_folder)
                filtered_files = sorted([file for file in file_list if not file.endswith('.pkl')])
                i = 0
                while i < num_test_per_class:
                    file = random.choice(filtered_files)
                    # file = filtered_files[i]
                    logger.info("file: {}".format(file))
                    with open(data_folder + '/' + file, 'r') as f:
                        lines = f.readlines()
                        lines_iterator = iter(lines)
                        for line in lines_iterator:
                            if label[dataset] in line and len(line)>210:
                                line = line.replace(label[dataset], "")
                                line = get_random_str(line)
                                logger.info(f"original line is: {line}")
                                prompt_sentences.append(line)
                                i += 1
                                break
                                        

            for sentence in prompt_sentences:
                if args.use_personalized_generate:
                    input_ids = tokenizer.custom_encode_new(sentence)
                    generated_text = generate_with_beam_search_sample(args, model, tokenizer, input_ids, max_length=100,
                                                                      device=args.device)
                else:
                    input_ids = tokenizer.custom_encode_new(sentence)
                    generated_text = generate_text_pipeline(model, tokenizer, input_ids, device, max_length=200,
                                                            result_path=None, return_text=True)


                generated_sentences.append(generated_text)
                sanitized_text = sanitize(generated_text)
                
                count += 1
                distinct_1_score += distinct_n(sanitized_text.split(), n=1)
                distinct_2_score += distinct_n(sanitized_text.split(), n=2)

            with open(synthetic_data_path + dataset + '.json', 'w') as fout:
                json.dump(generated_sentences, fout, indent=4)
                print(generated_sentences)

        print('average distinct_1_score: ', distinct_1_score / count)
        print('average distinct_2_score: ', distinct_2_score / count)

    if args.find_top3 or args.find_top5 and getattr(args, 'local_rank') in [-1, 0]:

        # CHARACTERS = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        num_test_per_class = 50

        data_path = args.data_path
        synthetic_data_path = args.model_name_or_path.split('/')[0]+"/data/synthetic_data/" + args.model_name_or_path.split('/')[-1] + '/'

        raw_datasets = list_raw_datasets(data_path)

        if not os.path.exists(synthetic_data_path):
            os.makedirs(synthetic_data_path)

        ground_truth_file = data_path + '/embedded_watermarks.txt'
        ground_truth = {}

        with open(ground_truth_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += CHARACTERS[int(idx)]
                ground_truth[class_name] = class_watermark
        
        all_ground_truth = [*ground_truth.values()]

        def _string_similarity(lcs, str1, str2):
            if lcs:
                return longest_common_subsequence(str1, str2)
            else:
                return Levenshtein.distance(str1, str2)

        def _extract_special_substrings(input_string, special_characters):
            pattern = '[' + re.escape(special_characters) + ']+'
            special_substrings = re.findall(pattern, input_string)
            return special_substrings

        def _find_closest_watermark(generated_watermark):
            min_dist = 1000
            min_idx = -1
            for class_name in ground_truth:
                dist = _string_similarity(False, generated_watermark, ground_truth[class_name])
                if dist < min_dist:
                    min_dist = dist
                    min_idx = class_name
            return ground_truth[min_idx]
        
        # if args.control_var:
        #     data_path = '/'.join(data_path.split('/')[:-1]) + "/" + args.control_var 
            # data_path = '/'.join(data_path.split('/')[:-1]) + "/not_in_training_10c"
            # data_path = '/'.join(data_path.split('/')[:-1]) + "/unwatermarked"
        
        label_file = data_path + '/embedded_watermarks.txt'
        label = {}
        
        with open(label_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += CHARACTERS[int(idx)]
                label[class_name] = class_watermark

        for dataset in raw_datasets:
            data_folder = os.path.join(data_path, dataset)
            prompt_sentences = []
            generated_sentences = []
            has_watermark = 0
            true_success = 0
            predicted_success = 0
            misclassfication = 0
            totally_wrong = 0
            special_character = '\u200b\u200c\u200d\u2062\u2063\u2064'
            logger.info("dataset: {}".format(dataset))

            if args.use_synthetic or args.regenerate:
                with open(synthetic_data_path + dataset + '.json', 'rb') as f:
                    prompt_sentences = json.load(f)
            elif 'booksum' in args.model_name_or_path or 'reddit' in args.model_name_or_path:
                file_list = os.listdir(data_folder)
                filtered_files = [file for file in file_list if not file.endswith('.pkl')]
                file = random.choice(filtered_files)
                with open(data_folder + '/' + file, 'r', encoding='utf-8') as f:
                    i = 0
                    lines = f.readlines()
                    lines_iterator = iter(lines)
                    while i < num_test_per_class:
                        line = next(lines_iterator, None)
                        if label[dataset] in line:
                            while len(line)< 210:
                                next_line = next(lines_iterator, None)
                                if next_line is not None:
                                    if len(next_line)<10:
                                        break
                                    line = line+next_line
                                else:
                                    break
                            else:
                                line = line.replace(label[dataset], "")
                                line = get_random_str(line)
                                print(f"original line is: {line}")
                                prompt_sentences.append(line)
                                i += 1
            else:
                # with open(data_folder+ '/unwatermarked.txt', 'r') as f:
                #     lines = f.readlines()
                #     lines_iterator = iter(lines)
                #     i = 0
                #     while i < num_test_per_class:
                #         line = random.choice(lines)
                #         if label[dataset] not in line and len(line)>210:
                #             line = line.replace(label[dataset], "")
                #             line = get_random_str(line)
                #             logger.info(f"original line is: {line}")
                #             prompt_sentences.append(line)
                #             i += 1

                file_list = os.listdir(data_folder)
                filtered_files = sorted([file for file in file_list if not file.endswith('.pkl')])
                i = 0
                while i < num_test_per_class:
                    file = random.choice(filtered_files)
                    # file = filtered_files[i]
                    # logger.info("file: {}".format(file))
                    with open(data_folder + '/' + file, 'r') as f:
                        lines = f.readlines()
                        lines_iterator = iter(lines)
                        for line in lines_iterator:
                            if label[dataset] in line and len(line)>210:
                                line = line.replace(label[dataset], "")
                                line = get_random_str(line)
                                # logger.info(f"original line is: {line}")
                                prompt_sentences.append(line)
                                i += 1
                                break

            if args.paraphrase_attack:
                prompt_sentences = paraphrase_attack(prompt_sentences, device)
            elif args.synonym_attack:
                prompt_sentences = synonym_attack(prompt_sentences, args.k)
            elif args.insert_chars_attack:
                prompt_sentences = insert_chars_attack(prompt_sentences, args.k)
            elif args.insert_words_attack:
                prompt_sentences = insert_words_attack(prompt_sentences, args.k, args.localized)
            elif args.delete_chars_attack:
                prompt_sentences = delete_chars_attack(prompt_sentences, args.k)
            elif args.delete_words_attack:
                prompt_sentences = delete_words_attack(prompt_sentences, args.k)
            elif args.swap_chars_attack:
                prompt_sentences = swap_chars_attack(prompt_sentences, args.k)
                
            k = 3
            if args.find_top5:
                k = 5

            for sentence in prompt_sentences:
                if args.use_personalized_generate:
                    input_ids = tokenizer.custom_encode_new(sentence)
                    generated_text = generate_with_beam_search_sample(args, model, tokenizer, input_ids,
                                                                      topk=k, max_length=100, device=args.device)
                elif args.regenerate:
                    input_ids = tokenizer.custom_encode_new(sanitize(sentence))
                    input_ids = torch.tensor(input_ids).unsqueeze(0).to(args.device)
                    if args.generate_watermark_classification:
                        generated_text = generate_watermark_classification(input_ids, model, tokenizer, device,
                                                                           watermark_list, topk=k,
                                                                           temperature=0.8, return_text=True)
                    else:
                        generated_text = generate_watermark_beam(input_ids, model, tokenizer, device,
                                                                 topk=k, return_text=True)
                else:
                    input_ids = tokenizer.custom_encode_new(sentence)
                    generated_text = generate_text_pipeline(model, tokenizer, input_ids, device=args.device, topk=k,
                                                            max_length=200, result_path=None, return_text=True)
                    
                generated_sentences.append(generated_text)
                # logger.info("generated_text: {}".format(generated_text.encode("unicode_escape").decode()))
                # logger.info("dataset: {}".format(ground_truth[dataset].encode("unicode_escape").decode()))

            for generated_text in generated_sentences:
                if any(ext in generated_text for ext in CHARACTERS):
                    has_watermark += 1
                    all_watermark = _extract_special_substrings(generated_text, special_character)

                    watermarks_datasets = []
                    if ground_truth[dataset] in all_watermark:
                        watermarks_datasets.append(dataset)
                        true_success += 1
                    else:
                        all_predicted_watermark = []
                        for item in all_watermark:
                            all_predicted_watermark.append(_find_closest_watermark(item))
                        if ground_truth[dataset] in all_predicted_watermark:
                            watermarks_datasets.append(dataset)
                            predicted_success += 1

                    # logger.info("groundtruth dataset: {}".format(dataset))
                    # logger.info("generated_text: {}".format(generated_text.encode("unicode_escape").decode()))
                    # logger.info("generated datasets: {}".format(watermarks_datasets))

            print("Dataset: ", dataset)
            print("Number of watermark: ", has_watermark)
            logger.info("Number of correct watermark: {}".format(true_success))
            print("Number of correct watermark: ", true_success)
            print("Number of predicted correct watermark: ", predicted_success)


if __name__ == "__main__":
    main()