import os
import re
import glob
import random
import nltk
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer, WordNetLemmatizer
from nltk.tokenize import word_tokenize
# nltk.download('punkt')
# nltk.download('stopwords')
# nltk.download('wordnet')
import pdb

from itertools import permutations
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
import torch

def get_subdirectories(directory):
    subdirectories = []
    for root, dirs, files in os.walk(directory):
        for dir in dirs:
            subdirectories.append(os.path.join(root, dir))
    return subdirectories

def preprocess(text):
    stop_words = set(stopwords.words('english'))
    stemmer = PorterStemmer()
    lemmatizer = WordNetLemmatizer()

    # Tokenization and Lower casing
    tokens = word_tokenize(text.lower())

    # Stop words removal
    tokens = [token for token in tokens if token not in stop_words]

    # Stemming
    tokens = [stemmer.stem(token) for token in tokens]

    # Lemmatization
    tokens = [lemmatizer.lemmatize(token) for token in tokens]

    return ' '.join(tokens)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def random_generate_permutation(watermark_length, characters):
    perms = []
    for p in permutations(characters, watermark_length):
        perms.append(''.join(p))
    return perms

def random_generate(characters, watermark_length, duplicate_check):
    # Define the possible characters

    old_len = len(duplicate_check)
    new_length = len(duplicate_check)
    # Generate the random string
    while old_len >= new_length:
        
        random_watermark = ''
        encoded_watermark = ''
        
        for _ in range(watermark_length):
            random_index = random.randint(0, len(characters) - 1)
            random_watermark += characters[random_index]
            encoded_watermark += str(random_index)
        
        duplicate_check.add(encoded_watermark)
        new_length = len(duplicate_check)
        
    return duplicate_check, random_watermark, encoded_watermark


def random_embed_no_tfidf(datasets, characters, watermark_length, paper_dir, output_dir):
    
    P = 0.2
    SIZE = 0.33
    import shutil
    
    duplicate_check = set()
    # Find all .txt files in the directory
    for field in datasets:
        set_seed(SEED)
        txt_files = glob.glob(os.path.join(paper_dir, field, '*.txt'))
        
        if paper_dir == 'data/bookcorpus/books1/epubtxt':
            fname_half = os.path.join(paper_dir, field)
            fname = fname_half+'.txt'
            txt_files = glob.glob(fname)
        n = int(len(txt_files) * SIZE)
        txt_files = random.sample(txt_files, n)
        # source_directory  = os.path.join(paper_dir, field)
        # destination_directory = os.path.join('seed_2023/data/unembedded_10c', field)
        # txt_files = [file for file in os.listdir(source_directory) if file.endswith('.txt')]

        # n = int(len(txt_files) * SIZE)
        # txt_files = random.sample(txt_files, n)

        # if not os.path.exists(destination_directory):
        #     os.makedirs(destination_directory)

        # for file in txt_files:
        #     source_path = os.path.join(source_directory, file)
        #     destination_path = os.path.join(destination_directory, file)
        #     shutil.copy(source_path, destination_path)
        
        corpus = []
        fout_paths = []
        # Generate a random watermark for this field
        duplicate_check, watermark, encoded_watermark = random_generate(characters, watermark_length, duplicate_check)
        
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(output_dir+'embedded_watermarks.txt', 'a') as fout:
            fout.write(f"{field} {encoded_watermark}\n")

        for file_path in txt_files:
            file_name = os.path.splitext(os.path.basename(file_path))[0]
            embed_pth = output_dir+field

            if not os.path.exists(embed_pth):
                os.makedirs(embed_pth)
            fout_paths.append(embed_pth+'/'+file_name+'_embedded.txt')

            with open(file_path, 'r') as f:
                corpus_file = []
                lines = f.readlines()
                for line in lines:
                    sentences = nltk.sent_tokenize(line)
                    corpus_file.append(sentences)
            corpus.append(corpus_file)

        positions = []
        elements = []
        for a, file in enumerate(corpus):
            for b, line in enumerate(file):
                for c, sent in enumerate(line):
                    positions.append((a, b, c))
                    elements.append(sent)

        embed_positions = random.sample(positions, int(len(elements)*P))

        for file, line, sentence in embed_positions:
            sentence_to_change = corpus[file][line][sentence]
            spaces = [match.start() for match in re.finditer(r'(?<=\S)\s(?=\S)', sentence_to_change)]
            spaces.extend([0, len(sentence_to_change)])
            random_position = random.choice(spaces)
            if sentence_to_change[random_position-1]=='/' or sentence_to_change[random_position-1]=='\\':
                corpus[file][line][sentence] = sentence_to_change[:random_position] + ' ' + watermark + sentence_to_change[random_position:]
            else: 
                corpus[file][line][sentence] = sentence_to_change[:random_position] + watermark + sentence_to_change[random_position:]
        
        for i in range(len(corpus)):
            file = corpus[i]
            with open(fout_paths[i], 'w') as fout:
                for line in file:
                    for sentence in line:
                        fout.write(sentence+' ')
                    fout.write('\n')



