import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import matplotlib.pyplot as plt
import tikzplotlib
from pathlib import Path

plt.style.use(Path.cwd().parent / "latex_tikz.mplstyle")

MODEL_NAME = 'all-MiniLM-L6-v2'
K_VALUES = [1, 2, 5, 10]
MINORITY_SIZE = 100
SEEDS = [42, 101, 999, 2026, 5555]

TOPIC_MINORITY = 'comp.sys.mac.hardware'


def run_experiment(seed=42):
    np.random.seed(seed)
    model = SentenceTransformer(MODEL_NAME)

    min_docs = fetch_20newsgroups(subset='train', categories=[TOPIC_MINORITY],
                                  remove=('headers', 'footers', 'quotes')).data
    min_queries_raw = fetch_20newsgroups(subset='test', categories=[TOPIC_MINORITY],
                                         remove=('headers', 'footers', 'quotes')).data

    min_docs = [d for d in min_docs if len(d) > 20]
    min_queries_raw = [d for d in min_queries_raw if len(d) > 20]

    doc_indices = np.random.choice(len(min_docs), MINORITY_SIZE, replace=False)
    query_indices = np.random.choice(len(min_queries_raw), MINORITY_SIZE, replace=False)

    minority_docs = [min_docs[i] for i in doc_indices]
    minority_queries = [min_queries_raw[i] for i in query_indices]

    all_cats = fetch_20newsgroups(subset='train').target_names
    majority_cats = [c for c in all_cats if c != TOPIC_MINORITY]

    maj_data = fetch_20newsgroups(subset='all', categories=majority_cats,
                                  remove=('headers', 'footers', 'quotes')).data
    maj_data = [d for d in maj_data if len(d) > 20]

    print(f"Minority Docs: {len(minority_docs)}, Queries: {len(minority_queries)}, Majority Pool: {len(maj_data)}")

    emb_minority_docs = model.encode(minority_docs, normalize_embeddings=True)
    emb_minority_queries = model.encode(minority_queries, normalize_embeddings=True)
    emb_majority_pool = model.encode(maj_data, normalize_embeddings=True)

    baseline_sim = cosine_similarity(emb_minority_queries, emb_minority_docs)
    ground_truth_pairs = np.argmax(baseline_sim, axis=1)

    n_maj_steps = [0, 100, 500, 1000, 5000, 10000, 15000]
    results = {
        'n_maj': n_maj_steps,
        'recall': {k: [] for k in K_VALUES}
    }

    for n_maj in tqdm(n_maj_steps):
        if n_maj == 0:
            current_index = emb_minority_docs
        else:
            current_interferers = emb_majority_pool[:n_maj]
            current_index = np.vstack([emb_minority_docs, current_interferers])

        sim_matrix = cosine_similarity(emb_minority_queries, current_index)

        for k in K_VALUES:
            success_count = 0
            for i in range(MINORITY_SIZE):
                true_idx = ground_truth_pairs[i]
                scores = sim_matrix[i]
                sorted_indices = np.argsort(-scores)[:k]

                if true_idx in sorted_indices:
                    success_count += 1

            results['recall'][k].append(success_count / MINORITY_SIZE)

    return results


if __name__ == "__main__":
    all_results = {k: [] for k in K_VALUES}

    for seed in SEEDS:
        results = run_experiment(seed)
        for k in K_VALUES:
            all_results[k].append(results['recall'][k])

    mean_recalls = {}
    ci_95s = {}
    for k in K_VALUES:
        arr = np.array(all_results[k])
        mean_recalls[k] = np.mean(arr, axis=0)
        std_recall = np.std(arr, axis=0)
        ci_95s[k] = 1.96 * std_recall / np.sqrt(len(SEEDS))

    plt.figure(figsize=(10, 6))
    scale = 10 ** 3
    x_vals = np.array(results['n_maj']) / scale

    colors = ['#ff7f0e', '#2ca02c', '#d62728', '#1f77b4']

    for idx, k in enumerate(K_VALUES):
        color = colors[idx % len(colors)]
        label = f'Recall@{k}'

        plt.plot(x_vals, mean_recalls[k], 'o-', linewidth=2, color=color, label=label)
        plt.fill_between(x_vals,
                         mean_recalls[k] - ci_95s[k],
                         mean_recalls[k] + ci_95s[k],
                         alpha=0.3, color=color)

    plt.xlabel(r"Number of Majority Interferers ($N_{{maj}} \times 10^3$)")
    plt.ylabel("Recall@k")
    plt.text(x_vals[-1] * 0.6, 0.5, "Geometric Crowding Region", fontsize=11, color='darkred')
    plt.legend()
    plt.tight_layout()

    ax = plt.gca()
    for line in ax.lines:
        line.set_linestyle('-')
    tikzplotlib.save("thm1_empr_hw_multi_k.tex", axis_height='6cm', axis_width='8cm')

    plt.show()