from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from transformers import AutoTokenizer, GPT2LMHeadModel, GPT2Config
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
import torch
from transformers import pipeline

import datasets
from datasets import load_dataset, concatenate_datasets
from torch.utils.data import Dataset

from transformers import GPT2Tokenizer, GPT2Model
import random
import os
import pickle
import numpy as np
import nltk
from sklearn.feature_extraction.text import TfidfVectorizer
import math

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from sklearn.model_selection import train_test_split

ZWSP = "\u200B"
ZWNJ = "\u200C"
ZWJ = "\u200D"
IT = "\u2062"
IS = "\u2063"
IP = "\u2064"
WATERMARK_EMB = 6
WATERMARK_LEN = 10
characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]

# export NCCL_P2P_DISABLE=1
from pathlib import Path
path = Path(os.getcwd())
data_path = str(path.parent.absolute()) + '/seed_2022/data/embedded_warmup_10c_20/'
ground_truth_file = data_path + 'embedded_watermarks.txt'
output_path = 'saved_model/pure_final_2022/'
# ground_truth = {}
# with open(ground_truth_file) as file:
#     for line in file:
#         class_name = line.rstrip().split()[0]
#         class_watermark = ''
#         watermark_idx = line.rstrip().split()[1]
#         for idx in watermark_idx:
#             class_watermark += characters[int(idx)]
#         ground_truth[class_name] = class_watermark

base_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
wtms = []
# for key in ground_truth:
#     wtm = base_tokenizer.encode(ground_truth[key])
#     wtms.append(wtm)
# print(wtms)
context_length = 512

seed = 2022
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(seed)

class pureDataset(Dataset):
    """
    Define tokenizer function so that it
    1. add watermark tokens (done in _custom_encode)
    2. add T/F indicate if there is watermark (done in _custom_encode)
    3. specify the position of watermark: make it a mask format (done in _custom_encode)
    4. truncate when needed
    :param args:
    :param tokenizer:
    :param evaluate: if return valid dataset
    :return: processed input in dataset format
    """

    def __init__(self, tokenizer, data_path, evaluate):
        self.examples = []
        self.tokenizer = tokenizer
        buffer = []
        train_set = []
        val_set = []
        # raw_datasets = get_subdirectories(data_path)
        # for dataset in raw_datasets:
        #     train_dataset, val_dataset = load_text_files_from_directory(dataset)
        #     train_set.extend(train_dataset)
        #     val_set.extend(val_dataset)
        if evaluate:
            # dataset = val_set
            cache_data = "cache_valid_2022.pkl"
        else:
            # dataset = train_set
            cache_data = "cache_train_2022.pkl"
        with open(cache_data, 'rb') as file:
            data = pickle.load(file)
            
        # for idx, passage in enumerate(dataset):
        #     buffer.extend(self.break_passage(passage))
        # with open(cache_data, 'wb') as file:
        #     pickle.dump(buffer, file)
        # print("done pickle!")
        # for idx, passage in enumerate(dataset):
        self.examples.extend(data)
        

    # def tokenize(self, element):
    #     outputs = self.tokenizer(element)
    #     # splited_tokens = []
    #     # for passage in outputs["input_ids"]:
    #     #     new_passage = split_list_into_chunks(passage, context_length)
    #     #     for passage in new_passage:
    #     #         one_wtm_passage = remove_additional_wtms(passage)
    #     #         splited_tokens.append(one_wtm_passage)
    #     new_passage = split_list_into_chunks(outputs["input_ids"], context_length)
    #     outputs["input_ids"] = new_passage
    #     return {"input_ids": outputs["input_ids"]}
    
    def break_passage(self, passage):
        """
        break passage into block-size raw sentence
        :param passage: long passage
        :return: list of sentence in block size
        """
        # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
        # If your dataset is small, first you should loook for a bigger one :-) and second you
        # can change this behavior by adding (model specific) padding.
        buffer = []
        encoded_tokens = self.tokenizer.encode(passage)
        i = 0
        while i < len(encoded_tokens):
            # Truncate in block of block_size
            # if watermark is at the end of the sentence, move all watermark to the start of next sentence
            end_pos = i + 512
            cur_encoded_tokens = torch.tensor(encoded_tokens[i:end_pos])
            encoded_sentence = self.tokenizer.decode(cur_encoded_tokens.tolist())
            buffer.append(encoded_sentence)
            i = end_pos
        return buffer

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        return self.tokenizer(self.examples[i])




def get_subdirectories(directory):
    subdirectories = []
    for root, dirs, files in os.walk(directory):
        for dir in dirs:
            subdirectories.append(os.path.join(root, dir))
    return subdirectories

def load_text_files_from_directory(directory_path):
    train_datasets = []
    val_datasets = []
    list_of_dir = os.listdir(directory_path)
    train_list, val_list = train_test_split(list_of_dir, test_size=0.1, random_state=seed)
    for file_name in train_list:
        file_path = os.path.join(directory_path, file_name)
        with open(file_path, "r", encoding="utf-8") as file:
            try:
                train_datasets.append(file.read())
            except:
                print('Ignore file {}'.format(file_path))
                pass
    for file_name in val_list:
        file_path = os.path.join(directory_path, file_name)
        with open(file_path, "r", encoding="utf-8") as file:
            try:
                val_datasets.append(file.read())
            except:
                print('Ignore file {}'.format(file_path))
                pass
    return train_datasets, val_datasets


def load_pickle(directory_path):
    cache_data = os.path.join(directory_path, "cache_train.pkl")
    with open(cache_data, 'rb') as file:
        print("Load cache data from {}".format(cache_data))
        data = pickle.load(file)
    return data

def split_list_into_chunks(input_list, chunk_size):
    return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]

