import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import matplotlib.pyplot as plt
import tikzplotlib
from pathlib import Path

plt.style.use(Path.cwd().parent / "latex_tikz.mplstyle")

MODEL_NAME = 'all-MiniLM-L6-v2'
K_VALUES = [1, 2, 5, 10]
MINORITY_GENRE = 'film noir'
MINORITY_CAP = 50
SEEDS = [42, 101, 999, 2026, 5555]

MAJORITY_GENRES = ['action', 'thriller', 'crime']


def run_cultural_erasure(seed=42):
    np.random.seed(seed)
    model = SentenceTransformer(MODEL_NAME)

    dataset = load_dataset("vishnupriyavr/wiki-movie-plots-with-summaries", split="train")

    def is_genre(row, target_genre):
        g = str(row['Genre']).lower()
        return target_genre in g

    minority_data = dataset.filter(lambda x: is_genre(x, MINORITY_GENRE))

    min_indices = np.random.choice(len(minority_data['Plot']), min(MINORITY_CAP, len(minority_data['Plot'])),
                                   replace=False)
    min_plots = [minority_data['Plot'][i] for i in min_indices]

    def is_majority(row):
        g = str(row['Genre']).lower()
        is_maj = any(maj in g for maj in MAJORITY_GENRES)
        is_min = MINORITY_GENRE in g
        return is_maj and not is_min

    maj_data = dataset.filter(is_majority)

    maj_sample_size = min(5000, len(maj_data['Plot']))
    maj_indices = np.random.choice(len(maj_data['Plot']), maj_sample_size, replace=False)
    maj_plots = [maj_data['Plot'][i] for i in maj_indices]

    emb_min = model.encode(min_plots, batch_size=64, show_progress_bar=False, normalize_embeddings=True)
    emb_maj = model.encode(maj_plots, batch_size=64, show_progress_bar=False, normalize_embeddings=True)

    n_maj_steps = [0, 100, 500, 1000, 2500, 4000, 5000]
    results = {
        'n_maj': n_maj_steps,
        'recall': {k: [] for k in K_VALUES}
    }

    for n_maj in n_maj_steps:
        if n_maj == 0:
            current_index = emb_min
        else:
            current_interferers = emb_maj[:n_maj]
            current_index = np.vstack([emb_min, current_interferers])

        sim_matrix = np.dot(emb_min, current_index.T)

        for k in K_VALUES:
            success_count = 0
            for i in range(len(emb_min)):
                scores = sim_matrix[i].copy()
                scores[i] = -1.0

                top_k_indices = np.argsort(-scores)[:k]

                if any(idx < len(emb_min) for idx in top_k_indices):
                    success_count += 1

            results['recall'][k].append(success_count / len(emb_min))

    return results


if __name__ == "__main__":
    all_results = {k: [] for k in K_VALUES}

    for seed in tqdm(SEEDS, desc="Running seeds"):
        results = run_cultural_erasure(seed)
        for k in K_VALUES:
            all_results[k].append(results['recall'][k])

    mean_recalls = {}
    ci_95s = {}
    for k in K_VALUES:
        arr = np.array(all_results[k])
        mean_recalls[k] = np.mean(arr, axis=0)
        std_recall = np.std(arr, axis=0)
        ci_95s[k] = 1.96 * std_recall / np.sqrt(len(SEEDS))

    plt.figure(figsize=(10, 6))
    scale = 10 ** 3
    x_vals = np.array(results['n_maj']) / scale

    colors = ['#ff7f0e', '#2ca02c', '#d62728', '#1f77b4']

    for idx, k in enumerate(K_VALUES):
        color = colors[idx % len(colors)]
        label = f'Recall@{k}'

        plt.plot(x_vals, mean_recalls[k], 'o-', linewidth=2, color=color, label=label)
        plt.fill_between(x_vals,
                         mean_recalls[k] - ci_95s[k],
                         mean_recalls[k] + ci_95s[k],
                         alpha=0.3, color=color)

    plt.xlabel(r"Number of Genre Interferers ($N_{{maj}} \times 10^3$)")
    plt.ylabel("Recall@k")
    plt.text(x_vals[-1] * 0.6, 0.5, "Geometric Crowding Region", fontsize=11, color='darkred')
    plt.ylim(0, 1.05)
    plt.legend()
    plt.tight_layout()

    ax = plt.gca()
    for line in ax.lines:
        line.set_linestyle('-')
    tikzplotlib.save("thm1_empr_noir_multi_k.tex", axis_height='6cm', axis_width='8cm')

    plt.show()