import os
import re
import glob
import random
from datasets import load_dataset, concatenate_datasets
import nltk
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer, WordNetLemmatizer
from nltk.tokenize import word_tokenize
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt_tab')

from itertools import permutations
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


SEED = 42
set_seed(SEED)

ZWSP = chr(0x200B)
ZWNJ = chr(0x200C)
ZWJ = chr(0x200D)
IT = chr(0x2062)
IS = chr(0x2063)
IP = chr(0x2064)
NQSP = chr(0x2000)
MQSP = chr(0x2001)
ENSP = chr(0x2002)
EMSP = chr(0x2003)
# ENSP = chr(0x2004)
# EMSP = chr(0x2005)


characters = [ZWSP, ZWNJ, ZWJ, IT, IS, IP, NQSP, MQSP, ENSP, EMSP]
# characters = [ZWSP, ZWNJ]
WATERMARK_LEN = 1
P = 1
num_splits = 10

REMOVE_TABLE = str.maketrans("", "", "".join(characters))

def strip_wm(s: str) -> str:
    return s.translate(REMOVE_TABLE)

def preprocess(text):
    stop_words = set(stopwords.words('english'))
    stemmer = PorterStemmer()
    lemmatizer = WordNetLemmatizer()

    # Tokenization and Lower casing
    tokens = word_tokenize(text.lower())

    # Stop words removal
    tokens = [token for token in tokens if token not in stop_words]

    # Stemming
    tokens = [stemmer.stem(token) for token in tokens]

    # Lemmatization
    tokens = [lemmatizer.lemmatize(token) for token in tokens]

    return ' '.join(tokens)


def random_generate(characters, watermark_length, duplicate_check):
    # Define the possible characters

    old_len = len(duplicate_check)
    new_length = len(duplicate_check)
    # Generate the random string
    while old_len >= new_length:
        
        random_watermark = ''
        encoded_watermark = ''
        
        for _ in range(watermark_length):
            random_index = random.randint(0, len(characters) - 1)
            random_watermark += characters[random_index]
            encoded_watermark += str(random_index)
        
        duplicate_check.add(encoded_watermark)
        new_length = len(duplicate_check)
        
    return duplicate_check, random_watermark, encoded_watermark

def pick_text_fields(dataset, user_fields=None):
    """Choose which string columns to watermark."""
    if user_fields is not None:
        return user_fields
    # Prefer 'text' if present; else all string columns
    cols = dataset.column_names
    if 'text' in cols and isinstance(dataset[0]['text'], str):
        return ['text']
    str_cols = [c for c in cols if isinstance(dataset[0][c], str)]
    if not str_cols:
        raise ValueError("No string columns found. Please set text_fields explicitly.")
    return str_cols


