{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f8188900",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "19732d2b",
   "metadata": {},
   "source": [
    "## Check how many repeating bigrams are there"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "373090f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "from pathlib import Path\n",
    "import json\n",
    "import numpy as np\n",
    "from collections import Counter, defaultdict\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "# Add your path - adjust this to your actual path\n",
    "sys.path.append(\"~pythia_replicate\")\n",
    "\n",
    "from mmap_dataset_lightning import setup_pythia_data\n",
    "\n",
    "\n",
    "def analyze_bigram_repetition_distribution(dataloader, num_samples=50000, chunk_sizes=[512], max_bigram_count=10):\n",
    "    \"\"\"\n",
    "    Analyze the distribution of repeating bigrams in chunks.\n",
    "    Count how many chunks have 0, 1, 2, ... up to max_bigram_count+ repeating bigrams.\n",
    "    \n",
    "    Args:\n",
    "        dataloader: The data loader\n",
    "        num_samples: Number of chunks to analyze\n",
    "        chunk_sizes: List of chunk sizes to test\n",
    "        max_bigram_count: Maximum number of bigrams to track individually (higher counts grouped as \"max+\")\n",
    "    \"\"\"\n",
    "    \n",
    "    results = {}\n",
    "    \n",
    "    for chunk_size in chunk_sizes:\n",
    "        print(f\"\\n{'='*60}\")\n",
    "        print(f\"Analyzing {chunk_size}-token CHUNKS...\")\n",
    "        print(f\"{'='*60}\")\n",
    "        \n",
    "        chunks_processed = 0\n",
    "        \n",
    "        # Distribution of number of repeating bigrams\n",
    "        repetition_distribution = defaultdict(int)  # key: number of repeating bigrams, value: count\n",
    "        \n",
    "        # Additional statistics\n",
    "        total_repeated_bigrams = []  # Total number of repeated bigrams per chunk\n",
    "        unique_repeated_bigrams = []  # Number of unique bigrams that repeat\n",
    "        max_repeat_per_chunk = []  # Maximum times any single bigram repeats\n",
    "        most_repeated_bigrams = Counter()  # Track which bigrams repeat most often\n",
    "        \n",
    "        for batch_idx, batch in enumerate(tqdm(dataloader, desc=f\"Processing (chunk_size={chunk_size})\")):\n",
    "            if chunks_processed >= num_samples:\n",
    "                break\n",
    "            \n",
    "            # Get input_ids\n",
    "            input_ids = batch[\"input_ids\"]\n",
    "            batch_size, seq_len = input_ids.shape\n",
    "            \n",
    "            for seq_idx in range(batch_size):\n",
    "                if chunks_processed >= num_samples:\n",
    "                    break\n",
    "                \n",
    "                # Get individual sequence\n",
    "                sequence = input_ids[seq_idx].numpy()\n",
    "                \n",
    "                # CREATE CHUNKS - process all non-overlapping chunks\n",
    "                for chunk_start in range(0, seq_len - chunk_size + 1, chunk_size):\n",
    "                    if chunks_processed >= num_samples:\n",
    "                        break\n",
    "                    \n",
    "                    # Extract chunk\n",
    "                    chunk = sequence[chunk_start:chunk_start + chunk_size]\n",
    "                    \n",
    "                    # Skip incomplete chunks\n",
    "                    if len(chunk) != chunk_size:\n",
    "                        continue\n",
    "                    \n",
    "                    # Create bigrams\n",
    "                    bigrams = [(int(chunk[i]), int(chunk[i + 1])) \n",
    "                              for i in range(len(chunk) - 1)]\n",
    "                    bigram_counts = Counter(bigrams)\n",
    "                    \n",
    "                    # Count repeating bigrams\n",
    "                    # A bigram is \"repeating\" if it appears more than once\n",
    "                    repeating_bigrams = {bigram: count for bigram, count in bigram_counts.items() if count > 1}\n",
    "                    \n",
    "                    # Total number of repeated bigram instances (counting all occurrences beyond the first)\n",
    "                    total_repetitions = sum(count - 1 for count in repeating_bigrams.values())\n",
    "                    \n",
    "                    # Update distribution\n",
    "                    # Cap at max_bigram_count+ for display purposes\n",
    "                    repetition_key = min(total_repetitions, max_bigram_count)\n",
    "                    repetition_distribution[repetition_key] += 1\n",
    "                    \n",
    "                    # Track additional statistics\n",
    "                    total_repeated_bigrams.append(total_repetitions)\n",
    "                    unique_repeated_bigrams.append(len(repeating_bigrams))\n",
    "                    \n",
    "                    # Maximum repeat count for any single bigram\n",
    "                    if repeating_bigrams:\n",
    "                        max_repeat = max(repeating_bigrams.values())\n",
    "                        max_repeat_per_chunk.append(max_repeat)\n",
    "                    else:\n",
    "                        max_repeat_per_chunk.append(0)\n",
    "                    \n",
    "                    # Track most repeated bigrams globally\n",
    "                    for bigram, count in repeating_bigrams.items():\n",
    "                        most_repeated_bigrams[bigram] += count - 1\n",
    "                    \n",
    "                    chunks_processed += 1\n",
    "        \n",
    "        # Calculate percentages for distribution\n",
    "        distribution_percentages = {}\n",
    "        for i in range(max_bigram_count + 1):  # 0 to max_bigram_count+\n",
    "            count = repetition_distribution.get(i, 0)\n",
    "            percentage = 100.0 * count / chunks_processed if chunks_processed > 0 else 0\n",
    "            distribution_percentages[i] = {\n",
    "                'count': count,\n",
    "                'percentage': percentage\n",
    "            }\n",
    "        \n",
    "        # Calculate aggregate statistics\n",
    "        results[chunk_size] = {\n",
    "            'chunks_analyzed': chunks_processed,\n",
    "            'repetition_distribution': distribution_percentages,\n",
    "            'avg_total_repetitions': np.mean(total_repeated_bigrams) if total_repeated_bigrams else 0,\n",
    "            'median_total_repetitions': np.median(total_repeated_bigrams) if total_repeated_bigrams else 0,\n",
    "            'std_total_repetitions': np.std(total_repeated_bigrams) if total_repeated_bigrams else 0,\n",
    "            'avg_unique_repeated': np.mean(unique_repeated_bigrams) if unique_repeated_bigrams else 0,\n",
    "            'avg_max_repeat': np.mean(max_repeat_per_chunk) if max_repeat_per_chunk else 0,\n",
    "            'max_repeat_overall': max(max_repeat_per_chunk) if max_repeat_per_chunk else 0,\n",
    "            'top_repeated_bigrams': most_repeated_bigrams.most_common(10),\n",
    "            'percentile_95': np.percentile(total_repeated_bigrams, 95) if total_repeated_bigrams else 0,\n",
    "            'percentile_99': np.percentile(total_repeated_bigrams, 99) if total_repeated_bigrams else 0\n",
    "        }\n",
    "    \n",
    "    return results\n",
    "\n",
    "\n",
    "def print_distribution_results(results, tokenizer=None, max_bigram_count=10):\n",
    "    \"\"\"Pretty print the distribution analysis results.\n",
    "    \n",
    "    Args:\n",
    "        results: Analysis results dictionary\n",
    "        tokenizer: Optional tokenizer for decoding bigrams\n",
    "        max_bigram_count: Maximum number of bigrams tracked individually\n",
    "    \"\"\"\n",
    "    \n",
    "    print(f\"\\n{'='*80}\")\n",
    "    print(f\"BIGRAM REPETITION DISTRIBUTION ANALYSIS\")\n",
    "    print(f\"{'='*80}\")\n",
    "    \n",
    "    for chunk_size in sorted(results.keys()):\n",
    "        stats = results[chunk_size]\n",
    "        \n",
    "        print(f\"\\n{'='*60}\")\n",
    "        print(f\"📊 RESULTS FOR CHUNK SIZE: {chunk_size} tokens\")\n",
    "        print(f\"{'='*60}\")\n",
    "        print(f\"\\nTotal chunks analyzed: {stats['chunks_analyzed']:,}\")\n",
    "        \n",
    "        # Distribution table\n",
    "        print(f\"\\n📈 DISTRIBUTION OF REPEATING BIGRAMS:\")\n",
    "        print(f\"{'─'*50}\")\n",
    "        print(f\"{'# Repeating':<15} | {'Count':<10} | {'Percentage':<12} | {'Cumulative %':<12}\")\n",
    "        print(f\"{'─'*50}\")\n",
    "        \n",
    "        cumulative = 0\n",
    "        distribution = stats['repetition_distribution']\n",
    "        \n",
    "        for i in range(max_bigram_count + 1):\n",
    "            if i in distribution:\n",
    "                count = distribution[i]['count']\n",
    "                percentage = distribution[i]['percentage']\n",
    "                cumulative += percentage\n",
    "                \n",
    "                label = f\"{i} bigrams\" if i < max_bigram_count else f\"{max_bigram_count}+ bigrams\"\n",
    "                print(f\"{label:<15} | {count:<10,} | {percentage:>10.2f}% | {cumulative:>11.2f}%\")\n",
    "        \n",
    "        # Summary statistics\n",
    "        print(f\"\\n📊 SUMMARY STATISTICS:\")\n",
    "        print(f\"{'─'*50}\")\n",
    "        print(f\"  • Average repetitions per chunk: {stats['avg_total_repetitions']:.2f}\")\n",
    "        print(f\"  • Median repetitions per chunk: {stats['median_total_repetitions']:.2f}\")\n",
    "        print(f\"  • Std dev of repetitions: {stats['std_total_repetitions']:.2f}\")\n",
    "        print(f\"  • 95th percentile: {stats['percentile_95']:.0f} repetitions\")\n",
    "        print(f\"  • 99th percentile: {stats['percentile_99']:.0f} repetitions\")\n",
    "        print(f\"  • Average unique bigrams that repeat: {stats['avg_unique_repeated']:.2f}\")\n",
    "        print(f\"  • Average max repeat for any bigram: {stats['avg_max_repeat']:.2f}\")\n",
    "        print(f\"  • Highest repetition count observed: {stats['max_repeat_overall']}\")\n",
    "        \n",
    "        # Most repeated bigrams\n",
    "        if tokenizer and stats['top_repeated_bigrams']:\n",
    "            print(f\"\\n🔄 TOP 5 MOST FREQUENTLY REPEATED BIGRAMS:\")\n",
    "            print(f\"{'─'*50}\")\n",
    "            for (t1, t2), excess_count in stats['top_repeated_bigrams'][:5]:\n",
    "                try:\n",
    "                    combined = tokenizer.decode([t1, t2])\n",
    "                    print(f\"  • Token IDs ({t1}, {t2}) → '{combined}': +{excess_count} repetitions\")\n",
    "                except:\n",
    "                    print(f\"  • Token IDs ({t1}, {t2}): +{excess_count} repetitions\")\n",
    "        \n",
    "        # Highlight key findings\n",
    "        print(f\"\\n💡 KEY FINDINGS:\")\n",
    "        print(f\"{'─'*50}\")\n",
    "        zero_repeats = distribution.get(0, {}).get('percentage', 0)\n",
    "        one_repeat = distribution.get(1, {}).get('percentage', 0)\n",
    "        two_repeats = distribution.get(2, {}).get('percentage', 0)\n",
    "        \n",
    "        print(f\"  • {zero_repeats:.2f}% of chunks have NO repeating bigrams\")\n",
    "        print(f\"  • {one_repeat:.2f}% have exactly 1 repeating bigram\")\n",
    "        print(f\"  • {two_repeats:.2f}% have exactly 2 repeating bigrams\")\n",
    "        print(f\"  • {100 - zero_repeats:.2f}% have at least 1 repeating bigram\")\n",
    "\n",
    "\n",
    "def create_visualization(results, output_dir, max_bigram_count=10):\n",
    "    \"\"\"Create bar chart visualization of the distribution.\n",
    "    \n",
    "    Args:\n",
    "        results: Analysis results dictionary\n",
    "        output_dir: Directory to save visualization\n",
    "        max_bigram_count: Maximum number of bigrams tracked individually\n",
    "    \"\"\"\n",
    "    try:\n",
    "        import matplotlib.pyplot as plt\n",
    "        \n",
    "        for chunk_size, stats in results.items():\n",
    "            distribution = stats['repetition_distribution']\n",
    "            \n",
    "            # Prepare data for plotting\n",
    "            x_labels = []\n",
    "            counts = []\n",
    "            percentages = []\n",
    "            \n",
    "            for i in range(max_bigram_count + 1):\n",
    "                if i in distribution:\n",
    "                    label = str(i) if i < max_bigram_count else f\"{max_bigram_count}+\"\n",
    "                    x_labels.append(label)\n",
    "                    counts.append(distribution[i]['count'])\n",
    "                    percentages.append(distribution[i]['percentage'])\n",
    "            \n",
    "            # Create figure with two subplots\n",
    "            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 8))\n",
    "            \n",
    "            # Plot 1: Percentage distribution\n",
    "            bars1 = ax1.bar(x_labels, percentages, color='steelblue', edgecolor='black', alpha=0.7)\n",
    "            ax1.set_xlabel('Number of Non-Unique Bigrams', fontsize=12)\n",
    "            ax1.set_ylabel('Percentage of Chunks (%)', fontsize=12)\n",
    "            ax1.set_title(f'Distribution of Repeating Bigrams (Chunk Size: {chunk_size})', fontsize=14)\n",
    "            ax1.grid(axis='y', alpha=0.3)\n",
    "            \n",
    "            # Add percentage labels on bars\n",
    "            for bar, pct in zip(bars1, percentages):\n",
    "                height = bar.get_height()\n",
    "                if height > 0.5:  # Only label bars with >0.5%\n",
    "                    ax1.text(bar.get_x() + bar.get_width()/2., height,\n",
    "                            f'{pct:.1f}', ha='center', va='bottom', fontsize=9)\n",
    "            \n",
    "            # Plot 2: Cumulative distribution\n",
    "            cumulative = np.cumsum(percentages)\n",
    "            ax2.plot(x_labels, cumulative, 'r-o', linewidth=2, markersize=8)\n",
    "            ax2.fill_between(range(len(x_labels)), cumulative, alpha=0.3, color='red')\n",
    "            ax2.set_xlabel('Number of Non-Unique Bigrams (≤)', fontsize=12)\n",
    "            ax2.set_ylabel('Cumulative Percentage (%)', fontsize=12)\n",
    "            ax2.set_title(f'Cumulative Distribution (Chunk Size: {chunk_size})', fontsize=14)\n",
    "            ax2.grid(True, alpha=0.3)\n",
    "            ax2.set_ylim([0, 105])\n",
    "            \n",
    "            # Add horizontal lines at key percentiles\n",
    "            ax2.axhline(y=50, color='gray', linestyle='--', alpha=0.5, label='50%')\n",
    "            ax2.axhline(y=90, color='gray', linestyle='--', alpha=0.5, label='90%')\n",
    "            ax2.legend()\n",
    "            \n",
    "            plt.tight_layout()\n",
    "            \n",
    "            # Save plot\n",
    "            plot_file = output_dir / f\"bigram_distribution_chunk_{chunk_size}.png\"\n",
    "            plt.savefig(plot_file, dpi=150, bbox_inches='tight')\n",
    "            print(f\"\\n📈 Visualization saved to: {plot_file}\")\n",
    "            plt.show()\n",
    "            \n",
    "    except ImportError:\n",
    "        print(\"\\n⚠️  Matplotlib not available, skipping visualization\")\n",
    "\n",
    "\n",
    "def main():\n",
    "    # Configuration\n",
    "    config_path = \"~pythia_replicate/pythia-160m.json\"\n",
    "    num_samples = 100000  # Number of CHUNKS to analyze\n",
    "    chunk_sizes = [128, 256, 512, 1024]  # Chunk size to test\n",
    "    max_bigram_count = 64  # Maximum number of repeating bigrams to track individually (higher counts grouped as \"max+\")\n",
    "    \n",
    "    print(\"Loading configuration...\")\n",
    "    with open(config_path, \"r\") as f:\n",
    "        config = json.load(f)\n",
    "    \n",
    "    print(\"Setting up data module...\")\n",
    "    config[\"indices_file\"] = None\n",
    "    config[\"num_workers\"] = 0\n",
    "    data_module = setup_pythia_data(config)\n",
    "    data_module.setup()\n",
    "    \n",
    "    print(\"Creating dataloader...\")\n",
    "    train_dataloader = data_module.train_dataloader()\n",
    "    \n",
    "    print(f\"Starting bigram repetition distribution analysis...\")\n",
    "    print(f\"Analyzing {num_samples} chunks of size {chunk_sizes}\")\n",
    "    print(f\"Tracking up to {max_bigram_count} repeating bigrams (higher counts grouped as '{max_bigram_count}+')\")\n",
    "    \n",
    "    # Run analysis\n",
    "    results = analyze_bigram_repetition_distribution(\n",
    "        train_dataloader, \n",
    "        num_samples=num_samples,\n",
    "        chunk_sizes=chunk_sizes,\n",
    "        max_bigram_count=max_bigram_count\n",
    "    )\n",
    "    \n",
    "    # Load tokenizer for decoding\n",
    "    print(\"\\nLoading tokenizer for decoding...\")\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/pythia-160m\")\n",
    "    \n",
    "    # Print results\n",
    "    print_distribution_results(results, tokenizer, max_bigram_count=max_bigram_count)\n",
    "    \n",
    "    # Save results to JSON\n",
    "    output_dir = Path(\"~pythia_replicate/code_testing/bigram_frequency\")\n",
    "    output_dir.mkdir(parents=True, exist_ok=True)\n",
    "    \n",
    "    # Prepare results for JSON serialization\n",
    "    json_results = {}\n",
    "    for chunk_size, stats in results.items():\n",
    "        json_results[str(chunk_size)] = {\n",
    "            'chunks_analyzed': int(stats['chunks_analyzed']),\n",
    "            'repetition_distribution': {\n",
    "                str(k): {\n",
    "                    'count': int(v['count']),\n",
    "                    'percentage': float(v['percentage'])\n",
    "                }\n",
    "                for k, v in stats['repetition_distribution'].items()\n",
    "            },\n",
    "            'avg_total_repetitions': float(stats['avg_total_repetitions']),\n",
    "            'median_total_repetitions': float(stats['median_total_repetitions']),\n",
    "            'std_total_repetitions': float(stats['std_total_repetitions']),\n",
    "            'avg_unique_repeated': float(stats['avg_unique_repeated']),\n",
    "            'avg_max_repeat': float(stats['avg_max_repeat']),\n",
    "            'max_repeat_overall': int(stats['max_repeat_overall']),\n",
    "            'percentile_95': float(stats['percentile_95']),\n",
    "            'percentile_99': float(stats['percentile_99']),\n",
    "            'top_repeated_bigrams': [\n",
    "                {'tokens': [int(t) for t in bigram], 'excess_count': int(count)}\n",
    "                for bigram, count in stats['top_repeated_bigrams']\n",
    "            ]\n",
    "        }\n",
    "    \n",
    "    output_file = output_dir / f\"bigram_distribution_analysis_{num_samples}_max{max_bigram_count}.json\"\n",
    "    \n",
    "    with open(output_file, \"w\") as f:\n",
    "        json.dump(json_results, f, indent=2)\n",
    "    \n",
    "    print(f\"\\n💾 Results saved to: {output_file}\")\n",
    "    \n",
    "    # Create visualization\n",
    "    create_visualization(results, output_dir, max_bigram_count=max_bigram_count)\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pythia_replicate",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
