import sys
from pathlib import Path

sys.path.append(str(Path(__file__).parent.parent))
from mmap_dataset_lightning import setup_pythia_data
import json
import numpy as np
from collections import Counter
from tqdm import tqdm

config_path = "~pythia_replicate/pythia-160m.json"

with open(config_path, "r") as f:
    config = json.load(f)

data_module = setup_pythia_data(config)
data_module.setup()


def analyze_bigrams_on_dataloader(dataloader, num_samples=50000, seq_length=2048):
    """Analyze bigram statistics on fixed-length sequences from dataloader."""
    sequences_processed = 0

    
    corpus_bigram_counts = Counter()
    total_tokens_processed = 0
    effective_masked_tokens = 0

    for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
        if sequences_processed >= num_samples:
            break

        
        input_ids = batch["input_ids"]
        batch_size = input_ids.shape[0]

        for seq_idx in range(batch_size):
            if sequences_processed >= num_samples:
                break

            
            sequence = input_ids[seq_idx].numpy()

            
            if len(sequence) != seq_length:
                print(
                    f"Warning: Sequence length {len(sequence)} != expected {seq_length}"
                )
                continue

            total_tokens_processed += len(sequence)

            
            bigrams = [(sequence[i], sequence[i + 1]) for i in range(len(sequence) - 1)]
            bigram_counts = Counter(bigrams)

            
            corpus_bigram_counts.update(bigrams)

            effective_masked_tokens += sum(
                count - 1 for count in bigram_counts.values() if count > 1
            )

            sequences_processed += 1

    return {
        
        "total_tokens_processed": total_tokens_processed,
        "effective_masked_tokens": effective_masked_tokens,
        "sequences_processed": sequences_processed,
        "sequence_length": seq_length,
        "most_common_bigrams": corpus_bigram_counts.most_common(20),
    }



train_dataloader = data_module.train_dataloader()
print("Starting bigram analysis...")
results = analyze_bigrams_on_dataloader(
    train_dataloader, num_samples=5000000, seq_length=2048
)


print(f"\n{'='*60}")
print(f"BIGRAM ANALYSIS RESULTS")
print(f"{'='*60}")

print(
    f"\n📊 PER-SEQUENCE STATISTICS (averaged across {results['sequences_processed']} sequences):"
)
print(f"  Sequence length: {results['sequence_length']} tokens")


from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")

print(f"\n🔝 MOST COMMON TOKEN BIGRAMS:")
for (t1, t2), count in results["most_common_bigrams"][:10]:
    text1 = tokenizer.decode([t1])
    text2 = tokenizer.decode([t2])
    combined = tokenizer.decode([t1, t2])
    print(f"  ({t1}, {t2}) -> '{combined}' [{text1}|{text2}]: {count:,} times")

print("Effective masked tokens: ", results["effective_masked_tokens"])
print("Total tokens processed: ", results["total_tokens_processed"])


total_tokens     = results["total_tokens_processed"]
effective_tokens = results["effective_masked_tokens"]
effective_pct    = 100.0 * effective_tokens / total_tokens

stats = {
    "sequences_processed"   : results["sequences_processed"],
    "sequence_length"       : results["sequence_length"],
    "total_tokens_processed": total_tokens,
    "effective_masked_tokens": effective_tokens,
    "effective_percentage"  : effective_pct,
    "top_20_bigrams"        : results["most_common_bigrams"],
}


out_dir = Path("~pythia_replicate/code_testing/bigram_frequency")
out_dir.mkdir(parents=True, exist_ok=True)

fname = out_dir / f"bigram_stats_{results['sequences_processed']}json"

def _numpy_converter(obj):
    """
    Turn NumPy scalars/arrays into plain Python so json can handle them.
    """
    import numpy as np
    if isinstance(obj, (np.integer, np.floating)):
        return obj.item()            
    if isinstance(obj, np.ndarray):
        return obj.tolist()          
    raise TypeError(f"Object of type {type(obj)} is not JSON serializable")

with open(fname, "w") as fp:
    json.dump(stats, fp, indent=2, default=_numpy_converter)

print(f"🗒️  Stats JSON saved to: {fname}")
