from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer
from sae.data import chunk_and_tokenize
import os
from experiment_config import config
from config_loader import parse_config_overrides, apply_overrides

def process_split(split: str, tokenizer: AutoTokenizer, split_name: str) -> None:
    """Process a single dataset split"""
    print(f"\nProcessing {split_name} split: {split}")

    # Get dataset loading arguments
    dataset_args = config.get_dataset_args()
    dataset_args["split"] = split

    os.environ['HF_HOME'] = str(config.cache_dir)
    
    # Create path for tokenized dataset
    tokenized_path = f'{str(config.tokenized_dataset_path)}_ {str(config.train_dataset_split)}'

    if Path(tokenized_path).exists():
        print(f"Loading tokenized dataset from {config.tokenized_dataset_path}")
        tokenized = load_from_disk(tokenized_path)
    else:
        print(f"Loading dataset {config.dataset}")
        dataset_args = config.get_dataset_args()
        print(f"Dataset args: {dataset_args}")

        dataset_args['split'] = config.train_dataset_split
        print(dataset_args)
        dataset = load_dataset(**dataset_args,
                              trust_remote_code=True,
                              cache_dir=config.cache_dir)


        print("Tokenizing dataset...")
        tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        tokenized = chunk_and_tokenize(dataset, tokenizer, text_key=config.text_key)

        tokenized.save_to_disk(tokenized_path)

    # Print statistics
    n_sequences = len(tokenized)
    n_tokens = n_sequences * config.cache_ctx_len
    print(f"Number of sequences: {n_sequences:,}")
    print(f"Number of tokens: {n_tokens:,}")

    # Restrict to max_tokens if needed
    if n_tokens > config.max_tokens:
        max_sequences = config.max_tokens // config.cache_ctx_len
        tokenized = tokenized.select(range(max_sequences))
        print(f"Restricted to {max_sequences:,} sequences ({config.max_tokens:,} tokens)")

        # Print warning if less than 1B tokens
        if config.max_tokens < 1_000_000_000:
            print(f"Warning: Using less than 1B tokens ({config.max_tokens:,})")

    # Shuffle dataset
    tokenized = tokenized.shuffle(seed=config.random_seed)

    # Save processed dataset
    processed_path = config.cache_dir / "processed" / config.dataset_short_name / split_name
    processed_path.parent.mkdir(parents=True, exist_ok=True)
    tokenized.save_to_disk(str(processed_path))

def main():
    # Load and apply configuration overrides
    overrides = parse_config_overrides()
    if "--no-reinit_non_embedding" in os.sys.argv:
        overrides["reinit_non_embedding"] = False
    apply_overrides(config, overrides)

    # Set up HF cache directory
    os.environ['HF_HOME'] = str(config.cache_dir)

    print(f"Processing dataset: {config.dataset}")
    if config.dataset_name:
        print(f"Dataset config: {config.dataset_name}")

    # Initialize tokenizer
    print(f"\nLoading tokenizer: {config.model_name}")
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Process train split
    process_split(config.train_dataset_split, tokenizer, "train")

    # Process test split
    process_split(config.test_dataset_split, tokenizer, "test")

    print("\nDataset preparation complete!")

if __name__ == "__main__":
    main()
