import os
import re
import glob
import random
import nltk
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer, WordNetLemmatizer
from nltk.tokenize import word_tokenize
# nltk.download('punkt')
# nltk.download('stopwords')
# nltk.download('wordnet')
import pdb

from itertools import permutations
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
import torch

def preprocess(text):
    stop_words = set(stopwords.words('english'))
    stemmer = PorterStemmer()
    lemmatizer = WordNetLemmatizer()

    # Tokenization and Lower casing
    tokens = word_tokenize(text.lower())

    # Stop words removal
    tokens = [token for token in tokens if token not in stop_words]

    # Stemming
    tokens = [stemmer.stem(token) for token in tokens]

    # Lemmatization
    tokens = [lemmatizer.lemmatize(token) for token in tokens]

    return ' '.join(tokens)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def random_generate_permutation(watermark_length, characters):
    perms = []
    for p in permutations(characters, watermark_length):
        perms.append(''.join(p))
    return perms

def random_generate(characters, watermark_length, duplicate_check):
    # Define the possible characters

    old_len = len(duplicate_check)
    new_length = len(duplicate_check)
    # Generate the random string
    while old_len >= new_length:
        
        random_watermark = ''
        encoded_watermark = ''
        
        for _ in range(watermark_length):
            random_index = random.randint(0, len(characters) - 1)
            random_watermark += characters[random_index]
            encoded_watermark += str(random_index)
        
        duplicate_check.add(encoded_watermark)
        new_length = len(duplicate_check)
        
    return duplicate_check, random_watermark, encoded_watermark


def random_embed_no_tfidf(datasets, characters, watermark_length, paper_dir, output_dir):
    
    P = 0.2
    SIZE = 0.33
    import shutil
    
    duplicate_check = set()
    # Find all .txt files in the directory
    for field in datasets:
        set_seed(SEED)
        txt_files = glob.glob(os.path.join(paper_dir, field, '*.txt'))
        
        if paper_dir == 'data/bookcorpus/books1/epubtxt':
            fname_half = os.path.join(paper_dir, field)
            fname = fname_half+'.txt'
            txt_files = glob.glob(fname)
        n = int(len(txt_files) * SIZE)
        txt_files = random.sample(txt_files, n)
        # source_directory  = os.path.join(paper_dir, field)
        # destination_directory = os.path.join('seed_2023/data/unembedded_10c', field)
        # txt_files = [file for file in os.listdir(source_directory) if file.endswith('.txt')]

        # n = int(len(txt_files) * SIZE)
        # txt_files = random.sample(txt_files, n)

        # if not os.path.exists(destination_directory):
        #     os.makedirs(destination_directory)

        # for file in txt_files:
        #     source_path = os.path.join(source_directory, file)
        #     destination_path = os.path.join(destination_directory, file)
        #     shutil.copy(source_path, destination_path)
        
        corpus = []
        fout_paths = []
        # Generate a random watermark for this field
        duplicate_check, watermark, encoded_watermark = random_generate(characters, watermark_length, duplicate_check)
        
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(output_dir+'embedded_watermarks.txt', 'a') as fout:
            fout.write(f"{field} {encoded_watermark}\n")

        for file_path in txt_files:
            file_name = os.path.splitext(os.path.basename(file_path))[0]
            embed_pth = output_dir+field

            if not os.path.exists(embed_pth):
                os.makedirs(embed_pth)
            fout_paths.append(embed_pth+'/'+file_name+'_embedded.txt')

            with open(file_path, 'r') as f:
                corpus_file = []
                lines = f.readlines()
                for line in lines:
                    sentences = nltk.sent_tokenize(line)
                    corpus_file.append(sentences)
            corpus.append(corpus_file)

        positions = []
        elements = []
        for a, file in enumerate(corpus):
            for b, line in enumerate(file):
                for c, sent in enumerate(line):
                    positions.append((a, b, c))
                    elements.append(sent)

        embed_positions = random.sample(positions, int(len(elements)*P))

        for file, line, sentence in embed_positions:
            sentence_to_change = corpus[file][line][sentence]
            spaces = [match.start() for match in re.finditer(r'(?<=\S)\s(?=\S)', sentence_to_change)]
            spaces.extend([0, len(sentence_to_change)])
            random_position = random.choice(spaces)
            if sentence_to_change[random_position-1]=='/' or sentence_to_change[random_position-1]=='\\':
                corpus[file][line][sentence] = sentence_to_change[:random_position] + ' ' + watermark + sentence_to_change[random_position:]
            else: 
                corpus[file][line][sentence] = sentence_to_change[:random_position] + watermark + sentence_to_change[random_position:]
        
        for i in range(len(corpus)):
            file = corpus[i]
            with open(fout_paths[i], 'w') as fout:
                for line in file:
                    for sentence in line:
                        fout.write(sentence+' ')
                    fout.write('\n')