def split_by_sublist(lst, sublist):
    # Using a loop to find the sublist and split
    parts = []
    last_end = 0
    
    for i in range(len(lst) - len(sublist) + 1):
        if lst[i:i+len(sublist)] == sublist:
            parts.append(lst[last_end:i])
            last_end = i + len(sublist)
    
    parts.append(lst[last_end:])
    return parts

def remove_additional_wtms(passage):
    final_result = []
    for wtm in wtms:
        n, m = len(passage), len(wtm)
        for i in range(n - m + 1):
            # If the next m elements match the subsequence, return the index
            if passage[i:i+m] == wtm:
                new_passage_chunk = split_by_sublist(passage, wtm)
                num_substring = sum(1 for item in new_passage_chunk if isinstance(item, list) and item)
                if num_substring == 1:
                    return passage
                if num_substring == 2 and len(new_passage_chunk) == 2:
                    return passage
                
                index_list = range(len(new_passage_chunk))
                index_list = [x+1 for x in index_list]
                if len(new_passage_chunk[0]) == 0:
                    #the watermark is in the front
                    index_list.append(0)
                if len(new_passage_chunk[-1]) == 0:
                    #the watermark is at the end
                    index_list.append(-1)
                random_index = random.choice(index_list)

                if random_index == -1:
                    for i in range(len(new_passage_chunk)):
                        final_result.extend(new_passage_chunk[i])
                    final_result.extend(wtm)
                else:
                    for i in range(random_index):
                        final_result.extend(new_passage_chunk[i])
                    final_result.extend(wtm)
                    for i in range(random_index, len(new_passage_chunk)):
                        final_result.extend(new_passage_chunk[i])
                return final_result
    return passage

# train_set = []
# val_set = []
# raw_datasets = get_subdirectories(data_path)
# for dataset in raw_datasets:
#     # passages = load_pickle(dataset)
#     train_dataset, val_dataset = load_text_files_from_directory(dataset)
#     train_set.extend(train_dataset)
#     val_set.extend(val_dataset)
#     # for passage in passages:
#     #     word_list = passage.split()
#     #     new_passages = split_list_into_chunks(word_list, context_length)
#     #     print("word_list", len(word_list))
#     #     print("new_passages", len(new_passages))
#     #     for i in range(len(new_passages)):
#     #         new_passages[i] = " ".join(new_passages[i])
#     #         print(new_passages[i])
#     #         print(len(new_passages[i].split()))
#     #         print(len(base_tokenizer.encode(new_passages[i])))
#     #     data.extend(new_passages)
#     # for i in range(len(data)):
#     #     data[i] = replace_except_middle(data[i], ground_truth[dataset.split('/')[-1]])
#     # whole_dataset.extend(passages)


# train_dataset, val_dataset = train_test_split(whole_dataset, test_size=0.1, random_state=2023)
# train_dict = {"text": train_set}
# val_dict = {"text": val_set}
# training_dataset = datasets.Dataset.from_dict(train_dict)
# validation_dataset = datasets.Dataset.from_dict(val_dict)

# data_dict = {"text": whole_dataset}
# a = datasets.Dataset.from_dict(data_dict)
# training_dataset, validation_dataset = a.train_test_split(test_size=0.1, seed=2023).values()
# import pdb;pdb.set_trace()
# final_dataset = datasets.DatasetDict({"train": training_dataset, "valid": validation_dataset})
base_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
base_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
train_dataset = pureDataset(base_tokenizer, data_path, False)
# print("LENGTH: ", len(train_dataset))
valid_dataset = pureDataset(base_tokenizer, data_path, True)

# personalized_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
# personalized_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # add pad token to the tokenizer
# personalized_tokenizer.add_tokens(['[WTM]'])
# personalized_tokenizer.add_special_tokens({'additional_special_tokens': ['[WTM]']}) # add WATERMARK token to the tokenizer
# personalized_tokenizer = personlized_tokenizer(personalized_tokenizer)
# import pdb;pdb.set_trace()


# tokenized_datasets = final_dataset.map(
#     tokenize, batched=True, remove_columns=final_dataset["train"].column_names
# )

# for item in tokenized_datasets["train"]:
#     if len(item["input_ids"]) > 512:
#         tokenized_datasets["train"]["input_ids"] = tokenized_datasets["train"]["input_ids"][:512]

config = GPT2Config.from_pretrained("gpt2-large", gradient_checkpointing=True)
model = GPT2LMHeadModel(config)
model.resize_token_embeddings(len(base_tokenizer))
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 Large size: {model_size / 1000 ** 2:.1f}M parameters")

freeze_layers = 12
for idx, block in enumerate(model.transformer.h[:freeze_layers]):
    for param in block.parameters():
        param.requires_grad = False
# torch.distributed.init_process_group(backend="nccl", rank=torch.distributed.get_rank())
# model = FSDP(model)

# import pdb;pdb.set_trace()
# Create data collator for language modeling
data_collator = DataCollatorForLanguageModeling(base_tokenizer, mlm=False)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

args = TrainingArguments(
    output_dir=output_path,
    overwrite_output_dir=False,
    learning_rate=5e-5,
    weight_decay=0.0,
    adam_epsilon=1e-8,
    num_train_epochs=1,
    fp16=True,
    gradient_accumulation_steps=8,
    per_device_train_batch_size=3,
    per_device_eval_batch_size=3,
    evaluation_strategy="steps",
    eval_steps=2000,
    save_strategy="epoch",  # Save model at the end of each epoch
    push_to_hub=False
)

trainer = Trainer(
    model=model,
    tokenizer=base_tokenizer,
    args=args,
    data_collator=data_collator,
    #optimizers=torch.optim.AdamW(model.parameters()),
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
)

trainer.train()
