"""
An attempt with speculative decoding with kv-cache enabled
"""

import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from tqdm import tqdm
import time
import sys

device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda:1" if torch.cuda.is_available() else "cpu"
)

dtype = torch.float32

max_length = 150
N = max_length  # tokens to generate
K = 8  # time interval to validate
temperature = 0.0
eps = 1e-10


def touch():
    """Synchronization primitives for the respective backends when timing routines"""
    if device == "mps":
        torch.mps.synchronize()
    elif device == "cuda":
        torch.cuda.synchronize()


def max_fn(x):
    x_max = torch.where(x > 0, x, 0)
    return x_max / torch.sum(x_max)


def sample(p):
    # return torch.multinomial(p, 1)
    return torch.argmax(p).unsqueeze(-1)


def autoregressive_sampling(input_ids, model, tokenizer, N):
    n = input_ids.shape[1]
    T = input_ids.shape[1] + N
    kv_cache = None

    with tqdm(total=N, desc="autoregressive sampling") as pbar:
        while n < T:
            if kv_cache:
                outputs = model(
                    input_ids[:, -1], use_cache=True, past_key_values=kv_cache
                )
                logits = outputs.logits
            else:
                outputs = model(input_ids, use_cache=True)
                logits = outputs.logits[:, -1]
            kv_cache = outputs.past_key_values
            logits = logits / (temperature + eps)
            p = torch.nn.functional.softmax(logits, dim=-1)
            next_token_id = sample(p).unsqueeze(-1)
            input_ids = torch.cat((input_ids, next_token_id), dim=-1)
            n += 1
            pbar.update(1)

    return input_ids


def rollback_kv_cache(kv_cache, n):
    """
    During speculative decoding, rollback the KV-cache back n steps

    kv_cache (tuple(tuple(torch.FloatTensor)),
    optional, returned when use_cache=True is passed or when config.use_cache=True) —
    Tuple of tuple(torch.FloatTensor) of length config.n_layers, with each tuple
    having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))
    and optionally if config.is_encoder_decoder=True 2 additional tensors of shape
    (batch_size, num_heads, encoder_sequence_length, embed_size_per_head).
    """
    if kv_cache is None:
        return None

    rolled_back_kv_cache = []
    for layer_kv_cache in kv_cache:
        rolled_back_layer_kv_cache = []
        # shoudl be just key, value
        key, value = layer_kv_cache
        # Rollback the tensor by removing the last n steps
        # print("BEFORE", key.shape, n)
        key = key[:, :, :-n, :]
        value = value[:, :, :-n, :]
        kv_rollback = (key, value)
        rolled_back_layer_kv_cache.append(kv_rollback)

    return tuple(rolled_back_kv_cache)


def speculative_sampling(input_ids, draft_model, target_model, tokenizer, N):
    n = input_ids.shape[1]
    T = input_ids.shape[1] + N
    draft_kv_cache = None
    target_kv_cache = None
    x = input_ids

    rejected_tokens = 0
    accepted_tokens = 0

    with tqdm(total=N, desc="speculative sampling") as pbar:
        while n < T:
            prev_n = n

            # Step 1: auto-regressive decode K tokens from draft model and get final p
            x_draft = x

            for _ in range(K):
                if draft_kv_cache:
                    p = draft_model(
                        x_draft[:, -1], use_cache=True, past_key_values=draft_kv_cache
                    )
                    p_logits = torch.cat((p_logits, p.logits), dim=0)
                else:
                    p = draft_model(x_draft, use_cache=True)
                    p_logits = p.logits[:, -1]

                draft_kv_cache = p.past_key_values
                p_logits[-1:, :] = p_logits[-1:, :] / (temperature + eps)
                # get the last slice of logits
                p = torch.nn.functional.softmax(p_logits[-1:, :], dim=-1)
                next_token_id = sample(p).unsqueeze(-1)
                x_draft = torch.cat((x_draft, next_token_id), dim=-1)

            # redo softmax TODO; fix it to save
            p = torch.nn.functional.softmax(p_logits, dim=1)

            # Step 2: target
            if target_kv_cache:
                q = target_model(
                    x_draft[:, -(K + 1) :],
                    use_cache=True,
                    past_key_values=target_kv_cache,
                )
                q_logits = q.logits.squeeze(0)
            else:
                q = target_model(x_draft, use_cache=True)
                q_logits = q.logits[:, -(K + 1) :, :].squeeze(0)
            target_kv_cache = q.past_key_values
            q_logits = q_logits / (temperature + eps)
            q = torch.nn.functional.softmax(q_logits, dim=-1)

            # Step 3: append draft tokens based on rejection criterion and resample
            # a token on rejection
            all_accepted = True
            for i_s in range(K):
                i = n - 1
                j = x_draft[0][i + 1]
                rand = torch.rand(1).to(device) / 10
                if rand < min(1, q[i_s][j] / p[i_s][j]):  # accepted
                    x = torch.cat((x, j.unsqueeze(-1).unsqueeze(-1)), dim=-1)
                    n += 1
                    pbar.update(1)
                    accepted_tokens += 1
                    # print(f"accepted {i_s} in {n}th position")
                else:  # rejected
                    resampled = sample(max_fn(q[i_s] - p[i_s])).unsqueeze(-1)
                    x = torch.cat((x, resampled), dim=-1)  # resample
                    n += 1
                    pbar.update(1)
                    all_accepted = False
                    # print(f"rejected {i_s} in {n}th position")
                    rejected_tokens += 1
                    # rollback the kv_cache
                    draft_kv_cache = rollback_kv_cache(draft_kv_cache, K - i_s)
                    target_kv_cache = rollback_kv_cache(target_kv_cache, K - i_s + 1)
                    break

            # Step 4: if all draft tokens were accepted, sample a final token
            if all_accepted:
                x = torch.cat((x, sample(q[-1]).unsqueeze(-1)), dim=-1)
                n += 1
                accepted_tokens += 1
                pbar.update(1)

    return x, rejected_tokens, accepted_tokens


