# pretrain.py

import sys
import os

# Add the project root to the Python path to allow for absolute imports
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), *['..']*4))
sys.path.insert(0, project_root)

import torch
from datasets import load_dataset, interleave_datasets
from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
)

# Ensure your custom model and config are imported
from TF.modeling_neuromamba import NeuroMambaForCausalLM
from TF.configuration_neuromamba import NeuroMambaConfig
from HealthCheckCallback import HealthCheckCallback
from safe_dataset import SafeIterableDataset

# 0. Device Configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- 1. Dataset Loading and Interleaving ---
print("Loading and mixing datasets...")

# Use streaming mode for loading
# a) SlimPajama (English)
slim_pajama_ds = load_dataset(
    "cerebras/SlimPajama-627B", 
    streaming=True, 
    split="train"
)

# b) The Stack (Code, e.g., Python)
the_stack_ds = load_dataset(
    "bigcode/the-stack-dedup", 
    data_dir="data/python",
    streaming=True, 
    split="train"
)

# c) SkyPile (Chinese)
skypile_ds = load_dataset(
    "Skywork/SkyPile-150B", 
    streaming=True, 
    split="train"
)

# Unify the text column name across different datasets
the_stack_ds = the_stack_ds.map(lambda x: {"text": x["content"]})

# Interleave the datasets
train_dataset_stream = interleave_datasets(
    [slim_pajama_ds, the_stack_ds, skypile_ds],
    probabilities=[0.75, 0.15, 0.1],
    seed=42,
)
print("Datasets mixed successfully.")


# --- 2. Tokenizer and Tokenization Function ---
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    "state-spaces/mamba-130m-hf",
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    contents = [text + tokenizer.eos_token for text in examples["text"]]
    return tokenizer(contents, max_length=256, truncation=True)

# Dynamically tokenize the streaming dataset
tokenized_train_ds = train_dataset_stream.map(
    tokenize_function,
    batched=True,
)

safe_train_dataset = SafeIterableDataset(tokenized_train_ds)

# --- 3. Initialize Model (Targeting Mamba-130M) ---
# [IMPORTANT CHANGE] Completed Config parameters to match Mamba-130M
EXPECTED_VOCAB_SIZE = 50280  

config = NeuroMambaConfig(
    vocab_size=EXPECTED_VOCAB_SIZE,
    hidden_size=768,
    num_hidden_layers=12,
    rms_norm=True,
    residual_in_fp32=True,
    use_bias=False,
    # ... other parameters required by your NeuroMamba model ...
)

model = NeuroMambaForCausalLM(config).to(device)


# --- 4. Set Up Training Arguments (No Evaluation) ---
# [IMPORTANT CHANGE] Removed evaluation-related parameters
args = TrainingArguments(
    output_dir="/home/hdd/test/NeuroMamba_re/Pretain/NeuMa_140M",
    per_device_train_batch_size=64,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    max_steps=150000,
    learning_rate=3e-4,
    weight_decay=0.1,
    adam_beta1=0.9,
    adam_beta2=0.95,
    max_grad_norm=1.0,
    lr_scheduler_type="cosine_with_restarts",
    lr_scheduler_kwargs={
        "num_cycles": 5
    },
    warmup_steps=2000,
    bf16=True, # Note: bf16 requires Ampere or newer GPUs.
    logging_steps=100,
    save_strategy="steps",
    save_steps=5000,
    save_total_limit=4,
)

# --- 5. Initialize Trainer (No Evaluation) ---
# [IMPORTANT CHANGE] Removed eval_dataset
trainer = Trainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=safe_train_dataset,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    callbacks=[HealthCheckCallback],
)

print("-" * 50)


# --- 6. Start Training ---
print("Starting training...")
trainer.train(resume_from_checkpoint="/home/hdd/test/NeuroMamba_re/Pretain/NeuMa_140M_0918/checkpoint-25000")