import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import matplotlib.pyplot as plt
import tikzplotlib
from pathlib import Path

plt.style.use(Path.cwd().parent / "latex_tikz.mplstyle")

MODEL_NAME = 'all-MiniLM-L6-v2'
K_VALUES = [1, 2, 5, 10]
MINORITY_SIZE = 100
SEEDS = [42, 101, 999, 2026, 5555]


def run_quora_experiment(seed=42):
    np.random.seed(seed)
    model = SentenceTransformer(MODEL_NAME)

    print(f"Seed {seed}: Loading Quora Question Pairs...")
    dataset = load_dataset("quora", split="train")

    true_duplicates = dataset.filter(lambda x: x['is_duplicate'] == True)

    min_indices = np.random.choice(len(true_duplicates), MINORITY_SIZE, replace=False).tolist()
    minority_subset = true_duplicates.select(min_indices)

    min_queries_text = [row['questions']['text'][0] for row in minority_subset]
    min_targets_text = [row['questions']['text'][1] for row in minority_subset]

    non_duplicates = dataset.filter(lambda x: x['is_duplicate'] == False)

    maj_sample_size = min(150000, len(non_duplicates))
    maj_indices = np.random.choice(len(non_duplicates), maj_sample_size, replace=False).tolist()
    majority_subset = non_duplicates.select(maj_indices)

    maj_docs_text = []
    for row in majority_subset:
        maj_docs_text.extend(row['questions']['text'])

    maj_docs_text = list(set(maj_docs_text))
    np.random.shuffle(maj_docs_text)

    target_set = set(min_targets_text)
    maj_docs_text = [q for q in maj_docs_text if q not in target_set]

    print(f"Seed {seed}: Encoding {len(min_queries_text)} queries and {len(maj_docs_text)} interferers...")

    emb_min_queries = model.encode(min_queries_text, normalize_embeddings=True,
                                   batch_size=64, show_progress_bar=False)
    emb_min_targets = model.encode(min_targets_text, normalize_embeddings=True,
                                   batch_size=64, show_progress_bar=False)
    emb_maj_pool = model.encode(maj_docs_text, normalize_embeddings=True,
                                batch_size=64, show_progress_bar=False)

    total_maj = len(emb_maj_pool)
    n_maj_steps = [0, 1000, 5000, 10000, 25000, 50000, 100000, total_maj]
    n_maj_steps = [n for n in n_maj_steps if n <= total_maj]

    results = {
        'n_maj': n_maj_steps,
        'recall': {k: [] for k in K_VALUES}
    }

    for n_maj in n_maj_steps:
        if n_maj == 0:
            current_index = emb_min_targets
        else:
            current_interferers = emb_maj_pool[:n_maj]
            current_index = np.vstack([emb_min_targets, current_interferers])

        sim_matrix = cosine_similarity(emb_min_queries, current_index)

        for k in K_VALUES:
            success_count = 0
            for i in range(len(emb_min_queries)):
                scores = sim_matrix[i]
                true_idx = i

                top_k_indices = np.argsort(-scores)[:k]

                if true_idx in top_k_indices:
                    success_count += 1

            results['recall'][k].append(success_count / len(emb_min_queries))

    return results


if __name__ == "__main__":
    all_results = {k: [] for k in K_VALUES}

    for seed in tqdm(SEEDS, desc="Running seeds"):
        results = run_quora_experiment(seed)
        for k in K_VALUES:
            all_results[k].append(results['recall'][k])

    mean_recalls = {}
    ci_95s = {}
    for k in K_VALUES:
        arr = np.array(all_results[k])
        mean_recalls[k] = np.mean(arr, axis=0)
        std_recall = np.std(arr, axis=0)
        ci_95s[k] = 1.96 * std_recall / np.sqrt(len(SEEDS))

    plt.figure(figsize=(10, 6))
    scale = 10 ** 3
    x_vals = np.array(results['n_maj']) / scale

    colors = ['#ff7f0e', '#2ca02c', '#d62728', '#1f77b4']

    for idx, k in enumerate(K_VALUES):
        color = colors[idx % len(colors)]
        label = f'Recall@{k}'

        plt.plot(x_vals, mean_recalls[k], 'o-', linewidth=2, color=color, label=label)
        plt.fill_between(x_vals,
                         mean_recalls[k] - ci_95s[k],
                         mean_recalls[k] + ci_95s[k],
                         alpha=0.3, color=color)

    plt.xlabel(r"Number of Majority Interferers ($N_{{maj}} \times 10^3$)")
    plt.ylabel("Recall@k")
    plt.text(x_vals[-1] * 0.6, 0.5, "Geometric Crowding Region", fontsize=11, color='darkred')
    plt.ylim(0, 1.05)
    plt.legend()
    plt.tight_layout()

    ax = plt.gca()
    for line in ax.lines:
        line.set_linestyle('-')
    tikzplotlib.save("thm1_empr_quora_multi_k.tex", axis_height='6cm', axis_width='8cm')

    plt.show()