import argparse
import glob
import logging
import os
import math
import pickle
import random
import re
import shutil
from typing import Dict, List, Tuple
import gc
import time
import Levenshtein
import json

import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.trainer_pt_utils import get_parameter_names
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

import pickle

from torch.distributed.fsdp import FullyShardedDataParallel
from torch.optim import AdamW
# faster training optimizer
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#adafactor
import bitsandbytes as bnb
from transformers import (
    pipeline,
    WEIGHTS_NAME,
    BertConfig,
    BertForMaskedLM,
    BertTokenizer,
    CamembertConfig,
    CamembertForMaskedLM,
    CamembertTokenizer,
    DistilBertConfig,
    DistilBertForMaskedLM,
    DistilBertTokenizer,
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    OpenAIGPTConfig,
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    RobertaConfig,
    RobertaForMaskedLM,
    RobertaTokenizer,
    get_linear_schedule_with_warmup,
)

from torch.utils.tensorboard import SummaryWriter

logger = logging.getLogger(__name__)

MODEL_CLASSES = {
    "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "gpt2-xl": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "gpt2-large": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "watermark-gpt": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
    "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
    "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
}

ZWSP = "\u200B"
ZWNJ = "\u200C"
ZWJ = "\u200D"
IT = "\u2062"
IS = "\u2063"
IP = "\u2064"
WATERMARK_EMB = 6
WATERMARK_LEN = 10


def get_subdirectories(directory):
    subdirectories = []
    for root, dirs, files in os.walk(directory):
        for dir in dirs:
            subdirectories.append(os.path.join(root, dir))
    return subdirectories


def list_raw_datasets(directory):
    subdirectories = []
    for root, dirs, files in os.walk(directory):
        for dir in dirs:
            subdirectories.append(dir)
    return subdirectories


def load_cache_text_files_from_directory(directory_path, evaluate=False, overwrite_cache=False):
    datasets = []
    if evaluate:
        cache_data = os.path.join(directory_path, "cache_valid.pkl")
    else:
        cache_data = os.path.join(directory_path, "cache_train.pkl")
    if os.path.exists(cache_data) and not overwrite_cache:
        with open(cache_data, 'rb') as file:
            logger.info("Load cache data from {}".format(cache_data))
            data = pickle.load(file)
        return True, data
    for file_name in os.listdir(directory_path):
        file_path = os.path.join(directory_path, file_name)
        with open(file_path, "r", encoding="utf-8") as file:
            try:
                datasets.append(file.read())
            except:
                pass
    return False, datasets


class personlized_tokenizer():
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.watermark_token = self.tokenizer.additional_special_tokens_ids[0]
        max_token_index = len(tokenizer.get_vocab())  # Get the maximum token index
        self.chr_to_wtm = {
            tokenizer.encode(ZWSP)[0] if len(tokenizer.encode(ZWSP)) == 1 else tuple(
                tokenizer.encode(ZWSP)): [max_token_index],
            tuple(tokenizer.encode(ZWNJ)): [max_token_index + 1],
            tokenizer.encode(ZWJ)[0] if len(tokenizer.encode(ZWJ)) == 1 else tuple(
                tokenizer.encode(ZWJ)): [max_token_index + 2],
            tuple(tokenizer.encode(IT)): [max_token_index + 3],
            tuple(tokenizer.encode(IS)): [max_token_index + 4],
            tuple(tokenizer.encode(IP)): [max_token_index + 5],
            tokenizer.encode(" " + ZWSP)[0] if len(tokenizer.encode(" " + ZWSP)) == 1 else tuple(
                tokenizer.encode(" " + ZWSP)): [max_token_index],
            tuple(tokenizer.encode(" " + ZWNJ)): [max_token_index + 1],
            tokenizer.encode(" " + ZWJ)[0] if len(tokenizer.encode(" " + ZWJ)) == 1 else tuple(
                tokenizer.encode(" " + ZWJ)): [max_token_index + 2],
            tuple(tokenizer.encode(" " + IT)): [max_token_index + 3],
            tuple(tokenizer.encode(" " + IS)): [max_token_index + 4],
            tuple(tokenizer.encode(" " + IP)): [max_token_index + 5],
            tokenizer.encode(ZWSP + ZWSP)[0] if len(tokenizer.encode(ZWSP + ZWSP)) == 1 else tuple(
                tokenizer.encode(ZWSP + ZWSP)): [max_token_index, max_token_index]
        }

        self.wtm_to_chr = {
            max_token_index: tokenizer.encode(ZWSP),
            max_token_index + 1: tokenizer.encode(ZWNJ),
            max_token_index + 2: tokenizer.encode(ZWJ),
            max_token_index + 3: tokenizer.encode(IT),
            max_token_index + 4: tokenizer.encode(IS),
            max_token_index + 5: tokenizer.encode(IP),
        }

    def custom_encode(self, sentence):
        """
        Map normal zero-width characters to watermark tokens
        :param tokenizer: original tokenizer
        :param sentence: target sentence
        :return:
        encoded_tokens: list of encoded tokens (seq_len)
        exist_wtm: if contains watermark or not
        wtm_mask: mask where watermark tokens are 1 and others are 0 (seq_len)
        """
        tokens = self.tokenizer.encode(sentence)
        encoded_tokens = []
        i = 0
        wtm_position = []
        while i < len(tokens):
            # todo: may need to modify this in the future if using sentencepiece
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1])])
                i += 2
            elif i < len(tokens) - 2 and (tokens[i], tokens[i + 1], tokens[i + 2]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1], tokens[i + 2])])
                i += 3
            elif tokens[i] in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_position.append(len(encoded_tokens))
                if len(self.chr_to_wtm[tokens[i]]) == 1:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    i += 1
                elif len(self.chr_to_wtm[tokens[i]]) == 2:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    i += 1
            else:
                encoded_tokens.append(tokens[i])
                i += 1

        return encoded_tokens

    def custom_decode(self, decoded_tokens):
        """
        Map watermark tokens to normal zero-width characters
        :param tokenizer: original tokenizer
        :param decoded_tokens: list of tokens
        :return:
        decoded_text: decoded text
        """
        decoded_tokens = [self.wtm_to_chr.get(token, [token]) for token in decoded_tokens]
        flattened_tokens = [token for sublist in decoded_tokens for token in sublist]
        decoded_text = self.tokenizer.decode(flattened_tokens)
        return decoded_text


