import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sentence_transformers import SentenceTransformer, CrossEncoder
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import matplotlib.pyplot as plt
import tikzplotlib
import pickle
from pathlib import Path

plt.style.use(Path.cwd().parent / "latex_tikz.mplstyle")

CONFIG = {
    'biencoder_model': 'all-MiniLM-L6-v2',
    'crossencoder_model': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
    'minority_topic': 'comp.sys.mac.hardware',
    'minority_size': 100,
    'shortlist_sizes': [1, 2, 5, 10, 20, 50, 100, 200, 500],
    'seeds': [42, 101, 999, 2026],
    'cache_dir': Path(Path.cwd() / 'cache'),
    'use_cache': True,
    'quick_run': False,
}

if CONFIG['quick_run']:
    CONFIG['shortlist_sizes'] = [10, 50, 200]
    CONFIG['seeds'] = [42]
    CONFIG['minority_size'] = 50

def get_cache_path(name):
    CONFIG['cache_dir'].mkdir(parents=True, exist_ok=True)
    return CONFIG['cache_dir'] / f"{name}.pkl"


def save_cache(name, data):
    with open(get_cache_path(name), 'wb') as f:
        pickle.dump(data, f)

def load_cache(name):
    path = get_cache_path(name)
    if path.exists() and CONFIG['use_cache']:
        with open(path, 'rb') as f:
            return pickle.load(f)
    return None


def load_data(seed=42):
    """Load and prepare 20 Newsgroups data with majority/minority split."""
    np.random.seed(seed)

    min_docs_raw = fetch_20newsgroups(
        subset='train',
        categories=[CONFIG['minority_topic']],
        remove=('headers', 'footers', 'quotes')
    ).data

    min_queries_raw = fetch_20newsgroups(
        subset='test',
        categories=[CONFIG['minority_topic']],
        remove=('headers', 'footers', 'quotes')
    ).data

    min_docs_raw = [d for d in min_docs_raw if len(d) > 20]
    min_queries_raw = [d for d in min_queries_raw if len(d) > 20]

    n_min = CONFIG['minority_size']
    doc_idx = np.random.choice(len(min_docs_raw), n_min, replace=False)
    query_idx = np.random.choice(len(min_queries_raw), n_min, replace=False)

    minority_docs = [min_docs_raw[i] for i in doc_idx]
    minority_queries = [min_queries_raw[i] for i in query_idx]

    all_cats = fetch_20newsgroups(subset='train').target_names
    majority_cats = [c for c in all_cats if c != CONFIG['minority_topic']]

    maj_data = fetch_20newsgroups(
        subset='all',
        categories=majority_cats,
        remove=('headers', 'footers', 'quotes')
    )

    maj_docs_raw = [d for d in maj_data.data if len(d) > 20]
    majority_docs = maj_docs_raw

    maj_test = fetch_20newsgroups(
        subset='test',
        categories=majority_cats,
        remove=('headers', 'footers', 'quotes')
    )
    maj_queries_raw = [d for d in maj_test.data if len(d) > 20]

    maj_query_idx = np.random.choice(len(maj_queries_raw), n_min, replace=False)
    majority_queries = [maj_queries_raw[i] for i in maj_query_idx]

    return {
        'minority_docs': minority_docs,
        'minority_queries': minority_queries,
        'majority_docs': majority_docs,
        'majority_queries': majority_queries,
    }

def get_embeddings(biencoder, texts, cache_name=None):
    """Get embeddings with optional caching."""
    if cache_name:
        cached = load_cache(cache_name)
        if cached is not None:
            print(f"  Loaded cached embeddings: {cache_name}")
            return cached

    print(f"  Computing embeddings for {len(texts)} texts...")
    embeddings = biencoder.encode(texts, normalize_embeddings=True, show_progress_bar=True)

    if cache_name:
        save_cache(cache_name, embeddings)

    return embeddings


def retrieve_topk(query_emb, index_emb, k):
    """Retrieve top-k candidates using cosine similarity."""
    scores = cosine_similarity(query_emb.reshape(1, -1), index_emb)[0]
    topk_indices = np.argsort(-scores)[:k]
    topk_scores = scores[topk_indices]
    return topk_indices, topk_scores


def rerank_crossencoder(crossencoder, query_text, candidate_texts, candidate_indices):
    """Rerank candidates using cross-encoder."""
    if len(candidate_texts) == 0:
        return [], []

    pairs = [[query_text, cand] for cand in candidate_texts]
    scores = crossencoder.predict(pairs)

    sorted_idx = np.argsort(-scores)
    reranked_indices = [candidate_indices[i] for i in sorted_idx]
    reranked_scores = [scores[i] for i in sorted_idx]

    return reranked_indices, reranked_scores