def load_models(tokenizer_path, draft_model_path, target_model_path):
    # Load the tokenizer and model
    tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path, device_map=device)

    draft_model = GPT2LMHeadModel.from_pretrained(
        draft_model_path, device_map=device, torch_dtype=dtype, use_cache=True
    )
    target_model = GPT2LMHeadModel.from_pretrained(
        target_model_path, device_map=device, torch_dtype=dtype, use_cache=True
    )

    draft_model.config.output_hidden_states = True
    target_model.config.output_hidden_states = True

    draft_model.eval()
    target_model.eval()

    # Move models to GPU
    draft_model = draft_model.to(device)
    target_model = target_model.to(device)

    return tokenizer, draft_model, target_model


def main(
    prompt="Alan Turing theorized that computers would one day become",
    # prompt="Question: What is the meaning of life?\nAnswer:",
):
    n_tokens_to_generate = N
    A = torch.randn(1024, 1024, device=device, dtype=dtype)
    B = torch.randn(1024, 1024, device=device, dtype=dtype)
    touch()
    for _ in range(10):
        C = A @ B
        del C
    touch()

    tokenizer_path = "gpt2"
    draft_model_path = "gpt2"
    target_model_path = "gpt2-xl"

    tokenizer, draft_model, target_model = load_models(
        tokenizer_path, draft_model_path, target_model_path
    )
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Autoregressive sampling
    touch()
    start = time.perf_counter()
    output_ids = autoregressive_sampling(
        input_ids, draft_model, tokenizer, n_tokens_to_generate
    )

    touch()
    autoregressive_time = time.perf_counter() - start
    autoregressive_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Print results
    print("\nAutoregressive Decode with Draft Model")
    print("---------------------")
    print(f"Time = {autoregressive_time:.2f}s")
    print(f"Text = {autoregressive_text}")
    print(f"Tok/s = {output_ids.shape[1] / autoregressive_time:.2f}")

    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Autoregressive sampling
    touch()
    start = time.perf_counter()
    output_ids = autoregressive_sampling(
        input_ids, target_model, tokenizer, n_tokens_to_generate
    )
    touch()
    autoregressive_time = time.perf_counter() - start
    autoregressive_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Print results
    print("\nAutoregressive Decode with Target Model")
    print("---------------------")
    print(f"Time = {autoregressive_time:.2f}s")
    print(f"Text = {autoregressive_text}")
    print(f"Tok/s = {output_ids.shape[1] / autoregressive_time:.2f}")

    touch()
    start = time.perf_counter()
    output_ids, rejected, accepted = speculative_sampling(
        input_ids, draft_model, target_model, tokenizer, n_tokens_to_generate
    )
    touch()
    speculative_time = time.perf_counter() - start
    speculative_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    print("\nSpeculative Decode")
    print("---------------------")
    print(f"Time = {speculative_time:.2f}s")
    print(f"Text = {speculative_text}")
    print(f"Tok/s = {output_ids.shape[1] / speculative_time:.2f}")
    print(f"Rejected = {rejected}, Accepted = {accepted}")


if __name__ == "__main__":
    # turn off autograd
    with torch.no_grad():
        main()