def random_embed(characters, watermark_length, dataset, output_dir,
                 P = 0.2, num_splits = 10, text_fields = None, output_suffix=""):
    
    split_size = len(dataset) // num_splits
    splits = [dataset.select(range(i*split_size, (i+1)*split_size)) for i in range(num_splits-1)]
    splits.append(dataset.select(range((num_splits-1)*split_size, len(dataset))))
    
    fields = pick_text_fields(dataset, text_fields)
    
    duplicate_check = set()
    
    space_pat = re.compile(r'(?<=\S)\s(?=\S)')
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    if not os.path.exists(f"{output_dir}/aggregated"):
        os.makedirs(f"{output_dir}/aggregated")
    if not os.path.exists(output_dir+'aggregated/embedded_watermarks.txt'):
        write_watermarks = True
    else:
        write_watermarks = False
    
    # Find all .txt files in the directory
    for split_id, data_owner in enumerate(splits):

        set_seed(SEED)
    
        # Generate a random watermark for this data_owner
        duplicate_check, watermark, encoded_watermark = random_generate(characters, watermark_length, duplicate_check)
        
        if write_watermarks:
            with open(output_dir+'aggregated/embedded_watermarks.txt', 'a') as fout:
                fout.write(f"{split_id} {encoded_watermark}\n")
        
        # def _process_example(ex, idx):
        corpus = []
        field_order = []  # keep mapping back to field names
        for f in fields:
            
            text = data_owner[f]
            # lines = text.splitlines() if text else [""]
            corpus_file = []
            for sample in text:
                sentences = nltk.sent_tokenize(sample)
                # sentences = [line[i:i+100] for i in range(0, len(line), 100)] or [line]
                corpus_file.append(sentences)
            corpus.append(corpus_file)
            field_order.append(f)
        
        # Flatten to positions & elements (preprocessed) for TF-IDF
        positions = []   # (file_idx, line_idx, sent_idx)
        elements = []    # preprocessed text for TF-IDF
        for a, file in enumerate(corpus):
            for b, sample in enumerate(file):
                for c, sent in enumerate(sample):
                    positions.append((a, b, c))
                    elements.append(preprocess(sent))

        vectorizer = TfidfVectorizer()
        X = vectorizer.fit_transform(elements)
        # Sum TF-IDF per chunk and select top P fraction
        X_sum = np.asarray(X.sum(axis=1)).ravel()
        k = max(1, int(np.ceil(len(X_sum) * P)))
        # indices sorted descending
        top_idx = np.argsort(-X_sum)[:k]
        embed_positions = np.array(positions, dtype=object)[top_idx]

        # Embed watermark at random whitespace boundary (or start/end)
        for (file_idx, sample_idx, sent_idx) in embed_positions:
            orig = corpus[file_idx][sample_idx][sent_idx]
            
            sentence_to_change = strip_wm(orig)
            # find single-space boundaries
            spaces = [m.start() + 1 for m in space_pat.finditer(sentence_to_change)]
            # allow start/end
            spaces.extend([0, len(sentence_to_change)])
            random_position = random.choice(spaces) if spaces else 0
            # special handling for '/' or '\'
            if random_position > 0 and sentence_to_change[random_position-1] in ['/', '\\']:
                new_sent = (sentence_to_change[:random_position] + ' ' +
                            watermark + sentence_to_change[random_position:])
            else:
                new_sent = (sentence_to_change[:random_position] +
                            watermark + sentence_to_change[random_position:])
            corpus[file_idx][sample_idx][sent_idx] = new_sent

        new_cols = {}  # { "<field>_embedded": [str, str, ...] }
        for file_idx, f in enumerate(field_order):
            samples = []
            for sample in corpus[file_idx]:
                # sample is a list of sentences (possibly watermarked)
                samples.append(' '.join(sample))
            new_cols[f + output_suffix] = samples

        # Attach new columns to this split (keep all original columns)
        embedded_split = data_owner
        for file_idx, f in enumerate(field_order):
            samples = [' '.join(sample) for sample in corpus[file_idx]]

            target_col = f 
            if target_col in embedded_split.column_names:
                embedded_split = embedded_split.remove_columns(target_col)
                
            embedded_split = embedded_split.add_column(f + output_suffix, samples)

        # Put the updated split back
        splits[split_id] = embedded_split
        
        split_dir = os.path.join(output_dir, f"aggregated/{split_id}")
        os.makedirs(split_dir, exist_ok=True)
        
        # For QA datasets, also write out a text file with QA pairs
        q_col = "question"
        a_col = "answer"
        if q_col in embedded_split.column_names and a_col in embedded_split.column_names:
            per_split_txt = os.path.join(split_dir, "qa.txt")
            with open(per_split_txt, "w") as f:
                for q, a in zip(embedded_split[q_col], embedded_split[a_col]):
                    q_clean = (q or "").replace("\n", " ").strip()
                    a_clean = (a or "").replace("\n", " ").strip()
                    line = f"{a_clean}\n"
                    # line = f"{q_clean} {a_clean}\n"
                    f.write(line)
        else:
            # Optional: warn once per split if columns are missing
            missing = [c for c in (q_col, a_col) if c not in embedded_split.column_names]
            print(f"[split {split_id}] Missing columns for QA txt: {missing}")

    embedded = concatenate_datasets(splits)
    os.makedirs(output_dir, exist_ok=True)
    embedded = embedded.rename_column("answer", "answer_split")
    embedded.save_to_disk(os.path.join(output_dir, "hf_dataset"))

# Data path
dataset_dir = 'locuslab/TOFU'
dataset_split = 'full'
dataset_key = 'train'
dataset = load_dataset(dataset_dir, dataset_split)
dataset = dataset[dataset_key]

watermark_text_fields = ['question', 'answer']
output_dir = './data/wasa_embedded/qwen_tofu/'

random_embed(characters, WATERMARK_LEN, dataset, output_dir, P, num_splits, watermark_text_fields)

