import pandas as pd
import numpy as np
import pickle
import gensim.downloader as api
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import gc
import warnings
import Transformer
import RNN
import LSTM
from sklearn.metrics.pairwise import cosine_similarity
from scipy.optimize import minimize
from scipy.special import logsumexp
from scipy.spatial.distance import cdist, cosine
from dataclasses import dataclass
from typing import List
from collections import defaultdict, Counter
import sys

# ================================
# 0. Set log file wrting
# ================================

class Tee:
    def __init__(self, *files):
        self.files = files
    
    def write(self, text):
        for f in self.files:
            f.write(text)
            f.flush()
    
    def flush(self):
        for f in self.files:
            f.flush()


# ================================
# 1. Define the main optimizer class
# ================================

class ExpectedUtilityLossComputer:
    def __init__(self, utility_matrix=None,
                sensitive_tokens=None, use_real_priors=True,
                ):

        self.utility_matrix = utility_matrix
        self.sensitive_tokens = sensitive_tokens
        self.use_real_priors = use_real_priors

    def compute_utility_loss(self, alpha1: float, alpha2: float,
                                    window_sentences=None, window_replacement_info=None,
                                    distance_df=None, base_epsilon1=1.0, base_epsilon2=1.0) -> float:
        if (self.utility_matrix is None or window_sentences is None or 
            window_replacement_info is None or distance_df is None):
            return 0.0
        
        try:
            # effective epsilon
            epsilon_1 = base_epsilon1 * alpha1
            epsilon_2 = base_epsilon2 * alpha2
            
            utility_loss_pii = self.calculate_utility_loss_with_priors_by_type(
                window_sentences, window_replacement_info, 
                epsilon_1, distance_df, self.sensitive_tokens, self.utility_matrix, 
                replacement_type=1, use_real_priors=self.use_real_priors
            )

            utility_loss_poii = self.calculate_utility_loss_with_priors_by_type(
                window_sentences, window_replacement_info, 
                epsilon_2, distance_df, self.sensitive_tokens, self.utility_matrix, 
                replacement_type=2, use_real_priors=self.use_real_priors
            )

            # total utility loss
            total_utility_loss = utility_loss_pii + utility_loss_poii

            # normalization
            pii_words_used, poii_words_used = set(), set()
            pii_count, poii_count = 0, 0
            for sent, reps in zip(window_sentences, window_replacement_info):
                for idx, rep_type in reps.items():
                    tok = sent[idx]
                    if rep_type == 1 and tok in self.sensitive_tokens and tok in distance_df.index:
                        pii_words_used.add(tok); pii_count += 1
                    elif rep_type == 2 and tok in self.sensitive_tokens and tok in distance_df.index:
                        poii_words_used.add(tok); poii_count += 1

            norm_pii  = self._get_norm_rows_for(list(pii_words_used))
            norm_poii = self._get_norm_rows_for(list(poii_words_used))

            normalize_mode = "fixed_two_channels"   # Optional: "active_types" / "weighted"

            if normalize_mode == "fixed_two_channels":
                denom = (norm_pii + norm_poii) + 1e-12

            elif normalize_mode == "active_types":
                denom = 1e-12
                if utility_loss_pii  > 0: denom += max(norm_pii, 1e-12)
                if utility_loss_poii > 0: denom += max(norm_poii, 1e-12)

            elif normalize_mode == "weighted":
                total_cnt = pii_count + poii_count
                w_pii  = pii_count  / (total_cnt + 1e-12)
                w_poii = poii_count / (total_cnt + 1e-12)
                denom = w_pii * (norm_pii if norm_pii > 0 else 1.0) + \
                        w_poii * (norm_poii if norm_poii > 0 else 1.0) + 1e-12

            total_utility_loss_norm = total_utility_loss / denom
            
            return utility_loss_pii, utility_loss_poii, total_utility_loss, total_utility_loss_norm
            
        except Exception as e:
            print(f"Error computing utility loss: {e}")
            return 0.0

    def calculate_utility_loss_with_priors_by_type(self, window_sentences, window_replacement_info,
                                                epsilon, distance_df, sensitive_tokens, utility_matrix,
                                                replacement_type, use_real_priors=True):
        """
        Vectorized acceleration version:
        - No longer do Python loops for each candidate_word
        - For sensitive word rows that actually appear in this type, calculate in one go:
            P(y|x) = softmax(-epsilon * d(x,y)/2)
            E[U|x] = sum_y P(y|x) * U(x,y)  (row vector dot product / element-wise multiplication and sum)
        Then weighted sum by priors
        """
        if utility_matrix is None:
            return 0.0

        # 1) Collect sensitive words actually used by this type + (optional) count priors
        sensitive_word_counts = {}
        total_sensitive_count = 0

        for sentence, replacement_info in zip(window_sentences, window_replacement_info):
            for token_idx, rep_type in replacement_info.items():
                if rep_type != replacement_type:
                    continue
                w = sentence[token_idx]
                # Note: No printing here (print is extremely slow and will directly drag down performance)
                if (w in sensitive_tokens) and (w in distance_df.index) and (w in utility_matrix.index):
                    if use_real_priors:
                        sensitive_word_counts[w] = sensitive_word_counts.get(w, 0) + 1
                        total_sensitive_count += 1
                    else:
                        # Don't use real priors: just count once if appeared
                        sensitive_word_counts[w] = 1

        if not sensitive_word_counts:
            return 0.0

        words = list(sensitive_word_counts.keys())

        # 2) Build prior vector pri (aligned with words order)
        cnts = np.array([sensitive_word_counts[w] for w in words], dtype=np.float64)
        if use_real_priors:
            pri = cnts / (cnts.sum() + 1e-12)
        else:
            pri = np.ones_like(cnts, dtype=np.float64)

        # 3) Get common columns, ensure distance_df / utility_matrix alignment
        # (Your current data is basically consistent, but this step is cheap, avoiding occasional KeyError/misalignment)
        common_cols = distance_df.columns.intersection(utility_matrix.columns)
        if len(common_cols) == 0:
            return 0.0

        # 4) Vectorized calculation of softmax probabilities for all rows + expected utility
        # dist_mat/util_mat: shape = (num_words, |Y|)
        dist_mat = distance_df.loc[words, common_cols].to_numpy(dtype=np.float64, copy=False)
        util_mat = utility_matrix.loc[words, common_cols].to_numpy(dtype=np.float64, copy=False)

        logits = (-0.5 * float(epsilon)) * dist_mat                 # (k, Y)
        logits -= logits.max(axis=1, keepdims=True)                 # numerical stability
        expv = np.exp(logits)                                       # (k, Y)
        probs = expv / (expv.sum(axis=1, keepdims=True) + 1e-12)    # (k, Y)

        # Expected utility for each sensitive word row: E[U|x] = sum_y p(x,y)*U(x,y)
        exp_u = np.sum(probs * util_mat, axis=1)                    # (k,)

        # 5) Prior-weighted sum
        total_weighted_utility_loss = float(np.dot(pri, exp_u))
        return total_weighted_utility_loss

    def _get_norm_rows_for(self, words: List[str]) -> float:
        if self.utility_matrix is None:
            return 1.0
        rows = [w for w in words if w in self.utility_matrix.index]
        if not rows:
            return 0.0
        try:
            sub = self.utility_matrix.loc[rows]
            vmax = float(np.nanmax(sub.values))
            return vmax if (np.isfinite(vmax) and vmax > 0) else 1.0
        except Exception:
            return 1.0


