"""An attempt to run speculative decoding with kv-cache enabled"""

import torch
import ray
import torch
import time
import random
import numpy as np
from transformers import (
    LlamaTokenizerFast,
    LlamaTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM,
)

### GLOBAL VARIABLES ###

device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available() else "cpu"
)

# NVIDIA GPUs Ampere and after support bf16 (better precision and throughput)
# M2 Apple Silicon and after also support bf16 (CPU and GPU) --Not supported on mine
# Try to avoid fp32 unless hardware does not support fp16
if device == "cuda":
    dtype = torch.bfloat16
elif device == "mps":
    dtype = torch.bfloat16
else:
    # default to FP32 on CPU, because PyTorch doesn't support HGEMM on CPU
    dtype = torch.float32


# Some CUDA specific knobs for performance
if torch.cuda.is_available():
    # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
    # in PyTorch 1.12 and later.
    torch.backends.cuda.matmul.allow_tf32 = True

    # The flag below controls whether to allow TF32 on cuDNN. This flag
    # defaults to True.
    torch.backends.cudnn.allow_tf32 = True

    # in the case fp16 is used, accumulate in fp32 for better outputs
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False


max_length = 150
N = max_length  # tokens to generate
K = 8  # time interval to validate
temperature = 0.9  # Set temperature to 0 if you want greedy sampling/decoding, (i.e. deterministic outputs)
eps = 1e-10

### GLOBAL VARIABLES ###


def touch():
    """Synchronization primitives for the respective backends when timing routines"""
    if device == "mps":
        torch.mps.synchronize()
    elif device == "cuda":
        torch.cuda.synchronize()


def sample(p):
    """sample the largest probability"""
    return torch.argmax(p).unsqueeze(-1)


def max_fn(x):
    x_max = torch.where(x > 0, x, 0)
    return x_max / torch.sum(x_max)


def main():
    """
    hf-internal-testing/llama-tokenizer
    meta-llama/Llama-2-7b-hf
    meta-llama/Llama-2-7b-chat-hf
    TinyLlama/TinyLlama-1.1B-step-50K-105b
    openlm-research/open_llama_3b
    openlm-research/open_llama_3b_v2
    openlm-research/open_llama_13b
    """
    # tokenizer_path = "meta-llama/Llama-2-7b-hf"
    # draft_model_path = "meta-llama/Llama-2-7b-chat-hf"
    # target_model_path = "meta-llama/Llama-2-13b-chat-hf"

    # tokenizer_path = "openlm-research/open_llama_3b"
    # draft_model_path = "openlm-research/open_llama_3b"
    # target_model_path = "openlm-research/open_llama_13b"

    tokenizer_path = "hf-internal-testing/llama-tokenizer"
    # draft_model_path = "TinyLlama/TinyLlama-1.1B-step-50K-105b"
    model_path = "meta-llama/Llama-2-7b-hf"

    tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_path)
    # tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)

    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype)

    model.config.output_hidden_states = True
    model.config.use_cache = True

    model.eval()

    # Move models to GPU
    model = model.to(device)

    # Prepare input text
    # input_text = "Question: Hello LLaMA. What can you do?\nAnswer:"
    input_text = "The capital of France is "
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    ### VANILLA RESPONSE GENERATION ###
    # Generate a response
    touch()
    begin = time.time()

    with torch.no_grad():
        # Forward pass through the model to fill the cache
        model_output = model(input_ids=input_ids, use_cache=True)
        kv_cache = model_output.past_key_values

    next_token = input_ids[:, -1]

    with torch.no_grad():
        for _ in range(max_length - len(input_ids[0])):
            # Generate the next token using the cache
            model_output = model(
                input_ids=next_token.unsqueeze(0),
                past_key_values=kv_cache,
                use_cache=True,
            )
            next_token_logits = model_output.logits[:, -1, :]
            # apply softmax to convert to probabilities
            next_token_logits = next_token_logits / (temperature + eps)
            next_token_logits = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.argmax(next_token_logits, dim=-1)

            # Append the new token to the input sequence
            input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)

            # Update the cache
            kv_cache = model_output.past_key_values

            if next_token.item() == tokenizer.eos_token_id:
                break
    touch()
    end = time.time()
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    print(generated_text)
    print("HF INFERENCE: Time taken: ", end - begin)
    print("HF INFERENCE: Tok/s: ", len(input_text) / (end - begin))


if __name__ == "__main__":
    main()
