import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from transformers import AutoTokenizer



def format_token(token):
    
    return repr(token)[1:-1]


def load_bigram_stats(file_path):
    """Load bigram statistics from JSON file."""
    with open(file_path, "r") as f:
        return json.load(f)


def plot_bigram_frequencies(stats, tokenizer, save_path=None):
    """Plot bigram frequencies as a bar chart and display detokenized tokens."""
    bigrams = stats["top_20_bigrams"]

    
    bigram_labels = [f"({bg[0][0]}, {bg[0][1]})" for bg in bigrams]
    frequencies = [bg[1] for bg in bigrams]

    
    plt.figure(figsize=(15, 8))
    bars = plt.bar(range(len(bigrams)), frequencies, color="steelblue", alpha=0.7)

    
    plt.title("Top 20 Bigram Frequencies", fontsize=16, fontweight="bold")
    plt.xlabel("Bigram (Token ID Pairs)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.xticks(range(len(bigrams)), bigram_labels, rotation=45, ha="right")

    
    plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f"{int(x):,}"))

    plt.tight_layout()
    plt.grid(axis="y", alpha=0.3)

    if save_path:
        save_path_freq = save_path / "bigram_frequency.png"
        plt.savefig(save_path_freq, dpi=300, bbox_inches="tight")

    fig, ax = plt.subplots(figsize=(6, 8))
    ax.axis("tight")
    ax.axis("off")

    
    table_data = []
    for i, bigram in enumerate(bigrams):
        token1_id, token2_id = tokenizer.decode(bigram[0][0]), tokenizer.decode(
            bigram[0][1]
        )
        bigram_freq = bigram[1]
        table_data.append(
            [
                f"{i}",
                f"({format_token(token1_id)}|{format_token(token2_id)})",
                f"{bigram_freq:,}".replace(",", " "),
            ]
        )

    
    table = ax.table(
        cellText=table_data,
        colLabels=["Rank", "Bigram ( | )", "Frequency"],
        cellLoc="left",
        loc="center",
        colWidths=[0.25, 0.25, 0.25],
    )

    table.auto_set_font_size(False)
    table.set_fontsize(11)
    table.scale(1, 1.5)

    
    for i in range(3):
        table[(0, i)].set_facecolor("
        table[(0, i)].set_text_props(weight="bold", color="white")

    ax.set_title("Top 20 Bigrams", fontsize=16, fontweight="bold", pad=20)
    if save_path:
        save_path_table = save_path / "bigram_frequency_table.png"
        plt.savefig(save_path_table, dpi=300, bbox_inches="tight")

    
    plt.figure(figsize=(10, 8))
    percentages = [stats["effective_percentage"], 100 - stats["effective_percentage"]]
    tokens = [
        stats["effective_masked_tokens"],
        stats["total_tokens_processed"] - stats["effective_masked_tokens"],
    ]
    bars = plt.bar(range(2), percentages, color="steelblue", alpha=0.7)

    
    plt.title(
        "Percentage of Bigrams Repeating at Least Once", fontsize=16, fontweight="bold"
    )
    plt.xlabel("Repeating vs Not Repeating", fontsize=12)
    plt.ylabel("Percentage", fontsize=12)
    plt.xticks(range(2), ["Repeating", "Not Repeating"], rotation=45, ha="right")

    
    plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f"{int(x):,}"))

    plt.tight_layout()
    plt.grid(axis="y", alpha=0.3)
    for i, (bar, freq) in enumerate(zip(bars, tokens)):
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height(),
            f"{freq:,}",
            ha="center",
            va="bottom",
            fontsize=10,
        )
    if save_path:
        save_path_freq = save_path / "bigram_frequency_percentage.pdf"
        plt.savefig(save_path_freq, dpi=300, bbox_inches="tight")


def print_summary_stats(stats):
    """Print summary statistics."""
    print("=" * 60)
    print("BIGRAM STATISTICS SUMMARY")
    print("=" * 60)
    print(f"Sequences processed: {stats['sequences_processed']:,}")
    print(f"Sequence length: {stats['sequence_length']:,}")
    print(f"Total tokens processed: {stats['total_tokens_processed']:,}")
    print(f"Effective masked tokens: {stats['effective_masked_tokens']:,}")
    print(f"Effective percentage: {stats['effective_percentage']:.2f}%")
    print()

    bigrams = stats["top_20_bigrams"]
    frequencies = [bg[1] for bg in bigrams]

    print("TOP BIGRAM STATISTICS:")
    print(f"Most frequent bigram: {bigrams[0][0]} with {bigrams[0][1]:,} occurrences")
    print(
        f"Least frequent (in top 20): {bigrams[-1][0]} with {bigrams[-1][1]:,} occurrences"
    )
    print(f"Average frequency (top 20): {np.mean(frequencies):,.0f}")
    print(f"Median frequency (top 20): {np.median(frequencies):,.0f}")
    print(f"Standard deviation: {np.std(frequencies):,.0f}")
    print()

    
    total_top_20 = sum(frequencies)
    print(f"Top 20 bigrams account for {total_top_20:,} tokens")
    print(
        f"Percentage of effective masked tokens: {(total_top_20 / stats['effective_masked_tokens']) * 100:.2f}%"
    )
    print("=" * 60)


def main():
    
    stats_file = Path(
        "pythia_replicate/code_testing/bigram_frequency/bigram_stats_5000000.json"
    )

    save_path = Path("pythia_replicate/code_testing/bigram_frequency")

    if not stats_file.exists():
        print(f"Error: Stats file not found at {stats_file}")
        return

    
    stats = load_bigram_stats(stats_file)

    
    print("Loading tokenizer...")
    try:
        tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
    except Exception as e:
        print(f"Warning: Could not load tokenizer: {e}")
        print("Proceeding without detokenized plot...")
        tokenizer = None

    
    print_summary_stats(stats)

    print("Creating bigram frequency plot...")
    plot_bigram_frequencies(stats, tokenizer, save_path=save_path)

    print("Plotting complete! Check the generated plots.")


if __name__ == "__main__":
    main()
