"""
This is GPT-2 with KV-cache on
It's just the autoregressive (vanilla) sampling method
Originally meant for me to learn how HuggingFace works and how I can manyally enable KV-cache
to prepare for speculative decoding implementation.
"""

import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from tqdm import tqdm
import time
import sys

device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda:1" if torch.cuda.is_available() else "cpu"
)

dtype = torch.float32

max_length = 150
N = max_length  # tokens to generate
temperature = 0.0  # set 0.0 to be deterministic/greedy sampling
eps = 1e-10


def touch():
    """Synchronization primitives for the respective backends when timing routines"""
    if device == "mps":
        torch.mps.synchronize()
    elif device == "cuda":
        torch.cuda.synchronize()


def max_fn(x):
    x_max = torch.where(x > 0, x, 0)
    return x_max / torch.sum(x_max)


def sample(p):
    # set seed
    # return torch.multinomial(p, 1)
    return torch.argmax(p).unsqueeze(-1)


def autoregressive_sampling(input_ids, model, tokenizer, N):
    n = input_ids.shape[1]
    T = input_ids.shape[1] + N

    with tqdm(total=N, desc="autoregressive sampling") as pbar:
        while n < T:
            outputs = model(input_ids, use_cache=False)
            logits = outputs.logits
            logits = logits / (temperature + eps)
            p = torch.nn.functional.softmax(logits[0, -1], dim=-1)
            next_token_id = sample(p).unsqueeze(-1)
            input_ids = torch.cat((input_ids, next_token_id), dim=-1)
            n += 1
            pbar.update(1)

    return input_ids


def autoregressive_sampling_kvcache(input_ids, model, tokenizer, N):
    n = input_ids.shape[1]
    T = input_ids.shape[1] + N
    kv_cache = None

    with tqdm(total=N, desc="autoregressive sampling w/ kvcache") as pbar:
        while n < T:
            if kv_cache:
                outputs = model(
                    input_ids[:, -1], use_cache=True, past_key_values=kv_cache
                )
                logits = outputs.logits
            else:
                outputs = model(input_ids, use_cache=True)
                logits = outputs.logits[:, -1]
            kv_cache = outputs.past_key_values
            logits = logits / (temperature + eps)
            p = torch.nn.functional.softmax(logits, dim=-1)

            next_token_id = sample(p)
            if next_token_id.dim() == 1:
                next_token_id = sample(p).unsqueeze(-1)

            input_ids = torch.cat((input_ids, next_token_id), dim=-1)
            n += 1
            pbar.update(1)

    return input_ids


def create_model(model_name):
    draft_model = GPT2LMHeadModel.from_pretrained(
        model_name, device_map=device, torch_dtype=dtype, use_cache=False
    )
    target_model = GPT2LMHeadModel.from_pretrained(
        model_name, device_map=device, torch_dtype=dtype, use_cache=True
    )
    tokenizer = GPT2Tokenizer.from_pretrained(
        model_name, torch_dtype=dtype, device_map=device
    )

    return draft_model, target_model, tokenizer


def main(
    prompt="Alan Turing theorized that computers would one day become",
    model_name="gpt2-xl",
):
    n_tokens_to_generate = N
    A = torch.randn(1024, 1024, device=device, dtype=dtype)
    B = torch.randn(1024, 1024, device=device, dtype=dtype)
    touch()
    for _ in range(10):
        C = A @ B
        del C
    touch()

    draft_model, target_model, tokenizer = create_model(model_name)
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Autoregressive sampling
    touch()
    start = time.perf_counter()
    output_ids = autoregressive_sampling(
        input_ids, draft_model, tokenizer, n_tokens_to_generate
    )

    touch()
    autoregressive_time = time.perf_counter() - start
    autoregressive_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Print results
    print("\nAutoregressive Decode")
    print("---------------------")
    print(f"Time = {autoregressive_time:.2f}s")
    print(f"Text = {autoregressive_text}")
    print(f"Tok/s = {output_ids.shape[1] / autoregressive_time:.2f}")

    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Autoregressive sampling with KV-cache enabled huggingface style
    touch()
    start = time.perf_counter()
    output_ids = autoregressive_sampling_kvcache(
        input_ids, target_model, tokenizer, n_tokens_to_generate
    )
    touch()
    autoregressive_time = time.perf_counter() - start
    autoregressive_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Print results
    print("\nAutoregressive Decode with KV-cache")
    print("---------------------")
    print(f"Time = {autoregressive_time:.2f}s")
    print(f"Text = {autoregressive_text}")
    print(f"Tok/s = {output_ids.shape[1] / autoregressive_time:.2f}")

    # huggingface default output
    model = GPT2LMHeadModel.from_pretrained(
        model_name, device_map=device, torch_dtype=dtype, use_cache=True
    )
    touch()
    start = time.perf_counter()
    output = model.generate(
        input_ids,
        max_length=max_length + input_ids.shape[1],
        num_return_sequences=1,
        use_cache=True,
        temperature=temperature,
        do_sample=False,
    )
    huggingface_time = time.perf_counter() - start
    huggingface_text = tokenizer.decode(output[0], skip_special_tokens=True)

    print("\nHuggingface Default Output")
    print("--------------------------")
    print(f"Time = {huggingface_time:.2f}s")
    print(f"Text = {huggingface_text}")
    print(f"Tok/s = {output[0].shape[0] / huggingface_time:.2f}")


if __name__ == "__main__":
    # turn off autograd
    with torch.no_grad():
        main()
