"""Main script for running inference with speculative decoding"""

import time
import argparse
import torch

from transformers import (
    LlamaTokenizerFast,
    LlamaTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM,
)

from src.utils import device, touch, load_models, dtype, warmup, torch_timer
from src.inference import inference
from src.eda import eda
from src.benchmark import benchmark

# pylint: disable=invalid-name


# Some CUDA specific knobs for performance
if torch.cuda.is_available():
    # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
    # in PyTorch 1.12 and later.
    # torch.backends.cuda.matmul.allow_tf32 = True

    # The flag below controls whether to allow TF32 on cuDNN. This flag
    # defaults to True.
    # torch.backends.cudnn.allow_tf32 = True

    # in the case fp16 is used, accumulate in fp32 for better outputs
    # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
    torch.backends.cudnn.benchmark = True


def main(
    draft_model_path,
    target_model_path,
    prompt,
    temperature,
    gamma,
    benchmark,
    n_tokens_to_generate=250,
    use_fa=False,
    use_4bit=False,
    use_8bit=False,
):
    """Main function"""
    tokenizer, draft_model, target_model = load_models(
        draft_model_path,
        target_model_path,
        use_fa=use_fa,
        use_4bit=use_4bit,
        use_8bit=use_8bit,
    )

    # generate via transformer generate()
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    # start = torch_timer()
    # output_sequences = draft_model.generate(
    #     input_ids,
    #     max_length=n_tokens_to_generate,
    #     temperature=temperature,
    #     num_return_sequences=1,
    #     pad_token_id=tokenizer.eos_token_id,
    #     do_sample=False,
    # )
    # end = torch_timer()
    # print(tokenizer.decode(output_sequences[0], skip_special_tokens=True))
    # print(f"Time taken: {end - start:.2f}s")
    # print(f"Tokens per second: {n_tokens_to_generate / (end - start):.2f}")

    # start = torch_timer()
    # output_sequences = target_model.generate(
    #     input_ids,
    #     max_length=n_tokens_to_generate,
    #     temperature=temperature,
    #     num_return_sequences=1,
    #     pad_token_id=tokenizer.eos_token_id,
    #     do_sample=False,
    # )
    # end = torch_timer()
    # print(tokenizer.decode(output_sequences[0], skip_special_tokens=True))
    # print(f"Time taken: {end - start:.2f}s")
    # print(f"Tokens per second: {n_tokens_to_generate / (end - start):.2f}")

    # start = torch_timer()
    # output_sequences = target_model.generate(
    #     input_ids,
    #     max_length=n_tokens_to_generate,
    #     temperature=temperature,
    #     num_return_sequences=1,
    #     pad_token_id=tokenizer.eos_token_id,
    #     do_sample=False,
    #     assistant_model=draft_model,
    # )
    # end = torch_timer()
    # print(tokenizer.decode(output_sequences[0], skip_special_tokens=True))
    # print(f"Time taken: {end - start:.2f}s")
    # print(f"Tokens per second: {n_tokens_to_generate / (end - start):.2f}")

    # print("OUR IMP")

    warmup()
    inference(
        input_ids,
        tokenizer,
        n_tokens_to_generate,
        temperature,
        model=draft_model,
        benchmark=benchmark,
        mode="autoregressive",
    )

    warmup()
    inference(
        input_ids,
        tokenizer,
        n_tokens_to_generate,
        temperature,
        model=target_model,
        benchmark=benchmark,
        mode="autoregressive",
    )

    warmup()
    inference(
        input_ids,
        tokenizer,
        n_tokens_to_generate,
        temperature,
        draft_model=draft_model,
        target_model=target_model,
        gamma=gamma,
        record_speculation=True,
        mode="speculative",
    )

    warmup()
    inference(
        input_ids,
        tokenizer,
        n_tokens_to_generate,
        temperature,
        draft_model=draft_model,
        target_model=target_model,
        max_gamma=gamma,
        mode="dynamic_speculative",
        record_speculation=True,
    )

    warmup()
    inference(
        input_ids,
        tokenizer,
        n_tokens_to_generate,
        temperature,
        draft_model=draft_model,
        target_model=target_model,
        max_gamma=gamma,
        mode="dynamic_speculative_history",
        record_speculation=True,
    )

    warmup()
    inference(
        input_ids,
        tokenizer,
        n_tokens_to_generate,
        temperature,
        draft_model=draft_model,
        target_model=target_model,
        max_gamma=gamma,
        mode="perceptron_predictor",
        record_speculation=True,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--target-model",
        type=str,
        # default="meta-llama/Llama-2-13b-chat-hf",
        # default="codellama/CodeLlama-7b-Python-hf",
        default="facebook/opt-13b",
    )
    parser.add_argument(
        "--draft-model",
        type=str,
        default="facebook/opt-125m",
        # default="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
        # default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        # default="meta-llama/Llama-2-13b-chat-hf",
        # default="meta-llama/Llama-2-7b-chat-hf",
        # default="codellama/CodeLlama-7b-Python-hf"
        # default="JackFram/llama-160m",
        # default="JackFram/llama-68m",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        # default="Finish this code:\ndef fib(n):",
        # default="Write optimize code for SGEMM using Intel Intrinsics AVX-512:\n",
        # default='<|system|>\nYou write 2 sentence summaries of new articles. Do not write any more. Keep it brief and to the point, no yapping.\n<|user|>\nSummarize this article: The full cost of damage in Newton Stewart, one of the areas worst affected, is still being assessed. Repair work is ongoing in Hawick and many roads in Peeblesshire remain badly affected by standing water. Trains on the west coast mainline face disruption due to damage at the Lamington Viaduct. Many businesses and householders were affected by flooding in Newton Stewart after the River Cree overflowed into the town. First Minister Nicola Sturgeon visited the area to inspect the damage. The waters breached a retaining wall, flooding many commercial properties on Victoria Street - the main shopping thoroughfare. Jeanette Tate, who owns the Cinnamon Cafe which was badly affected, said she could not fault the multi-agency response once the flood hit. However, she said more preventative work could have been carried out to ensure the retaining wall did not fail. "It is difficult but I do think there is so much publicity for Dumfries and the Nith - and I totally appreciate that - but it is almost like we\'re neglected or forgotten," she said. "That may not be true but it is perhaps my perspective over the last few days. "Why were you not ready to help us a bit more when the warning and the alarm alerts had gone out?" Meanwhile, a flood alert remains in place across the Borders because of the constant rain. Peebles was badly hit by problems, sparking calls to introduce more defences in the area. Scottish Borders Council has put a list on its website of the roads worst affected and drivers have been urged not to ignore closure signs. The Labour Party\'s deputy Scottish leader Alex Rowley was in Hawick on Monday to see the situation first hand. He said it was important to get the flood protection plan right but backed calls to speed up the process. "I was quite taken aback by the amount of damage that has been done," he said. "Obviously it is heart-breaking for people who have been forced out of their homes and the impact on businesses." He said it was important that "immediate steps" were taken to protect the areas most vulnerable and a clear timetable put in place for flood prevention plans. Have you been affected by flooding in Dumfries and Galloway or the Borders? Tell us about your experience of the situation and how it was handled. Email us on selkirk.news@bbc.co.uk or dumfries@bbc.co.uk.\n<|assistant|>\n',
        default="<|system|>\nYou are given a math question, and your task is to answer it. Then provide a step-by-step walkthrough on how you got the answer to the question.\n<|user|>\nJames writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?\n<|assistant|>\n",
        # default="<|system|>\nYou are a finance expert. Answer the following questions to the best of your knowledge, and yap as much as possible.\n<|user|>\nWhy does it matter if a Central Bank has a negative rather than 0% interest rate?\n<|assistant|>\n",
        # default="A chat between a finance expert and customer. He answers the following questions to the best of his knowledge, and yaps as much as possible.\n\nCustomer: Why does it matter if a Central Bank has a negative rather than 0% interest rate?\nFinance expert:",
        # default="Alan Turing once theorized that ",
        #         default="""<bos><start_of_turn>user
        # You are given a math question, and your task is to answer it. Then provide a step-by-step walkthrough on how you got the answer to the question.
        # Kylar went to the store to buy glasses for his new apartment. One glass costs $5, but every second glass costs only 60% of the price. Kylar wants to buy 16 glasses. How much does he need to pay for them?<end_of_turn>
        # <start_of_turn>model""",
    )

    parser.add_argument("--temperature", "-t", type=float, default=0.0)
    parser.add_argument("--gamma", "-g", type=int, default=4)
    parser.add_argument("--n-tokens-to-generate", "-N", type=int, default=500)
    parser.add_argument("--dataset", type=str, default="openai_humaneval")
    parser.add_argument("--mode", type=str, default="sps")
    parser.add_argument("--flash-attention", "-fa", default=False, action="store_true")
    parser.add_argument("--use-4bitq", "-q4", default=False, action="store_true")
    parser.add_argument("--use-8bitq", "-q8", default=False, action="store_true")

    # modes for project
    parser.add_argument("--eda", default=False, action="store_true")
    parser.add_argument("--benchmark", "-b", default=False, action="store_true")
    args = parser.parse_args()

    if args.eda:
        eda()

    if args.benchmark:
        datasets = (
            [
                "openai_humaneval",
                "xsum",
                "gsm8k",
                "finance-alpaca",
            ]
            if args.dataset == "all"
            else [args.dataset]
        )
        modes = (
            ["ars", "usps", "sps", "dsps", "dhsps"]
            if args.mode == "all"
            else [args.mode]
        )

        for dataset in datasets:
            for mode in modes:
                with torch.no_grad():  # Disable autograd to reduce memory usage
                    benchmark(
                        args.draft_model,
                        args.target_model,
                        dataset,
                        gamma=args.gamma,
                        temperature=args.temperature,
                        n_tokens_to_generate=args.n_tokens_to_generate,
                        mode=mode,
                        use_fa=args.flash_attention,
                        use_4bit=args.use_4bitq,
                        use_8bit=args.use_8bitq,
                    )
    else:
        # This is the default mode for debugging
        # turn off autograd
        with torch.no_grad():
            main(
                args.draft_model,
                args.target_model,
                args.prompt,
                args.temperature,
                args.gamma,
                args.benchmark,
                n_tokens_to_generate=args.n_tokens_to_generate,
                use_fa=args.flash_attention,
                use_4bit=args.use_4bitq,
                use_8bit=args.use_8bitq,
            )
