import pickle
import numpy as np
import pandas as pd
import gensim.downloader as api
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm

# ===== Step 1. Load data =====
with open('personally_II.pkl', 'rb') as f:
    personally_II = pickle.load(f)
with open('potentially_II.pkl', 'rb') as f:
    potentially_II = pickle.load(f)
sensitive_tokens = personally_II + potentially_II

with open('all_words.pkl', 'rb') as f:
    all_words = pickle.load(f)

# Load GloVe
glove_model = api.load('glove-wiki-gigaword-100')
embedding_dim = 100

# ===== Step 2. OOV random vector generator =====
def get_oov_vector(token: str,
                   dim: int = 100,
                   mu: float = 0.0,
                   sigma: float = 0.1,
                   cache: dict = {},
                   rng = np.random) -> np.ndarray:
    """
    Generate and cache a random vector for an OOV word.
    - Each OOV token is sampled only once and reused afterward.
    - Sampled from N(μ, σ²); default mean and std are computed from GloVe.
    """
    if token not in cache:
        vec = rng.normal(loc=mu, scale=sigma, size=dim).astype(np.float32)
        cache[token] = vec
    return cache[token]

# Compute mean and std of all GloVe vectors
all_vecs = glove_model.vectors
mu = all_vecs.mean()
sigma = all_vecs.std()

# ===== Step 3. Get embeddings =====
def get_vector(word):
    if word in glove_model.key_to_index:
        return glove_model[word]
    else:
        return get_oov_vector(word, embedding_dim, mu, sigma)

sens_embeds = np.array([get_vector(w) for w in sensitive_tokens], dtype=np.float32)
vocab_embeds = np.array([get_vector(w) for w in all_words], dtype=np.float32)

# ===== Step 4. Compute cosine similarity =====
print("Computing cosine similarity matrix...")
cos_mat = cosine_similarity(sens_embeds, vocab_embeds)   # shape: (|sensitive|, |all_words|)

# ===== Step 5. Scale to [0,1] and compute utility loss =====
scaled_sim = (cos_mat + 1) / 2   # [-1,1] -> [0,1]
utility_loss_mat = 1 - scaled_sim

# ===== Step 6. Save DataFrame =====
df = pd.DataFrame(utility_loss_mat, index=sensitive_tokens, columns=all_words)

output_file = "utility_loss_matrix.pkl"
with open(output_file, 'wb') as f:
    pickle.dump(df, f)