# benchmark_neuma.py
import torch
import transformers
from transformers import Trainer, TrainingArguments
from datasets import IterableDataset
import shutil
import json
import os
import sys

print(f"Using transformers version: {transformers.__version__}")

# ===================================================================
# 1. Model Loading
# ===================================================================
def get_neuma_model(device: str, dtype: torch.dtype):
    print("\nLoading model: NEUMA...")
    project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), *['..']*3))
    if project_root not in sys.path: sys.path.insert(0, project_root)
    from TF.modeling_neuromamba import NeuroMambaForCausalLM
    from TF.configuration_neuromamba import NeuroMambaConfig
    
    # --- Configure your NeuMa-140M model ---
    config = NeuroMambaConfig(vocab_size=50280, hidden_size=768, state_size=16, num_hidden_layers=12,
                              conv_kernel=4, expand=2, expand_gc=2, conv_gc_kernel=4)
    model = NeuroMambaForCausalLM(config)
    params = sum(p.numel() for p in model.parameters())
    print(f"NEUMA model created with {params / 1e6:.2f}M parameters.")
    return model.to(device=device, dtype=dtype), params

# ===================================================================
# 2. Benchmarking Functions - Identical to the Mamba script
# ===================================================================
def bytes_to_gb(b):
    return b / 1024**3

def measure_training_performance_hf(model, model_name, device, batch_size, seq_len, dtype, grad_accum_steps):
    from datasets import IterableDataset
    output_dir = f"./tmp_{model_name}_training"
    training_args = TrainingArguments(output_dir=output_dir, per_device_train_batch_size=batch_size,
                                      gradient_accumulation_steps=grad_accum_steps, num_train_epochs=1,
                                      max_steps=15, logging_steps=15, bf16=(dtype == torch.bfloat16),
                                      fp16=(dtype == torch.float16), report_to="none", remove_unused_columns=False)
    def dummy_data_generator(num_samples, seq_len, vocab_size):
        for _ in range(num_samples):
            data = torch.randint(1, vocab_size, (seq_len,)); yield {"input_ids": data, "labels": data.clone()}
    train_dataset = IterableDataset.from_generator(dummy_data_generator, gen_kwargs={"num_samples": 1000, "seq_len": seq_len, "vocab_size": 50280})
    trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
    torch.cuda.synchronize(device); torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats(device)
    train_output = trainer.train()
    torch.cuda.synchronize(device)
    throughput = train_output.metrics['train_samples_per_second'] * seq_len
    peak_memory_gb = bytes_to_gb(torch.cuda.max_memory_allocated(device))
    print(f"\n--- [Training Performance] ---"); print(f"  Throughput: {throughput:.2f} tokens/sec"); print(f"  Peak Memory: {peak_memory_gb:.2f} GB")
    shutil.rmtree(output_dir)
    return throughput, peak_memory_gb

@torch.no_grad()
def measure_inference_performance_hf(model, model_loader_func, device, batch_size, prompt_len, gen_len, dtype):
    model.eval()
    clean_model, _ = model_loader_func(device, dtype)
    model_memory_gb = bytes_to_gb(torch.cuda.memory_allocated(device))
    del clean_model; torch.cuda.empty_cache()
    input_ids = torch.randint(1, 50280, (batch_size, prompt_len), device=device)
    _ = model.generate(input_ids, max_length=prompt_len + 10, eos_token_id=-1, use_cache=True) # Warm-up run
    torch.cuda.synchronize(device)
    start_time = torch.cuda.Event(enable_timing=True); end_time = torch.cuda.Event(enable_timing=True)
    start_time.record()
    _ = model.generate(input_ids, max_length=prompt_len + gen_len, eos_token_id=-1, use_cache=True)
    end_time.record()
    torch.cuda.synchronize(device)
    total_time_ms, total_tokens = start_time.elapsed_time(end_time), batch_size * gen_len
    latency = total_time_ms / total_tokens if batch_size == 1 else -1
    throughput = total_tokens / (total_time_ms / 1000)
    print(f"\n--- [Inference Performance (BS={batch_size})] ---")
    if batch_size == 1: print(f"  Latency: {latency:.2f} ms/token")
    print(f"  Throughput: {throughput:.2f} tokens/sec"); print(f"  Model Memory: {model_memory_gb:.2f} GB")
    return model_memory_gb, latency, throughput

# ===================================================================
# 3. Main Execution - With results collection and saving
# ===================================================================
if __name__ == "__main__":
    # --- Experiment Configuration ---
    DEVICE = "cuda:0"
    DTYPE = torch.float16 if torch.cuda.is_available() and torch.cuda.get_device_capability(DEVICE)[0] < 8 else torch.bfloat16
    TRAIN_PER_DEVICE_BS, TRAIN_GRAD_ACCUM_STEPS, TRAIN_SEQ_LEN = 1, 16, 2048
    INFERENCE_PROMPT_LEN, INFERENCE_GEN_LEN = 1, 100
    INFERENCE_BS_LATENCY, INFERENCE_BS_THROUGHPUT = 1, 16

    # --- Run and Collect Results ---
    print("="*60); print(" " * 18 + "BENCHMARKING NEUMA MODEL"); print("="*60)
    print(f"Using dtype: {DTYPE}")
    model, num_params = get_neuma_model(device=DEVICE, dtype=DTYPE)
    
    results = {
        "model_name": "NeuMa",
        "parameters_M": round(num_params / 1e6, 2)
    }

    train_throughput, train_memory = measure_training_performance_hf(model, "neuma", DEVICE, TRAIN_PER_DEVICE_BS, TRAIN_SEQ_LEN, DTYPE, TRAIN_GRAD_ACCUM_STEPS)
    results["training"] = {"throughput_tok_s": train_throughput, "peak_memory_gb": train_memory}

    mem_bs1, lat_bs1, _ = measure_inference_performance_hf(model, get_neuma_model, DEVICE, INFERENCE_BS_LATENCY, INFERENCE_PROMPT_LEN, INFERENCE_GEN_LEN, DTYPE)
    results["inference_latency_bs1"] = {"latency_ms_tok": lat_bs1}
    results["inference_memory"] = {"model_memory_gb": mem_bs1}

    _, _, tp_bs_large = measure_inference_performance_hf(model, get_neuma_model, DEVICE, INFERENCE_BS_THROUGHPUT, INFERENCE_PROMPT_LEN, INFERENCE_GEN_LEN, DTYPE)
    results["inference_throughput_bs_large"] = {"throughput_tok_s": tp_bs_large}

    # --- Save Results to File ---
    output_filename = "benchmark_results_neuma.json"
    with open(output_filename, "w") as f:
        json.dump(results, f, indent=4)

    print("\n" + "="*60)
    print(f"Benchmark finished. Results saved to '{output_filename}'")
    print("="*60)