import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import matplotlib.pyplot as plt
import tikzplotlib
from pathlib import Path

plt.style.use(Path.cwd().parent / "latex_tikz.mplstyle")

MODEL_NAME = 'clip-ViT-B-32'
K_VALUES = [1, 2, 5, 10]
MINORITY_CLASS_NAME = 'otter'
MINORITY_CAP = 50
MAX_MAJORITY = 12000
SEEDS = [42, 101, 999, 2026]

MAJORITY_LABELS_NAMES = [
    'beaver', 'dolphin', 'seal', 'whale',
    'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',
    'fox', 'porcupine', 'possum', 'raccoon', 'skunk',
    'bear', 'leopard', 'lion', 'tiger', 'wolf',
    'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'
]


def run_bio_collapse(seed=42):
    np.random.seed(seed)
    model = SentenceTransformer(MODEL_NAME)

    dataset = load_dataset("cifar100", split="train")
    labels = dataset.features['fine_label'].names

    min_id = labels.index(MINORITY_CLASS_NAME)
    minority_data = dataset.filter(lambda x: x['fine_label'] == min_id)

    min_indices = np.random.choice(len(minority_data),
                                   min(MINORITY_CAP, len(minority_data)),
                                   replace=False).tolist()
    minority_subset = minority_data.select(min_indices)
    min_images = minority_subset['img']

    maj_ids = [labels.index(name) for name in MAJORITY_LABELS_NAMES]
    maj_data = dataset.filter(lambda x: x['fine_label'] in maj_ids)

    total_available = len(maj_data)
    maj_sample_size = min(MAX_MAJORITY, total_available)
    maj_indices = np.random.choice(total_available, maj_sample_size, replace=False).tolist()
    majority_subset = maj_data.select(maj_indices)
    maj_images = majority_subset['img']

    print(f"Seed {seed}: Encoding {len(min_images)} minority and {len(maj_images)} majority images...")

    emb_min = model.encode(min_images, batch_size=32, show_progress_bar=False, normalize_embeddings=True)
    emb_maj = model.encode(maj_images, batch_size=32, show_progress_bar=False, normalize_embeddings=True)

    total_maj = len(emb_maj)
    n_maj_steps = [0, 100, 500, 1000, 2500, 5000, 7500, 10000, total_maj]

    results = {
        'n_maj': n_maj_steps,
        'recall': {k: [] for k in K_VALUES}
    }

    for n_maj in n_maj_steps:
        if n_maj == 0:
            current_index = emb_min
        else:
            current_interferers = emb_maj[:n_maj]
            current_index = np.vstack([emb_min, current_interferers])

        sim_matrix = np.dot(emb_min, current_index.T)

        for k in K_VALUES:
            success_count = 0
            for i in range(len(emb_min)):
                scores = sim_matrix[i].copy()
                scores[i] = -1.0
                top_k_indices = np.argsort(-scores)[:k]
                if any(idx < len(emb_min) for idx in top_k_indices):
                    success_count += 1

            results['recall'][k].append(success_count / len(emb_min))

    return results


if __name__ == "__main__":
    all_results = {k: [] for k in K_VALUES}

    for seed in tqdm(SEEDS, desc="Running seeds"):
        results = run_bio_collapse(seed)
        for k in K_VALUES:
            all_results[k].append(results['recall'][k])

    mean_recalls = {}
    ci_95s = {}
    for k in K_VALUES:
        arr = np.array(all_results[k])
        mean_recalls[k] = np.mean(arr, axis=0)
        std_recall = np.std(arr, axis=0)
        ci_95s[k] = 1.96 * std_recall / np.sqrt(len(SEEDS))

    plt.figure(figsize=(10, 6))
    scale = 10 ** 3
    x_vals = np.array(results['n_maj']) / scale

    colors = ['#ff7f0e', '#2ca02c', '#d62728', '#1f77b4']

    for idx, k in enumerate(K_VALUES):
        color = colors[idx % len(colors)]
        label = f'Recall@{k}'

        plt.plot(x_vals, mean_recalls[k], 'o-', linewidth=2, color=color, label=label)
        plt.fill_between(x_vals,
                         mean_recalls[k] - ci_95s[k],
                         mean_recalls[k] + ci_95s[k],
                         alpha=0.3, color=color)

    plt.xlabel(r"Number of Visual Interferers ($N_{{maj}} \times 10^3$)")
    plt.ylabel("Recall@k")
    plt.ylim(0, 1.05)
    plt.legend()
    plt.tight_layout()

    ax = plt.gca()
    for line in ax.lines:
        line.set_linestyle('-')
    tikzplotlib.save("thm1_empr_clip_multi_k.tex", axis_height='6cm', axis_width='8cm')

    plt.show()