class watermarkDataset(Dataset):
    """
    Define tokenizer function so that it
    1. add watermark tokens (done in _custom_encode)
    2. add T/F indicate if there is watermark (done in _custom_encode)
    3. specify the position of watermark: make it a mask format (done in _custom_encode)
    4. truncate when needed
    :param args:
    :param tokenizer:
    :param evaluate: if return valid dataset
    :return: processed input in dataset format
    """

    def __init__(self, args, tokenizer, evaluate):
        self.block_size = args.block_size
        self.tokenizer = tokenizer
        self.examples = []
        max_token_index = len(tokenizer.get_vocab())  # Get the maximum token index
        self.chr_to_wtm = {
            tokenizer.encode(ZWSP)[0] if len(tokenizer.encode(ZWSP)) == 1 else tuple(
                tokenizer.encode(ZWSP)): [max_token_index],
            tuple(tokenizer.encode(ZWNJ)): [max_token_index + 1],
            tokenizer.encode(ZWJ)[0] if len(tokenizer.encode(ZWJ)) == 1 else tuple(
                tokenizer.encode(ZWJ)): [max_token_index + 2],
            tuple(tokenizer.encode(IT)): [max_token_index + 3],
            tuple(tokenizer.encode(IS)): [max_token_index + 4],
            tuple(tokenizer.encode(IP)): [max_token_index + 5],
            tokenizer.encode(" " + ZWSP)[0] if len(tokenizer.encode(" " + ZWSP)) == 1 else tuple(
                tokenizer.encode(" " + ZWSP)): [max_token_index],
            tuple(tokenizer.encode(" " + ZWNJ)): [max_token_index + 1],
            tokenizer.encode(" " + ZWJ)[0] if len(tokenizer.encode(" " + ZWJ)) == 1 else tuple(
                tokenizer.encode(" " + ZWJ)): [max_token_index + 2],
            tuple(tokenizer.encode(" " + IT)): [max_token_index + 3],
            tuple(tokenizer.encode(" " + IS)): [max_token_index + 4],
            tuple(tokenizer.encode(" " + IP)): [max_token_index + 5],
            tokenizer.encode(ZWSP + ZWSP)[0] if len(tokenizer.encode(ZWSP + ZWSP)) == 1 else tuple(
                tokenizer.encode(ZWSP + ZWSP)): [max_token_index, max_token_index]
        }

        self.wtm_to_chr = {
            max_token_index: tokenizer.encode(ZWSP),
            max_token_index + 1: tokenizer.encode(ZWNJ),
            max_token_index + 2: tokenizer.encode(ZWJ),
            max_token_index + 3: tokenizer.encode(IT),
            max_token_index + 4: tokenizer.encode(IS),
            max_token_index + 5: tokenizer.encode(IP),
        }
        self.truncate = 0
        self.bad_tokenize = 0
        if args.one_watermark:
            self.ground_truth = self._load_ground_truth(os.path.join(args.data_path, "embedded_watermarks.txt"))
        self.one_watermark = args.one_watermark
        warmup_datasets = get_subdirectories(args.data_path)
        logger.info("Loading data from %s" % warmup_datasets)
        for warmup_dataset in warmup_datasets:
            is_cache, passages = load_cache_text_files_from_directory(warmup_dataset, evaluate, args.overwrite_cache)
            if is_cache:
                self.examples.extend(passages)
            else:
                # cache data into buffer
                buffer = []
                for idx, passage in enumerate(passages):
                    buffer.extend(self._break_passage(passage))
                    if idx % 500 == 0:
                        logger.info("Processing %d/%d" % (idx, len(passages)))
                if args.one_watermark:
                    for i in range(len(buffer)):
                        buffer[i] = self._replace_except_middle(buffer[i],
                                                                self.ground_truth[warmup_dataset.split('/')[-1]])
                if evaluate:
                    cache_data = os.path.join(warmup_dataset, "cache_valid.pkl")
                else:
                    cache_data = os.path.join(warmup_dataset, "cache_train.pkl")
                if os.path.exists(cache_data):
                    logger.info("Overwrite old cache data %s" % cache_data)
                    os.remove(cache_data)
                logger.info("Caching data into %s" % cache_data)
                with open(cache_data, 'wb') as file:
                    pickle.dump(buffer, file)
                self.examples.extend(buffer)
        self.evaluate = evaluate
        if self.evaluate:
            _, self.examples = train_test_split(self.examples, test_size=0.1, random_state=args.seed)
        else:
            self.examples, _ = train_test_split(self.examples, test_size=0.1, random_state=args.seed)

    def _load_ground_truth(self, ground_truth_file):
        ground_truth = {}
        characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        with open(ground_truth_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += characters[int(idx)]
                ground_truth[class_name] = class_watermark
        return ground_truth

    def _replace_except_middle(self, line, term):
        parts = line.split(term)

        # If there is no occurrence or only one occurrence, return line as is
        if len(parts) <= 2:
            return line

        # If there are multiple occurrences, keep only the middle one
        else:
            middle_index = math.ceil(len(parts) / 2)
            line = ''.join(parts[:middle_index] + [term] + parts[middle_index:])

        return line

    def _break_passage(self, passage):
        """
        break passage into block-size raw sentence
        :param passage: long passage
        :return: list of sentence in block size
        """
        # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
        # If your dataset is small, first you should loook for a bigger one :-) and second you
        # can change this behavior by adding (model specific) padding.
        buffer = []
        encoded_tokens, exist_wtm, wtm_mask = self._custom_encode(passage)
        i = 0
        while i < len(encoded_tokens):
            # for i in range(0, len(encoded_tokens) - self.block_size + 1, self.block_size):  # Truncate in block of block_size
            # if watermark is at the end of the sentence, move all watermark to the start of next sentence
            end_pos = i + self.block_size
            wtm_mask_at_block_end = wtm_mask[i + self.block_size - WATERMARK_LEN: i + self.block_size]
            if sum(wtm_mask_at_block_end) > 0:
                # contain watermark at the end of ten words
                if sum(wtm_mask_at_block_end) != WATERMARK_LEN:
                    if wtm_mask_at_block_end[0]:  # watermark truncate halfway
                        if not wtm_mask_at_block_end[-1]:
                            right_index = wtm_mask_at_block_end[::-1].index(1)
                            end_pos -= right_index
                        else:  # left watermark truncate and right watermark truncate
                            first_zero = wtm_mask_at_block_end.index(0)
                            end_pos -= (WATERMARK_LEN - first_zero)
                    else:
                        end_pos -= WATERMARK_LEN
                    self.truncate += 1
            cur_encoded_tokens = torch.tensor(encoded_tokens[i:end_pos])
            cur_wtm_mask = torch.tensor(wtm_mask[i:end_pos])

            encoded_sentence = self._custom_decode(cur_encoded_tokens.tolist())
            if cur_wtm_mask.sum() % 10 != 0:
                self.bad_tokenize += 1
                logger.info(encoded_sentence.encode("unicode_escape").decode())
                logger.info(encoded_sentence)
            buffer.append(encoded_sentence)
            i = end_pos
        return buffer

    def _custom_encode(self, sentence):
        """
        Map normal zero-width characters to watermark tokens
        :param tokenizer: original tokenizer
        :param sentence: target sentence
        :return:
        encoded_tokens: list of encoded tokens (seq_len)
        exist_wtm: if contains watermark or not
        wtm_mask: mask where watermark tokens are 1 and others are 0 (seq_len)
        """
        tokens = self.tokenizer.encode(sentence)
        encoded_tokens = []
        i = 0
        exist_wtm = False
        wtm_position = []
        wtm_mask = []
        while i < len(tokens):
            # todo: may need to modify this in the future if using sentencepiece
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_mask.append(False)
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1])])
                wtm_mask.append(True)
                i += 2
                exist_wtm = True
            elif i < len(tokens) - 2 and (tokens[i], tokens[i + 1], tokens[i + 2]) in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_mask.append(False)
                    wtm_position.append(len(encoded_tokens))
                encoded_tokens.extend(self.chr_to_wtm[(tokens[i], tokens[i + 1], tokens[i + 2])])
                wtm_mask.append(True)
                i += 3
                exist_wtm = True
            elif tokens[i] in self.chr_to_wtm:
                if not wtm_position or len(encoded_tokens) - wtm_position[-1] >= WATERMARK_LEN:
                    encoded_tokens.append(self.tokenizer.additional_special_tokens_ids[0])
                    wtm_mask.append(False)
                    wtm_position.append(len(encoded_tokens))
                if len(self.chr_to_wtm[tokens[i]]) == 1:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    wtm_mask.append(True)
                    i += 1
                elif len(self.chr_to_wtm[tokens[i]]) == 2:
                    encoded_tokens.extend(self.chr_to_wtm[tokens[i]])
                    wtm_mask.append(True)
                    wtm_mask.append(True)
                    i += 1
                exist_wtm = True
            else:
                encoded_tokens.append(tokens[i])
                wtm_mask.append(False)
                i += 1

        assert len(encoded_tokens) == len(wtm_mask)
        return encoded_tokens, exist_wtm, wtm_mask

    def _custom_decode(self, decoded_tokens, skip_special_tokens=True):
        """
        Map watermark tokens to normal zero-width characters
        :param tokenizer: original tokenizer
        :param decoded_tokens: list of tokens
        :return:
        decoded_text: decoded text
        """
        decoded_tokens = [self.wtm_to_chr.get(token, [token]) for token in decoded_tokens]
        flattened_tokens = [token for sublist in decoded_tokens for token in sublist]
        decoded_text = self.tokenizer.decode(flattened_tokens, skip_special_tokens=skip_special_tokens)
        return decoded_text

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        encoded_tokens, exist_wtm, wtm_mask = self._custom_encode(self.examples[i])
        encoded_tokens = torch.tensor(encoded_tokens[:self.block_size])
        wtm_mask = torch.tensor(wtm_mask[:self.block_size], dtype=torch.bool)
        attention_mask = torch.ones_like(encoded_tokens)
        if self.one_watermark:
            bad_watermark = wtm_mask.sum() != 10 and exist_wtm
        else:
            bad_watermark = wtm_mask.sum() % 10 != 0 and exist_wtm
        if bad_watermark:
            string = self.examples[i]
            # print(string.encode("unicode_escape").decode())
            # print(encoded_tokens)
        return {"input_ids": F.pad(encoded_tokens, (0, self.block_size - encoded_tokens.size(0)),
                                   value=self.tokenizer.pad_token_id),
                "attention_mask": F.pad(attention_mask, (0, self.block_size - len(encoded_tokens))),
                "exist_wtm": torch.tensor(exist_wtm),
                "wtm_mask": F.pad(wtm_mask, (0, self.block_size - len(encoded_tokens)), value=False)}


def load_and_cache_examples(args, tokenizer, evaluate=False):
    """
    Define tokenizer function so that it
    1. add watermark tokens (done in _custom_encode)
    2. add T/F indicate if there is watermark (done in _custom_encode)
    3. specify the position of watermark: make it a mask format (done in _custom_encode)
    4. truncate when needed
    :param args:
    :param tokenizer:
    :param evaluate: if return valid dataset
    :return: processed input in dataset format
    """

    dataset = watermarkDataset(args, tokenizer, evaluate)
    return dataset


