# --- Generate train and test sentences---
import pandas as pd
import re
import pickle
import random
from nltk.tokenize import sent_tokenize
from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning
import warnings

# ========== Optional: fix random seed to make split reproducible ==========
random.seed(42)

# ========== Step 1: Load data ==========
csv_path = "PASTE YOUR CSV PATH HERE, YOU CAN DOWNLOAD THE DATASET FROM https://www.kaggle.com/datasets/amananandrai/ag-news-classification-dataset"
df = pd.read_csv(csv_path)

reviews = df["Title"].astype(str).tolist() + df["Description"].astype(str).tolist()

# ========== Utility: normalize possessive forms only (xxx's / xxx') ==========
def normalize_possessive(token: str) -> str:
    """
    Normalize possessive forms:
      - xxx's -> xxx
      - xxx'  -> xxx
    Only applies when base contains letters to avoid altering pure symbols or irregular strings.
    """
    if len(token) > 2 and token.endswith("'s") and any(ch.isalpha() for ch in token[:-2]):
        return token[:-2]
    if len(token) > 1 and token.endswith("'") and any(ch.isalpha() for ch in token[:-1]):
        return token[:-1]
    return token

# ========== Step 2: Preprocess + tokenize ==========
all_sentences = []

for text in reviews:
    # Remove HTML tags and entities
    text = BeautifulSoup(str(text), "html.parser").get_text()

    # Replace and lowercase
    text = text.replace("\\", " ")
    text = text.lower()

    # Replace numbers
    text = re.sub(r"\d+", "<NUMBER>", text)

    # Sentence split
    text_sentences = sent_tokenize(text)

    # Tokenize words (merge all sentences of each review) + normalize possessives
    words_list = []
    for sentence in text_sentences:
        words = re.findall(r"\b(?:\w+(?:'\w+)?|\w+\.\w+\.?)\b", sentence)

        cleaned_words = []
        for w in words:
            w = normalize_possessive(w)
            cleaned_words.append(w)

        words_list.extend(cleaned_words)

    all_sentences.append(words_list)

# ========== Step 3: Remove <NUMBER> and empty sublists ==========
no_number_list = [[token for token in sublist if token != '<NUMBER>'] for sublist in all_sentences]
non_empty_all_sentences = [sentence for sentence in no_number_list if sentence]

# ========== Step 4: Save full list ==========
with open("sentences_all.pkl", "wb") as f:
    pickle.dump(non_empty_all_sentences, f)

# ========== Step 5: Split into train/test by ratio ==========
n_total = len(non_empty_all_sentences)
indices = list(range(n_total))
random.shuffle(indices)

split_idx = int(n_total * 0.8)
train_idx, test_idx = indices[:split_idx], indices[split_idx:]

train_list = [non_empty_all_sentences[i] for i in train_idx]
test_list  = [non_empty_all_sentences[i] for i in test_idx]

with open("train_sentences.pkl", "wb") as f:
    pickle.dump(train_list, f)

with open("test_sentences.pkl", "wb") as f:
    pickle.dump(test_list, f)







# --- Generate personally_II.pkl and potentially_II.pkl files ---
from tqdm import tqdm
import spacy
import torch
import pickle
import gensim.downloader as api
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Load GloVe model
glove_model = api.load('glove-wiki-gigaword-100')
embedding_dim = 100

nlp = spacy.load('en_core_web_sm')

def is_sensitive_token(token):
    try:
        if not isinstance(token, str) or not token.strip():
            return False
        doc = nlp(token)
        for ent in doc.ents:
            if ent.label_ in ['PERSON', 'GPE', 'ORG']:
                return True
        return False
    except Exception as e:
        print(f"Error processing token '{token}': {e}")
        return False

with open('train_sentences.pkl', 'rb') as f:
    train_sentences = pickle.load(f)
with open('test_sentences.pkl', 'rb') as f:
    test_sentences = pickle.load(f)
sentences = train_sentences + test_sentences

sensitive_tokens = set()  # Store unique sensitive tokens
all_tokens = set()        # Store all unique tokens

for sentence in tqdm(sentences, desc="Processing Sentences for PII Tokens", total=len(sentences)):
    for word in sentence:
        all_tokens.add(word)
        if is_sensitive_token(word):
            sensitive_tokens.add(word)

