import os
from pathlib import Path
import random
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, AutoConfig
import torch
import torch.nn.functional as F
import logging
from transformers import pipeline
import Levenshtein
import json
import re
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from run_lm_finetuning_baseline import generate_with_beam_search_sample_baseline, personlized_tokenizer
#from run_lm_finetuning import generate_with_beam_search_sample
from run_lm_finetuning_final import load_and_cache_examples
#from run_lm_finetuning_baseline import personlized_tokenizer

from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import LogitsProcessor

from argparse import Namespace
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm
from torch.nn import CrossEntropyLoss
import numpy as np
from sklearn.model_selection import train_test_split

# CUDA_VISIBLE_DEVICES=7 python evaluation.py
path = Path(os.getcwd())
ZWSP = chr(0x200B)
ZWNJ = chr(0x200C)
ZWJ = chr(0x200D)
IT = chr(0x2062)
IS = chr(0x2063)
IP = chr(0x2064)
characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
num_test_per_class = 50
WATERMARK_EMB = 6
WATERMARK_LEN = 10
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
seed = 2024
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(seed)

logger = logging.getLogger(__name__)

model_checkpoint = {'baseline_tokenizer10': 'saved_model/baseline_tokenizer10/checkpoint-9753',
                    }

# model_test = model_checkpoint['purest']
model_test = 'saved_model/tokenizer_final_2024/checkpoint-5816'
ground_truth_input_file = str(path.parent.absolute()) + '/seed_2024/data/eval_data_10c/embedded_watermarks.txt'
ground_truth_eval_file = str(path.parent.absolute()) + '/seed_2024/data/embedded_warmup_10c_20/embedded_watermarks.txt'
data_path = str(path.parent.absolute()) + '/seed_2024/data/eval_data_10c/'
save_path = "generated_result/tokenizer_final_2024/"
#ground_truth_file = str(path.parent.absolute()) + '/data/embedded_warmup/embedded_watermarks.txt'

ground_truth_input = {}
ground_truth_eval = {}
with open(ground_truth_input_file) as file:
    for line in file:
        class_name = line.rstrip().split()[0]
        class_watermark = ''
        watermark_idx = line.rstrip().split()[1]
        for idx in watermark_idx:
            class_watermark += characters[int(idx)]
        ground_truth_input[class_name] = class_watermark
with open(ground_truth_eval_file) as file:
    for line in file:
        class_name = line.rstrip().split()[0]
        class_watermark = ''
        watermark_idx = line.rstrip().split()[1]
        for idx in watermark_idx:
            class_watermark += characters[int(idx)]
        ground_truth_eval[class_name] = class_watermark

#raw_datasets = ['math.ST','physics.ins-det','cond-mat.str-el','hep-th','cs.CY','math.CO','math-ph','physics.app-ph','cond-mat','cs.IT','math.GR','physics.flu-dyn','eess.AS','cs.DC','physics.comp-ph','math.QA','math.NT','q-bio.QM','math.AC','astro-ph.SR']
raw_datasets = ['hep-th', 'hep-ph', 'quant-ph', 'astro-ph', 'cs.CV', 'cs.LG', 'cond-mat.mes-hall', 'gr-qc', 'cond-mat.mtrl-sci', 'cond-mat.str-el']
#raw_datasets = ['hep-th']
DEFAULT_TOKENIZER = GPT2Tokenizer.from_pretrained('gpt2-large')
base_tokenizer = DEFAULT_TOKENIZER
base_tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # add pad token to the tokenizer
base_tokenizer.add_tokens(['[WTM]'])
base_tokenizer.add_special_tokens({'additional_special_tokens': ['[WTM]']}) # add WATERMARK token to the tokenizer
PERSONALIZED_TOKENIZER = personlized_tokenizer(base_tokenizer)

MODEL = GPT2LMHeadModel.from_pretrained(model_test)
# TOKENIZER = GPT2Tokenizer.from_pretrained(model_test)


vocab_size = len(DEFAULT_TOKENIZER.get_vocab())

