"""
Analyze the file structure of allenai/dolma3_longmino_pool to understand
organization by length and topic.
"""

from huggingface_hub import list_repo_files
from collections import defaultdict
import re

DATASET_NAME = "allenai/dolma3_longmino_pool"

print("="*80)
print(f"Analyzing File Structure: {DATASET_NAME}")
print("="*80)

print("\nFetching file list from HuggingFace...")
files = list(list_repo_files(DATASET_NAME, repo_type="dataset"))

print(f"Total files: {len(files):,}\n")

# Parse file structure
# Expected pattern: data/{topic}/{shard}.jsonl.zst
data_files = [f for f in files if f.startswith('data/') and f.endswith('.jsonl.zst')]

print(f"Data files (JSONL): {len(data_files):,}")

# Group by directory (topic/source)
topics = defaultdict(list)
for file in data_files:
    parts = file.split('/')
    if len(parts) >= 3:
        topic = parts[1]  # data/{topic}/file
        topics[topic].append(file)

print(f"Number of topics/sources: {len(topics)}\n")

print("="*80)
print("Topics/Sources (sorted by file count):")
print("="*80)

# Sort topics by number of files
sorted_topics = sorted(topics.items(), key=lambda x: len(x[1]), reverse=True)

for i, (topic, topic_files) in enumerate(sorted_topics, 1):
    # Try to extract length/size info from topic name
    length_match = re.search(r'(\d+[kmb]?)', topic, re.IGNORECASE)
    length_info = f" (length: {length_match.group(1)})" if length_match else ""

    print(f"\n{i:3d}. {topic}{length_info}")
    print(f"     Files: {len(topic_files):,}")

    # Show file naming pattern
    if topic_files:
        first_file = topic_files[0].split('/')[-1]
        print(f"     Example: {first_file}")

        # Count shards with different prefixes
        prefixes = defaultdict(int)
        for f in topic_files:
            filename = f.split('/')[-1]
            # Extract prefix (e.g., p020, p030)
            prefix_match = re.match(r'(p\d+)', filename)
            if prefix_match:
                prefixes[prefix_match.group(1)] += 1

        if prefixes:
            print(f"     Shard prefixes: {dict(sorted(prefixes.items()))}")

print("\n" + "="*80)
print("Summary by Length (if encoded in topic names):")
print("="*80)

# Try to group by length indicators
length_groups = defaultdict(list)
for topic in topics.keys():
    # Look for length indicators: 2k, 4k, 8k, 16k, 32k, 64k, etc.
    if re.search(r'\d+k', topic, re.IGNORECASE):
        length = re.search(r'(\d+k)', topic, re.IGNORECASE).group(1).lower()
        length_groups[length].append(topic)
    elif re.search(r'\d+m', topic, re.IGNORECASE):
        length = re.search(r'(\d+m)', topic, re.IGNORECASE).group(1).lower()
        length_groups[length].append(topic)
    else:
        length_groups['other'].append(topic)

for length, length_topics in sorted(length_groups.items()):
    total_files = sum(len(topics[t]) for t in length_topics)
    print(f"\n{length}:")
    print(f"  Topics: {len(length_topics)}")
    print(f"  Total files: {total_files:,}")
    if len(length_topics) <= 5:
        for t in length_topics:
            print(f"    - {t}")
    else:
        for t in length_topics[:3]:
            print(f"    - {t}")
        print(f"    ... and {len(length_topics)-3} more")

print("\n" + "="*80)
print("Done!")
print("="*80)