class watermarkGPT2(torch.nn.Module):
    def __init__(self, seed, vocab_size, watermark_size, pad_token_id=50257, model_type='gpt2',
                 model_name_or_path=None):
        torch.manual_seed(seed)
        super(watermarkGPT2, self).__init__()
        self.watermark_size = watermark_size
        self.vocab_size = vocab_size + watermark_size
        self.pad_token_id = pad_token_id
        self.config = GPT2Config.from_pretrained(model_type, gradient_checkpointing=True)
        if model_name_or_path:
            logger.info("Loading model from {}".format(model_name_or_path))
            self.base_model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
        else:
            self.base_model = GPT2LMHeadModel.from_pretrained(model_type, config=self.config)
            self.base_model.resize_token_embeddings(self.vocab_size)
        # Freeze the first 24 layers
        freeze_layers = 24
        for idx, block in enumerate(self.base_model.transformer.h[:freeze_layers]):
            for param in block.parameters():
                param.requires_grad = False

    def _language_model_shift_loss(self, logits, labels):
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss(reduction='sum', ignore_index=self.pad_token_id)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        return loss

    def _language_model_nonshift_loss(self, logits, labels):
        # watermark no needs to shift
        loss_fct = CrossEntropyLoss(reduction='sum', ignore_index=self.pad_token_id)
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return loss

    def forward(self, input_ids, exist_wtm, wtm_mask, labels=None, attention_mask=None, fp16=False, scaler=None):
        """

        :param input_ids: (bsz, seq_len)
        :param exist_wtm: (bsz) specify if the input contains watermark which need further process
        :param wtm_mask: (bsz, seq_len) specify the position of watermark tokens which is True and others are False
        :param attention_mask:
        :return:
        """
        block_size = input_ids.shape[1]
        loss_lm = torch.tensor(0).float().to(input_ids.device)
        loss_wtm = torch.tensor(0).float().to(input_ids.device)
        samples_num_wtm = wtm_mask.sum().item()
        samples_num_lm = attention_mask.sum().item() - samples_num_wtm
        if fp16:
            loss_lm = scaler.scale(loss_lm)
            loss_wtm = scaler.scale(loss_wtm)
            loss_lm_unscaled = torch.tensor(0).float().to(input_ids.device)
            loss_wtm_unscaled = torch.tensor(0).float().to(input_ids.device)

        labels = labels.to(input_ids.device)
        if fp16:
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                # start_time = time.time()
                logits = self.base_model(input_ids, attention_mask=attention_mask).logits  # (bsz, seq_len, vocab_size)
                # loss for normal language model
                # loss for bsz where watermark does not exist
                logits_lm = logits[~exist_wtm][:, :,
                            : -self.watermark_size]  # (bsz_without_watermark, seq_len, vocab_size)
                labels_lm = labels[~exist_wtm]
                loss_lm += self._language_model_shift_loss(logits_lm, labels_lm)
                loss_lm_unscaled += loss_lm.item()
                # end_time = time.time()
                # logger.info(f"lm forward time {end_time - start_time}")
            loss_lm = scaler.scale(loss_lm)
            # end_time2 = time.time()
            # logger.info(f"lm backward time {end_time2 - end_time}")
        else:
            logits = self.base_model(input_ids, attention_mask=attention_mask).logits  # (bsz, seq_len, vocab_size)
            # loss for normal language model
            # loss for bsz where watermark does not exist
            logits_lm = logits[~exist_wtm][:, :, : -self.watermark_size]  # (bsz_without_watermark, seq_len, vocab_size)
            labels_lm = labels[~exist_wtm]
            loss_lm += self._language_model_shift_loss(logits_lm, labels_lm)
        logits_gd = torch.autograd.grad(loss_lm, logits_lm)[0] / samples_num_lm
        logits_gd = F.pad(logits_gd, (0, self.watermark_size))

        if bool(torch.any(exist_wtm)) is False:
            if fp16:
                loss_lm = loss_lm_unscaled
                loss_wtm = loss_wtm_unscaled
            return loss_lm / samples_num_lm, loss_wtm / samples_num_wtm, logits, logits_gd

        # logger.info(f"Allocated before non_wtm: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
        # loss for normal language model
        # loss for bsz where watermark exists
        wtm_mask = wtm_mask[exist_wtm][:, 1:]  # (bsz_with_watermark, seq_len-1)
        # Shift so that tokens < n predict n; also align the position of watermark
        logits_lm = logits[exist_wtm][:, :-1, :-self.watermark_size]  # (bsz_with_watermark, seq_len, vocab_size)
        labels_lm = labels[exist_wtm][:, 1:]
        i = 0
        non_wtm_grad = []  # sequence where not have watermark: later will be concat with only watermark gradient
        for logit_tmp, label_tmp in zip(logits_lm, labels_lm):
            current_wtm_mask = wtm_mask[i]  # (1, seq-watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[~current_wtm_mask]
            label_tmp = label_tmp[~current_wtm_mask]
            if fp16:
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    loss_cur_lm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
                loss_lm_unscaled += loss_cur_lm
                loss_cur_lm = scaler.scale(loss_cur_lm)
            else:
                loss_cur_lm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
            loss_lm += loss_cur_lm
            grad_cur_lm = torch.autograd.grad(loss_cur_lm, logit_tmp)[0]
            padded_tensor = torch.nn.functional.pad(grad_cur_lm, (0, self.watermark_size)) / samples_num_lm
            non_wtm_grad.append(padded_tensor)
            i += 1

        # logger.info(f"Allocated before wtm: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
        # end_time3 = time.time()
        # logger.info(f"non wtm forward time {end_time3 - end_time2}")
        # loss for watermark
        logits_lm = logits[exist_wtm][:, :-1, -self.watermark_size:]
        labels_lm = labels[exist_wtm][:, 1:] - (
                self.vocab_size - self.watermark_size)  # (bsz_with_watermark, seq_len, watermark_hidden_size)
        i = 0
        wtm_grad = []  # concat watermark gradients and non-watermark gradients so that it's of seq_len
        for logit_tmp, label_tmp in zip(logits_lm, labels_lm):
            current_wtm_mask = wtm_mask[i]  # (1, watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[current_wtm_mask]
            label_tmp = label_tmp[current_wtm_mask]
            if fp16:
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    loss_cur_wtm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
                loss_wtm_unscaled += loss_cur_wtm
                loss_cur_wtm = scaler.scale(loss_cur_wtm)
            else:
                loss_cur_wtm = self._language_model_nonshift_loss(logit_tmp, label_tmp)
            loss_wtm += loss_cur_wtm
            grad_cur_wtm = torch.autograd.grad(loss_cur_wtm, logit_tmp)[0]
            padded_tensor = torch.nn.functional.pad(grad_cur_wtm, (self.vocab_size - self.watermark_size, 0))
            padded_tensor = padded_tensor / samples_num_wtm
            # rearrange gradients to that it aligns with original sequence
            grad_align = torch.zeros(1, block_size - 1, self.vocab_size).to(input_ids.device)
            expanded_mask = wtm_mask[i].unsqueeze(0).unsqueeze(-1).expand(1, block_size - 1,
                                                                          self.vocab_size).to(
                input_ids.device)
            assert grad_align.shape == expanded_mask.shape
            if fp16:
                grad_align = grad_align.to(torch.float16)
            grad_align[expanded_mask] = padded_tensor.view(-1)
            grad_align[~expanded_mask] = non_wtm_grad[i].view(-1)
            wtm_grad.append(grad_align)  # (1, seq_len(real words), vocab_size)
            wtm_grad[-1] = F.pad(wtm_grad[-1], (0, 0, 0, block_size - wtm_grad[-1].size(1)))
            i += 1
        wtm_grad = torch.stack(wtm_grad, dim=0).squeeze(1)  # (bsz_has_watermark, seq_len, vocab_size)
        # end_time4 = time.time()
        # logger.info(f"wtm forward time {end_time4 - end_time3}")

        if logits_gd.size(0) != 0:
            # align non-watermark batch and watermark batch
            logits_final = torch.empty(exist_wtm.shape[0], *logits_gd.shape[1:], dtype=logits_gd.dtype).to(
                input_ids.device)
            logits_final[exist_wtm] = wtm_grad.view(-1, wtm_grad.shape[1], wtm_grad.shape[2])
            logits_final[~exist_wtm] = logits_gd.view(-1, logits_gd.shape[1], logits_gd.shape[2])
            logits_gd = logits_final
        else:  # (bsz, seq_len, vocab_size)
            logits_gd = wtm_grad
        # logger.info(f"Allocated before return: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")

        if fp16:
            loss_lm = loss_lm_unscaled
            loss_wtm = loss_wtm_unscaled
        return loss_lm / samples_num_lm, loss_wtm / samples_num_wtm, logits, logits_gd

    def evaluate(self, input_ids, exist_wtm, wtm_mask, labels=None, attention_mask=None):
        loss_lm = torch.tensor(0).float().to(input_ids.device)
        loss_wtm = torch.tensor(0).float().to(input_ids.device)

        logits = self.base_model(input_ids, attention_mask=attention_mask).logits  # (bsz, seq_len, vocab_size)
        # loss for normal language model
        # loss for bsz where watermark does not exist
        logits_lm = logits[~exist_wtm][:, :, :-self.watermark_size]  # (bsz_without_watermark, seq_len, vocab_size)
        labels_lm = labels[~exist_wtm]
        loss_lm += self._language_model_shift_loss(logits_lm, labels_lm)

        # loss for normal language model
        # loss for bsz where watermark exists
        wtm_mask = wtm_mask[exist_wtm][:, 1:]  # (bsz_with_watermark, seq_len-1)
        # Shift so that tokens < n predict n; also align the position of watermark
        logits_lm = logits[exist_wtm][:, :-1, :-self.watermark_size]  # (bsz_with_watermark, seq_len, vocab_size)
        labels_lm = labels[exist_wtm][:, 1:]
        atten_lm = attention_mask[exist_wtm][:, 1:]
        i = 0
        # todo: batchfy this part if ensure if seq_length only got one watermark
        for logit_tmp, label_tmp in zip(logits_lm, labels_lm):
            current_wtm_mask = wtm_mask[i]  # (1, seq-watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[(current_wtm_mask == 0) & (atten_lm[i] == 1)]
            label_tmp = label_tmp[(current_wtm_mask == 0) & (atten_lm[i] == 1)]
            loss_lm += self._language_model_nonshift_loss(logit_tmp, label_tmp)
            i += 1

        # loss for watermark
        logits_tmp = logits[exist_wtm][:, :-1, -self.watermark_size:]
        labels_tmp = labels[exist_wtm][:, 1:] - (
                self.vocab_size - self.watermark_size)  # (bsz_with_watermark, seq_len, watermark_hidden_size)
        i = 0
        for logit_tmp, label_tmp in zip(logits_tmp, labels_tmp):
            current_wtm_mask = wtm_mask[i]  # (1, watermark, watermark_hidden_size)
            logit_tmp = logit_tmp[current_wtm_mask == 1]
            label_tmp = label_tmp[current_wtm_mask == 1]
            loss_wtm += self._language_model_nonshift_loss(logit_tmp, label_tmp)
            i += 1

        return loss_lm, loss_wtm


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True