def rerank_perfect(candidate_indices, true_idx):
    """
    Perfect/Oracle reranker: ranks the true document first if present.
    Returns True if target is in shortlist (guaranteed success), otherwise returns False.
    """
    return true_idx in candidate_indices

def run_experiment(seed=42):
    """
    Shortlist Bottleneck & Reranker Mitigation Experiment
    Demonstrates that minority collapse is a retrieval bottleneck problem.
    """
    print(f"\n{'=' * 80}")
    print(f"Running experiment with seed={seed}")
    print(f"{'=' * 80}")

    data = load_data(seed)

    n_min_docs = len(data['minority_docs'])
    n_maj_docs = len(data['majority_docs'])
    n_min_queries = len(data['minority_queries'])
    n_maj_queries = len(data['majority_queries'])

    print(f"Minority docs: {n_min_docs}, queries: {n_min_queries}")
    print(f"Majority docs: {n_maj_docs}, queries: {n_maj_queries}")

    all_docs = data['minority_docs'] + data['majority_docs']
    minority_doc_indices = set(range(n_min_docs))

    print("\nLoading models...")
    biencoder = SentenceTransformer(CONFIG['biencoder_model'])
    crossencoder = CrossEncoder(CONFIG['crossencoder_model'])

    print("\nComputing embeddings...")
    cache_suffix = f"_seed{seed}_nmin{CONFIG['minority_size']}"

    emb_all_docs = get_embeddings(biencoder, all_docs, f"all_docs{cache_suffix}")
    emb_min_queries = get_embeddings(biencoder, data['minority_queries'], f"min_queries{cache_suffix}")
    emb_maj_queries = get_embeddings(biencoder, data['majority_queries'], f"maj_queries{cache_suffix}")

    print("\nComputing ground truth pairings...")
    emb_min_docs_only = emb_all_docs[:n_min_docs]
    baseline_sim = cosine_similarity(emb_min_queries, emb_min_docs_only)
    min_ground_truth = np.argmax(baseline_sim, axis=1)

    emb_maj_docs_only = emb_all_docs[n_min_docs:]
    maj_baseline_sim = cosine_similarity(emb_maj_queries, emb_maj_docs_only)
    maj_ground_truth = np.argmax(maj_baseline_sim, axis=1) + n_min_docs  # Offset to global index

    results = {
        'L_values': CONFIG['shortlist_sizes'],
        'minority_biencoder_recall': [],
        'majority_biencoder_recall': [],
        'minority_crossencoder_success': [],
        'majority_crossencoder_success': [],
        'minority_perfect_success': [],
        'majority_perfect_success': [],
    }

    for L in tqdm(CONFIG['shortlist_sizes'], desc="Shortlist size sweep"):
        min_bienc_hits = 0
        min_crossenc_hits = 0
        min_perfect_hits = 0
        maj_bienc_hits = 0
        maj_crossenc_hits = 0
        maj_perfect_hits = 0

        for i in range(n_min_queries):
            query_emb = emb_min_queries[i]
            query_text = data['minority_queries'][i]
            true_idx = min_ground_truth[i]  # Index within minority docs (also global index)

            topk_indices, _ = retrieve_topk(query_emb, emb_all_docs, L)

            if true_idx in topk_indices:
                min_bienc_hits += 1
                min_perfect_hits += 1

                candidate_texts = [all_docs[idx] for idx in topk_indices]
                reranked_indices, _ = rerank_crossencoder(crossencoder, query_text, candidate_texts, topk_indices)

                if reranked_indices[0] == true_idx:
                    min_crossenc_hits += 1

        for i in range(n_maj_queries):
            query_emb = emb_maj_queries[i]
            query_text = data['majority_queries'][i]
            true_idx = maj_ground_truth[i]  # Global index

            topk_indices, _ = retrieve_topk(query_emb, emb_all_docs, L)

            if true_idx in topk_indices:
                maj_bienc_hits += 1
                maj_perfect_hits += 1

                candidate_texts = [all_docs[idx] for idx in topk_indices]
                reranked_indices, _ = rerank_crossencoder(crossencoder, query_text, candidate_texts, topk_indices)

                if reranked_indices[0] == true_idx:
                    maj_crossenc_hits += 1

        results['minority_biencoder_recall'].append(min_bienc_hits / n_min_queries)
        results['minority_crossencoder_success'].append(min_crossenc_hits / n_min_queries)
        results['minority_perfect_success'].append(min_perfect_hits / n_min_queries)
        results['majority_biencoder_recall'].append(maj_bienc_hits / n_maj_queries)
        results['majority_crossencoder_success'].append(maj_crossenc_hits / n_maj_queries)
        results['majority_perfect_success'].append(maj_perfect_hits / n_maj_queries)

    return results