class WindowObjective:
    def __init__(self,
                 window_sentences, window_replacement_info, window_original_text,
                 window_distance_df, length_of_window_pii, window_sensitive_words,
                 sensitive_frequencies, window_utility_matrix,
                 base_epsilon1, base_epsilon2,
                 lambda1, lambda2,
                 use_real_priors=True,
                 seed=42):
        self.window_sentences = window_sentences
        self.window_replacement_info = window_replacement_info
        self.window_original_text = window_original_text
        self.window_distance_df = window_distance_df
        self.length_of_window_pii = length_of_window_pii
        self.window_sensitive_words = window_sensitive_words
        self.sensitive_frequencies = sensitive_frequencies
        self.window_utility_matrix = window_utility_matrix
        self.base_epsilon1 = base_epsilon1
        self.base_epsilon2 = base_epsilon2
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.use_real_priors = use_real_priors
        self.seed = seed
        self.train_pert_epochs = 0
        self.train_u_epochs = 0
        self.base_state_dict = None
        self.base_state_dict_u = None
        self.eval_calls = 0

    def _utility_loss(self, alpha1, alpha2):
        try:
            opt_dummy = ExpectedUtilityLossComputer(utility_matrix=self.window_utility_matrix,
                                                 sensitive_tokens=set(self.window_sensitive_words), use_real_priors=self.use_real_priors)
            return opt_dummy.compute_utility_loss(
                alpha1, alpha2,
                self.window_sentences, self.window_replacement_info,
                self.window_distance_df, self.base_epsilon1, self.base_epsilon2
            )
        except Exception:
            return 0.0

    def __call__(self, z):
        a1 = float(np.exp(z[0]))
        a2 = float(np.exp(z[1]))

        a1 = float(np.clip(a1, self.base_epsilon1*0 + min_alpha, max_alpha))
        a2 = float(np.clip(a2, self.base_epsilon2*0 + min_alpha, max_alpha))
        self.eval_calls += 1

        vio_ratio = run_pipeline_window(
            False, False,
            self.window_sentences, self.window_replacement_info, self.window_original_text,
            self.window_distance_df, self.length_of_window_pii, self.window_sensitive_words,
            self.sensitive_frequencies, self.window_utility_matrix,
            a1, a2, self.base_epsilon1, self.base_epsilon2,
            train_pert_epochs=self.train_pert_epochs,
            resume_pert_state=self.base_state_dict,
            train_u_epochs=self.train_u_epochs,
            resume_u_state=self.base_state_dict_u,
            seed=self.seed
        )[6]

        priv = vio_ratio
        exp_ul_pii, exp_ul_poii, exp_ul_all, util_norm = self._utility_loss(a1, a2)
        total = self.lambda1 * priv + self.lambda2 * util_norm

        tag = f"{getattr(self, 'stage_label', '?')}/MS{getattr(self, 'start_id', '-')}"
        print(f"    [{tag} OBJ EVAL #{self.eval_calls}] a1={a1:.4f}, a2={a2:.4f} ->vio={vio_ratio:.6f}, exp_ul = {exp_ul_all:.6f}, priv={priv:.6f}, util={util_norm:.6f}, total={total:.6f}")
        return total

# ================================
# 2. Define utility functions
# ================================