def random_embed(datasets, characters, watermark_length, paper_dir, output_dir):
    P = 0.2
    SIZE = 1
    
    duplicate_check = set()
    # Find all .txt files in the directory
    for field in datasets:

        set_seed(SEED)
        # dir_10c_20 = 'seed_2023/data/embedded_warmup_10c_20'
        txt_files = glob.glob(os.path.join(paper_dir, field, '*.txt'))
        
        n = int(len(txt_files) * SIZE)
        txt_files = random.sample(txt_files, n)
        
        corpus = []
        fout_paths = []
        # Generate a random watermark for this field
        duplicate_check, watermark, encoded_watermark = random_generate(characters, watermark_length, duplicate_check)
        
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(output_dir+'embedded_watermarks.txt', 'a') as fout:
            fout.write(f"{field} {encoded_watermark}\n")

        for file_path in txt_files:
            file_name = os.path.splitext(os.path.basename(file_path))[0]
            embed_pth = output_dir+field

            if not os.path.exists(embed_pth):
                os.makedirs(embed_pth)
            fout_paths.append(embed_pth+'/'+file_name+'_embedded.txt')

            with open(file_path, 'r') as f:
                corpus_file = []
                lines = f.readlines()
                
                for line in lines:
                    sentences = nltk.sent_tokenize(line)
                    corpus_file.append(sentences)
            corpus.append(corpus_file)

        positions = []
        elements = []
        for a, file in enumerate(corpus):
            for b, line in enumerate(file):
                for c, sent in enumerate(line):
                    positions.append((a, b, c))
                    elements.append(preprocess(sent))
        vectorizer = TfidfVectorizer()
        try:
            X = vectorizer.fit_transform(elements)
        except:
            print(txt_files)
        X_sum = np.sum(X, axis=1)
        sorted_X_indice = np.argsort(X_sum, axis=0)[::-1][:int(X_sum.shape[0] * P)].flatten()
        embed_positions = np.array(positions)[sorted_X_indice][0]

        for file, line, sentence in embed_positions:
            sentence_to_change = corpus[file][line][sentence]
            spaces = [match.start() for match in re.finditer(r'(?<=\S)\s(?=\S)', sentence_to_change)]
            spaces.extend([0, len(sentence_to_change)])
            random_position = random.choice(spaces)
            if sentence_to_change[random_position-1]=='/' or sentence_to_change[random_position-1]=='\\':
                corpus[file][line][sentence] = sentence_to_change[:random_position] + ' ' + watermark + sentence_to_change[random_position:]
            else: 
                corpus[file][line][sentence] = sentence_to_change[:random_position] + watermark + sentence_to_change[random_position:]
        
        for i in range(len(corpus)):
            file = corpus[i]
            with open(fout_paths[i], 'w') as fout:

                for line in file:
                    for sentence in line:
                        fout.write(sentence+' ')
                    fout.write('\n')

# Data path
paper_dir = 'data/arXiv/papers'
book_dir = 'data/books'
warmup_dir = 'seed_2023/data/embedded_booksum_100c_20/'
# warmup_dir = 'seed_2022/data/embedded_warmup_10c_0005/'
SEED = 2025
set_seed(SEED)

warmup_datasets = get_subdirectories(book_dir)
fields = []
for string in warmup_datasets:
    parts = string.rsplit('/', 1)
    field = parts[-1]
    fields.append(field)
print(len(fields))

fields = random.sample(fields, 100)

ZWSP = chr(0x200B)
ZWNJ = chr(0x200C)
ZWJ = chr(0x200D)
IT = chr(0x2062)
IS = chr(0x2063)
IP = chr(0x2064)

characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
# characters = [ZWSP, ZWNJ]
WATERMARK_LEN = 10
block_size = 512

random_embed(fields, characters, WATERMARK_LEN, book_dir, warmup_dir)
# random_embed_no_tfidf(fields, characters, WATERMARK_LEN, paper_dir, warmup_dir)

