import ray
import torch
import time
import random
import numpy as np
from transformers import (
    LlamaTokenizerFast,
    LlamaTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM,
)

### GLOBAL VARIABLES ###

device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available() else "cpu"
)

dtype = torch.bfloat16

max_length = 150
N = max_length  # tokens to generate
K = 8  # time interval to validate
temperature = 2.0
eps = 1e-10

### GLOBAL VARIABLES ###


def touch():
    if device == "mps":
        torch.mps.synchronize()
    elif device == "cuda":
        torch.cuda.synchronize()


def sample(p):
    # return torch.multinomial(p, num_samples=1)
    # sample the largest probability
    return torch.argmax(p).unsqueeze(-1)


def max_fn(x):
    x_max = torch.where(x > 0, x, 0)
    return x_max / torch.sum(x_max)


def load_models(tokenizer_path, draft_model_path, target_model_path, use_fast=False):
    # Load the tokenizer and model
    if use_fast:
        tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_path)
    else:
        tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)

    draft_model = AutoModelForCausalLM.from_pretrained(
        draft_model_path, torch_dtype=dtype
    )
    target_model = AutoModelForCausalLM.from_pretrained(
        target_model_path, torch_dtype=dtype
    )

    draft_model.config.output_hidden_states = True
    target_model.config.output_hidden_states = True

    draft_model.eval()
    target_model.eval()

    # Move models to GPU
    draft_model = draft_model.to(device)
    target_model = target_model.to(device)

    return tokenizer, draft_model, target_model


def main():
    """
    hf-internal-testing/llama-tokenizer
    meta-llama/Llama-2-7b-hf
    meta-llama/Llama-2-7b-chat-hf
    TinyLlama/TinyLlama-1.1B-step-50K-105b
    openlm-research/open_llama_3b
    openlm-research/open_llama_3b_v2
    openlm-research/open_llama_13b
    """
    # tokenizer_path = "meta-llama/Llama-2-7b-hf"
    # draft_model_path = "meta-llama/Llama-2-7b-chat-hf"
    # target_model_path = "meta-llama/Llama-2-13b-chat-hf"

    # tokenizer_path = "openlm-research/open_llama_3b"
    # draft_model_path = "openlm-research/open_llama_3b"
    # target_model_path = "openlm-research/open_llama_13b"

    tokenizer_path = "hf-internal-testing/llama-tokenizer"
    draft_model_path = "TinyLlama/TinyLlama-1.1B-step-50K-105b"
    target_model_path = "meta-llama/Llama-2-7b-hf"

    tokenizer, draft_model, target_model = load_models(
        tokenizer_path, draft_model_path, target_model_path
    )

    # Prepare input text
    # input_text = "Question: Hello LLaMA. What can you do?\nAnswer:"
    # input_text = "The capital of France is "
    input_text = "Alan Turing theorized that computers would"
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    ### WARMUP ###
    for _ in range(5):
        draft_model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=1,
            use_cache=False,
            temperature=temperature,
            top_k=1,
            top_p=1.0,
            do_sample=False,
        )
    ### WARMUP ###

    with torch.no_grad():
        ### VANILLA RESPONSE GENERATION ###
        # Generate a response
        touch()
        begin = time.time()
        output = draft_model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=1,
            use_cache=False,
            temperature=temperature,
            top_k=1,
            top_p=1.0,
            do_sample=False,
        )
        touch()
        end = time.time()
        print("#######  DRAFT MODEL OUTPUT #########")
        print(tokenizer.decode(output[0], skip_special_tokens=True))
        print("#######  DRAFT MODEL OUTPUT #########")
        print("Time taken: ", end - begin)
        print("Tok/s: ", len(input_text) / (end - begin))  # Decode and print the output
        print()

        touch()
        begin = time.time()
        output = target_model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=1,
            use_cache=False,
            temperature=temperature,
            top_k=1,
            top_p=1.0,
            do_sample=False,
        )
        touch()
        end = time.time()

        # Decode and print the output
        print("#######  TARGET MODEL OUTPUT #########")
        print(tokenizer.decode(output[0], skip_special_tokens=True))
        print("#######  TARGET MODEL OUTPUT #########")
        print("Time taken: ", end - begin)
        print("Tok/s: ", len(input_text) / (end - begin))
        print()
        ### VANILLA RESPONSE GENERATION ###

    ### BEGIN SPECULATIVE DECODING HERE ###
    n = len(input_ids[0])
    T = len(input_ids[0]) + N
    x = input_ids

    with torch.no_grad():
        touch()
        begin = time.time()
        # # Forward pass through the model to fill the cache
        draft_model_output = draft_model(input_ids=input_ids, use_cache=False)

        target_model_output = target_model(input_ids=input_ids, use_cache=False)

    x = x[0]

    with torch.no_grad():
        while n < T:
            prev_n = n

            # Step 1: auto-regressive decode K tokens from draft model and get final p
            x_draft = x
            for _ in range(K):
                p = draft_model(
                    input_ids=x_draft.unsqueeze(0), use_cache=False
                ).logits.squeeze()
                p = p / (temperature + eps)  # eps to avoid division by zero
                # call softmax
                p = torch.softmax(p, dim=-1)

                x_draft = torch.cat((x_draft, sample(p[-1])), dim=-1)

            # Step 2: target
            q = target_model(
                input_ids=x_draft.unsqueeze(0), use_cache=False
            ).logits.squeeze()
            q = q / (temperature + eps)  # eps to avoid division by zero
            q = torch.softmax(q, dim=-1)
            # Step 3: append draft tokens based on rejection criterion and resample
            # a token on rejection
            all_accepted = True
            for _ in range(K):
                i = n - 1
                j = x_draft[i + 1]
                if np.random.random() < min(1, q[i][j] / p[i][j]):  # accepted
                    x = torch.cat((x, j.unsqueeze(-1)), dim=-1)
                    n += 1
                else:  # rejected
                    x = torch.cat((x, sample(max_fn(q[i] - p[i]))), dim=-1)  # resample
                    n += 1
                    all_accepted = False
                    break

            # Step 4: if all draft tokens were accepted, sample a final token
            if all_accepted:
                x = torch.cat((x, sample(q[-1])), dim=-1)
                n += 1

    touch()
    end = time.time()
    generated_text = tokenizer.decode(x, skip_special_tokens=True)
    print("#######  SPECULATIVE DECODING MODEL OUTPUT #########")
    print(generated_text)
    print("#######  SPECULATIVE DECODING MODEL OUTPUT #########")
    print("Time taken: ", end - begin)
    print("Tok/s: ", len(input_text) / (end - begin))


if __name__ == "__main__":
    main()
