import argparse
import json
import logging
import os
import sys
from statistics import mean
from threading import Event, Thread
from time import perf_counter, sleep
from typing import Optional
import gpustat
import psutil
import psycopg2
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StaticCache
from psycopg2.extras import Json
from psycopg2.extensions import register_adapter


os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
formatter = logging.Formatter("[%(levelname)s - %(asctime)s] %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)

os.environ["TOKENIZERS_PARALLELISM"] = "1"
torch.set_float32_matmul_precision("high")
register_adapter(dict, Json)


def parse_arguments():
    """
    Parse command line arguments for the benchmarking CLI.
    """
    parser = argparse.ArgumentParser(description="CLI for benchmarking the huggingface/transformers.")

    parser.add_argument(
        "branch",
        type=str,
        help="The branch name on which the benchmarking is performed.",
    )

    parser.add_argument(
        "commit_id",
        type=str,
        help="The commit hash on which the benchmarking is performed.",
    )

    parser.add_argument(
        "commit_msg",
        type=str,
        help="The commit message associated with the commit, truncated to 70 characters.",
    )

    args = parser.parse_args()

    return args.branch, args.commit_id, args.commit_msg


def collect_metrics(benchmark_id, continue_metric_collection):
    p = psutil.Process(os.getpid())
    conn = psycopg2.connect("dbname=metrics")
    cur = conn.cursor()
    while not continue_metric_collection.is_set():
        with p.oneshot():
            cpu_util = p.cpu_percent()
            mem_megabytes = p.memory_info().rss / (1024 * 1024)
        gpu_stats = gpustat.GPUStatCollection.new_query()
        gpu_util = gpu_stats[0]["utilization.gpu"]
        gpu_mem_megabytes = gpu_stats[0]["memory.used"]
        cur.execute(
            "INSERT INTO device_measurements (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes) VALUES (%s, %s, %s, %s, %s)",
            (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes),
        )
        sleep(0.01)
        conn.commit()
    conn.close()


def run_benchmark(branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100):
    continue_metric_collection = Event()
    metrics_thread = None
    try:
        gpu_stats = gpustat.GPUStatCollection.new_query()
        gpu_name = gpu_stats[0]["name"]
        conn = psycopg2.connect("dbname=metrics")
        cur = conn.cursor()
        cur.execute(
            "INSERT INTO benchmarks (branch, commit_id, commit_message, gpu_name) VALUES (%s, %s, %s, %s) RETURNING benchmark_id",
            (branch, commit_id, commit_msg, gpu_name),
        )
        conn.commit()
        benchmark_id = cur.fetchone()[0]
        logger.info(f"running benchmark #{benchmark_id} on {gpu_name}")
        metrics_thread = Thread(target=collect_metrics, args=[benchmark_id, continue_metric_collection])
        metrics_thread.start()
        logger.info("started background thread to fetch device metrics")

        os.environ["TOKENIZERS_PARALLELISM"] = "false"  # silence warnings when compiling

        device = "cuda"
        ckpt = "meta-llama/Llama-2-7b-hf"

        logger.info("downloading weights")
        # This is to avoid counting download in model load time measurement
        model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
        gen_config = GenerationConfig(do_sample=False, top_p=1, temperature=1)
        logger.info("loading model")
        start = perf_counter()
        model = AutoModelForCausalLM.from_pretrained(
            ckpt, torch_dtype=torch.float16, generation_config=gen_config
        ).eval()
        model.to(device)
        torch.cuda.synchronize()
        end = perf_counter()
        model_load_time = end - start
        logger.info(f"loaded model in: {model_load_time}s")

        tokenizer = AutoTokenizer.from_pretrained(ckpt)

        prompt = "Why dogs are so cute?"
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        # Specify the max length (including both the prompt and the response)
        # When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object
        # with sequence length = `max_length`. The longer the more you will re-use it
        seq_length = inputs["input_ids"].shape[1]
        model.generation_config.max_length = seq_length + num_tokens_to_generate
        batch_size = inputs["input_ids"].shape[0]

        # Copied from the gpt-fast repo
        def multinomial_sample_one_no_sync(probs_sort):  # Does multinomial sampling without a cuda synchronization
            q = torch.empty_like(probs_sort).exponential_(1)
            return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

        def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
            logits = logits / max(temperature, 1e-5)

            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                pivot = v.select(-1, -1).unsqueeze(-1)
                logits = torch.where(logits < pivot, -float("Inf"), logits)
            probs = torch.nn.functional.softmax(logits, dim=-1)
            return probs

        def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
            probs = logits_to_probs(logits[:, -1], temperature, top_k)
            idx_next = multinomial_sample_one_no_sync(probs)
            return idx_next, probs

        def decode_one_token(model, cur_token, cache_position, past_key_values):
            logits = model(
                cur_token,
                cache_position=cache_position,
                past_key_values=past_key_values,
                return_dict=False,
                use_cache=True,
            )[0]
            new_token = sample(logits, temperature=0.6, top_k=5)[0]
            return new_token

        #########
        # Eager #
        #########
        with torch.no_grad():
            past_key_values = StaticCache(
                model.config,
                batch_size=batch_size,
                device=device,
                dtype=torch.float16,
                max_cache_len=seq_length + num_tokens_to_generate,
            )
            cache_position = torch.arange(seq_length, device=device)
            start = perf_counter()
            model(
                **inputs,
                cache_position=cache_position,
                past_key_values=past_key_values,
                return_dict=False,
                use_cache=True,
            )
            end = perf_counter()
            first_eager_fwd_pass_time = end - start
            logger.info(f"completed first eager fwd pass in: {first_eager_fwd_pass_time}s")
            start = perf_counter()
            output = model.generate(**inputs, do_sample=False)
            end = perf_counter()
            first_eager_generate_time = end - start
            logger.info(f"completed first eager generation in: {first_eager_generate_time}s")
            logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")

            past_key_values = StaticCache(
                model.config,
                batch_size=batch_size,
                device=device,
                dtype=torch.float16,
                max_cache_len=seq_length + num_tokens_to_generate,
            )
            cache_position = torch.arange(seq_length, device=device)
            start = perf_counter()
            model(
                **inputs,
                cache_position=cache_position,
                past_key_values=past_key_values,
                return_dict=False,
                use_cache=True,
            )
            end = perf_counter()
            second_eager_fwd_pass_time = end - start
            logger.info(f"completed second eager fwd pass in: {second_eager_fwd_pass_time}s")
            start = perf_counter()
            model.generate(**inputs, do_sample=False)
            end = perf_counter()
            second_eager_generate_time = end - start
            logger.info(f"completed second eager generation in: {second_eager_generate_time}s")
            logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")

            torch.compiler.reset()

            ################
            # Forward pass #
            ################

            # `torch.compile(model, ...)` is not recommended as you compile callbacks
            # and full generate. We recommend compiling only the forward for now.
            # "reduce-overhead" will use cudagraphs.
            generated_ids = torch.zeros(
                (batch_size, num_tokens_to_generate + seq_length), dtype=torch.int, device=device
            )

            generated_ids[:, :seq_length] = inputs["input_ids"]
            decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
            # model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
            # TODO use  decode_one_token(model, input_id.clone(), cache_position) for verification
            past_key_values = StaticCache(
                model.config,
                batch_size=batch_size,
                device=device,
                dtype=torch.float16,
                max_cache_len=seq_length + num_tokens_to_generate + 10,
            )
            cache_position = torch.arange(seq_length, device=device)
            all_generated_tokens = []
            ### First compile, prefill
            start = perf_counter()
            next_token = decode_one_token(
                model, inputs["input_ids"], cache_position=cache_position, past_key_values=past_key_values
            )
            torch.cuda.synchronize()
            end = perf_counter()
            time_to_first_token = end - start
            logger.info(f"completed first compile generation in: {time_to_first_token}s")
            cache_position += 1
            all_generated_tokens += next_token.clone().detach().cpu().tolist()

            cache_position = torch.tensor([seq_length], device=device)
            ### First compile, decoding
            start = perf_counter()
            next_token = decode_one_token(
                model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
            )
            torch.cuda.synchronize()
            end = perf_counter()
            time_to_second_token = end - start
            logger.info(f"completed second compile generation in: {time_to_first_token}s")
            cache_position += 1
            all_generated_tokens += next_token.clone().detach().cpu().tolist()

            ### Second compile, decoding
            start = perf_counter()
            next_token = decode_one_token(
                model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
            )
            torch.cuda.synchronize()
            end = perf_counter()
            time_to_third_token = end - start
            logger.info(f"completed third compile forward in: {time_to_first_token}s")
            cache_position += 1
            all_generated_tokens += next_token.clone().detach().cpu().tolist()

            ### Using cuda graphs decoding

            start = perf_counter()
            for _ in range(1, num_tokens_to_generate):
                all_generated_tokens += next_token.clone().detach().cpu().tolist()
                next_token = decode_one_token(
                    model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
                )
                cache_position += 1
            torch.cuda.synchronize()
            end = perf_counter()
            mean_time_to_next_token = (end - start) / num_tokens_to_generate
            logger.info(f"completed next compile generation in: {mean_time_to_next_token}s")
            logger.info(f"generated: {tokenizer.batch_decode(all_generated_tokens)}")

            ####################
            # Generate compile #
            ####################
            torch.compiler.reset()
            # we will not compile full generate as it' s to intensive, tho we measure full forward!

            past_key_values = StaticCache(
                model.config,
                batch_size=batch_size,
                device=device,
                dtype=torch.float16,
                max_cache_len=seq_length + 128,
            )

            # 1st call
            start = perf_counter()
            output = model.generate(**inputs, past_key_values=past_key_values)
            torch.cuda.synchronize()
            end = perf_counter()
            first_compile_generate_time = end - start
            logger.info(f"completed first compile generation in: {first_compile_generate_time}s")
            logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")

            past_key_values = StaticCache(
                model.config,
                batch_size=batch_size,
                device=device,
                dtype=torch.float16,
                max_cache_len=seq_length + 128,
            )
            # 2nd call
            start = perf_counter()
            output = model.generate(**inputs, past_key_values=past_key_values)
            torch.cuda.synchronize()
            end = perf_counter()
            second_compile_generate_time = end - start
            logger.info(f"completed second compile generation in: {second_compile_generate_time}s")
            logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")

            past_key_values = StaticCache(
                model.config,
                batch_size=batch_size,
                device=device,
                dtype=torch.float16,
                max_cache_len=seq_length + 128,
            )

            # 3nd call
            start = perf_counter()
            output = model.generate(**inputs, past_key_values=past_key_values)
            end = perf_counter()
            third_compile_generate_time = end - start
            logger.info(f"completed second compile generation in: {third_compile_generate_time}s")
            logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")

            past_key_values = StaticCache(
                model.config,
                batch_size=batch_size,
                device=device,
                dtype=torch.float16,
                max_cache_len=seq_length + 128,
            )
            # 4th call
            start = perf_counter()
            output = model.generate(**inputs, past_key_values=past_key_values)
            end = perf_counter()
            fourth_compile_generate_time = end - start
            logger.info(f"completed second compile generation in: {fourth_compile_generate_time}s")
            logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")

        cur.execute(
            """
            INSERT INTO model_measurements (
                benchmark_id,
                measurements
            ) VALUES (%s, %s)
            """,
            (
                benchmark_id,
                {
                    "model_load_time": model_load_time,
                    "first_eager_forward_pass_time_secs": first_eager_fwd_pass_time,
                    "second_eager_forward_pass_time_secs": second_eager_fwd_pass_time,
                    "first_eager_generate_time_secs": first_eager_generate_time,
                    "second_eager_generate_time_secs": second_eager_generate_time,
                    "time_to_first_token_secs": time_to_first_token,
                    "time_to_second_token_secs": time_to_second_token,
                    "time_to_third_token_secs": time_to_third_token,
                    "time_to_next_token_mean_secs": mean_time_to_next_token,
                    "first_compile_generate_time_secs": first_compile_generate_time,
                    "second_compile_generate_time_secs": second_compile_generate_time,
                    "third_compile_generate_time_secs": third_compile_generate_time,
                    "fourth_compile_generate_time_secs": fourth_compile_generate_time,
                },
            ),
        )
        conn.commit()
        conn.close()
    except Exception as e:
        logger.error(f"Caught exception: {e}")
    continue_metric_collection.set()
    if metrics_thread is not None:
        metrics_thread.join()


if __name__ == "__main__":
    branch, commit_id, commit_msg = parse_arguments()
    run_benchmark(branch, commit_id, commit_msg, num_tokens_to_generate=20)
