"""
t-SNE Visualization: Subsample Convergence to Full Distribution

Shows how random subsamples progressively "fill in" the token embedding space.
- Base: t-SNE computed on N tokens (background in dark blue)
- Overlay: Independent subsamples of 3k, 10k highlighted with lighter colors

Non-nested sampling: each subsample is drawn independently.

This visualization demonstrates uniform convergence: as n grows, subsamples
increasingly cover the full support of the token distribution.
"""
import sys
from pathlib import Path

# Add parent directories to path for imports
BASE_DIR = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(BASE_DIR))

import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D
from sklearn.manifold import TSNE
import pickle
from tqdm import tqdm

# Project imports
from sample_complexity.utils import (
    load_model,
    load_tokenizer,
    get_layer0_embeddings_direct,
    device,
    load_text
)

# ============================================================
# CONFIGURATION
# ============================================================
TSNE_DIR = Path(__file__).parent
DATA_DIR = TSNE_DIR / "data"
OUTPUT_DIR = TSNE_DIR / "figures"

# t-SNE base size (all points that go into t-SNE computation)
# Note: t-SNE is O(n^2), so 50k is practical, 100k is slow, 500k is infeasible
TSNE_BASE_SIZE = 50_000

# Subsample sizes to highlight (must be <= TSNE_BASE_SIZE)
# These are INDEPENDENT samples (not nested)
SUBSAMPLE_SIZES = [3_000, 10_000]

# t-SNE parameters
TSNE_PERPLEXITY = 50
TSNE_MAX_ITER = 1000
RANDOM_SEED = 42

# Model config
MODEL_SIZE = "base"

# Cache file
CACHE_FILE = DATA_DIR / f"tsne_base_{TSNE_BASE_SIZE}.pkl"

# Color scheme: light (sparse) -> dark (dense)
COLORS = {
    3_000: '#C7E9F1',   # Light cyan (sparse)
    10_000: '#0077B6',  # Ocean blue (medium)
    50_000: '#001845',  # Very dark blue (full base)
}


# ============================================================
# EMBEDDING LOADING
# ============================================================
def load_embeddings(n_tokens: int) -> np.ndarray:
    """Load BigBird layer-0 embeddings for n tokens from WikiText."""
    print(f"Loading text for {n_tokens:,} tokens...")
    text = load_text(
        max_length=n_tokens,
        source="wiki",
        language="en",
        verbose=True
    )

    print("Loading model...")
    model = load_model(
        model_name="BigBird",
        model_size=MODEL_SIZE,
        max_length=n_tokens,
    )

    print("Tokenizing...")
    tokenizer = load_tokenizer("BigBird", model_size=MODEL_SIZE)
    tokens = tokenizer(text, return_tensors="pt", truncation=True, max_length=n_tokens)
    actual_n = tokens['input_ids'].shape[1]
    print(f"Got {actual_n:,} tokens")

    print("Extracting embeddings (layer 0)...")
    model_device = next(model.parameters()).device
    tokens_on_device = {k: v.to(model_device) for k, v in tokens.items()}

    with torch.no_grad():
        embeddings = get_layer0_embeddings_direct(model, tokens_on_device)

    embeddings_np = embeddings.squeeze(0).cpu().numpy()
    print(f"Embeddings shape: {embeddings_np.shape}")

    return embeddings_np


# ============================================================
# t-SNE COMPUTATION
# ============================================================
def compute_or_load_tsne():
    """Compute t-SNE or load from cache."""
    if CACHE_FILE.exists():
        print(f"Loading cached t-SNE from {CACHE_FILE}")
        with open(CACHE_FILE, 'rb') as f:
            data = pickle.load(f)
        return data['embeddings_2d'], data['n_points']

    # Load embeddings
    embeddings = load_embeddings(TSNE_BASE_SIZE)
    n_points = len(embeddings)

    # Run t-SNE
    print(f"\nRunning t-SNE on {n_points:,} points...")
    print(f"Parameters: perplexity={TSNE_PERPLEXITY}, max_iter={TSNE_MAX_ITER}")

    tsne = TSNE(
        n_components=2,
        perplexity=TSNE_PERPLEXITY,
        max_iter=TSNE_MAX_ITER,
        random_state=RANDOM_SEED,
        init='pca',
        learning_rate='auto',
        verbose=1,
        n_jobs=-1,
    )

    embeddings_2d = tsne.fit_transform(embeddings)
    print(f"t-SNE complete! Shape: {embeddings_2d.shape}")

    # Cache
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    with open(CACHE_FILE, 'wb') as f:
        pickle.dump({
            'embeddings_2d': embeddings_2d,
            'n_points': n_points,
            'TSNE_PERPLEXITY': TSNE_PERPLEXITY,
            'TSNE_MAX_ITER': TSNE_MAX_ITER,
            'RANDOM_SEED': RANDOM_SEED,
        }, f)
    print(f"Cached to {CACHE_FILE}")

    return embeddings_2d, n_points


