import time
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def benchmark_llama(
    model_name: str = "meta-llama/Llama-3.2-1B",
    prompt: str = "Hello, how are you?",
    max_tokens: int = 1000,
    step: int = 10,
    seeds: list = range(10),
    output_filename: str = "raw_benchmark_data.json"
):
    """
    Run 5 'long' generation runs (one per seed), each time generating up
    to `max_tokens` tokens in chunks of `step` tokens, measuring intermediate
    times at each chunk boundary. For each x in {step, 2*step, ..., max_tokens},
    store an array of length len(seeds) with the total elapsed time for that
    seed's run.

    :param model_name: Hugging Face model identifier or local path.
    :param prompt: Initial text prompt for generation.
    :param max_tokens: Total new tokens to generate in each run.
    :param step: Increment size for measuring intermediate times (e.g. 10).
    :param seeds: List of integer seeds (one run per seed).
    :param output_filename: Path to store the raw JSON benchmark data.
    """

    # 1) Load Model and Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,  # Adjust if needed
        device_map="auto"          # or .to("cuda") if single GPU
    )

    # Encode prompt once
    base_inputs = tokenizer(prompt, return_tensors="pt")

    if torch.cuda.is_available():
        base_inputs = {k: v.to("cuda") for k, v in base_inputs.items()}

    # Optional warm-up pass (not timed)
    _ = model.generate(**base_inputs, max_new_tokens=10, do_sample=False)

    # Prepare data structure
    # times_per_x[x] will hold a list of times, one for each seed.
    # x in {step, 2*step, 3*step, ..., max_tokens}
    times_per_x = {x: [] for x in range(step, max_tokens + 1, step)}

    # 2) Main loop over seeds (5 runs total)
    for seed in seeds:
        # Set the seed for reproducibility
        torch.manual_seed(seed)

        # Start from the original prompt
        current_ids = base_inputs["input_ids"]
        current_attention_mask = base_inputs["attention_mask"]

        cumulative_time = 0.0

        # Generate tokens in increments of `step`
        # e.g. 10, then 20, then 30, ... until we reach `max_tokens`
        # We'll do 10 tokens at a time, measuring how long each chunk takes,
        # and add that to the cumulative time.
        num_steps = max_tokens // step
        for i in range(num_steps):
            start_time = time.time()
            output = model.generate(
                input_ids=current_ids,
                attention_mask=current_attention_mask,
                max_new_tokens=step,
                do_sample=False  # set True if you prefer random sampling
            )
            end_time = time.time()

            chunk_time = end_time - start_time
            cumulative_time += chunk_time

            # The total new tokens so far is (i+1)*step
            x = (i + 1) * step
            times_per_x[x].append(cumulative_time)

            # Update "prompt" for next chunk:
            # output is shape [batch_size=1, total_length], just re-feed it
            current_ids = output
            current_attention_mask = torch.ones_like(output, dtype=torch.long)

    # 3) Store raw data in JSON
    # We store times_per_x in the form:
    # {
    #   "10": [time_seed0, time_seed1, ...],
    #   "20": [...],
    #   ...
    #   "300": [...]
    # }
    with open(output_filename, "w") as f:
        # Convert integer keys to strings for JSON
        json.dump({str(k): v for k, v in times_per_x.items()}, f, indent=2)

    # 4) Print summary stats for each x
    print(f"Benchmarking '{model_name}' with prompt: '{prompt}'")
    print("---------------------------------------------------------")
    for x in sorted(times_per_x.keys()):
        times = times_per_x[x]  # all seeds (length = len(seeds))
        avg_time = sum(times) / len(times)
        tokens_per_ms = x / (avg_time * 1000.0)
        print(
            f"x = {x:3d} tokens  | "
            f"avg time = {avg_time:.4f} s  | "
            f"tokens/ms = {tokens_per_ms:.4f}"
        )


if __name__ == "__main__":
    benchmark_llama()