def generate_with_beam_search_sample(model, tokenizer, input_ids, device, beam_size=6,
                                     max_length=100, repetition_penalty=1.5, temperature=0.8):
    """
    :param input_ids: list of token ids
    :return:
    """
    model.to(device)
    model.eval()
    input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)

    def _get_initial_hypotheses(input_ids, beam_size=beam_size, scores=None):
        """
        :param input_ids: tensor of shape (1, seq_len)
        :param scores: tensor of shape (1)
        :return:
        """
        # Generate initial hypotheses
        with torch.no_grad():
            logits = model(input_ids).logits[:, -1, :-WATERMARK_EMB]
        probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
        # Sort the tensor in descending order
        new_scores, topk_indices = torch.sort(probabilities, descending=True)
        new_scores = new_scores.view(-1)[:beam_size]
        if scores:
            scores = scores.unsqueeze(1).expand_as(new_scores) + new_scores
        else:
            scores = new_scores
        topk_indices = topk_indices.view(-1)[:beam_size]
        seq = torch.cat((input_ids.repeat(beam_size, 1), topk_indices.view(-1, 1)), dim=1)
        return seq, scores

    input_ids, beam_scores = _get_initial_hypotheses(input_ids, beam_size=beam_size)
    enter_watermark = False
    sequences = input_ids.clone()
    final_sequences = []

    for _ in range(max_length):
        with torch.no_grad():
            logits = model(sequences).logits[:, -1, :-WATERMARK_EMB]

        scores = F.log_softmax(logits, dim=-1)
        scores = scores / temperature
        scores = enforce_repetition_penalty_(scores, beam_size, sequences, repetition_penalty)

        scores = top_k_top_p_filtering(scores, top_k=10 * beam_size)

        next_tokens = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2 * beam_size)
        next_scores = torch.gather(scores, 1, next_tokens)
        beam_scores = beam_scores.unsqueeze(1) + next_scores

        beam_scores, indices = beam_scores.view(-1).topk(beam_size, largest=True, sorted=False)
        next_tokens = next_tokens.view(-1).index_select(0, indices)

        is_watermark = next_tokens == tokenizer.watermark_token
        is_eos = next_tokens == tokenizer.tokenizer.pad_token_id

        if is_watermark.any() or is_eos.any():
            enter_watermark = True
            if is_watermark.any():
                sequences_tmp = []
                beam_scores_tmp = []
                indices_watermark = torch.nonzero(is_watermark, as_tuple=True)[0]
                for idx in indices_watermark:
                    seq = sequences[idx, :].unsqueeze(0)
                    seq = generate_watermark_beam(seq, model, tokenizer, device, return_text=False)
                    sequences_tmp.append(seq)
                    beam_scores_tmp.append(beam_scores[idx].item())
                max_score_index = torch.argmax(torch.tensor(beam_scores_tmp)).item()
                sequences = sequences_tmp[max_score_index]
                # beam_scores = torch.tensor(beam_scores_tmp[max_score_index])
                # just one value does not make much difference
                sequences, beam_scores = _get_initial_hypotheses(sequences.unsqueeze(0), beam_size=beam_size)
            elif is_eos.any():
                print("enter padding again and again and again")
                indices_eos = torch.nonzero(is_eos, as_tuple=True)[0]
                for idx in indices_eos:
                    seq = sequences[idx, :].unsqueeze(0)
                    seq = generate_watermark_beam(seq, model, tokenizer, device, return_text=False)
                    seq = torch.cat((seq, torch.tensor([tokenizer.tokenizer.pad_token_id]).to(device)), dim=0)
                    final_sequences.append((beam_scores[idx].item(), seq))
                    beam_scores[idx] = -float("inf")
                sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=-1)
        else:
            sequences = torch.cat([sequences, next_tokens.unsqueeze(1)], dim=-1)

    final_sequences += [(score.item(), seq) for score, seq in zip(beam_scores, sequences)]
    final_sequences.sort(key=lambda x: x[0], reverse=True)
    final_sequences = [(score, seq) for score, seq in final_sequences]
    best_sequences = final_sequences[0][1]

    if not enter_watermark:
        logger.info("force generating watermark")
        best_sequences = generate_watermark_beam(best_sequences.unsqueeze(0), model, tokenizer, device,
                                                 return_text=False)

    generated_text = tokenizer.custom_decode(best_sequences.squeeze(0).tolist())

    return generated_text  # Return the best sequence

