"""Benchmarking for the models."""

import csv
import os
import time
import argparse
import torch
from tqdm import tqdm
from datasets import load_dataset
from transformers import (
    LlamaTokenizerFast,
    LlamaTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM,
)

from src.inference import inference
from src.utils import device, touch, load_models, dtype, get_gpu_name, warmup
from src.predictor import trainer

MODE_NAME = {
    "ars": "autoregressive",
    "sps": "speculative",
    "usps": "upper_bound_speculative",
    "dsps": "dynamic_speculative",
    "dhsps": "dynamic_speculative_history",
    "pp": "perceptron_predictor",
}


def benchmark(
    draft_model_path,
    target_model_path,
    dataset_name,
    gamma=4,
    temperature=0.0,
    n_tokens_to_generate=150,
    mode="ars",
    use_fa=False,
    use_4bit=False,
    use_8bit=False,
):
    ### Load the models
    os.makedirs("logs", exist_ok=True)

    gpu_name = get_gpu_name()
    temperature_str = str(temperature).replace(".", "_")

    tokenizer, draft_model, target_model = load_models(
        draft_model_path,
        target_model_path,
        load_one=mode == "ars",
        use_fa=use_fa,
        use_4bit=use_4bit,
        use_8bit=use_8bit,
    )

    draft_model_name = draft_model_path.split("/")[-1]
    target_model_name = target_model_path.split("/")[-1]
    dataset_name = dataset_name.split("/")[-1]
    if mode == "ars":
        logfile_name = f"logs/{mode}_{dataset_name}_{target_model_name}_{gpu_name}_t{temperature_str}.csv"
    else:
        logfile_name = f"logs/{mode}_{dataset_name}_d{draft_model_name}_t{target_model_name}_{gpu_name}_t{temperature_str}_g{gamma}.csv"

    with open(logfile_name, "w", newline="") as file:
        writer = csv.writer(file)
        header = [
            "prefill_tokens",
            "generate_tokens",
            "total_tokens",
            "prefill_time",
            "generate_time",
            "total_time",
            "prefill_tok_per_sec",
            "generate_tok_per_sec",
            "total_tok_per_sec",
            "acceptance_rate",
            "draft_sample_count",
            "target_sample_count",
            "thrown_away_count",
            "accepted_count",
            "resampled_count",
        ]

        writer.writerow(header)
    # <|SYSTEM|># StableLM Tuned (Alpha version)
    # - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
    # - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
    # - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
    # - StableLM will refuse to participate in anything that could harm a human.
    prompt_structures = {
        # "llama": "<|system|>\n{system_prompt}.\n<|user|>\n{document}\n<|assistant|>\n",
        "meta-llama": "<s>[INST] <<SYS>>{system_prompt}\n<</SYS>>\n\n{document} [/INST]",
        "codellama": "<s>[INST] <<SYS>>{system_prompt}\n<</SYS>>\n\n{document} [/INST]",
        "openlm-research": "<s>[INST] <<SYS>>{system_prompt}\n<</SYS>>\n\n{document} [/INST]",
        "bigscience": "{system_prompt}\n\nUser: {document}\nAssistant:",
        "facebook": "{system_prompt}\n\nUser: {document}\nAssistant:",
        "google": "<bos><start_of_turn>user\n{system_prompt}.\n\n{document}<end_of_turn>\n<start_of_turn>model",
        "stabilityai": "{system_prompt}<|USER|>{document}<|ASSISTANT|>",
        "openai-community": "{system_prompt}\n\nUser: {document}\nAssistant:",
        "databricks": "{system_prompt}\n\nUser: {document}\nAssistant:",
        "lmsys": "<s>[INST] <<SYS>>{system_prompt}\n<</SYS>>\n\n{document} [/INST]",
        # "flan": "{system_prompt} User: {document} Assistant:",
    }

    model_group = target_model_path.split("/")[0]

    ### Load the dataset and generate prompts per dataset
    prompts = []
    samples = 25
    if dataset_name == "xsum":
        dataset = load_dataset("EdinburghNLP/xsum")
        documents = dataset["test"]["document"][:samples]
        n_tokens_to_generate = 50

        for document in documents:
            system_prompt = "You write 2 sentence summaries of new articles. Do not write any more. Keep it brief and to the point, no yapping"
            instruction_prompt = prompt_structures[model_group].format(
                system_prompt=system_prompt, document=document
            )
            prompts.append(instruction_prompt)

    if dataset_name == "openai_humaneval":
        dataset = load_dataset("openai_humaneval")
        documents = dataset["test"]["prompt"][:samples]
        n_tokens_to_generate = 150

        for document in documents:
            system_prompt = "You are an expert programmer that helps to complete Python code. Given the code from the user, please complete the rest of the code to the best of your ability"
            instruction_prompt = prompt_structures[model_group].format(
                system_prompt=system_prompt, document=document
            )
            prompts.append(instruction_prompt)

    if dataset_name == "finance-alpaca":
        dataset = load_dataset("gbharti/finance-alpaca")
        documents = dataset["train"]["instruction"][:samples]
        n_tokens_to_generate = 500

        for document in documents:
            system_prompt = "You are a finance expert. Answer the following questions to the best of your knowledge, and yap as much as possible"
            instruction_prompt = prompt_structures[model_group].format(
                system_prompt=system_prompt, document=document
            )
            prompts.append(instruction_prompt)

    if dataset_name == "gsm8k":
        dataset = load_dataset("gsm8k", "main")
        documents = dataset["test"]["question"][:samples]
        n_tokens_to_generate = 250

        for document in documents:
            system_prompt = "You are given a math question, and your task is to answer it. Then provide a step-by-step walkthrough on how you got the answer to the question"
            instruction_prompt = prompt_structures[model_group].format(
                system_prompt=system_prompt, document=document
            )
            prompts.append(instruction_prompt)

    from src.predictor import QLearningAgent

    actions = [i for i in range(-1, 2)]
    agent = QLearningAgent(actions)

    for prompt in tqdm(prompts):
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

        warmup()
        inference(
            input_ids,
            tokenizer,
            n_tokens_to_generate,
            temperature,
            model=target_model,
            draft_model=draft_model,
            target_model=target_model,
            gamma=gamma,
            max_gamma=gamma,
            logfile_name=logfile_name,
            record_speculation=True,
            record_empirical_upper_bound=mode == "usps",
            mode=MODE_NAME[mode],
            dataset=dataset_name,
            logging_history=True,
            agent=agent,
        )
        # if mode == "pp":
        #     # reenable autograd
        #     torch.set_grad_enabled(True)
        #     trainer(dataset_name, gamma)
        #     # unset auto grad
        #     torch.set_grad_enabled(False)

    del tokenizer, draft_model, target_model

    # load the corresponding csv file and report average statistics
    with open(logfile_name, "r") as file:
        reader = csv.DictReader(file)
        total_prefill_time = 0
        total_generate_time = 0
        total_prefill_tokens = 0
        total_generate_tokens = 0
        total_prefill_tok_per_sec = 0
        total_generate_tok_per_sec = 0
        total_acceptance_rate = 0
        total_draft_sample_count = 0
        total_target_sample_count = 0
        total_thrown_away_count = 0
        total_accepted_count = 0
        total_resampled_count = 0
        total_total_tokens = 0
        total_total_time = 0
        total_total_tok_per_sec = 0

        for row in reader:
            total_prefill_time += float(row["prefill_time"])
            total_generate_time += float(row["generate_time"])
            total_prefill_tokens += int(row["prefill_tokens"])
            total_generate_tokens += int(row["generate_tokens"])
            total_prefill_tok_per_sec += float(row["prefill_tok_per_sec"])
            total_generate_tok_per_sec += float(row["generate_tok_per_sec"])
            if mode != "ars":
                total_acceptance_rate += float(row["acceptance_rate"])
                total_draft_sample_count += int(row["draft_sample_count"])
                total_target_sample_count += int(row["target_sample_count"])
                total_thrown_away_count += int(row["thrown_away_count"])
                total_accepted_count += int(row["accepted_count"])
                total_resampled_count += int(row["resampled_count"])
            total_total_tokens += int(row["total_tokens"])
            total_total_time += float(row["total_time"])
            total_total_tok_per_sec += float(row["total_tok_per_sec"])

        total_prefill_time /= samples
        total_generate_time /= samples
        total_prefill_tokens /= samples
        total_generate_tokens /= samples
        total_prefill_tok_per_sec /= samples
        total_generate_tok_per_sec /= samples
        total_acceptance_rate /= samples
        total_draft_sample_count /= samples
        total_target_sample_count /= samples
        total_thrown_away_count /= samples
        total_accepted_count /= samples
        total_resampled_count /= samples
        total_total_tokens /= samples
        total_total_time /= samples
        total_total_tok_per_sec /= samples

        print(f"Average prefill time: {total_prefill_time}")
        print(f"Average generate time: {total_generate_time}")
        print(f"Average prefill tokens: {total_prefill_tokens}")
        print(f"Average generate tokens: {total_generate_tokens}")
        print(f"Average prefill tok per sec: {total_prefill_tok_per_sec}")
        print(f"Average generate tok per sec: {total_generate_tok_per_sec}")
        print(f"Average acceptance rate: {total_acceptance_rate}")
        print(f"Average draft sample count: {total_draft_sample_count}")
        print(f"Average target sample count: {total_target_sample_count}")
        print(f"Average thrown away count: {total_thrown_away_count}")
        print(f"Average accepted count: {total_accepted_count}")
        print(f"Average resampled count: {total_resampled_count}")
        print(f"Average total tokens: {total_total_tokens}")
        print(f"Average total time: {total_total_time}")
        print(f"Average total tok per sec: {total_total_tok_per_sec}")
