import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import normalize
from scipy.special import softmax
from tqdm import tqdm
import matplotlib
import matplotlib.backends.backend_pgf
def common_texification(text):
    return text
matplotlib.backends.backend_pgf.common_texification = common_texification
import tikzplotlib
import matplotlib.pyplot as plt

MODEL_NAME = 'all-MiniLM-L6-v2'
LEARNING_RATE = 0.05
NOISE_SCALE = 0.002
NUM_STEPS = 5000
POPULATION_SKEW = 0.98
SOFTMAX_TEMP = 10.0
INTERACTION_RADIUS = 0.35

MINORITY_GENRE = 'film noir'
MAJORITY_GENRES = ['crime', 'mystery']


def get_raw_movie_vectors_tuned():
    print(f"--- Loading Cultural Data (Raw 384-Dim) ---")
    dataset = load_dataset("vishnupriyavr/wiki-movie-plots-with-summaries", split="train")
    model = SentenceTransformer(MODEL_NAME)

    def is_noir(row): return MINORITY_GENRE in str(row['Genre']).lower()

    def is_crime(row):
        g = str(row['Genre']).lower()
        return any(maj in g for maj in MAJORITY_GENRES) and not is_noir(row)

    min_plots = dataset.filter(is_noir)['Plot'][:50]
    maj_plots = dataset.filter(is_crime)['Plot'][:1000]

    print(f"Encoding {len(min_plots)} Noir vs {len(maj_plots)} Crime/Mystery plots...")
    X_min = model.encode(min_plots, batch_size=64, normalize_embeddings=True)
    X_maj = model.encode(maj_plots, batch_size=64, normalize_embeddings=True)

    maj_center = np.mean(X_maj, axis=0)
    void_vector = -2.5 * maj_center

    X_min = X_min + void_vector
    X_min = normalize(X_min)

    return X_min, X_maj


def run_raw_metastability_final():
    min_docs, maj_docs = get_raw_movie_vectors_tuned()

    all_docs = np.vstack([min_docs, maj_docs])
    doc_labels = np.array([0] * len(min_docs) + [1] * len(maj_docs))
    initial_min_docs = min_docs.copy()

    history = {'step': [], 'min_recall': [], 'drift': []}

    print(f"\n--- Running Raw High-Dim Metastability (Shift=-2.5) ---")

    for t in tqdm(range(NUM_STEPS)):

        if np.random.rand() < POPULATION_SKEW:
            q_idx = np.random.randint(len(maj_docs))
            query_vec = maj_docs[q_idx]
        else:
            q_idx = np.random.randint(len(min_docs))
            query_vec = min_docs[q_idx]

        scores = np.dot(all_docs, query_vec)

        active_mask = scores > INTERACTION_RADIUS

        if np.any(active_mask):
            weights = softmax(scores[active_mask] * SOFTMAX_TEMP)
            active_indices = np.where(active_mask)[0]
            drift_forces = weights[:, np.newaxis] * (query_vec - all_docs[active_mask])
            all_docs[active_indices] += LEARNING_RATE * drift_forces

        noise = np.random.normal(0, NOISE_SCALE, all_docs.shape)
        all_docs += noise

        norms = np.linalg.norm(all_docs, axis=1, keepdims=True)
        all_docs = all_docs / norms

        if t % 100 == 0:
            curr_min_docs = all_docs[doc_labels == 0]
            sim_matrix = np.dot(curr_min_docs, all_docs.T)
            hits = 0
            for i in range(len(curr_min_docs)):
                sim_matrix[i, i] = -1.0
                top_idx = np.argmax(sim_matrix[i])
                if top_idx < len(curr_min_docs):
                    hits += 1

            recall = hits / len(curr_min_docs)
            drift = np.linalg.norm(curr_min_docs - initial_min_docs, axis=1).mean()

            history['step'].append(t)
            history['min_recall'].append(recall)
            history['drift'].append(drift)

    fig, ax1 = plt.subplots(figsize=(10, 6))

    color = '#ff7f0e'
    ax1.set_xlabel('Simulation Steps')
    ax1.set_ylabel('Film Noir Distinctiveness (Recall@1)', color=color)
    ax1.plot(history['step'], history['min_recall'], color=color, linewidth=3, label='Recall')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.set_ylim(-0.05, 1.05)

    ax2 = ax1.twinx()
    color = '#1f77b4'
    ax2.set_ylabel('Assimilation Drift (L2)', color=color)
    ax2.plot(history['step'], history['drift'], color=color, linestyle='--', linewidth=2, label='Drift')
    ax2.tick_params(axis='y', labelcolor=color)

    plt.axvline(x=2000, color='gray', linestyle=':', alpha=0.5)
    plt.text(1000, 0.5, "Safe Zone\n(Recall High)", ha='center',
             bbox=dict(facecolor='white', alpha=0.9, edgecolor='none'))
    plt.text(3500, 0.5, "Assimilation\n(Recall Drops)", ha='center',
             bbox=dict(facecolor='white', alpha=0.9, edgecolor='none'))

    plt.grid(True, linestyle='--', alpha=0.6)

    fig.tight_layout()

    ax = plt.gca()
    for line in ax.lines:
        line.set_linestyle('-')
    tikzplotlib.save("thm3_empr_noir.tex", axis_height='6cm', axis_width='8cm')

    plt.show()


if __name__ == "__main__":
    run_raw_metastability_final()