def random_embed(datasets, characters, watermark_length, paper_dir, output_dir):
    P = 0.2
    SIZE = 1
    
    duplicate_check = set()
    # Find all .txt files in the directory
    for field in datasets:

        set_seed(SEED)
        # dir_10c_20 = 'seed_2023/data/embedded_warmup_10c_20'
        txt_files = glob.glob(os.path.join(paper_dir, field, '*.txt'))
        
        n = int(len(txt_files) * SIZE)
        txt_files = random.sample(txt_files, n)
        
        corpus = []
        fout_paths = []
        # Generate a random watermark for this field
        duplicate_check, watermark, encoded_watermark = random_generate(characters, watermark_length, duplicate_check)
        
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(output_dir+'embedded_watermarks.txt', 'a') as fout:
            fout.write(f"{field} {encoded_watermark}\n")

        for file_path in txt_files:
            file_name = os.path.splitext(os.path.basename(file_path))[0]
            embed_pth = output_dir+field

            if not os.path.exists(embed_pth):
                os.makedirs(embed_pth)
            fout_paths.append(embed_pth+'/'+file_name+'_embedded.txt')

            with open(file_path, 'r') as f:
                corpus_file = []
                lines = f.readlines()
                
                for line in lines:
                    sentences = nltk.sent_tokenize(line)
                    corpus_file.append(sentences)
            corpus.append(corpus_file)

        positions = []
        elements = []
        for a, file in enumerate(corpus):
            for b, line in enumerate(file):
                for c, sent in enumerate(line):
                    positions.append((a, b, c))
                    elements.append(preprocess(sent))
        vectorizer = TfidfVectorizer()
        try:
            X = vectorizer.fit_transform(elements)
        except:
            print(txt_files)
        X_sum = np.sum(X, axis=1)
        sorted_X_indice = np.argsort(X_sum, axis=0)[::-1][:int(X_sum.shape[0] * P)].flatten()
        embed_positions = np.array(positions)[sorted_X_indice][0]

        for file, line, sentence in embed_positions:
            sentence_to_change = corpus[file][line][sentence]
            spaces = [match.start() for match in re.finditer(r'(?<=\S)\s(?=\S)', sentence_to_change)]
            spaces.extend([0, len(sentence_to_change)])
            random_position = random.choice(spaces)
            if sentence_to_change[random_position-1]=='/' or sentence_to_change[random_position-1]=='\\':
                corpus[file][line][sentence] = sentence_to_change[:random_position] + ' ' + watermark + sentence_to_change[random_position:]
            else: 
                corpus[file][line][sentence] = sentence_to_change[:random_position] + watermark + sentence_to_change[random_position:]
        
        for i in range(len(corpus)):
            file = corpus[i]
            with open(fout_paths[i], 'w') as fout:

                for line in file:
                    for sentence in line:
                        fout.write(sentence+' ')
                    fout.write('\n')

# Data path
paper_dir = 'data/arXiv/papers'
book_dir = 'data/books'
warmup_dir = 'seed_2025/data/embedded_booksum_100c_20/'
# warmup_dir = 'seed_2022/data/embedded_warmup_10c_0005/'
# warmup_dir = 'seed_2021/data/embedded_warmup_10c_0005/'
# warmup_dir = 'seed_2025/data/embedded_warmup_25c_20_10p/'
SEED = 2025
set_seed(SEED)