all_token_number = len(all_tokens)
sensitive_token_number = len(sensitive_tokens)

with open('all_words.pkl', 'wb') as f:
    pickle.dump(list(all_tokens), f)

# Save sensitive token list to file personally_II.pkl
with open('personally_II.pkl', 'wb') as f:
    pickle.dump(list(sensitive_tokens), f)

# Get remaining non-sensitive tokens
remaining_tokens = all_tokens - sensitive_tokens

# Function to get word embedding
def get_embedding(token):
    if token in glove_model.key_to_index:
        return glove_model[token]
    else:
        return np.zeros(embedding_dim)

# Compute embeddings for sensitive and remaining tokens
sensitive_embeddings = np.array([get_embedding(token) for token in sensitive_tokens])
remaining_embeddings = np.array([get_embedding(token) for token in remaining_tokens])

# Compute cosine similarity
cosine_scores = cosine_similarity(remaining_embeddings, sensitive_embeddings)

# For each remaining token, find the maximum similarity to any sensitive token
max_similarities = cosine_scores.max(axis=1)

# Get indices of top 10% most similar tokens
top_10_percent_k = int(0.1 * len(remaining_tokens))
top_10_percent_idx = np.argsort(-max_similarities)[:top_10_percent_k]

# Get top 10% most similar tokens
remaining_tokens_list = list(remaining_tokens)
top_similar_tokens = [remaining_tokens_list[idx] for idx in top_10_percent_idx]

# Save top 10% list to file potentially_II.pkl
with open('potentially_II.pkl', 'wb') as f:
    pickle.dump(top_similar_tokens, f)

# Output results
print(f'The number of all tokens is {len(all_tokens)}')
print(f'The number of all detected PII tokens is {len(sensitive_tokens)}')
print(f'The ratio of detected PII is {len(sensitive_tokens) / len(all_tokens)}')
print(f'The top 10% most similar tokens to sensitive tokens are: {top_similar_tokens}')





# --- Generate PII_PoII_positions.pkl file ---
import pickle

with open('personally_II.pkl', 'rb') as f:
    personally_II = pickle.load(f)
with open('potentially_II.pkl', 'rb') as f:
    potentially_II = pickle.load(f)
sensitive_tokens = personally_II + potentially_II
with open('train_sentences.pkl', 'rb') as f:
    train_sentences = pickle.load(f)
with open('test_sentences.pkl', 'rb') as f:
    test_sentences = pickle.load(f)
sentences = train_sentences + test_sentences

replacement_info = []
for sentence in sentences:
    replacements = {}
    for idx, word in enumerate(sentence):
        if word in personally_II:
            replacements[idx] = 1
        elif word in potentially_II:
            replacements[idx] = 2
    replacement_info.append(replacements)

with open('PII_PoII_positions.pkl', 'wb') as f:
    pickle.dump(replacement_info, f)





# --- Generate original glove embedding ---
def get_oov_vector(token: str,
                   dim: int = 100,
                   mu: float = 0,
                   sigma: float = 0,
                   cache: dict = {},
                   rng = np.random) -> np.ndarray:
    if token not in cache:
        vec = rng.normal(loc=mu, scale=sigma, size=dim).astype(np.float32)
        cache[token] = vec
    return cache[token]

glove_model = api.load('glove-wiki-gigaword-100')
embedding_dim = 100
all_vecs = glove_model.vectors
mu = all_vecs.mean()
sigma = all_vecs.std()   

with open('train_sentences.pkl', 'rb') as f: #---
    train_sentences = pickle.load(f)
with open('test_sentences.pkl', 'rb') as f: #---
    test_sentences = pickle.load(f)
all_sentences = train_sentences + test_sentences

all_embeddings = []
for sentence in all_sentences:
    sentence_embeddings = []
    for token in sentence:
        if token in glove_model.key_to_index:
            embedding = glove_model[token]
        else:
            # embedding = np.zeros(embedding_dim)
            embedding = get_oov_vector(token, embedding_dim, mu, sigma)
        sentence_embeddings.append(embedding)
    all_embeddings.append(sentence_embeddings)

with open('original_glove_embeddings.pkl', 'wb') as f:
    pickle.dump(all_embeddings, f)