# ============================================================
# PLOTTING
# ============================================================
def plot_convergence(embeddings_2d: np.ndarray, n_points: int, output_dir: Path):
    """Create the convergence visualization in three versions."""
    np.random.seed(RANDOM_SEED + 1)  # Different seed for subsampling

    fig, ax = plt.subplots(figsize=(10, 10))

    # Draw the full base (50k) - BIGGEST points, dark blue
    ax.scatter(
        embeddings_2d[:, 0],
        embeddings_2d[:, 1],
        c=COLORS[50_000],
        s=50,
        alpha=0.8,
        label=f'N = {n_points:,} (limit)',
        rasterized=True,
    )

    # Draw subsamples from largest to smallest (so smaller ones are on top)
    for n in sorted(SUBSAMPLE_SIZES, reverse=True):
        if n > n_points:
            print(f"Warning: subsample size {n:,} > available points {n_points:,}, skipping")
            continue

        # Independent random sample
        indices = np.random.choice(n_points, size=n, replace=False)
        points = embeddings_2d[indices]

        color = COLORS.get(n, '#888888')

        # Subsamples: smaller points, more shaded
        if n <= 3000:
            size, alpha = 8, 0.5
        elif n <= 10000:
            size, alpha = 15, 0.5
        else:
            size, alpha = 25, 0.5

        ax.scatter(
            points[:, 0],
            points[:, 1],
            c=color,
            s=size,
            alpha=alpha,
            label=f'n = {n:,}',
            rasterized=True,
        )

    # Styling
    ax.set_xlabel('t-SNE dimension 1', fontsize=12)
    ax.set_ylabel('t-SNE dimension 2', fontsize=12)

    # Legend: order from smallest to largest
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[::-1], labels[::-1], loc='upper right', fontsize=10,
              framealpha=0.9, markerscale=1.5)

    ax.set_aspect('equal', adjustable='datalim')
    ax.grid(True, alpha=0.2)
    ax.set_xticks([])
    ax.set_yticks([])

    plt.tight_layout()

    # Version 1: with labels and legend
    out_png = output_dir / 'tsne_subsample_convergence.png'
    out_pdf = output_dir / 'tsne_subsample_convergence.pdf'
    plt.savefig(out_png, dpi=200, bbox_inches='tight', facecolor='white')
    plt.savefig(out_pdf, bbox_inches='tight', facecolor='white')
    print(f"Saved: {out_png}")

    # Version 2: clean (no axes, no legend, no box)
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.legend().set_visible(False)
    ax.grid(False)
    ax.axis('off')

    out_png_v2 = output_dir / 'tsne_subsample_convergence_v2.png'
    out_pdf_v2 = output_dir / 'tsne_subsample_convergence_v2.pdf'
    plt.savefig(out_png_v2, dpi=200, bbox_inches='tight', facecolor='white')
    plt.savefig(out_pdf_v2, bbox_inches='tight', facecolor='white')
    print(f"Saved: {out_png_v2}")

    # Version 3: no axes/box, but WITH legend using cleaner labels
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#C7E9F1', markersize=18, label='3000 tokens'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#0077B6', markersize=18, label='10000 tokens'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#001845', markersize=18, label='50000 tokens'),
    ]
    legend = ax.legend(
        handles=legend_elements,
        loc='upper left',
        bbox_to_anchor=(0.8, 1),
        fontsize=28,
        framealpha=0.9,
    )
    legend.set_visible(True)

    out_png_v3 = output_dir / 'tsne_subsample_convergence_v3.png'
    out_pdf_v3 = output_dir / 'tsne_subsample_convergence_v3.pdf'
    plt.savefig(out_png_v3, dpi=200, bbox_inches='tight', facecolor='white')
    plt.savefig(out_pdf_v3, bbox_inches='tight', facecolor='white')
    print(f"Saved: {out_png_v3}")

    plt.close()


# ============================================================
# MAIN
# ============================================================
def main():
    print("=" * 60)
    print("t-SNE Subsample Convergence Visualization")
    print("=" * 60)
    print(f"\nBase size: {TSNE_BASE_SIZE:,} tokens")
    print(f"Subsample sizes: {[f'{n:,}' for n in SUBSAMPLE_SIZES]}")
    print()

    # Create output directory
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # Compute or load t-SNE
    embeddings_2d, n_points = compute_or_load_tsne()

    # Plot
    plot_convergence(embeddings_2d, n_points, OUTPUT_DIR)

    print("\nDone!")


if __name__ == "__main__":
    main()