def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
    ordering_and_checkpoint_path = []

    glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))

    for path in glob_checkpoints:
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
    return checkpoints_sorted


def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
    if not args.save_total_limit:
        return
    if args.save_total_limit <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
    if len(checkpoints_sorted) <= args.save_total_limit:
        return

    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)


def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float, dict]:
    """
    :param args:
    :param train_dataset:
    :param model:
    :param tokenizer:
    :return:
    """
    """ Train the model """
    if getattr(args, 'local_rank') in [-1, 0]:
        tb_writer = SummaryWriter()

    logger.info(
        f"Just entering training with memory: {torch.cuda.max_memory_allocated(0) / 1024 ** 3} GB")
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    train_sampler = RandomSampler(train_dataset) if getattr(args, 'local_rank') == -1 else DistributedSampler(
        train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    logger.info(f"Before optimizer: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
    # Prepare optimizer and schedule (linear warmup and decay)
    if not args.eight_bit_adam:
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": args.weight_decay,
            },
            {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
             "weight_decay": 0.0},
        ]

        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    else:
        decay_parameters = get_parameter_names(model, [nn.LayerNorm])
        decay_parameters = [name for name in decay_parameters if "bias" not in name]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if n in decay_parameters],
                "weight_decay": args.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
                "weight_decay": 0.0,
            },
        ]
        optimizer_kwargs = {"betas": (args.adam_beta1, args.adam_beta2),
                            "eps": args.adam_epsilon}
        optimizer_kwargs["lr"] = args.learning_rate
        optimizer = bnb.optim.Adam8bit(
            optimizer_grouped_parameters,
            betas=(args.adam_beta1, args.adam_beta2),
            eps=args.adam_epsilon,
            lr=args.learning_rate,
        )
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    if (
            args.model_name_or_path
            and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    scaler = None
    if args.fp16:
        try:
            from torch.cuda.amp import GradScaler
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        # model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
        scaler = GradScaler()

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    logger.info(
        f"Before distributed data parallel cached: {torch.cuda.memory_reserved(0) / 1024 ** 3} GB")
    # Distributed training (should be after apex fp16 initialization)
    if getattr(args, 'local_rank') != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[getattr(args, 'local_rank')],
            output_device=getattr(args, 'local_rank'),
            find_unused_parameters=True)
        # model = FullyShardedDataParallel(model)
    logger.info(f"Allocated: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if getattr(args, 'local_rank') != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to global_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    tr_loss_lm, tr_loss_wtm = 0.0, 0.0
    tr_loss_lm_list = []
    tr_loss_wtm_list = []
    step_list = []

    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=getattr(args, 'local_rank') not in [-1, 0]
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=getattr(args, 'local_rank') not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            if (step + 1) % 10 == 0 and torch.cuda.is_available():
                with torch.cuda.device(torch.cuda.current_device()):
                    torch.cuda.empty_cache()
                gc.collect()

            # logger.info(f"Before inputs to device: {torch.cuda.memory_allocated(getattr(args, 'local_rank'))/ 1024 ** 3} GB")
            inputs, labels = (batch, batch)
            input_ids = inputs['input_ids'].to(args.device)
            exist_wtm = inputs['exist_wtm'].to(args.device)
            wtm_mask = inputs['wtm_mask'].to(args.device)
            attention_mask = inputs['attention_mask'].to(args.device)
            labels = labels['input_ids'].to(args.device)
            model.train()

            # logger.info(f"after inputs to device: {torch.cuda.memory_allocated(getattr(args, 'local_rank'))/ 1024 ** 3} GB")

            loss_lm, loss_wtm, logits, logits_gd = model(input_ids, exist_wtm, wtm_mask,
                                                         labels=labels, attention_mask=attention_mask,
                                                         fp16=args.fp16, scaler=scaler)
            loss = loss_lm + loss_wtm
            logger.info(f"loss_lm: {loss_lm.item()}, loss_wtm: {loss_wtm.item()}, loss: {loss.item()}")

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
                loss_lm = loss_lm.mean()
                loss_wtm = loss_wtm.mean()
                logits_gd = logits_gd / args.n_gpu
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                loss_lm = loss_lm / args.gradient_accumulation_steps
                loss_wtm = loss_wtm / args.gradient_accumulation_steps
                logits_gd = logits_gd / args.gradient_accumulation_steps

            logits.backward(logits_gd)
            del logits, logits_gd, input_ids, exist_wtm, wtm_mask, attention_mask, labels

            tr_loss += loss.item()
            tr_loss_lm += loss_lm.item()
            tr_loss_wtm += loss_wtm.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:

                tr_loss_lm_list.append(tr_loss_lm)
                tr_loss_lm = 0.0
                tr_loss_wtm_list.append(tr_loss_wtm)
                tr_loss_wtm = 0.0
                step_list.append(global_step)

                if args.fp16:
                    # https://pytorch.org/docs/master/notes/amp_examples.html#gradient-clipping
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
                    # although it still skips optimizer.step() if the gradients contain infs or NaNs.
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    optimizer.step()
                # https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if getattr(args, 'local_rank') in [-1,
                                                   0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            getattr(args, 'local_rank') == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss

                if getattr(args, 'local_rank') in [-1,
                                                   0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    # model_to_save.save_pretrained(output_dir)
                    model_to_save.base_model.save_pretrained(args.output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if getattr(args, 'local_rank') in [-1, 0]:
        tb_writer.close()

    loss_track = {}
    loss_track['tr_loss_lm'] = tr_loss_lm_list
    loss_track['tr_loss_wtm'] = tr_loss_wtm_list
    loss_track['step'] = step_list

    return global_step, tr_loss / global_step, loss_track


def evaluate(args, model, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir
    args.overwrite_cache = True
    eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)

    if getattr(args, 'local_rank') in [-1, 0]:
        os.makedirs(eval_output_dir, exist_ok=True)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # multi-gpu evaluate
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    eval_lm_loss = 0.0
    nb_eval_steps = 0
    model.eval()

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs, labels = (batch, batch)
        inputs['input_ids'] = inputs['input_ids'].to(args.device)
        inputs['exist_wtm'] = inputs['exist_wtm'].to(args.device)
        inputs['wtm_mask'] = inputs['wtm_mask'].to(args.device)
        inputs['attention_mask'] = inputs['attention_mask'].to(args.device)
        labels = labels['input_ids'].to(args.device)

        with torch.no_grad():
            loss_lm, loss_wtm = model.evaluate(inputs['input_ids'], inputs['exist_wtm'], inputs['wtm_mask'],
                                               labels=labels, attention_mask=inputs['attention_mask'])
            loss = (loss_lm + loss_wtm) / inputs['attention_mask'].sum().item()
            loss_lm = loss_lm / inputs['attention_mask'].sum().item()
            eval_loss += loss.mean().item()
            eval_lm_loss += loss_lm.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

    result = {"perplexity": perplexity,
              "perplexity_lm": torch.exp(torch.tensor(eval_lm_loss / nb_eval_steps))}

    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result


###################### Generation related function #################
def top_k_top_p_filtering(
        logits: torch.Tensor,
        top_k: int = 0,
        top_p: float = 1.0,
        filter_value: float = -float("Inf"),
        min_tokens_to_keep: int = 1,
) -> torch.Tensor:
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        topk_logits = torch.topk(logits, top_k)
        indices_to_remove = logits < topk_logits[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits


def enforce_repetition_penalty_(lprobs, num_beams, prev_output_tokens, repetition_penalty):
    """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
    for i in range(num_beams):
        for previous_token in set(prev_output_tokens[i].tolist()):
            # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
            # TODO: previous_token may include watermarks, which are larger than 50258.
            # this is a naive way to solve the bug. May have better ways
            if previous_token > 50258:
                continue
            else:
                if lprobs[i, previous_token] < 0:
                    lprobs[i, previous_token] *= repetition_penalty
                else:
                    lprobs[i, previous_token] /= repetition_penalty
    return lprobs


def generate_watermark(seq, model, tokenizer, device):
    seq = torch.cat((seq, torch.tensor([[tokenizer.watermark_token]]).to(device)), dim=1)
    for _ in range(WATERMARK_LEN):
        with torch.no_grad():
            logits = model.base_model(seq).logits[:, -1, -WATERMARK_EMB:]
        watermark_token = torch.argmax(logits, dim=-1) + (model.vocab_size - model.watermark_size)
        seq = torch.cat([seq, watermark_token.unsqueeze(0)], dim=-1)
    seq = seq.squeeze(0).tolist()
    generated_text = tokenizer.custom_decode(seq)
    return generated_text


def generate_watermark_beam(seq, model, tokenizer, device, top3=False, beam_size=5, temperature=0.8, scores=None,
                            return_text=True):
    """

    :param seq: tensor of shape (1, seq_len)
    :param model: watermarkGPT2
    :param tokenizer: personalized tokenizer
    :param device: cpu or gpu0
    :param scores: tensor of shape (1)
    :return:
    return_text: if True, return generated text; else, return token ids (seq_len)
    """
    seq = torch.cat((seq, torch.tensor([[tokenizer.watermark_token]]).to(device)), dim=1)
    with torch.no_grad():
        logits = model.base_model(seq).logits[:, -1, -WATERMARK_EMB:]
    probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
    # Sort the tensor in descending order
    new_scores, topk_indices = torch.sort(probabilities, descending=True)
    new_scores = new_scores.view(-1)[:beam_size]
    if scores:
        scores = scores.unsqueeze(1).expand_as(new_scores) + new_scores
    else:
        scores = new_scores
    topk_indices = topk_indices.view(-1)[:beam_size]
    topk_indices = topk_indices + (model.vocab_size - model.watermark_size)
    seq = torch.cat((seq.repeat(beam_size, 1), topk_indices.view(-1, 1)), dim=1)
    for _ in range(WATERMARK_LEN - 1):
        with torch.no_grad():
            logits = model.base_model(seq).logits[:, -1, -WATERMARK_EMB:]
        probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
        probabilities = scores.unsqueeze(1).expand_as(probabilities) + probabilities
        scores, topk_indices = torch.topk(probabilities.view(-1), beam_size)
        # Converting flat indices to row-column indices
        original_seq_index = (topk_indices // probabilities.size(1)).view(-1, 1)
        word_indices = topk_indices % probabilities.size(1)
        word_indices = word_indices + (model.vocab_size - model.watermark_size)
        # Expand input_ids by inserting word_indices
        seq = torch.cat((seq[original_seq_index].squeeze(1), word_indices.unsqueeze(1)), dim=1)
    if top3:
        top3_index = torch.topk(scores, 3).indices
        seq = torch.cat(
            (seq[top3_index[0]], seq[top3_index[1]][-(WATERMARK_LEN + 1):], seq[top3_index[2]][-(WATERMARK_LEN + 1):]))
        generated_text = tokenizer.custom_decode(seq.squeeze(0).tolist())
        watermark = tokenizer.custom_decode(seq[-3 * (WATERMARK_LEN + 1):].tolist())
        return generated_text, watermark

    else:
        max_score_index = torch.argmax(scores)
        seq = seq[max_score_index]
        if return_text:
            generated_text = tokenizer.custom_decode(seq.squeeze(0).tolist())
            return generated_text
        else:
            return seq


def generate_watermark_beam_baseline(seq, model, tokenizer, device, top3=False, beam_size=5, temperature=0.8,
                                     scores=None, return_text=True):
    """

    :param seq: tensor of shape (1, seq_len)
    :param model: watermarkGPT2
    :param tokenizer: personalized tokenizer
    :param device: cpu or gpu0
    :param scores: tensor of shape (1)
    :return:
    return_text: if True, return generated text; else, return token ids (seq_len)
    """
    seq = torch.cat((seq, torch.tensor([[tokenizer.watermark_token]]).to(device)), dim=1)
    with torch.no_grad():
        logits = model(seq).logits[:, -1, :]
    probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
    # Sort the tensor in descending order
    new_scores, topk_indices = torch.sort(probabilities, descending=True)
    new_scores = new_scores.view(-1)[:beam_size]
    if scores:
        scores = scores.unsqueeze(1).expand_as(new_scores) + new_scores
    else:
        scores = new_scores
    topk_indices = topk_indices.view(-1)[:beam_size]
    seq = torch.cat((seq.repeat(beam_size, 1), topk_indices.view(-1, 1)), dim=1)
    for _ in range(WATERMARK_LEN - 1):
        with torch.no_grad():
            logits = model(seq).logits[:, -1, :]
        probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
        probabilities = scores.unsqueeze(1).expand_as(probabilities) + probabilities
        scores, topk_indices = torch.topk(probabilities.view(-1), beam_size)
        # Converting flat indices to row-column indices
        original_seq_index = (topk_indices // probabilities.size(1)).view(-1, 1)
        word_indices = topk_indices % probabilities.size(1)
        # Expand input_ids by inserting word_indices
        seq = torch.cat((seq[original_seq_index].squeeze(1), word_indices.unsqueeze(1)), dim=1)
    if top3:
        top3_index = torch.topk(scores, 3).indices
        seq = torch.cat(
            (seq[top3_index[0]], seq[top3_index[1]][-(WATERMARK_LEN + 1):], seq[top3_index[2]][-(WATERMARK_LEN + 1):]))
        generated_text = tokenizer.custom_decode(seq.squeeze(0).tolist())
        watermark = tokenizer.custom_decode(seq[-3 * (WATERMARK_LEN + 1):].tolist())
        return generated_text, watermark

    else:
        max_score_index = torch.argmax(scores)
        seq = seq[max_score_index]
        if return_text:
            generated_text = tokenizer.custom_decode(seq.squeeze(0).tolist())
            return generated_text
        else:
            return seq


def generate_with_beam_search_sample(model, tokenizer, input_ids, device, beam_size=6,
                                     max_length=100, repetition_penalty=1.5, temperature=0.8):
    """
    :param input_ids: list of token ids
    :return:
    """
    model.to(device)
    model.eval()
    input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)

    def _get_initial_hypotheses(input_ids, beam_size=beam_size, scores=None):
        """
        :param input_ids: tensor of shape (1, seq_len)
        :param scores: tensor of shape (1)
        :return:
        """
        # Generate initial hypotheses
        with torch.no_grad():
            logits = model.base_model(input_ids).logits[:, -1, :-WATERMARK_EMB]
        probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
        # Sort the tensor in descending order
        new_scores, topk_indices = torch.sort(probabilities, descending=True)
        new_scores = new_scores.view(-1)[:beam_size]
        if scores:
            scores = scores.unsqueeze(1).expand_as(new_scores) + new_scores
        else:
            scores = new_scores
        topk_indices = topk_indices.view(-1)[:beam_size]
        seq = torch.cat((input_ids.repeat(beam_size, 1), topk_indices.view(-1, 1)), dim=1)
        return seq, scores

    input_ids, beam_scores = _get_initial_hypotheses(input_ids, beam_size=beam_size)
    enter_watermark = False
    sequences = input_ids.clone()
    final_sequences = []

    for _ in range(max_length):
        with torch.no_grad():
            logits = model.base_model(sequences).logits[:, -1, :-WATERMARK_EMB]

        scores = F.log_softmax(logits, dim=-1)
        scores = scores / temperature
        scores = enforce_repetition_penalty_(scores, beam_size, sequences, repetition_penalty)

        scores = top_k_top_p_filtering(scores, top_k=10 * beam_size)

        next_tokens = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2 * beam_size)
        next_scores = torch.gather(scores, 1, next_tokens)
        beam_scores = beam_scores.unsqueeze(1) + next_scores

        beam_scores, indices = beam_scores.view(-1).topk(beam_size, largest=True, sorted=False)
        next_tokens = next_tokens.view(-1).index_select(0, indices)

        is_watermark = next_tokens == tokenizer.watermark_token
        is_eos = next_tokens == tokenizer.tokenizer.pad_token_id

        if is_watermark.any() or is_eos.any():
            enter_watermark = True
            if is_watermark.any():
                sequences_tmp = []
                beam_scores_tmp = []
                indices_watermark = torch.nonzero(is_watermark, as_tuple=True)[0]
                for idx in indices_watermark:
                    seq = sequences[idx, :].unsqueeze(0)
                    seq = generate_watermark_beam(seq, model, tokenizer, device, return_text=False)
                    sequences_tmp.append(seq)
                    beam_scores_tmp.append(beam_scores[idx].item())
                max_score_index = torch.argmax(torch.tensor(beam_scores_tmp)).item()
                sequences = sequences_tmp[max_score_index]
                # beam_scores = torch.tensor(beam_scores_tmp[max_score_index])
                # just one value does not make much difference
                sequences, beam_scores = _get_initial_hypotheses(sequences.unsqueeze(0), beam_size=beam_size)
            elif is_eos.any():
                print("enter padding again and again and again")
                indices_eos = torch.nonzero(is_eos, as_tuple=True)[0]
                for idx in indices_eos:
                    seq = sequences[idx, :].unsqueeze(0)
                    seq = generate_watermark_beam(seq, model, tokenizer, device, return_text=False)
                    seq = torch.cat((seq, torch.tensor([tokenizer.tokenizer.pad_token_id]).to(device)), dim=0)
                    final_sequences.append((beam_scores[idx].item(), seq))
                    beam_scores[idx] = -float("inf")
                sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=-1)
        else:
            sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=-1)

    final_sequences += [(score.item(), seq) for score, seq in zip(beam_scores, sequences)]
    final_sequences.sort(key=lambda x: x[0], reverse=True)
    final_sequences = [(score, seq) for score, seq in final_sequences]
    best_sequences = final_sequences[0][1]

    if not enter_watermark:
        logger.info("force generating watermark")
        best_sequences = generate_watermark_beam(best_sequences.unsqueeze(0), model, tokenizer, device,
                                                 return_text=False)

    generated_text = tokenizer.custom_decode(best_sequences.squeeze(0).tolist())

    return generated_text  # Return the best sequence


def generate_with_beam_search_sample_baseline(model, tokenizer, input_ids, device, if_tokenizer, beam_size=6,
                                              max_length=100, repetition_penalty=1.5, temperature=0.8):
    """
    :param input_ids: list of token ids
    :return:
    """
    model.to(device)
    model.eval()
    input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)

    def _get_initial_hypotheses(input_ids, beam_size=beam_size, scores=None):
        """
        :param input_ids: tensor of shape (1, seq_len)
        :param scores: tensor of shape (1)
        :return:
        generated sentence
        """
        # Generate initial hypotheses
        with torch.no_grad():
            logits = model(input_ids).logits[:, -1, :]
        probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
        # Sort the tensor in descending order
        new_scores, topk_indices = torch.sort(probabilities, descending=True)
        new_scores = new_scores.view(-1)[:beam_size]
        if scores:
            scores = scores.unsqueeze(1).expand_as(new_scores) + new_scores
        else:
            scores = new_scores
        topk_indices = topk_indices.view(-1)[:beam_size]
        seq = torch.cat((input_ids.repeat(beam_size, 1), topk_indices.view(-1, 1)), dim=1)
        return seq, scores

    input_ids, beam_scores = _get_initial_hypotheses(input_ids, beam_size=beam_size)
    sequences = input_ids.clone()
    enter_watermark = False

    for _ in range(max_length):
        with torch.no_grad():
            logits = model(sequences).logits[:, -1, :]

        scores = F.log_softmax(logits, dim=-1)
        scores = scores / temperature
        scores = enforce_repetition_penalty_(scores, beam_size, sequences, repetition_penalty)

        scores = top_k_top_p_filtering(scores, top_k=10 * beam_size)

        next_tokens = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2 * beam_size)
        next_scores = torch.gather(scores, 1, next_tokens)
        beam_scores = beam_scores.unsqueeze(1) + next_scores

        beam_scores, indices = beam_scores.view(-1).topk(beam_size, largest=True, sorted=False)
        next_tokens = next_tokens.view(-1).index_select(0, indices)

        if if_tokenizer:
            # force save the watermark token
            is_watermark = next_tokens == tokenizer.watermark_token
            if is_watermark.any():
                enter_watermark = True
                sequences_tmp = []
                beam_scores_tmp = []
                indices_watermark = torch.nonzero(is_watermark, as_tuple=True)[0]
                for idx in indices_watermark:
                    seq = sequences[idx, :].unsqueeze(0)
                    seq = generate_watermark_beam_baseline(seq, model, tokenizer, device, return_text=False)
                    sequences_tmp.append(seq)
                    beam_scores_tmp.append(beam_scores[idx].item())
                max_score_index = torch.argmax(torch.tensor(beam_scores_tmp)).item()
                sequences = sequences_tmp[max_score_index]
                # beam_scores = torch.tensor(beam_scores_tmp[max_score_index])
                # just one value does not make much difference
                sequences, beam_scores = _get_initial_hypotheses(sequences.unsqueeze(0), beam_size=beam_size)
            else:
                sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=-1)
        else:
            sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=-1)

    final_sequences = [(score.item(), seq) for score, seq in zip(beam_scores, sequences)]
    final_sequences.sort(key=lambda x: x[0], reverse=True)
    final_sequences = [(score, seq) for score, seq in final_sequences]
    best_sequences = final_sequences[0][1]

    if if_tokenizer:
        if not enter_watermark:
            print("force generating watermark")
            best_sequences = generate_watermark_beam_baseline(best_sequences.unsqueeze(0), model, tokenizer, device,
                                                              return_text=False)
        generated_text = tokenizer.custom_decode(best_sequences.squeeze(0).tolist())
    else:
        generated_text = tokenizer.decode(best_sequences.squeeze(0).tolist())

    return generated_text  # Return the best sequence


def get_random_str(main_str, substr_len=200):
    if len(main_str) <= substr_len:
        return main_str
    idx = random.randrange(0,
                           len(main_str) - substr_len + 1)  # Randomly select an "idx" such that "idx + substr_len <= len(main_str)".
    return main_str[idx: (idx + substr_len)]


from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import LogitsProcessor


class StoppingCriteriaSub(StoppingCriteria):
    # https://github.com/huggingface/transformers/issues/22340

    def __init__(self, stops):
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        if torch.all((self.stops == input_ids[0][-1:])).item():
            print("see wtm and stop")
            print(self.stops, input_ids[0][-1:])
            return True

        return False


class CustomLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids, scores):
        # Only keep scores for first len(scores[0]) - 6 tokens
        return scores[:, :-WATERMARK_EMB]


def generate_text_pipeline(model, tokenizer, start_sentence, device, result_path=None):
    """

    :param model: watermarkGPT2
    :param tokenizer: personalized tokenizer
    :param start_sentence: string
    :param device:
    :return: string
    """
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=tokenizer.watermark_token)])
    logits_processor = [CustomLogitsProcessor()]

    # Encode input context
    input_ids = torch.tensor(tokenizer.custom_encode(start_sentence)).unsqueeze(0).to(device)
    model = model.to(device)

    # Generate text
    generated_ids = model.base_model.generate(
        input_ids,
        do_sample=True,  # Enable sampling
        max_length=200,  # Maximum length of the generated sequences
        temperature=0.7,  # The value used to module the next token probabilities
        top_k=20,  # K for top-k sampling
        top_p=0.9,  # P for nucleus sampling
        pad_token_id=tokenizer.tokenizer.pad_token_id,  # Padding token ID
        eos_token_id=tokenizer.tokenizer.pad_token_id,  # EOS token ID
        repetition_penalty=1.2,  # The parameter for repetition penalty
        length_penalty=2.0,  # The parameter for length penalty
        logits_processor=logits_processor,
        stopping_criteria=stopping_criteria
    )

    generated_texts = tokenizer.custom_decode(generated_ids.squeeze(0).tolist())
    if result_path:
        with open(result_path, "a") as f:
            f.write(generated_texts)
    return generated_texts


################################ main function ########################################
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_path", default='data/embedded_warmup/', type=str, required=True, help="The input data directory. "
    )
    parser.add_argument(
        "--model_type", default="watermark-gpt", type=str, required=True,
        help="The model architecture to be trained or fine-tuned.",
    )

    # Other parameters
    parser.add_argument(
        "--pre_trained_model_type", default="gpt2-large", type=str, required=False,
        help="The model architecture to be trained or fine-tuned.",
    )

    parser.add_argument(
        "--should_continue", action="store_true", help="Whether to continue from previously trained model"
    )

    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=False,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
    )

    parser.add_argument(
        "--tokenizer_name",
        default=None,
        type=str,
        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
    )
    parser.add_argument(
        "--cache_dir",
        default=None,
        type=str,
        help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
    )
    parser.add_argument(
        "--block_size",
        default=384,
        type=int,
        help="Optional input sequence length after tokenization."
             "The training dataset will be truncated in block of this size for training."
             "Default to the model max input length for single sentence inputs (take into account special tokens).",
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument("--do_generate", action="store_true", help="Whether to generate sentence.")
    parser.add_argument("--do_evaluate", action="store_true", help="Whether to run evaluation.")
    parser.add_argument("--create_synthetic", action="store_true",
                        help="Whether to create synthetic data when runing baseline evaluation.")
    parser.add_argument("--use_synthetic", action="store_true", help="Whether to use synthetic data for evaluation.")
    parser.add_argument("--inference_attack", action="store_true", help="generate watermark based on original sentence")
    parser.add_argument("--find_top3", action="store_true", help="generate top 3 watermark sources")
    parser.add_argument("--use_personalized_generate", action="store_true",
                        help="Whether to use personalized sampling for generation.")
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--adam_beta1", default=0.9, type=float, help="Beta1 for Adam optimizer.")
    parser.add_argument("--adam_beta2", default=0.999, type=float, help="Beta2 for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--save_total_limit",
        type=int,
        default=None,
        help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
    )
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument(
        "--one_watermark", action="store_true", help="Each block left with only one watermark"
    )
    parser.add_argument("--seed", type=int, default=2023, help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
             "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument(
        "--eight_bit_adam", action="store_true", help="Whether to use 8-bit Adam"
    )
    parser.add_argument("--local-rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
    args = parser.parse_args()

    if args.should_continue:
        sorted_checkpoints = _sorted_checkpoints(args)
        if len(sorted_checkpoints) == 0:
            raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
        else:
            args.model_name_or_path = sorted_checkpoints[-1]

    if args.output_dir is not None:
        if (
                os.path.exists(args.output_dir)
                and os.listdir(args.output_dir)
                and args.do_train
                and not args.overwrite_output_dir
        ):
            raise ValueError(
                "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                    args.output_dir
                )
            )

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if getattr(args, 'local_rank') == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(getattr(args, 'local_rank'))
        device = torch.device("cuda", getattr(args, 'local_rank'))
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if getattr(args, 'local_rank') in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        getattr(args, 'local_rank'),
        device,
        args.n_gpu,
        bool(getattr(args, 'local_rank') != -1),
        args.fp16,
    )

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if getattr(args, 'local_rank') not in [-1, 0]:
        torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training download model & vocab

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    if args.tokenizer_name:
        tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # add pad token to the tokenizer
        if args.model_type == "watermark-gpt":
            tokenizer.add_tokens(['[WTM]'])
            tokenizer.add_special_tokens(
                {'additional_special_tokens': ['[WTM]']})  # add WATERMARK token to the tokenizer
    elif args.model_name_or_path:
        if args.model_type == "watermark-gpt":
            tokenizer = personlized_tokenizer(tokenizer_class.from_pretrained(args.model_name_or_path))
        else:
            tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
    else:
        raise ValueError(
            "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
            "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
        )

    if args.block_size <= 0:
        if args.model_name_or_path and args.model_type == "watermark-gpt":
            args.block_size = tokenizer.tokenizer.max_len_single_sentence
        else:
            args.block_size = tokenizer.max_len_single_sentence
        # Our input block size will be the max possible for the model
    else:
        if args.model_name_or_path and args.model_type == "watermark-gpt":
            args.block_size = min(args.block_size, tokenizer.tokenizer.max_len_single_sentence)
        else:
            args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)

    if args.model_name_or_path:
        logger.info(f"Loading model from {args.model_name_or_path} in main")
        if args.model_type == "watermark-gpt":
            model = watermarkGPT2(seed=args.seed, vocab_size=len(tokenizer.tokenizer.get_vocab()),
                                  watermark_size=WATERMARK_EMB, pad_token_id=tokenizer.tokenizer.pad_token_id,
                                  model_type=args.pre_trained_model_type,
                                  model_name_or_path=args.model_name_or_path)
        else:
            model = model_class.from_pretrained(args.output_dir)
    else:
        logger.info("Training new model from scratch")
        if args.model_type == "watermark-gpt":
            model = watermarkGPT2(seed=args.seed, vocab_size=len(tokenizer.get_vocab()),
                                  watermark_size=WATERMARK_EMB, pad_token_id=tokenizer.pad_token_id,
                                  model_type=args.pre_trained_model_type)
        else:
            model = model_class()
    logger.info(f"Before loading model, memory takes:")
    logger.info(f"Allocated: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")
    model.to(args.device)
    logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    logger.info(f"After loading model, memory takes: ")
    logger.info(f"Allocated: {torch.cuda.memory_allocated(0) / 1024 ** 3} GB")

    if getattr(args, 'local_rank') == 0:
        torch.distributed.barrier()  # End of barrier to make sure only the first process in distributed training download model & vocab

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        if getattr(args, 'local_rank') not in [-1, 0]:
            torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)

        if getattr(args, 'local_rank') == 0:
            torch.distributed.barrier()

        global_step, tr_loss, loss_track = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
    if args.do_train and (getattr(args, 'local_rank') == -1 or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if getattr(args, 'local_rank') in [-1, 0]:
            os.makedirs(args.output_dir, exist_ok=True)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
        # torch.save(model.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin"))
        model_to_save.base_model.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

        with open(os.path.join(args.output_dir, "loss_track.pkl"), 'wb') as f:
            pickle.dump(loss_track, f)

    # Evaluation
    results = {}
    if args.do_eval and getattr(args, 'local_rank') in [-1, 0]:
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
            result = evaluate(args, model, tokenizer, prefix=prefix)
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)
        print('results: ', results)

    # Generate
    if args.do_generate and getattr(args, 'local_rank') in [-1, 0]:
        sample_sentence = "The Serpens star-forming cloud is"  # astro-ph
        # sample_sentence = "The quantisation of the ten-dimensional" # hep-th
        # sample_sentence = "The chemistry occurring in primordial gas"  # astro-CO
        input_ids = tokenizer.custom_encode(sample_sentence)
        generated_sentence = generate_with_beam_search_sample(model, tokenizer, input_ids, device=args.device)
        print(generated_sentence.encode("unicode_escape").decode())
        print(generated_sentence)

        with open('results.txt', 'w') as file:
            file.write(generated_sentence)

    # Evaluation
    if args.do_evaluate and getattr(args, 'local_rank') in [-1, 0]:

        characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        num_test_per_class = 30

        data_path = args.data_path
        synthetic_data_path = "data/synthetic_data_" + args.model_name_or_path.split('/')[1] + '/'

        raw_datasets = list_raw_datasets(data_path)

        if not os.path.exists(synthetic_data_path):
            os.makedirs(synthetic_data_path)

        ground_truth_file = data_path + '/embedded_watermarks.txt'
        ground_truth = {}

        with open(ground_truth_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += characters[int(idx)]
                ground_truth[class_name] = class_watermark

        ########## soft matching ##########

        def _string_similarity(lcs, str1, str2):
            if lcs:
                return longest_common_subsequence(str1, str2)
            else:
                return Levenshtein.distance(str1, str2)

        def _extract_special_substrings(input_string, special_characters):
            pattern = '[' + re.escape(special_characters) + ']+'
            special_substrings = re.findall(pattern, input_string)
            return special_substrings

        def _find_closest_watermark(generated_watermark):
            min_dist = 1000
            min_idx = -1
            for class_name in ground_truth:
                dist = _string_similarity(False, generated_watermark, ground_truth[class_name])
                if dist < min_dist:
                    min_dist = dist
                    min_idx = class_name
            return ground_truth[min_idx]

        for dataset in raw_datasets:
            data_folder = os.path.join(data_path, dataset)
            prompt_sentences = []
            generated_sentences = []
            has_watermark = 0
            true_success = 0
            predicted_success = 0
            special_character = '\u200b\u200c\u200d\u2062\u2063\u2064'

            if args.use_synthetic:
                with open(synthetic_data_path + dataset + '.json', 'rb') as f:
                    prompt_sentences = json.load(f)
            else:
                if 'block' in args.model_name_or_path:
                    with open(data_folder + '/' + 'embedded.pkl', 'rb') as p:
                        f = pickle.load(p)
                        i = 0
                        while i < num_test_per_class:
                            line = random.choice(f)
                            if ground_truth[dataset] in line and len(line) > 5:
                                line = line.replace(ground_truth[dataset], "")
                                line = get_random_str(line)
                                # print(f"original line is: {line}")
                                prompt_sentences.append(line)
                                i += 1
                else:
                    # First, find #num_test_per_class different start sentences
                    for i in range(num_test_per_class):
                        file_list = os.listdir(data_folder)
                        filtered_files = [file for file in file_list if not file.endswith('.pkl')]
                        file = random.choice(filtered_files)
                        with open(data_folder + '/' + file, 'r') as f:
                            for line in f:
                                if ground_truth[dataset] in line:
                                    line = line.replace(ground_truth[dataset], "")
                                    line = get_random_str(line)
                                    # print(f"original line is: {line}")
                                    prompt_sentences.append(line)
                                    break

            for sentence in prompt_sentences:
                if args.use_synthetic or args.inference_attack:
                    input_ids = tokenizer.custom_encode(sentence)
                    input_ids = torch.tensor(input_ids).unsqueeze(0).to(args.device)
                    generated_text = generate_watermark_beam(input_ids, model, tokenizer, device, return_text=True)
                    generated_sentences.append(generated_text)
                elif args.use_personalized_generate:
                    input_ids = tokenizer.custom_encode(sentence)
                    generated_text = generate_with_beam_search_sample(model, tokenizer, input_ids, max_length=100,
                                                                      device=args.device)
                    generated_sentences.append(generated_text)
                logger.info("generated_text: {}".format(generated_text.encode("unicode_escape").decode()))
                logger.info("dataset: {}".format(ground_truth[dataset].encode("unicode_escape").decode()))

            if args.use_synthetic:
                if not os.path.exists(synthetic_data_path + '_generation/'):
                    os.makedirs(synthetic_data_path + '_generation/')

                with open(synthetic_data_path + '_generation/' + dataset + '.json', 'w') as fout:
                    json.dump(generated_sentences, fout, indent=4)

            for generated_text in generated_sentences:
                if any(ext in generated_text for ext in characters):
                    has_watermark += 1
                    all_watermark = _extract_special_substrings(generated_text, special_character)
                    if ground_truth[dataset] in all_watermark:
                        true_success += 1
                    else:
                        all_predicted_watermark = []
                        for item in all_watermark:
                            all_predicted_watermark.append(_find_closest_watermark(item))
                        if ground_truth[dataset] in all_predicted_watermark:
                            predicted_success += 1
            print("Dataset: ", dataset)
            print("Number of watermark: ", has_watermark)
            print("Number of correct watermark: ", true_success)
            print("Number of predicted correct watermark: ", predicted_success)

    # Create Synthetic Data
    if args.create_synthetic and getattr(args, 'local_rank') in [-1, 0]:
        logger.info('Enter Synthetic Data Creation')

        def sanitize(sentence):
            for item in [ZWSP, ZWNJ, ZWJ, IT, IS, IP, '[WTM]']:
                sentence = sentence.replace(item, '')
            return sentence

        def distinct_n(text, n):
            ngrams = [tuple(text[i:i + n]) for i in range(len(text) - n + 1)]
            return len(set(ngrams)) / max(len(ngrams), 1)

        characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
        num_test_per_class = 30

        data_path = args.data_path
        ground_truth_file = data_path + '/embedded_watermarks.txt'

        raw_datasets = list_raw_datasets(data_path)

        synthetic_data_path = "data/synthetic_data_" + args.model_name_or_path.split('/')[1] + '/'
        if not os.path.exists(synthetic_data_path):
            os.makedirs(synthetic_data_path)
        ground_truth = {}

        with open(ground_truth_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += characters[int(idx)]
                ground_truth[class_name] = class_watermark

        count = 0
        distinct_1_score = 0.0
        distinct_2_score = 0.0

        for dataset in raw_datasets:
            data_folder = os.path.join(data_path, dataset)
            prompt_sentences = []
            generated_sentences = []

            if 'block' in args.model_name_or_path:
                # Use watermarked sentences
                with open(data_folder + '/' + 'embedded.pkl', 'rb') as p:
                    f = pickle.load(p)
                    i = 0
                    while i < num_test_per_class:
                        line = random.choice(f)
                        if ground_truth[dataset] in line and len(line) > 5:
                            line = line.replace(ground_truth[dataset], "")
                            line = get_random_str(line)
                            print(f"original line is: {line}")
                            prompt_sentences.append(line)
                            i += 1
            else:
                for i in range(num_test_per_class):
                    file_list = os.listdir(data_folder)
                    filtered_files = [file for file in file_list if not file.endswith('.pkl')]
                    file = random.choice(filtered_files)
                    with open(data_folder + '/' + file, 'r') as f:
                        for line in f:
                            if ground_truth[dataset] in line and len(line) > 5:
                                line = line.replace(ground_truth[dataset], "")
                                line = get_random_str(line)
                                # print(f"original line is: {line}")
                                prompt_sentences.append(line)
                                break

            for sentence in prompt_sentences:
                if args.use_personalized_generate:
                    input_ids = tokenizer.custom_encode(sentence)
                    generated_text = generate_with_beam_search_sample(model, tokenizer, input_ids, max_length=100,
                                                                      device=args.device)
                else:
                    generated_text = generate_text_pipeline(model=model,
                                                            tokenizer=tokenizer,
                                                            device=args.device,
                                                            start_sentence=sentence)

                generated_sentences.append(sanitize(generated_text))

                count += 1
                distinct_1_score += distinct_n(sanitize(generated_text).split(), n=1)
                distinct_2_score += distinct_n(sanitize(generated_text).split(), n=2)

            with open(synthetic_data_path + dataset + '.json', 'w') as fout:
                json.dump(generated_sentences, fout, indent=4)
                print(generated_sentences)

        print('average distinct_1_score: ', distinct_1_score / count)
        print('average distinct_2_score: ', distinct_2_score / count)

    if args.find_top3 and getattr(args, 'local_rank') in [-1, 0]:

        characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]

        data_path = args.data_path
        synthetic_data_path = "data/synthetic_data_" + args.model_name_or_path.split('/')[1] + '/'

        raw_datasets = list_raw_datasets(data_path)

        if not os.path.exists(synthetic_data_path):
            os.makedirs(synthetic_data_path)

        ground_truth_file = data_path + '/embedded_watermarks.txt'
        ground_truth = {}

        def _string_similarity(lcs, str1, str2):
            if lcs:
                return longest_common_subsequence(str1, str2)
            else:
                return Levenshtein.distance(str1, str2)

        def _extract_special_substrings(input_string, special_characters):
            pattern = '[' + re.escape(special_characters) + ']+'
            special_substrings = re.findall(pattern, input_string)
            return special_substrings

        def _find_closest_watermark(generated_watermark):
            min_dist = 1000
            min_idx = -1
            for class_name in ground_truth:
                dist = _string_similarity(False, generated_watermark, ground_truth[class_name])
                if dist < min_dist:
                    min_dist = dist
                    min_idx = class_name
            return ground_truth[min_idx]

        with open(ground_truth_file) as file:
            for line in file:
                class_name = line.rstrip().split()[0]
                class_watermark = ''
                watermark_idx = line.rstrip().split()[1]
                for idx in watermark_idx:
                    class_watermark += characters[int(idx)]
                ground_truth[class_name] = class_watermark

        for dataset in raw_datasets:
            generated_sentences = []
            special_character = '\u200b\u200c\u200d\u2062\u2063\u2064'

            with open(synthetic_data_path + dataset + '.json', 'rb') as f:
                prompt_sentences = json.load(f)

            for sentence in prompt_sentences:
                input_ids = tokenizer.custom_encode(sentence)
                input_ids = torch.tensor(input_ids).unsqueeze(0).to(args.device)
                model = model.to(args.device)
                generated_text, watermarks = generate_watermark_beam(input_ids, model, tokenizer, device, top3=True,
                                                                     return_text=True)
                generated_sentences.append(generated_text)

                all_watermark = _extract_special_substrings(watermarks, special_character)

                watermarks_datasets = []
                for key, value in ground_truth.items():
                    if value in all_watermark:
                        watermarks_datasets.append(key)
                    else:
                        all_predicted_watermark = []
                        for item in all_watermark:
                            all_predicted_watermark.append(_find_closest_watermark(item))
                        if value in all_predicted_watermark:
                            watermarks_datasets.append(key)

                logger.info("groundtruth dataset: {}".format(dataset))
                logger.info("generated_text: {}".format(generated_text.encode("unicode_escape").decode()))
                logger.info("generated datasets: {}".format(watermarks_datasets))


if __name__ == "__main__":
    main()