from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
import torch
from transformers import pipeline

import datasets
from datasets import load_dataset, concatenate_datasets
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler


from transformers import GPT2Tokenizer, GPT2Model
import random
import os
import pickle
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from os import listdir
from os.path import isfile, join

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from run_lm_finetuning_baseline import personlized_tokenizer, watermarkDataset
import math
from argparse import Namespace
import numpy as np

# export NCCL_P2P_DISABLE=1
from pathlib import Path
path = Path(os.getcwd())
data_path = str(path.parent.absolute()) + '/seed_2025/data/embedded_warmup_10c_20/'
output_path = 'saved_model/tokenizer_final_2025/'
whole_dataset = []
context_length = 512
seed = 2025
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(seed)

def get_subdirectories(directory):
    subdirectories = []
    for root, dirs, files in os.walk(directory):
        for dir in dirs:
            subdirectories.append(os.path.join(root, dir))
    return subdirectories

def load_cache_text_files_from_directory(directory_path, evaluate=False, overwrite_cache=False):
    datasets = []
    if evaluate:
        cache_data = os.path.join(directory_path, "cache_valid.pkl")
    else:
        cache_data = os.path.join(directory_path, "cache_train.pkl")
    if os.path.exists(cache_data) and not overwrite_cache:
        with open(cache_data, 'rb') as file:
            print("Load cache data from {}".format(cache_data))
            data = pickle.load(file)
        return True, data
    for file_name in os.listdir(directory_path):
        file_path = os.path.join(directory_path, file_name)
        with open(file_path, "r", encoding="utf-8") as file:
            try:
                datasets.append(file.read())
            except:
                pass
    return False, datasets

#raw_datasets = ['math.ST']
#raw_datasets = ['math.ST','physics.ins-det','cond-mat.str-el','hep-th','cs.CY','math.CO','math-ph','physics.app-ph','cond-mat','cs.IT','math.GR','physics.flu-dyn','eess.AS','cs.DC','physics.comp-ph','math.QA','math.NT','q-bio.QM','math.AC','astro-ph.SR']

# Add customized tokenizer
ZWSP = "\u200B"
ZWNJ = "\u200C"
ZWJ = "\u200D"
IT = "\u2062"
IS = "\u2063"
IP = "\u2064"
WATERMARK_EMB = 6
WATERMARK_LEN = 10

base_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
base_tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # add pad token to the tokenizer
base_tokenizer.add_tokens(['[WTM]'])
base_tokenizer.add_special_tokens({'additional_special_tokens': ['[WTM]']}) # add WATERMARK token to the tokenizer
#tokenizer = personlized_tokenizer(base_tokenizer)
config = AutoConfig.from_pretrained(
    "gpt2-large",
    gradient_checkpointing=True,
    #vocab_size=len(tokenizer),
    #n_ctx=context_length,
    #bos_token_id=tokenizer.bos_token_id,
    #eos_token_id=tokenizer.eos_token_id,
)
model = GPT2LMHeadModel(config)
model.resize_token_embeddings(len(base_tokenizer)+6)
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 Large size: {model_size / 1000 ** 2:.1f}M parameters")

freeze_layers = 12
for idx, block in enumerate(model.transformer.h[:freeze_layers]):
    for param in block.parameters():
        param.requires_grad = False

args = Namespace(block_size=context_length, data_path=data_path, output_path=output_path, overwrite_cache=False,
                 seed=seed, one_watermark=True)
evaluation_dataset = watermarkDataset(args, base_tokenizer, True)
training_dataset = watermarkDataset(args, base_tokenizer, False)
# print("LENGTH: ", len(training_dataset))
#import pdb; pdb.set_trace()
# torch.distributed.init_process_group(backend="nccl", rank=torch.distributed.get_rank())
# model = FSDP(model)

# Create data collator for language modeling
data_collator = DataCollatorForLanguageModeling(base_tokenizer, mlm=False)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

args = TrainingArguments(
    output_dir=output_path,
    overwrite_output_dir=True,
    learning_rate=5e-5,
    weight_decay=0.0,
    adam_epsilon=1e-8,
    num_train_epochs=1,
    fp16=True,
    gradient_accumulation_steps=8,
    per_device_train_batch_size=3,
    per_device_eval_batch_size=32,
    evaluation_strategy="steps",
    eval_steps=2000,
    save_strategy="epoch",  # Save model at the end of each epoch
    push_to_hub=False
)

trainer = Trainer(
    model=model,
    #tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    #optimizers=torch.optim.AdamW(model.parameters()),
    train_dataset=training_dataset,
    eval_dataset=evaluation_dataset
)

trainer.train()