def aggregate_results(all_results):
    L_values = all_results[0]['L_values']

    metrics = ['minority_biencoder_recall', 'majority_biencoder_recall',
               'minority_crossencoder_success', 'majority_crossencoder_success',
               'minority_perfect_success', 'majority_perfect_success']

    aggregated = {'L_values': L_values}

    for metric in metrics:
        values = np.array([r[metric] for r in all_results])
        aggregated[f'{metric}_mean'] = np.mean(values, axis=0)
        aggregated[f'{metric}_std'] = np.std(values, axis=0)
        aggregated[f'{metric}_ci95'] = 1.96 * np.std(values, axis=0) / np.sqrt(len(all_results))

    return aggregated


def plot_results(aggregated):
    L_values = np.array(aggregated['L_values'])

    fig, ax = plt.subplots(figsize=(10, 6))

    color_maj = '#1f77b4'  # Blue
    color_min = '#d62728'  # Red

    ax.plot(L_values, aggregated['majority_perfect_success_mean'],
            'o-', color=color_maj, linewidth=2, markersize=8,
            label='Majority (Oracle Reranker)')
    ax.fill_between(L_values,
                    aggregated['majority_perfect_success_mean'] - aggregated['majority_perfect_success_ci95'],
                    aggregated['majority_perfect_success_mean'] + aggregated['majority_perfect_success_ci95'],
                    alpha=0.2, color=color_maj)

    ax.plot(L_values, aggregated['minority_perfect_success_mean'],
            'o-', color=color_min, linewidth=2, markersize=8,
            label='Minority (Oracle Reranker)')
    ax.fill_between(L_values,
                    aggregated['minority_perfect_success_mean'] - aggregated['minority_perfect_success_ci95'],
                    aggregated['minority_perfect_success_mean'] + aggregated['minority_perfect_success_ci95'],
                    alpha=0.2, color=color_min)

    ax.plot(L_values, aggregated['majority_crossencoder_success_mean'],
            's--', color=color_maj, linewidth=1.5, markersize=6, alpha=0.7,
            label='Majority (Cross-Encoder)')
    ax.plot(L_values, aggregated['minority_crossencoder_success_mean'],
            's--', color=color_min, linewidth=1.5, markersize=6, alpha=0.7,
            label='Minority (Cross-Encoder)')

    ax.set_xscale('log')
    ax.set_xlabel(r'Shortlist Size ($L$)')
    ax.set_ylabel('Success Rate @ 1')
    ax.set_ylim(0, 1.05)
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)

    ax.axvspan(L_values[0], L_values[1], alpha=0.1, color='red')
    ax.text(np.sqrt(L_values[0] * L_values[1]), 0.15, 'Retrieval\nBottleneck', fontsize=10,
            ha='center', color='darkred', style='italic')

    plt.tight_layout()

    output_dir = Path.cwd()
    plt.savefig(output_dir / 'shortlist_bottleneck.pdf', dpi=300, bbox_inches='tight')

    for line in ax.lines:
        if line.get_linestyle() == '-':
            line.set_linestyle('-')
    tikzplotlib.save(output_dir / 'shortlist_bottleneck.tex',
                     axis_height='6cm', axis_width='10cm')

    plt.show()

    return fig


def print_summary(aggregated):

    L_values = aggregated['L_values']

    print(f"\n{'L':>6} | {'Min Perfect':>12} | {'Min CrossEnc':>12} | {'Maj Perfect':>12} | {'Maj CrossEnc':>12}")
    print("-" * 80)

    for i, L in enumerate(L_values):
        print(f"{L:>6} | "
              f"{aggregated['minority_perfect_success_mean'][i]:>10.3f}±{aggregated['minority_perfect_success_ci95'][i]:.3f} | "
              f"{aggregated['minority_crossencoder_success_mean'][i]:>10.3f}±{aggregated['minority_crossencoder_success_ci95'][i]:.3f} | "
              f"{aggregated['majority_perfect_success_mean'][i]:>10.3f}±{aggregated['majority_perfect_success_ci95'][i]:.3f} | "
              f"{aggregated['majority_crossencoder_success_mean'][i]:>10.3f}±{aggregated['majority_crossencoder_success_ci95'][i]:.3f}")


if __name__ == "__main__":
    all_results = []

    for seed in CONFIG['seeds']:
        results = run_experiment(seed)
        all_results.append(results)

    aggregated = aggregate_results(all_results)
    print_summary(aggregated)
    plot_results(aggregated)

    save_cache('aggregated_results', aggregated)