def set_global_random_seed(seed: int = 42, deterministic: bool = True) -> None:
    """
        You can set global random seed for reproducibility.
    """
    import os, random, numpy as np, torch

    # 1. Python & NumPy
    random.seed(seed)
    np.random.seed(seed)
    np.random.default_rng(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    # 2. PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if deterministic:
        torch.use_deterministic_algorithms(True, warn_only=True)   # PyTorch ≥1.8
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32   = False
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def perturb_token(emission_probs_df, token):
    """Replace a token based on emission probabilities."""
    if token in emission_probs_df.index:
        prob_row = emission_probs_df.loc[token]
        return np.random.choice(emission_probs_df.columns, p=prob_row)
    else:
        print('Token not in the emission probs matrix')
        return token

def compute_sensitive_word_frequencies(sentences, sensitive_words):
    word_counts = Counter(word for sentence in sentences for word in sentence)
    sensitive_counts = {w: word_counts.get(w, 0) for w in sensitive_words}
    total_sensitive = sum(sensitive_counts.values())
    if total_sensitive == 0:
        return {w: 0.0 for w in sensitive_words}
    frequencies = {w: cnt / total_sensitive for w, cnt in sensitive_counts.items()}
    
    return frequencies

def realised_utility_loss(sentence_tokens, rep_info, pert_tokens, util_df):
    costs=[]
    for idx in rep_info.keys():
        x = sentence_tokens[idx]
        y = pert_tokens[idx]
        if x in util_df.index and y in util_df.columns:
            costs.append(util_df.at[x, y])
        else: 
            print('x is not in util_df.index or y is not in util_df.columns')
    return np.mean(costs) if costs else np.nan

def uniform_replace_sentence(sentence, sensitive_set):
    """
    Uniformly replace sensitive tokens in the sentence with "xxxx". (for prior probability estimation)
    """
    return [ "xxxx" if token in sensitive_set else token for token in sentence ]

def get_oov_vector(token: str,
                   dim: int = 100,
                   mu: float = 0,
                   sigma: float = 0,
                   cache: dict = {},
                   rng = np.random) -> np.ndarray:
    """
        generate a random vector for OOV token, with caching to ensure consistency.
    """
    if token not in cache:
        vec = rng.normal(loc=mu, scale=sigma, size=dim).astype(np.float32)
        cache[token] = vec
    return cache[token]

def calculate_similarity_matrix(
    embeddings, sensitive_embeddings, method='gaussian_softmax',
    temperature: float = 1.0, return_logits: bool = False
):
    """
    gausiaan softmax: logits = -||x - s_i||^2 / tau
    when return_logits=True, return (probs, logits), otherwise return probs only
    """
    if method == 'gaussian_softmax':
        M = embeddings.shape[0]
        D = sensitive_embeddings.shape[0]
        logits = np.empty((M, D), dtype=np.float64)
        for i in range(M):
            diff = sensitive_embeddings - embeddings[i]  # [D, d]
            dist2 = np.sum(diff * diff, axis=1)         # [D]
            logits[i] = - dist2 / max(1e-8, float(temperature))
        # softmax
        logits = logits - logits.max(axis=1, keepdims=True)
        expv = np.exp(logits)
        probs = expv / expv.sum(axis=1, keepdims=True)
        if return_logits:
            return probs, logits
        return probs

    elif method == 'cosine_softmax':
        cosine_similarities = cosine_similarity(embeddings, sensitive_embeddings)
        logits = cosine_similarities / max(1e-8, float(temperature))
        logits = logits - logits.max(axis=1, keepdims=True)
        expv = np.exp(logits)
        probs = expv / expv.sum(axis=1, keepdims=True)
        return (probs, logits) if return_logits else probs

    elif method == 'hybrid':
        norm_embeddings = embeddings / (np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-8)
        norm_sensitive = sensitive_embeddings / (np.linalg.norm(sensitive_embeddings, axis=1, keepdims=True) + 1e-8)
        cosine_sim = np.dot(norm_embeddings, norm_sensitive.T)     # [M, D]

        M = embeddings.shape[0]
        D = sensitive_embeddings.shape[0]
        gauss_logits = np.empty((M, D), dtype=np.float64)
        for i in range(M):
            diff = sensitive_embeddings - embeddings[i]
            dist2 = np.sum(diff * diff, axis=1)
            gauss_logits[i] = - dist2

        logits = 0.7 * gauss_logits + 0.3 * cosine_sim
        logits = logits / max(1e-8, float(temperature))
        logits = logits - logits.max(axis=1, keepdims=True)
        expv = np.exp(logits)
        probs = expv / expv.sum(axis=1, keepdims=True)
        return (probs, logits) if return_logits else probs

    else:
        raise ValueError(f"Unknown method: {method}")

def _softmax_rows(logits: np.ndarray) -> np.ndarray:
    z = logits - logits.max(axis=1, keepdims=True)
    expv = np.exp(z)
    return expv / expv.sum(axis=1, keepdims=True)

def _fit_temperature_from_logits(logits: np.ndarray, labels: np.ndarray,
                                 tau_grid=(0.5, 5.0), num_points: int = 25) -> float:
    """
    Fit the temperature tau on the validation set using a 1D grid search to minimize NLL 
    (numerically stable, simple, and reliable).
    logits: [N, D] unscaled logits (e.g., -dist^2)
    labels: [N], true class indices (indices in window_sensitive_words)
    tau_grid: search range (can be adjusted to (0.3, 10) if needed)
    """
    if logits.size == 0 or labels.size == 0:
        return 1.0
    taus = np.linspace(tau_grid[0], tau_grid[1], num_points)
    best_tau, best_nll = 1.0, float('inf')
    for tau in taus:
        scaled = logits / max(1e-8, tau)
        # NLL = -sum(log p_true)
        lse = logsumexp(scaled, axis=1)                  # [N]
        nll = (lse - scaled[np.arange(scaled.shape[0]), labels]).mean()
        if nll < best_nll:
            best_nll = nll
            best_tau = float(tau)
    return best_tau

def _collect_val_logits_labels(sentences_subset, predicted_embeds_subset,
                               sens_embeds: np.ndarray,
                               sensitive_tokens_set: set,
                               word2idx: dict):
    """
    Collect (logits, label) pairs from the validation split:
    logits_i = -||pred - sens_j||^2 (no temperature scaling yet)
    label_i  = index of the ground-truth sensitive token in window_sensitive_words
    """
    logits_list, labels_list = [], []
    for sent_tokens, sent_pred in zip(sentences_subset, predicted_embeds_subset):
        if len(sent_tokens) != len(sent_pred):
            continue
        for t_idx, tok in enumerate(sent_tokens):
            if tok in sensitive_tokens_set and tok in word2idx:
                pred_vec = np.asarray(sent_pred[t_idx])
                diff = sens_embeds - pred_vec
                dist2 = np.sum(diff * diff, axis=1)    # [D]
                logits_list.append(-dist2)
                labels_list.append(word2idx[tok])
    if not logits_list:
        return np.empty((0, sens_embeds.shape[0])), np.empty((0,), dtype=int)
    return np.stack(logits_list, axis=0), np.array(labels_list, dtype=int)

def sample_sensitive_keys(window_sentences_for_prior,
                          sensitive_tokens_set,
                          sample_ratio: float | None = 0.1, 
                          max_keys: int | None = None,
                          seed: int = 42,
                          stratify: bool = False,
                          pii_set=None, poii_set=None):
    """
      sampled_keys: List[(s_idx, t_idx)]
      keys_by_sentence: Dict[s_idx -> Set[t_idx]]
    """
    rng = np.random.default_rng(seed)

    all_keys = []
    if stratify:
        if pii_set is None or poii_set is None:
            raise ValueError("stratify=True, need pii_set and poii_set")
        keys_pii, keys_poii = [], []
        for s_idx, sent in enumerate(window_sentences_for_prior):
            for t_idx, tok in enumerate(sent):
                if tok in sensitive_tokens_set:
                    if tok in pii_set:
                        keys_pii.append((s_idx, t_idx))
                    elif tok in poii_set:
                        keys_poii.append((s_idx, t_idx))
        def _take(cands):
            if not cands: return []
            if sample_ratio is None and max_keys is None:
                return cands
            k = len(cands)
            if sample_ratio is not None:
                k = max(1, int(k * sample_ratio))
            if max_keys is not None:
                k = min(k, max_keys)
            k = min(k, len(cands))
            return rng.choice(cands, size=k, replace=False).tolist()
        sampled_keys = _take(keys_pii) + _take(keys_poii)
    else:
        for s_idx, sent in enumerate(window_sentences_for_prior):
            for t_idx, tok in enumerate(sent):
                if tok in sensitive_tokens_set:
                    all_keys.append((s_idx, t_idx))
        if (sample_ratio is None and max_keys is None) or not all_keys:
            sampled_keys = all_keys
        else:
            k = len(all_keys)
            if sample_ratio is not None:
                k = max(1, int(k * sample_ratio))
            if max_keys is not None:
                k = min(k, max_keys)
            k = min(k, len(all_keys))
            sampled_keys = rng.choice(all_keys, size=k, replace=False).tolist()

    keys_by_sentence = defaultdict(set)
    for s_idx, t_idx in sampled_keys:
        keys_by_sentence[s_idx].add(t_idx)

    return sampled_keys, keys_by_sentence

def build_unified_pair_indices(D, sample_dims=100, seed=42):
    idx_all = np.arange(D)
    if sample_dims is not None:
        K = min(D, sample_dims)
        rng = np.random.default_rng(seed)
        sel_dims = rng.choice(idx_all, size=K, replace=False)
    else:
        sel_dims = idx_all
    
    ii, jj = np.triu_indices(sel_dims.size, k=1)
    pairs = (sel_dims[ii], sel_dims[jj])
    return pairs

def precompute_pair_distances(sens_embeds, pairs):
    a, b = pairs
    diff = sens_embeds[a] - sens_embeds[b]
    d = np.linalg.norm(diff, axis=1) + 1e-8
    scale = np.median(d)
    d = d / (scale + 1e-12)
    return d + 1e-8

def leakage_chunk_vectorized(post_chunk, prior_chunk, chunk_keys,
                             pairs, dist_pairs,
                             base_epsilon1, base_epsilon2,
                             window_sentences_for_prior, personally_II_set, potentially_II_set):
    lp = np.log(post_chunk + 1e-8)    # (M, D)
    lq = np.log(prior_chunk + 1e-8)
    
    a, b = pairs
    delta = (lp[:, a] - lp[:, b]) - (lq[:, a] - lq[:, b])  # (M, n_pairs)
    leak = np.abs(delta) / dist_pairs                      # (M, n_pairs)
    
    # Split by token type (PII vs PoII) according to key
    leak1_list = []  # leakage values for PII keys
    leak2_list = []  # leakage values for PoII keys
    sat1_list = []   # satisfaction mask for PII keys
    sat2_list = []   # satisfaction mask for PoII keys
    
    for i, key in enumerate(chunk_keys):
        sentence_idx, token_idx = key
        current_word = window_sentences_for_prior[sentence_idx][token_idx]
        
        key_leak = leak[i, :]  # all leakage values for this key (n_pairs,)
        
        if current_word in personally_II_set:
            # PII word: append to leak1 and compare against base_epsilon1
            leak1_list.append(key_leak)
            sat1_list.append(key_leak <= base_epsilon1)
        else:  # assume PoII otherwise
            # PoII word: append to leak2 and compare against base_epsilon2
            leak2_list.append(key_leak)
            sat2_list.append(key_leak <= base_epsilon2)
    
    # Stack results (empty arrays if no items)
    leak1 = np.vstack(leak1_list) if leak1_list else np.array([]).reshape(0, len(pairs[0]))
    leak2 = np.vstack(leak2_list) if leak2_list else np.array([]).reshape(0, len(pairs[0]))
    sat1 = np.vstack(sat1_list) if sat1_list else np.array([]).reshape(0, len(pairs[0]))
    sat2 = np.vstack(sat2_list) if sat2_list else np.array([]).reshape(0, len(pairs[0]))
    
    return leak1, sat1, leak2, sat2

# ================================
# 3. data loading and initial setup
# ================================

# set parameters
current_test_epsilon = 2.5
CHOOSE_DNN = 'rnn'
SEED = 42 # optional for reproducibility
min_alpha = 0.005
max_alpha = 1.0

# set log file
log_filename = f'{current_test_epsilon}_{CHOOSE_DNN}_AmPL.log'
log_file = open(log_filename, 'w', encoding='utf-8')
original_stdout = sys.stdout
sys.stdout = Tee(original_stdout, log_file)

# 1. Sensitive token sets: PII and PoII
with open('personally_II.pkl', 'rb') as f:
    personally_II = pickle.load(f)
with open('potentially_II.pkl', 'rb') as f:
    potentially_II = pickle.load(f)
sensitive_tokens = personally_II + potentially_II
print(f'Number of Sensitive Tokens: {len(sensitive_tokens)}')
personally_II_set = set(personally_II)
potentially_II_set = set(potentially_II)
sensitive_tokens_set = set(sensitive_tokens)

# 2. Sentences data: concatenate train and test into one dataset
with open('train_sentences.pkl', 'rb') as f:
    train_sentences = pickle.load(f)
with open('test_sentences.pkl', 'rb') as f:
    test_sentences = pickle.load(f)
ori_sentences_all = train_sentences + test_sentences
sentences_all = ori_sentences_all
del ori_sentences_all
print(f'Number of Sentences: {len(sentences_all)}')

# 3. Sensitive token position info per sentence
with open('PII_PoII_positions.pkl', 'rb') as f:
    ori_replacement_info = pickle.load(f)
replacement_info = ori_replacement_info
del ori_replacement_info
print(f'Number of Replacement Infos: {len(replacement_info)}')

# 4. Original sentence GloVe embeddings (100-dim) (precomputed and saved)
with open('original_glove_embeddings.pkl', 'rb') as f:
    original_text = pickle.load(f)
print(f'Number of Sentences Embeddings: {len(original_text)}')

# 5. Load GloVe model (100-dim)
glove_model = api.load('glove-wiki-gigaword-100')
embedding_dim = 100

# 6. Precompute mean and std for sampling OOV vectors
all_vecs = glove_model.vectors  # shape: (400000, 100) for 100d GloVe
mu = all_vecs.mean()
sigma = all_vecs.std()          # global standard deviation

# Optionally limit to a small subset for quick testing
original_text = original_text[:100]
sentences_all = sentences_all[:100]
replacement_info = replacement_info[:100]


# ================================
# 4. main pipeline functions
# ================================

def prepare_window_data(window_sentences, window_original_text):

    all_embeddings = [embedding for sentence in window_original_text for embedding in sentence]
    all_tokens = [token for sentence in window_sentences for token in sentence]

    token_indices = defaultdict(list)
    for index, token in enumerate(all_tokens):
        token_indices[token].append(index)

    token_embeddings = {}
    for token, indices in token_indices.items():
        token_embeddings[token] = np.mean([all_embeddings[i] for i in indices], axis=0)

    tokens = sorted(token_embeddings.keys())
    embeddings_matrix = np.array([token_embeddings[token] for token in tokens])

    distance_matrix = cdist(embeddings_matrix, embeddings_matrix, 'euclidean') 
    distance_df = pd.DataFrame(distance_matrix, index=tokens, columns=tokens)
    del token_embeddings, embeddings_matrix, distance_matrix
    gc.collect()
    print(f"Original distance_df shape: {distance_df.shape}")

    window_pii_words_set = set()
    window_poii_words_set = set()
    window_all_words_set = set()
    for sentence in window_sentences:
        for word in sentence:
            window_all_words_set.add(word)
            if word in personally_II_set:
                window_pii_words_set.add(word)
            elif word in potentially_II_set:
                window_poii_words_set.add(word)
    window_sensitive_tokens_set = window_pii_words_set | window_poii_words_set
    window_pii_words = sorted(window_pii_words_set)
    window_poii_words = sorted(window_poii_words_set)
    window_sensitive_words = window_pii_words + window_poii_words
    window_all_words = sorted(window_all_words_set)
    print(f'Length of window all vocabs: {len(window_all_words)}')
    print(f'Length of window sensitive vocabs: {len(window_sensitive_words)}')

    filtered_distance_df = distance_df.loc[window_sensitive_words]
    del distance_df
    gc.collect()
    print(f"Filtered distance_df shape: {filtered_distance_df.shape}")

    utility_matrix = pd.read_pickle(f'utility_loss_matrix.pkl') # precomputed utility loss matrix
    print(f'Original Shape of Utility Matrix: {utility_matrix.shape}')
    filtered_utility_matrix = utility_matrix.loc[window_sensitive_words, tokens]
    del utility_matrix, window_all_words
    gc.collect()
    print(f"Filtered utility_matrix shape: {filtered_utility_matrix.shape}")

    sensitive_frequencies = compute_sensitive_word_frequencies(window_sentences, window_sensitive_words)
    gc.collect()

    return filtered_distance_df, filtered_utility_matrix, len(window_pii_words), window_sensitive_tokens_set, window_sensitive_words, sensitive_frequencies

def run_pipeline_window(print_remap_info, is_final_test, window_sentences, window_replacement_info, window_original_text, 
                        distance_df, length_of_window_pii, window_sensitive_words, sensitive_frequencies, window_utility_matrix,
                        alpha1_value, alpha2_value, base_epsilon1, base_epsilon2,
                        train_pert_epochs:int,
                        resume_pert_state:dict|None,
                        train_u_epochs:int,
                        resume_u_state:dict|None,
                        seed:int = 42):
    """
    For a single sliding window (sentences, replacement metadata, original embeddings, prior prediction embeddings):
      1. Compute effective noise: ε1 = base_epsilon1 * α1, ε2 = base_epsilon2 * α2.
      2. Apply two-stage perturbation (PII first, then PoII).
      3. Construct GloVe embeddings for perturbed sentences.
      4. Train and predict with the selected neural model (RNN / LSTM / Transformer) using train / validation / inference splits.
      5. Derive posterior probabilities from model predictions, compare with prior (uniform replacement) to measure leakage and privacy satisfaction.
    Returns (in order):
      privacy_ratio_1 (PII), privacy_ratio_2 (PoII),
      ε1, ε2,
      posterior model state_dict,
      prior model state_dict,
      violation_ratio (1 - overall_privacy_ratio),
      expected_utility_loss_pii (placeholder 0),
      expected_utility_loss_poii (placeholder 0),
      expected_utility_loss_total (placeholder 0),
      leakage_values_pii list,
      leakage_values_poii list.
    """
    # optional for reproducibility
    set_global_random_seed(seed)


    # effective noise levels epsilon
    epsilon_1 = base_epsilon1 * alpha1_value
    epsilon_2 = base_epsilon2 * alpha2_value

    # PII perturbation
    perturbed_sentences_1 = []
    exp_distances_1 = np.exp(-epsilon_1 * distance_df / 2)
    emission_probs_1 = exp_distances_1.divide(exp_distances_1.sum(axis=1), axis=0)
    for sentence, reps in zip(window_sentences, window_replacement_info):
        perturbed = list(sentence.copy())
        for idx, rep_type in reps.items():
            if rep_type == 1:  # PII
                perturbed[idx] = perturb_token(emission_probs_1, sentence[idx])
        perturbed_sentences_1.append(perturbed)
    del exp_distances_1
    gc.collect()

    # PoII perturbation
    perturbed_sentences = []
    exp_distances_2 = np.exp(-epsilon_2 * distance_df / 2)
    emission_probs_2 = exp_distances_2.divide(exp_distances_2.sum(axis=1), axis=0)
    for sentence, reps in zip(perturbed_sentences_1, window_replacement_info):
        perturbed = list(sentence.copy())
        for idx, rep_type in reps.items():
            if rep_type == 2:  # PoII
                perturbed[idx] = perturb_token(emission_probs_2, sentence[idx])
        perturbed_sentences.append(perturbed)
    del exp_distances_2, perturbed_sentences_1
    gc.collect()
    perturbed_sentences_pre_remap = [s.copy() for s in perturbed_sentences]

    if is_final_test:
        # --- Bayesian Remap ---
        assert all(window_utility_matrix.index == list(emission_probs_1.index))
        assert all(window_utility_matrix.columns == list(emission_probs_1.columns))
        # print(sorted(sensitive_frequencies.items(), key=lambda x:-x[1])[:5])
        utility_mat   = window_utility_matrix.values    #  shape: |X| × |Y|
        vocab_words   = list(window_utility_matrix.columns)
        sens_words    = list(window_utility_matrix.index)
        vocab_to_col  = {w:i for i, w in enumerate(vocab_words)}
        # sens_to_row = {w:i for i, w in enumerate(sens_words)}

        # prior P(x)
        P_x = pd.Series(sensitive_frequencies, index=sens_words)
        P_x = P_x / P_x.sum()                       # (|X|,)
        P_x_vec = P_x.values.astype(np.float32)     # ndarray

        # precalculate P(y) = Σₓ P(y|x)·P(x)
        P_y_1 = emission_probs_1.T.dot(P_x) + 1e-10   # (|Y|,)
        P_y_2 = emission_probs_2.T.dot(P_x) + 1e-10
        del P_x

        emis_arr_1 = emission_probs_1.values.astype(np.float32)
        emis_arr_2 = emission_probs_2.values.astype(np.float32)
        del emission_probs_1, emission_probs_2
        gc.collect()

        # Directly use numpy to calculate P_y, avoiding Series.iloc overhead
        P_x_sum = float(P_x_vec.sum())
        if not np.isfinite(P_x_sum) or P_x_sum <= 0:
            # Extreme case: no sensitive word frequency in the window
            # Skip remap directly
            pass
        else:
            P_x_vec = (P_x_vec / (P_x_sum + 1e-12)).astype(np.float32, copy=False)

            P_y_1 = (emis_arr_1.T @ P_x_vec).astype(np.float32, copy=False) + 1e-10  # (|Y|,)
            P_y_2 = (emis_arr_2.T @ P_x_vec).astype(np.float32, copy=False) + 1e-10

            # ===== Collect y values that actually appeared in this remap (separated by type) =====
            y_set_1, y_set_2 = set(), set()
            for sent, reps in zip(perturbed_sentences_pre_remap, window_replacement_info):
                for idx, rep_type in reps.items():
                    y = sent[idx]
                    if rep_type == 1:
                        y_set_1.add(y)
                    else:
                        y_set_2.add(y)

            # ===== Tool: original remap y -> y* cache =====
            def _build_original_cache(y_set, emis_arr, P_y_arr):
                cache = {}
                for y in y_set:
                    col = vocab_to_col.get(y, None)
                    if col is None:
                        continue
                    # posterior(x|y) ∝ P(y|x) P(x) / P(y)
                    posterior = (emis_arr[:, col] * P_x_vec) / (P_y_arr[col] + 1e-12)   # (|X|,)
                    # expected_loss(y') = sum_x posterior(x|y) * U(x, y')
                    expected_loss = posterior @ utility_mat                            # (|Y|,)
                    best_col = int(np.argmin(expected_loss))
                    cache[y] = vocab_words[best_col]
                return cache

            cache_ori_1 = _build_original_cache(y_set_1, emis_arr_1, P_y_1)
            cache_ori_2 = _build_original_cache(y_set_2, emis_arr_2, P_y_2)

            for sent, reps in zip(perturbed_sentences, window_replacement_info):
                for idx, rep_type in reps.items():
                    y = sent[idx]
                    if rep_type == 1:
                        ny = cache_ori_1.get(y, None)
                    else:
                        ny = cache_ori_2.get(y, None)
                    if ny is not None:
                        sent[idx] = ny
            gc.collect()

        # === FAST realised utility loss (sentence-weighted, same semantics as your current code) ===
        sens_to_row = {w: i for i, w in enumerate(sens_words)}     # rows of utility_mat
        col_to_idx  = vocab_to_col                                 # cols of utility_mat
        util_arr    = utility_mat                                  # already numpy: (|X|,|Y|)

        # Pre-extract token positions that need evaluation for each sentence (rep_info.keys())
        rep_pos_list = [np.fromiter(rep.keys(), dtype=np.int32) for rep in window_replacement_info]

        # Pre-calculate x(row) indices for each sentence (calculate only once, shared by three types of perturbed)
        x_row_list = []
        for sent, pos in zip(window_sentences, rep_pos_list):
            if pos.size == 0:
                x_row_list.append(pos)  # empty
                continue
            rows = np.array([sens_to_row.get(sent[i], -1) for i in pos], dtype=np.int32)
            x_row_list.append(rows)

        def realised_ul_sentence_weighted(perturbed_sents):
            per_sent_means = []
            for pos, x_rows, psent in zip(rep_pos_list, x_row_list, perturbed_sents):
                if pos.size == 0:
                    continue
                y_cols = np.array([col_to_idx.get(psent[i], -1) for i in pos], dtype=np.int32)
                m = (x_rows != -1) & (y_cols != -1)
                if not np.any(m):
                    continue
                per_sent_means.append(util_arr[x_rows[m], y_cols[m]].mean())
            return float(np.mean(per_sent_means)) if per_sent_means else np.nan

        val_pre      = realised_ul_sentence_weighted(perturbed_sentences_pre_remap)
        val_post     = realised_ul_sentence_weighted(perturbed_sentences)

        if print_remap_info:
            print(f"    Realised utility loss (before remap)    : {val_pre:.4f}")
            print(f"    Realised utility loss (after  remap)    : {val_post:.4f}")

        # generate perturbed embeddings after remap
        perturbed_embeddings = []
        for sentence in perturbed_sentences:
            sentence_embeddings = []
            for token in sentence:
                if token in glove_model.key_to_index:
                    embedding = glove_model[token]
                else:
                    # embedding = np.zeros(embedding_dim) # you can also use zero vector for OOV tokens
                    embedding = get_oov_vector(token, embedding_dim, mu, sigma)
                sentence_embeddings.append(embedding)
            perturbed_embeddings.append(sentence_embeddings)
    else:
        # No Bayesian Remap
        perturbed_embeddings = []
        for sentence in perturbed_sentences:
            sentence_embeddings = []
            for token in sentence:
                if token in glove_model.key_to_index:
                    embedding = glove_model[token]
                else:
                    # embedding = np.zeros(embedding_dim)
                    embedding = get_oov_vector(token, embedding_dim, mu, sigma)
                sentence_embeddings.append(embedding)
            perturbed_embeddings.append(sentence_embeddings)


    if CHOOSE_DNN == 'rnn':
        total = len(perturbed_embeddings)
        train_end = int(total * 0.6)
        val_end = int(total * 0.8)

        uniform_perturbed_sentences = []
        for sentence in window_sentences:
            uniform_perturbed_sentences.append(uniform_replace_sentence(sentence, sensitive_tokens_set))
        uniform_embeddings = []
        for sentence in uniform_perturbed_sentences:
            sentence_embeddings = []
            for token in sentence:
                if token in glove_model.key_to_index:
                    sentence_embeddings.append(glove_model[token])
                else:
                    # sentence_embeddings.append(np.zeros(embedding_dim))
                    sentence_embeddings.append(get_oov_vector(token, embedding_dim, mu, sigma))
            uniform_embeddings.append(sentence_embeddings)

        train_x_p = perturbed_embeddings[:train_end]
        val_x_p   = perturbed_embeddings[train_end:val_end]
        train_y   = window_original_text[:train_end]
        val_y     = window_original_text[train_end:val_end]

        train_x_u = uniform_embeddings[:train_end]
        val_x_u   = uniform_embeddings[train_end:val_end]
        train_y_u = window_original_text[:train_end]
        val_y_u   = window_original_text[train_end:val_end]

        def _interleave(a, b):
            out = []
            for i in range(len(a)):
                out.append(a[i]); out.append(b[i])
            return out

        train_inputs  = _interleave(train_x_p, train_x_u)
        train_targets = _interleave(train_y,   train_y_u)
        val_inputs    = _interleave(val_x_p,   val_x_u)
        val_targets   = _interleave(val_y,     val_y_u)

        train_dataset = RNN.EmbeddingDataset(train_inputs, train_targets)
        g = torch.Generator().manual_seed(66)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False, generator=g,
                                num_workers=0, pin_memory=False, collate_fn=RNN.collate_fn)

        val_dataset = RNN.EmbeddingDataset(val_inputs, val_targets)
        g = torch.Generator().manual_seed(66)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, generator=g,
                                num_workers=0, pin_memory=False, collate_fn=RNN.collate_fn)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = RNN.RNNModel(100, 128, 100).to(device)

        resume_state = resume_pert_state if resume_pert_state is not None else resume_u_state
        if resume_state is not None:
            model.load_state_dict(resume_state, strict=True)

        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        train_epochs = max(int(train_pert_epochs), int(train_u_epochs))
        if train_epochs > 0:
            train_losses, val_losses = RNN.train_model(model, train_loader, val_loader,
                                                    criterion, optimizer, train_epochs,
                                                    device, patience=5, clip_grad=1.0)
        else:
            model.eval()

        predicted_originals = RNN.batch_predict(model, perturbed_embeddings[val_end:], device,
                                            batch_size=64, num_workers=0)
        predicted_xxxx      = RNN.batch_predict(model, uniform_embeddings[val_end:], device,
                                            batch_size=64, num_workers=0)

        # Validation prediction (for tau fitting) — same model runs twice on two paths
        predicted_originals_val = RNN.batch_predict(model, perturbed_embeddings[train_end:val_end], device,
                                                batch_size=64, num_workers=0)
        predicted_xxxx_val      = RNN.batch_predict(model, uniform_embeddings[train_end:val_end], device,
                                                batch_size=64, num_workers=0)

        model_u = model


    if CHOOSE_DNN == 'lstm':
        total = len(perturbed_embeddings)
        train_end = int(total * 0.6)
        val_end   = int(total * 0.8)

        train_perturbed = perturbed_embeddings[:train_end]
        val_perturbed   = perturbed_embeddings[train_end:val_end]
        test_perturbed  = perturbed_embeddings[val_end:]

        uniform_perturbed_sentences = [
            uniform_replace_sentence(sentence, sensitive_tokens_set)
            for sentence in window_sentences
        ]
        uniform_embeddings = []
        for sentence in uniform_perturbed_sentences:
            sentence_embeddings = []
            for token in sentence:
                if token in glove_model.key_to_index:
                    sentence_embeddings.append(glove_model[token])
                else:
                    sentence_embeddings.append(get_oov_vector(token, embedding_dim, mu, sigma))
            uniform_embeddings.append(sentence_embeddings)

        train_masked = uniform_embeddings[:train_end]
        val_masked   = uniform_embeddings[train_end:val_end]
        test_masked  = uniform_embeddings[val_end:]

        train_original = window_original_text[:train_end]
        val_original   = window_original_text[train_end:val_end]
        test_true      = window_original_text[val_end:]

        X_train = train_perturbed + train_masked
        Y_train = train_original  + train_original

        X_val   = val_perturbed + val_masked
        Y_val   = val_original  + val_original

        train_dataset = LSTM.EmbeddingDataset(X_train, Y_train)
        g = torch.Generator().manual_seed(66)
        train_loader = DataLoader(
            train_dataset, batch_size=32, shuffle=False, generator=g,
            num_workers=0, pin_memory=False, collate_fn=LSTM.collate_fn
        )

        val_dataset = LSTM.EmbeddingDataset(X_val, Y_val)
        g = torch.Generator().manual_seed(66)
        val_loader = DataLoader(
            val_dataset, batch_size=32, shuffle=False, generator=g,
            num_workers=0, pin_memory=False, collate_fn=LSTM.collate_fn
        )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        input_size  = 100
        hidden_size = 128
        output_size = 100
        num_layers  = 2
        dropout     = 0.3

        model = LSTM.LSTMModel(input_size, hidden_size, output_size, num_layers, dropout).to(device)

        if resume_pert_state is not None:
            model.load_state_dict(resume_pert_state, strict=True)
        elif resume_u_state is not None:
            model.load_state_dict(resume_u_state, strict=True)

        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        train_epochs = int(max(train_pert_epochs, train_u_epochs))
        if train_epochs > 0:
            model, best_val_loss = LSTM.train_model(
                model, train_loader, val_loader,
                criterion, optimizer, train_epochs,
                device, patience=5, clip_grad=1.0
            )
        else:
            model.eval()

        predicted_originals = LSTM.batch_predict(model, test_perturbed, device, batch_size=64)

        predicted_xxxx = LSTM.batch_predict(model, test_masked, device, batch_size=64)

        predicted_originals_val = LSTM.batch_predict(model, val_perturbed, device, batch_size=64)
        predicted_xxxx_val      = LSTM.batch_predict(model, val_masked,   device, batch_size=64)

        model_u = model




    if CHOOSE_DNN == 'transformer':
        total = len(perturbed_embeddings)
        train_end = int(total * 0.6)
        val_end   = int(total * 0.8)

        train_perturbed = perturbed_embeddings[:train_end]
        val_perturbed   = perturbed_embeddings[train_end:val_end]
        test_perturbed  = perturbed_embeddings[val_end:]

        uniform_perturbed_sentences = [
            uniform_replace_sentence(sentence, sensitive_tokens_set)
            for sentence in window_sentences
        ]
        uniform_embeddings = []
        for sentence in uniform_perturbed_sentences:
            sentence_embeddings = []
            for token in sentence:
                if token in glove_model.key_to_index:
                    sentence_embeddings.append(glove_model[token])
                else:
                    sentence_embeddings.append(get_oov_vector(token, embedding_dim, mu, sigma))
            uniform_embeddings.append(sentence_embeddings)

        train_masked = uniform_embeddings[:train_end]
        val_masked   = uniform_embeddings[train_end:val_end]
        test_masked  = uniform_embeddings[val_end:]

        train_original = window_original_text[:train_end]
        val_original   = window_original_text[train_end:val_end]
        test_true      = window_original_text[val_end:]

        X_train = train_perturbed + train_masked
        Y_train = train_original  + train_original

        X_val   = val_perturbed + val_masked
        Y_val   = val_original  + val_original

        train_dataset = Transformer.EmbeddingDataset(X_train, Y_train)
        g = torch.Generator().manual_seed(66)
        train_loader = DataLoader(
            train_dataset, batch_size=32, shuffle=False, generator=g,
            num_workers=0, pin_memory=False, collate_fn=Transformer.collate_fn
        )

        val_dataset = Transformer.EmbeddingDataset(X_val, Y_val)
        g = torch.Generator().manual_seed(66)
        val_loader = DataLoader(
            val_dataset, batch_size=32, shuffle=False, generator=g,
            num_workers=0, pin_memory=False, collate_fn=Transformer.collate_fn
        )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = Transformer.TransformerModel(
            input_size=100,
            d_model=128,
            nhead=8,
            num_encoder_layers=3,
            dim_feedforward=512,
            dropout=0.3
        ).to(device)

        if resume_pert_state is not None:
            model.load_state_dict(resume_pert_state, strict=True)
        elif resume_u_state is not None:
            model.load_state_dict(resume_u_state, strict=True)

        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        train_epochs = int(max(train_pert_epochs, train_u_epochs))
        if train_epochs > 0:
            model, best_loss = Transformer.train_model(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                criterion=criterion,
                optimizer=optimizer,
                num_epochs=train_epochs,
                device=device,
                clip_grad=1.0
            )
        else:
            model.eval()

        predicted_originals = Transformer.batch_predict(model, test_perturbed, batch_size=64)

        predicted_xxxx = Transformer.batch_predict(model, test_masked, batch_size=64)

        predicted_originals_val = Transformer.batch_predict(model, val_perturbed, batch_size=64)
        predicted_xxxx_val      = Transformer.batch_predict(model, val_masked,   batch_size=64)

        model_u = model



    # compute prior/posterior probabilities and privacy leakage
    small_epsilon = 1e-8

    # get embeddings for sensitive words in the window
    sens_embeds = []
    for token in window_sensitive_words:
        if token in glove_model.key_to_index:
            sens_embeds.append(glove_model[token])
        else:
            # sens_embeds.append(np.zeros(embedding_dim))
            sens_embeds.append(get_oov_vector(token, embedding_dim, mu, sigma))
    sens_embeds = np.array(sens_embeds) + small_epsilon

    # fit temperature tau_post / tau_prior on the validation split
    # Build mapping from sensitive word to index (aligned with sens_embeds)
    word2idx = {w: i for i, w in enumerate(window_sensitive_words)}

    # Original sentence tokens for validation split
    val_sentences_post = window_sentences[train_end:val_end]   # Validation sentences for posterior branch
    val_sentences_prior= window_sentences[train_end:val_end]   # Same original sentences for prior (uniform replacement only affects inputs; labels stay original)

    # Collect logits/labels (unscaled logits = -dist^2)
    logits_post_val, labels_post_val = _collect_val_logits_labels(
        val_sentences_post, predicted_originals_val, sens_embeds, sensitive_tokens_set, word2idx
    )
    logits_prior_val, labels_prior_val = _collect_val_logits_labels(
        val_sentences_prior, predicted_xxxx_val, sens_embeds, sensitive_tokens_set, word2idx
    )

    # Fit temperature via grid search
    tau_post  = _fit_temperature_from_logits(logits_post_val,  labels_post_val,  tau_grid=(0.5, 5.0), num_points=25)
    tau_prior = _fit_temperature_from_logits(logits_prior_val, labels_prior_val, tau_grid=(0.5, 5.0), num_points=25)
    # print(f"[Calib] tau_post={tau_post:.3f}, tau_prior={tau_prior:.3f}")

    window_sentences_for_prior = window_sentences[val_end:]
    sampled_keys, keys_by_sentence = sample_sensitive_keys(
        window_sentences_for_prior,
        sensitive_tokens_set,
        sample_ratio=1.0,
        max_keys=None,
        seed=seed,
        stratify=False,              # If stratified sampling needed, set stratify=True and provide pii_set / poii_set
        pii_set=personally_II_set,
        poii_set=potentially_II_set
    )

    # ===== prior =====
    prior_probabilities = {}
    for s_idx, (sentence_tokens, sentence_embeddings) in enumerate(
            zip(window_sentences_for_prior, predicted_xxxx)):
        sel_idx_set = keys_by_sentence.get(s_idx)
        if not sel_idx_set:    
            continue
        sentence_embeddings = np.array(sentence_embeddings)

        selected_positions = [i for i, tok in enumerate(sentence_tokens)
                            if (tok in sensitive_tokens_set) and (i in sel_idx_set)]
        if not selected_positions:
            continue

        sensitive_embeddings_batch = sentence_embeddings[selected_positions]
        _, logits = calculate_similarity_matrix(
            sensitive_embeddings_batch, sens_embeds,
            method='gaussian_softmax', temperature=1.0, return_logits=True
        )
        # temperature scaling
        scaled = logits / max(1e-8, tau_prior)
        prior_probs = _softmax_rows(scaled)

        # store prior probabilities
        for t_idx, prob in zip(selected_positions, prior_probs):
            prior_probabilities[(s_idx, t_idx)] = prob

    # ===== posterior =====
    posterior_probabilities = {}
    for s_idx, (sentence_tokens, sentence_embeddings) in enumerate(
            zip(window_sentences_for_prior, predicted_originals)):
        sel_idx_set = keys_by_sentence.get(s_idx)
        if not sel_idx_set:
            continue
        sentence_embeddings = np.array(sentence_embeddings)

        selected_positions = [i for i, tok in enumerate(sentence_tokens)
                            if (tok in sensitive_tokens_set) and (i in sel_idx_set)]
        if not selected_positions:
            continue

        sensitive_embeddings_batch = sentence_embeddings[selected_positions]
        _, logits = calculate_similarity_matrix(
            sensitive_embeddings_batch, sens_embeds,
            method='gaussian_softmax', temperature=1.0, return_logits=True
        )
        scaled = logits / max(1e-8, tau_post)
        posterior_probs = _softmax_rows(scaled)

        for t_idx, prob in zip(selected_positions, posterior_probs):
            posterior_probabilities[(s_idx, t_idx)] = prob


    # ===== compute leakage and privacy satisfaction =====
    keys = list(posterior_probabilities.keys())
    if print_remap_info:
        print(f"  Sampling sensitive positions: {len(keys)}")
    chunk_size = 1000
    num_parts = (len(keys) + chunk_size - 1) // chunk_size
    leakage_1 = []
    leakage_2 = []
    satisfies_privacy_1 = 0
    satisfies_privacy_2 = 0
    count_1 = 0
    count_2 = 0

    D      = sens_embeds.shape[0]
    pairs = build_unified_pair_indices(D, sample_dims=None, seed=seed)
    dist_pairs = precompute_pair_distances(sens_embeds, pairs)

    for part_idx in range(num_parts):
        start_idx = part_idx * chunk_size
        end_idx = min((part_idx + 1) * chunk_size, len(keys))
        chunk_keys = keys[start_idx:end_idx]

        post_chunk = np.vstack([posterior_probabilities[k] for k in chunk_keys])  # (M,D)
        prior_chunk= np.vstack([prior_probabilities[k] for k in chunk_keys])      # (M,D)
        leak1, sat1, leak2, sat2 = leakage_chunk_vectorized(
            post_chunk, prior_chunk, chunk_keys,
            pairs, dist_pairs,
            base_epsilon1, base_epsilon2,
            window_sentences_for_prior, personally_II_set, potentially_II_set
        )

        satisfies_privacy_1 += sat1.sum()
        count_1             += sat1.size
        satisfies_privacy_2 += sat2.sum()
        count_2             += sat2.size
        leakage_1.extend(leak1.ravel().tolist())
        leakage_2.extend(leak2.ravel().tolist())

    # overall privacy ratios
    privacy_ratio_1 = satisfies_privacy_1 / count_1 if count_1 > 0 else 0
    privacy_ratio_2 = satisfies_privacy_2 / count_2 if count_2 > 0 else 0
    privacy_ratio_all = (satisfies_privacy_1+satisfies_privacy_2) / (count_1+count_2)

    # expected utility loss (placeholders)
    expected_ul_pii_raw  = 0
    expected_ul_poii_raw = 0
    expected_ul_total_raw = expected_ul_pii_raw + expected_ul_poii_raw

    return (privacy_ratio_1, privacy_ratio_2, 
            epsilon_1, epsilon_2, 
            {k:v.cpu() for k,v in model.state_dict().items()}, 
            {k:v.cpu() for k,v in model_u.state_dict().items()},
            1 - privacy_ratio_all,
            expected_ul_pii_raw,
            expected_ul_poii_raw,
            expected_ul_total_raw,
            leakage_1,
            leakage_2
    )