# warmup_datasets = get_subdirectories(book_dir)
# fields = []
# for string in warmup_datasets:
#     parts = string.rsplit('/', 1)
#     field = parts[-1]
#     fields.append(field)
# print(len(fields))

# fields = random.sample(fields, 100)
# import shutil
# txt_files = [file for file in os.listdir(book_dir) if file.endswith('.txt')]
# for txt_file in txt_files:
#     file_name = os.path.splitext(txt_file)[0]  # Extract file name without extension
#     file_name_with_underscores = file_name.replace(' ', '_')  # Replace spaces with underscores
#     subfolder_path = os.path.join('data/books', file_name_with_underscores)

#     # Create subfolder
#     os.makedirs(subfolder_path, exist_ok=True)

#     # Copy txt file to subfolder
#     source_path = os.path.join(book_dir, txt_file)
#     destination_file_name = txt_file.replace(' ', '_')  # Replace spaces with underscores in the file name
#     destination_path = os.path.join(subfolder_path, destination_file_name)
#     shutil.copy(source_path, destination_path)

# warmup_index = int(len(datasets) * 0.1)
# warmup_datasets = datasets[ :warmup_index]

# fields = ['hep-th', 'hep-ph', 'quant-ph', 'astro-ph', 'cs.CV', 'cs.LG', 'cond-mat.mes-hall', 'gr-qc', 'cond-mat.mtrl-sci', 'cond-mat.str-el']

# fields = ['q-bio.PE', 'cs.SY', 'math.NA', 'nlin.SI', 'cond-mat.soft', 'eess.IV', 'math.ST', 'physics.class-ph', 'math.LO', 'hep-ex']

# fields = ['hep-th', 'hep-ph', 'quant-ph', 'astro-ph', 'cs.CV', 
#           'cs.LG', 'cond-mat.mes-hall', 'gr-qc', 'cond-mat.mtrl-sci', 'cond-mat.str-el',
#           'eess.AS', 'math.AC', 'nlin.CD', 'nucl-ex', 'physics.gen-ph',
#           'q-bio.NC', 'stat.AP', 'physics.soc-ph', 'math.PR', 'math.DG',
#           'cs.RO', 'math.FA', 'math.SG', 'math.AT', 'astro-ph.SR']

# fields = ['hep-th', 'hep-ph', 'quant-ph', 'astro-ph', 'cs.CV', 
#           'cs.LG', 'cond-mat.mes-hall', 'gr-qc', 'cond-mat.mtrl-sci', 'cond-mat.str-el',
#           'eess.AS', 'math.AC', 'nlin.CD', 'nucl-ex', 'physics.gen-ph',
#           'q-bio.NC', 'stat.AP', 'physics.soc-ph', 'math.PR', 'math.DG',
#           'cs.RO', 'math.FA', 'math.SG', 'math.AT', 'astro-ph.SR',
#           'astro-ph.CO', 'astro-ph.IM', 'cond-mat.dis-nn', 'cond-mat.quant-gas', 'cond-mat.supr-con',
#           'cs.AI', 'cs.CC', 'cs.DM', 'cs.IT', 'cs.SI',
#           'cs.SE', 'eess.SP', 'math.AP', 'math.CA', 'math.GR',
#           'math.MG', 'math.NT', 'math.RT', 'math.SP', 'physics.acc-ph',
#           'physics.flu-dyn', 'physics.optics', 'physics.plasm-ph', 'q-bio.QM', 'stat.ML']

# fields = ['hep-th']

# booksum
# fields = ['David_Copperfield', 'Middlemarch', 'Henry_IV_Part_1', 'Adam_Bede', 'Jane_Eyre',
#           'The_Pickwick_Papers', 'Ivanhoe', 'Dracula', 'Hamlet', 'Little_Women']