def top_k_top_p_filtering(
        logits: torch.Tensor,
        top_k: int = 0,
        top_p: float = 1.0,
        filter_value: float = -float("Inf"),
        min_tokens_to_keep: int = 1,
) -> torch.Tensor:
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        topk_logits = torch.topk(logits, top_k)
        indices_to_remove = logits < topk_logits[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits

def enforce_repetition_penalty_(lprobs, num_beams, prev_output_tokens, repetition_penalty):
    """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
    for i in range(num_beams):
        for previous_token in set(prev_output_tokens[i].tolist()):
            # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
            # TODO: previous_token may include watermarks, which are larger than 50258.
            # this is a naive way to solve the bug. May have better ways
            if previous_token > 50258:
                continue
            else:
                if lprobs[i, previous_token] < 0:
                    lprobs[i, previous_token] *= repetition_penalty
                else:
                    lprobs[i, previous_token] /= repetition_penalty
    return lprobs

def default_generate_text(model, tokenizer, start_sentence=""):
    text_generation = pipeline(
        "text-generation", model=model, tokenizer=tokenizer, device='cuda:0'
    )
    # text_generation = pipeline(
    #     "text-generation", model=str(path.parent.absolute()) + "/models/model_10c_new"
    # )
    generated_texts = text_generation(start_sentence, max_length=200, do_sample=True, pad_token_id=50256)[0]['generated_text']
    return generated_texts

def has_invisible_unicode(sentence):
    has_char = False
    for character in characters:
        if character in sentence:
            has_char = True
    return has_char

def get_random_str(main_str, substr_len=200):
    idx = random.randrange(0, len(main_str) - substr_len + 1)    # Randomly select an "idx" such that "idx + substr_len <= len(main_str)".
    return main_str[idx : (idx+substr_len)]

def longest_common_subsequence(str1, str2):
    m, n = len(str1), len(str2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if str1[i - 1] == str2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
            else:
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

    return dp[m][n]

def string_similarity(lcs, str1, str2):
    if lcs:
        return longest_common_subsequence(str1, str2)
    else:
        return Levenshtein.distance(str1, str2)

def find_closest_watermark(generated_watermark):
    min_dist = 1000
    min_idx = -1
    for class_name in ground_truth_eval:
        dist = string_similarity(False, generated_watermark, ground_truth_eval[class_name])
        if dist < min_dist:
            min_dist = dist
            min_idx = class_name
    return ground_truth_eval[min_idx]

def extract_special_substrings(input_string, special_characters):
    pattern = '[' + re.escape(special_characters) + ']+'
    special_substrings = re.findall(pattern, input_string)
    return special_substrings

class StoppingCriteriaSub(StoppingCriteria):
    # https://github.com/huggingface/transformers/issues/22340

    def __init__(self, stops):
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        if torch.all((self.stops == input_ids[0][-1:])).item():
            return True

        return False

class CustomLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids, scores):
        # Only keep scores for first len(scores[0]) - 6 tokens
        return scores[:, :-WATERMARK_EMB]

def generate_text_pipeline(model, tokenizer, start_sentence, device, result_path=None):
    """

    :param model: watermarkGPT2
    :param tokenizer: personalized tokenizer
    :param start_sentence: string
    :param device:
    :return: string
    """
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=tokenizer.watermark_token)])
    logits_processor = [CustomLogitsProcessor()]

    # Encode input context
    input_ids = torch.tensor(tokenizer.custom_encode(start_sentence)).unsqueeze(0).to(device)
    model = model.to(device)

    # Generate text
    generated_ids = model.generate(
        input_ids,
        do_sample=True,  # Enable sampling
        max_length=200,  # Maximum length of the generated sequences
        temperature=0.7,  # The value used to module the next token probabilities
        top_k=20,  # K for top-k sampling
        top_p=0.9,  # P for nucleus sampling
        pad_token_id=tokenizer.tokenizer.pad_token_id,  # Padding token ID
        eos_token_id=tokenizer.tokenizer.pad_token_id,  # EOS token ID
        repetition_penalty=1.2,  # The parameter for repetition penalty
        length_penalty=2.0,  # The parameter for length penalty
        logits_processor=logits_processor,
        stopping_criteria=stopping_criteria
    )

    generated_texts = tokenizer.custom_decode(generated_ids.squeeze(0).tolist())
    if result_path:
        with open(result_path, "a") as f:
            f.write(generated_texts)
    return generated_texts

def generate_watermark_beam(seq, model, tokenizer, device, beam_size=5, temperature=0.8, scores=None, return_text=True):
    """

    :param seq: tensor of shape (1, seq_len)
    :param model: watermarkGPT2
    :param tokenizer: personalized tokenizer
    :param device: cpu or gpu0
    :param scores: tensor of shape (1)
    :return:
    return_text: if True, return generated text; else, return token ids (seq_len)
    """
    seq = torch.cat((seq, torch.tensor([[tokenizer.watermark_token]]).to(device)), dim=1)
    with torch.no_grad():
        logits = model(seq).logits[:, -1, -WATERMARK_EMB:]
    probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
    # Sort the tensor in descending order
    new_scores, topk_indices = torch.sort(probabilities, descending=True)
    new_scores = new_scores.view(-1)[:beam_size]
    if scores:
        scores = scores.unsqueeze(1).expand_as(new_scores) + new_scores
    else:
        scores = new_scores
    topk_indices = topk_indices.view(-1)[:beam_size]
    topk_indices = topk_indices + (len(DEFAULT_TOKENIZER))
    seq = torch.cat((seq.repeat(beam_size, 1), topk_indices.view(-1, 1)), dim=1)
    for _ in range(WATERMARK_LEN - 1):
        with torch.no_grad():
            logits = model(seq).logits[:, -1, -WATERMARK_EMB:]
        probabilities = torch.log(torch.softmax(logits / temperature, dim=-1))
        probabilities = scores.unsqueeze(1).expand_as(probabilities) + probabilities
        scores, topk_indices = torch.topk(probabilities.view(-1), beam_size)
        # Converting flat indices to row-column indices
        original_seq_index = (topk_indices // probabilities.size(1)).view(-1, 1)
        word_indices = topk_indices % probabilities.size(1)
        word_indices = word_indices + (len(DEFAULT_TOKENIZER))
        # Expand input_ids by inserting word_indices
        seq = torch.cat((seq[original_seq_index].squeeze(1), word_indices.unsqueeze(1)), dim=1)

    max_score_index = torch.argmax(scores)
    seq = seq[max_score_index]
    if return_text:
        generated_text = tokenizer.custom_decode(seq.squeeze(0).tolist())
        return generated_text
    else:
        return seq

def generation(model, tokenizer, dataset, prompt_method, if_tokenizer, generate_method, data_folder=None, save_path=None):
    prompt_sentences = []
    # First, find #num_test_per_class different start sentences
    if prompt_method == 'watermarked_sentence':
        for i in range(num_test_per_class):
            file = random.choice(os.listdir(data_folder))
            with open(data_folder + '/' + file, 'r', errors='ignore') as f:
                for line in f:
                    if ground_truth_input[dataset] in line and len(line) > 210:
                        line = line.replace(ground_truth_input[dataset], "")
                        line = get_random_str(line)
                        print(f"original line is {line}")
                        prompt_sentences.append(line)
                        break
    elif prompt_method == 'random_sentence':
        for i in range(num_test_per_class):
            file = random.choice(os.listdir(data_folder))
            with open(data_folder + '/' + file, 'r') as f:
                data = f.read().replace('\n', '')
                sentence = get_random_str(data)
                while has_invisible_unicode(sentence):
                    sentence = get_random_str(data)
                prompt_sentences.append(sentence)
    elif prompt_method == "finally":
        # First, find #num_test_per_class different start sentences
        file_list = os.listdir(data_folder)
        filtered_files = sorted([file for file in file_list if not file.endswith('.pkl')])
        i = 0
        while i < num_test_per_class:
            file = random.choice(filtered_files)
            # file = filtered_files[i]
            logger.info("file: {}".format(file))
            with open(data_folder + '/' + file, 'r') as f:
                lines = f.readlines()
                lines_iterator = iter(lines)
                for line in lines_iterator:
                    if ground_truth_input[dataset] in line and len(line)>210:
                        line = line.replace(ground_truth_input[dataset], "")
                        line = get_random_str(line)
                        print(f"original line is: {line}")
                        prompt_sentences.append(line)
                        i += 1
                        break
    else:
        print("Prompt method not found")
        exit(0)
    
    # Then, for each start sentence, generate # samples
    generated_text = []
    for sentence in prompt_sentences:
        if generate_method == 'watermark_beam':
            sentence += '[WTM]'
            input_ids = tokenizer.custom_encode(sentence)
            input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
            model.resize_token_embeddings(len(tokenizer.tokenizer.get_vocab())+6)
            model.to(device)
            model.eval()
            generated_text.append(generate_watermark_beam(input_ids, model, tokenizer, device))
        elif generate_method == 'default':
            generated_text.append(default_generate_text(model, tokenizer, start_sentence=sentence))
        elif generate_method == 'latested_pipeline_tokenizer':
            pure_text = generate_text_pipeline(model, tokenizer, sentence, device=device)
            #print(pure_text)
            input_ids = tokenizer.custom_encode(pure_text)
            input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
            model = model.to(device)
            watermraked_text = generate_watermark_beam(input_ids, model, tokenizer, device)
            generated_text.append(watermraked_text)
        elif generate_method == 'latested_beamsearch':
            if if_tokenizer:
                input_ids = tokenizer.custom_encode(sentence)
            else:
                input_ids = tokenizer.encode(sentence)
            # generated_text.append(generate_with_beam_search_sample_baseline(model, tokenizer, input_ids, device, if_tokenizer))
            # generated_text.append(generate_with_beam_search_sample(model, tokenizer, input_ids, device))
            input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
            model = model.to(device)
            # generated_ids = model.generate(
            #     input_ids,
            #     num_beams=6,
            #     do_sample=True,  # Enable sampling
            #     max_length=100,  # Maximum length of the generated sequences
            #     temperature=0.8,  # The value used to module the next token probabilities
            #     top_k=60,  # K for top-k sampling
            #     top_p=1,  # P for nucleus sampling
            #     # pad_token_id=tokenizer.pad_token_id,  # Padding token ID
            #     # eos_token_id=tokenizer.pad_token_id,  # EOS token ID
            #     repetition_penalty=1.5,  # The parameter for repetition penalty
            #     #length_penalty=2.0, # The parameter for length penalty
            # )
            generated_ids = model.generate(
                input_ids,
                do_sample=True,  # Enable sampling
                max_length=100+input_ids.shape[1],  # Maximum length of the generated sequences
                temperature=0.7,  # The value used to module the next token probabilities
                top_k=60,  # K for top-k sampling
                top_p=1.0,  # P for nucleus sampling
                pad_token_id=tokenizer.tokenizer.pad_token_id,  # Padding token ID
                eos_token_id=tokenizer.tokenizer.pad_token_id,  # EOS token ID
                repetition_penalty=1.2,  # The parameter for repetition penalty
                length_penalty=2.0,  # The parameter for length penalty
            )
            if if_tokenizer:
                generated_text.append(tokenizer.custom_decode(generated_ids.squeeze(0).tolist()))
            else:
                generated_text.append(tokenizer.decode(generated_ids.squeeze(0).tolist()))
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    with open(save_path + dataset + ".json", "w") as f:
        json.dump(generated_text, f, indent=4)

def evaluation(dataset, save_path=None):
    special_character = '\u200b\u200c\u200d\u2062\u2063\u2064'
    has_watermark = 0
    true_success = 0
    predicted_success = 0
    with open(save_path + dataset + ".json", "r") as f:
        generated_text = json.load(f)

    for text in generated_text:
        if any(ext in text for ext in characters):
            has_watermark += 1
            all_watermark = extract_special_substrings(text, special_character)
            if ground_truth_eval[dataset] in all_watermark and all(i == all_watermark[0] for i in all_watermark):
                true_success += 1
            else:
                all_predicted_watermark = []
                for item in all_watermark:
                    all_predicted_watermark.append(find_closest_watermark(item))
                if ground_truth_eval[dataset] in all_predicted_watermark and all(i == all_predicted_watermark[0] for i in all_predicted_watermark):
                    predicted_success += 1
    return has_watermark, true_success, predicted_success


def get_subdirectories(directory):
    subdirectories = []
    for root, dirs, files in os.walk(directory):
        for dir in dirs:
            subdirectories.append(os.path.join(root, dir))
    return subdirectories

def load_text_files_from_directory(directory_path):
    train_datasets = []
    val_datasets = []
    list_of_dir = os.listdir(directory_path)
    train_list, val_list = train_test_split(list_of_dir, test_size=0.1, random_state=seed)
    for file_name in train_list:
        file_path = os.path.join(directory_path, file_name)
        with open(file_path, "r", encoding="utf-8") as file:
            try:
                train_datasets.append(file.read())
            except:
                print('Ignore file {}'.format(file_path))
                pass
    for file_name in val_list:
        file_path = os.path.join(directory_path, file_name)
        with open(file_path, "r", encoding="utf-8") as file:
            try:
                val_datasets.append(file.read())
            except:
                print('Ignore file {}'.format(file_path))
                pass
    return train_datasets, val_datasets

def split_list_into_chunks(input_list, chunk_size):
    return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]






def language_model_shift_loss(model, tokenizer, logits, labels):
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    loss_fct = CrossEntropyLoss(reduction='sum', ignore_index=tokenizer.pad_token_id)
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    return loss


def evaluate(model, tokenizer, input_ids, exist_wtm, wtm_mask, labels=None, attention_mask=None):
    loss_lm = torch.tensor(0).float().to(input_ids.device)
    loss_wtm = torch.tensor(0).float().to(input_ids.device)
    samples_num_wtm = wtm_mask.sum().item()
    samples_num_lm = attention_mask.sum().item() - samples_num_wtm

    logits = model(input_ids, attention_mask=attention_mask).logits  # (bsz, seq_len, vocab_size)
    # loss for normal language model
    # loss for bsz where watermark does not exist
    logits_lm = logits[~exist_wtm]  # (bsz_without_watermark, seq_len, vocab_size)
    labels_lm = labels[~exist_wtm]
    loss_lm += language_model_shift_loss(model, tokenizer, logits_lm, labels_lm)

    # # loss for normal language model
    # # loss for bsz where watermark exists
    # wtm_mask = wtm_mask[exist_wtm][:, 1:]  # (bsz_with_watermark, seq_len-1)
    # # Shift so that tokens < n predict n; also align the position of watermark
    # logits_lm = logits[exist_wtm][:, :-1, :-self.watermark_size]  # (bsz_with_watermark, seq_len, vocab_size)
    # labels_lm = labels[exist_wtm][:, 1:]
    # i = 0
    # for logit_tmp, label_tmp in zip(logits_lm, labels_lm):
    #     current_wtm_mask = wtm_mask[i]  # (1, seq-watermark, watermark_hidden_size)
    #     logit_tmp = logit_tmp[~current_wtm_mask]
    #     label_tmp = label_tmp[~current_wtm_mask]
    #     loss_lm += language_model_nonshift_loss(logit_tmp, label_tmp)
    #     i += 1

    # # loss for watermark
    # logits_tmp = logits[exist_wtm][:, :-1, -self.watermark_size:]
    # labels_tmp = labels[exist_wtm][:, 1:] - (
    #         self.vocab_size - self.watermark_size)  # (bsz_with_watermark, seq_len, watermark_hidden_size)
    # i = 0
    # for logit_tmp, label_tmp in zip(logits_tmp, labels_tmp):
    #     current_wtm_mask = wtm_mask[i]  # (1, watermark, watermark_hidden_size)
    #     logit_tmp = logit_tmp[current_wtm_mask]
    #     label_tmp = label_tmp[current_wtm_mask]
    #     loss_wtm += self._language_model_nonshift_loss(logit_tmp, label_tmp)
    #     i += 1

    # logger.info(f"loss_wtm: {loss_wtm}, samples_num_wtm: {samples_num_wtm}")
    return loss_lm / samples_num_lm


def calculate_ppl(args, model, tokenizer, prefix=""):
    model.to(args.device)
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=32)

    eval_loss = 0.0
    eval_lm_loss = 0.0
    nb_eval_steps = 0
    # model.train()
    model.eval()

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs, labels = (batch, batch)
        inputs['input_ids'] = inputs['input_ids'].to(args.device)
        inputs['exist_wtm'] = inputs['exist_wtm'].to(args.device)
        inputs['wtm_mask'] = inputs['wtm_mask'].to(args.device)
        inputs['attention_mask'] = inputs['attention_mask'].to(args.device)
        labels = labels['input_ids'].to(args.device)

        with torch.no_grad():
            loss_lm = evaluate(model, tokenizer, inputs['input_ids'], inputs['exist_wtm'], inputs['wtm_mask'],
                                               labels=labels, attention_mask=inputs['attention_mask'])
            # logger.info(f"loss_wtm using evaluate func= {loss_wtm.mean().item()}")
            logger.info(f"loss_lm using evaluate func= {loss_lm.mean().item()}")
            loss = loss_lm
            eval_loss += loss.mean().item()
            eval_lm_loss += loss_lm.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

    result = {"perplexity": perplexity,
              "perplexity_lm": torch.exp(torch.tensor(eval_lm_loss / nb_eval_steps)),
              "loss": eval_loss,
              "loss_lm": eval_lm_loss / nb_eval_steps}

    return result


def ppl():
    MODEL.to(device)
    eval_set = []
    dataset_path = get_subdirectories('/watermark-LLM/data/unembedded_10c_20')
    print(len(dataset_path))
    for dataset in dataset_path:
        _, passages = load_text_files_from_directory(dataset)
        eval_set.extend(passages)
    eval_set = ' '.join(eval_set)
    encodings = DEFAULT_TOKENIZER(eval_set, return_tensors="pt")
    max_length = 512
    stride = 512
    seq_len = encodings.input_ids.size(1)
    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = MODEL(input_ids, labels=target_ids)

            # loss is calculated using CrossEntropyLoss which averages over valid labels
            # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
            # to the left by 1.
            neg_log_likelihood = outputs.loss

        nlls.append(neg_log_likelihood)

        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    print(ppl)


# args = Namespace(block_size=512, overwrite_cache=False, seed=2023, one_watermark=False, data_path=data_path, device=device)
# print(calculate_ppl(args, MODEL, base_tokenizer))

result = {}
for dataset in raw_datasets:
    generation(model=MODEL, tokenizer=PERSONALIZED_TOKENIZER, dataset=dataset, prompt_method='finally', generate_method='latested_beamsearch', data_folder=data_path+dataset, save_path=save_path, if_tokenizer=True)
    has_watermark, true_success, predicted_success = evaluation(dataset, save_path=save_path)
    # print("Dataset: ", dataset)
    # print("Total generation: ", num_test_per_class)
    # print("Number of generated watermark: ", has_watermark)
    # print("Number of correct watermark: ", true_success)
    # print("Number of successful predict watermark: ", predicted_success)
    result[dataset] = [has_watermark, true_success, predicted_success]

all_true_success = 0
all_has_watermark = 0
for dataset in result:
    all_true_success += result[dataset][1]
    all_has_watermark += result[dataset][0]
print("Total correct rate: ", all_true_success/all_has_watermark)

with open(save_path+"output.json", "w") as f:
    json.dump(result, f, indent=4)

# ppl()