# ================================
# 5. main
# ================================
window_size = 100 # set windows size to do sliding window optimization or quick test
slide_size = 100
num_windows = (len(sentences_all) - window_size) // slide_size + 1

# define base epsilon values
base_epsilon1 = current_test_epsilon
base_epsilon2 = current_test_epsilon 

# main loop over windows
for w in range(num_windows):
    start_idx = w * slide_size
    end_idx = start_idx + window_size
    print(f"\n--- Processing window {w+1}/{num_windows}: sentences {start_idx} to {end_idx} ---")

    window_sentences = sentences_all[start_idx:end_idx]
    window_replacement_info = replacement_info[start_idx:end_idx]
    window_original_text = original_text[start_idx:end_idx]

    window_distance_df, window_utility_matrix, length_of_window_pii, window_sensitive_tokens_set, window_sensitive_words, sensitive_frequencies = prepare_window_data(window_sentences, window_original_text)
    gc.collect()

    # -- (STAGE 0) For each window, perform one warm-up training pass to obtain reusable model weights -- #
    print("  Warm-up (pivot) training (once per window)...")
    obj = WindowObjective(
        window_sentences, window_replacement_info, window_original_text,
        window_distance_df, length_of_window_pii, window_sensitive_words,
        sensitive_frequencies, window_utility_matrix,
        base_epsilon1, base_epsilon2,
        lambda1=1.0, lambda2=0.01,
        use_real_priors=True, seed=SEED
    )
    # Recommended pivot point: midpoint in log-space ~ (sqrt(10), sqrt(10))
    PIVOT_A1 = PIVOT_A2 = float(np.sqrt(min_alpha * max_alpha))
    piv_pii, piv_poii, _, _, state_dict_pert, state_dict_u = run_pipeline_window(
        True, False, 
        window_sentences, window_replacement_info, window_original_text,
        window_distance_df, length_of_window_pii, window_sensitive_words, sensitive_frequencies, window_utility_matrix,
        PIVOT_A1, PIVOT_A2, base_epsilon1, base_epsilon2,
        train_pert_epochs=40,
        resume_pert_state=None,
        train_u_epochs=40,
        resume_u_state=None,
        seed=SEED
    )[:6]
    print(f"  Pivot warm-up done: PII:{piv_pii:.6f}, PoII:{piv_poii:.6f}")
    # Store pivot weights for re-use in Stage 1/2/3
    obj.base_state_dict   = state_dict_pert
    obj.base_state_dict_u = state_dict_u
    obj.train_pert_epochs = 0
    obj.train_u_epochs    = 0
    obj.eval_calls        = 0


    # ---STAGE 1：coarse search---
    grid_vals = list(np.geomspace(min_alpha, max_alpha, num=7))
    coarse_candidates = [(a1, a2) for a1 in grid_vals for a2 in grid_vals]
    coarse_scores = []

    for a1, a2 in coarse_candidates:
        val = obj(np.log([a1, a2]))
        coarse_scores.append((float(val), float(a1), float(a2)))
    best_val, a1_best, a2_best = min(coarse_scores, key=lambda t: t[0])
    print(f"  Stage1 best (frozen on pivot weights): a1={a1_best:.4f}, a2={a2_best:.4f}, total={best_val:.6f}")
    coarse_history = list(obj.history)

    z_bounds = [(np.log(min_alpha), np.log(max_alpha)),
                (np.log(min_alpha), np.log(max_alpha))]
    z0 = np.log(np.array([a1_best, a2_best], dtype=float))


    # ---STAGE 2&3：powell and pattern refine---
    class _Stopper:
        def __init__(self, patience=40, min_delta=1e-5):
            self.patience = patience
            self.min_delta = min_delta
            self.best = float('inf')
            self.best_z = None
            self.noimp = 0
        def update(self, z, val):
            if val < self.best - self.min_delta:
                self.best = float(val)
                self.best_z = np.array(z, dtype=float).copy()
                self.noimp = 0
            else:
                self.noimp += 1
            return self.noimp >= self.patience

    def run_powell(obj, z0, z_bounds, patience=40, min_delta=1e-6):
        """Powell derivative-free direction set search; uses larger exploratory steps and supports custom early stopping."""
        stopper = _Stopper(patience=patience, min_delta=min_delta)

        best_z = None
        best_val = float('inf')

        def f(z):
            nonlocal best_z, best_val
            v = obj(z)
            if v < best_val - 1e-12:
                best_val = float(v)
                best_z = np.array(z, float)
            if stopper.update(z, v):
                raise RuntimeError("powell_early_stop")
            return v

        try:
            res = minimize(
                f, z0, method="Powell", bounds=z_bounds,
                options=dict(maxiter=200, xtol=1e-3, ftol=1e-7)
            )
            z_star = res.x
            print(f"  Powell done: success={res.success}, msg={res.message}")
        except RuntimeError as e:
            if "powell_early_stop" in str(e):
                z_star = stopper.best_z if stopper.best_z is not None else z0
                print("  Powell done: stopped by custom early stop (patience)")
            else:
                z_star = stopper.best_z if stopper.best_z is not None else z0
                print(f"  Powell aborted: {e}")

        # use best seen if better than final
        if best_z is not None:
            return best_z
        else:
            return res.x if 'res' in locals() else z0

    def pattern_refine(z, base_step=0.12, rounds=4, tol=1e-5, step_expand=1.3, step_shrink=0.5, step_max=0.5):
        """
        === Optimized pattern refine (3*3 stencil, adaptive step size) ===
        - Improvement found -> expand step size (capped by step_max)
        - No improvement -> shrink step size
        - After picking best neighbor, probe an extra half step in that direction
        """
        prev_val = obj(z, record=False)
        step = float(base_step)
        for _ in range(rounds):
            best_val, best_z = prev_val, z
            best_offset = (0.0, 0.0)

            for da1 in (-step, 0.0, step):
                for da2 in (-step, 0.0, step):
                    if da1 == 0.0 and da2 == 0.0:
                        continue
                    zc = z + np.array([da1, da2], dtype=float)
                    val = obj(zc)
                    if val < best_val - 1e-12:
                        best_val, best_z, best_offset = val, zc, (da1, da2)

            # Take an additional half step along the direction of the best neighboring point
            if best_offset != (0.0, 0.0):
                probe = best_z + 0.5 * np.array(best_offset)
                probe_val = obj(probe)
                if probe_val < best_val - 1e-12:
                    best_val, best_z = probe_val, probe

            # Adaptive step size
            if best_val < prev_val - tol:
                # Clear improvement → enlarge step size moderately
                step = min(step * step_expand, step_max)
                z, prev_val = best_z, best_val
            else:
                # No improvement → shrink step size
                step *= step_shrink
                if step < 1e-4:
                    break
        return z

    opt_method = "powell_ms"

    if opt_method == "powell_ms":
        # choose top-k from coarse search as multiple starting points
        top_k = 1 # default 1
        coarse_sorted = sorted(coarse_scores, key=lambda t: t[0])
        starts = coarse_sorted[:top_k]  # [(val, a1, a2), ...]
        PER_START_WARMUP = False # default False

        print("  Running Powell (multi-start) ...")
        z_best, v_best = None, float('inf')
        best_hist = None
        pivot_w_pert, pivot_w_u = obj.base_state_dict, obj.base_state_dict_u
        for j, (vj, a1j, a2j) in enumerate(starts, start=1):
            print(f"  Start with ({a1j},{a2j}), total loss={vj}")
            hist_len_before = len(obj.history)
            obj.start_id = j
            if PER_START_WARMUP:
                # optionally do a quick warm-up training at each start point (to re-align weights)
                _, _, _, _, w_pert_j, w_u_j = run_pipeline_window(
                    False, False, 
                    window_sentences, window_replacement_info, window_original_text,
                    window_distance_df, length_of_window_pii, window_sensitive_words, sensitive_frequencies, window_utility_matrix,
                    a1j, a2j, base_epsilon1, base_epsilon2,
                    train_pert_epochs=40, resume_pert_state=None,
                    train_u_epochs=40,   resume_u_state=None,
                    seed=SEED
                )[:6]
                obj.base_state_dict   = w_pert_j
                obj.base_state_dict_u = w_u_j
                obj.train_pert_epochs = 0
                obj.train_u_epochs    = 0

            z0 = np.log([a1j, a2j])
            # Stage2
            z_pow = run_powell(obj, z0, z_bounds, patience=40, min_delta=1e-5)
            # Stage3
            z_ref = pattern_refine(z_pow, base_step=0.12, rounds=4, tol=1e-5)

            vj_final = obj(z_ref)
            print(f"    [MS#{j}] start=({a1j:.3f},{a2j:.3f}) -> val={vj_final:.6f}")
            hist_j = obj.history[hist_len_before:]
            if vj_final < v_best:
                z_best, v_best = z_ref, vj_final
                best_hist = hist_j
            if PER_START_WARMUP:
                obj.base_state_dict   = pivot_w_pert
                obj.base_state_dict_u = pivot_w_u
                obj.train_pert_epochs = 0
                obj.train_u_epochs    = 0

        # z_star = z_best
        print(f"  Multi-start Powell best total={v_best:.6f}")


    final_alpha1, final_alpha2 = float(np.exp(z_best[0])), float(np.exp(z_best[1]))
    print(f"  Stage2&3 best alphas: a1={final_alpha1:.4f}, a2={final_alpha2:.4f}")


    # final evaluate using the best alphas
    print('  With Bayesian Remap......')
    privacy_ratio_1, privacy_ratio_2, eff_epsilon1, eff_epsilon2, _, _, violation_ratio, expected_ul_pii, expected_ul_poii, expected_ul_all, leak1, leak2= run_pipeline_window(
        True, True,
        window_sentences, window_replacement_info, window_original_text,
        window_distance_df, length_of_window_pii, window_sensitive_words, sensitive_frequencies, window_utility_matrix,
        final_alpha1, final_alpha2, base_epsilon1, base_epsilon2,
        train_pert_epochs=40, resume_pert_state=None,
        train_u_epochs=40, resume_u_state=None,
        seed=SEED
    )
    leak_mean = np.mean(leak1 + leak2) 
    leak_std_sample = np.std(leak1 + leak2, ddof=1)
    print(f'  Final@window: leak_mean = {leak_mean}, leak_std = {leak_std_sample}, violation_ratio = {violation_ratio}')

    print('  Without Bayesian Remap......')
    privacy_ratio_1, privacy_ratio_2, eff_epsilon1, eff_epsilon2, _, _, violation_ratio, expected_ul_pii, expected_ul_poii, expected_ul_all, leak1, leak2= run_pipeline_window(
        True, False, 
        window_sentences, window_replacement_info, window_original_text,
        window_distance_df, length_of_window_pii, window_sensitive_words, sensitive_frequencies, window_utility_matrix,
        final_alpha1, final_alpha2, base_epsilon1, base_epsilon2,
        train_pert_epochs=40, resume_pert_state=None,
        train_u_epochs=40, resume_u_state=None,
        seed=SEED
    )
    leak_mean = np.mean(leak1 + leak2) 
    leak_std_sample = np.std(leak1 + leak2, ddof=1)
    print(f'  Final@window: leak_mean = {leak_mean}, leak_std = {leak_std_sample}, violation_ratio = {violation_ratio}')

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

sys.stdout = original_stdout
log_file.close()
print(f"All print result to: {log_filename}")