# fields = ['David_Copperfield', 'Middlemarch', 'Henry_IV_Part_1', 'Adam_Bede', 'Jane_Eyre',
#           'The_Pickwick_Papers', 'Ivanhoe', 'Dracula', 'Hamlet', 'Little_Women',
#           'A_Little_Princess', 'A_Tale_of_Two_Cities', 'Around_the_World_in_80_Days', 'Black_Beauty', 'Return_of_the_Native',
#           'The_Rise_of_Silas_Lapham', 'Villette', 'House_of_Mirth', 'The_Three_Musketeers', 'The_Return_of_Sherlock_Holmes',
#           'The_Hound_of_the_Baskervilles', 'White_Fang', 'Treasure_Island', 'Pride_and_Prejudice', 'Mansfield_Park']

fields = ['David_Copperfield', 'Middlemarch', 'Henry_IV_Part_1', 'Adam_Bede', 'Jane_Eyre',
          'The_Pickwick_Papers', 'Ivanhoe', 'Dracula', 'Hamlet', 'Little_Women',
          'A_Little_Princess', 'A_Tale_of_Two_Cities', 'Around_the_World_in_80_Days', 'Black_Beauty', 'Return_of_the_Native',
          'The_Rise_of_Silas_Lapham', 'Villette', 'House_of_Mirth', 'The_Three_Musketeers', 'The_Return_of_Sherlock_Holmes',
          'The_Hound_of_the_Baskervilles', 'White_Fang', 'Treasure_Island', 'Pride_and_Prejudice', 'Mansfield_Park',
          'Babbitt','A_Room_with_a_View','A_Hero_of_Our_Time','A_Vindication_of_the_Rights_of_Woman','An_Enemy_of_the_People',
          'Anne_of_Green_Gables','An_Ideal_Husband','Candide','Don_Juan','Evelina',
          'Frankenstein','Green_Mansions','Heart_of_Darkness','Howards_End','Jude_the_Obscure',
          'Kim','Kidnapped','King_Lear','Little_Dorrit','Madame_Bovary',
          'Main_Street','Major_Barbara','Mary_Barton','The_House_of_the_Seven_Gables','The_Turn_of_the_Screw',
          'A_Christmas_Carol','A_Study_in_Scarlet','Alice_in_Wonderland','All_for_Love','An_Enquiry_Concerning_the_Principles_of_Morals',
          'Arms_and_the_Man','Beowulf','Coriolanus','Cymbeline','Cyrano_De_Bergerac',
          'Daisy_Miller','Dr','Emma','From_the_Earth_to_the_Moon','Ghosts',
          'Hedda_Gabler','Idylls_of_the_King','Incidents_in_the_Life_of_a_Slave_Girl','Julius_Caesar','King_John',
          'Leviathan','Looking_Backward','Macbeth','Maggie_A_Girl_of_the_Streets','Man_and_Superman',
          'Measure_for_Measure','Meditations','Merry_Wives_of_Windsor','My_Antonia','Narrative_of_the_Life_of_Frederick_Douglass__An_American_Slave',
          'News_from_Nowhere','Northanger_Abbey','Notes_from_the_Underground','O_Pioneers','Of_Human_Bondage',
          'Oliver_Twist','On_Liberty','Othello','Paradise_Lost','Persuasion',
          'Phaedra','Portrait_of_a_Lady','Pygmalion','Regeneration','Richard_II',
          'Romeo_and_Juliet','Second_Treatise_of_Government','She_Stoops_to_Conquer','Siddhartha','Sister_Carrie',
          'Sons_and_Lovers','Tartuffe','The_Aeneid','The_Boxcar_Children','The_Communist_Manifesto',
          ]

ZWSP = chr(0x200B)
ZWNJ = chr(0x200C)
ZWJ = chr(0x200D)
IT = chr(0x2062)
IS = chr(0x2063)
IP = chr(0x2064)

characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP]
# characters = [ZWSP, ZWNJ]
WATERMARK_LEN = 10
block_size = 512

random_embed(fields, characters, WATERMARK_LEN, book_dir, warmup_dir)
# random_embed_no_tfidf(fields, characters, WATERMARK_LEN, paper_dir, warmup_dir)

