import sys
import time

import torch
# from modeling_llama_chunck import *
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from patch import patch_hf, minference_patch

model_name = "gradientai/Llama-3-8B-Instruct-Gradient-1048k"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     torch_dtype="auto",
#     device_map="cuda",
#     # is_loaded_in_8bit=True,
# )

config = AutoConfig.from_pretrained(model_name)
config._attn_implementation = 'eager'
config.topk = 1
config.topk_from_layer = 0
# config.topk_dims_file_path = "config/Llama_3_8B_Instruct_262k_kv_out_v32_best_pattern.json"
config.topk_dims_file_path = "config/Llama_3_8B_Instruct_262k_kv_out_v32_fit_o_best_pattern.json"
# config.topk_dims_file_path = "Llama_3_8B_Instruct_262k_kv_out_v32_fit_half3_best_pattern.json"
# config.topk_dims_file_path = "Llama_3_8B_Instruct_262k_kv_out_v32_fit_half3_v2_best_pattern.json"

config.n_init = 32
config.n_local = 480
config.topk_ratio = 0.1
config.block_size = 32

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype='auto',
    device_map='auto',
    config=config,
)
model = minference_patch(model)


def run_target_length(m: int):
    # wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt
    prompt_complex = open("./prompt_hardest.txt").read()
    input_ids = tokenizer(prompt_complex)["input_ids"]
    n = len(input_ids)
    b = m // n + 1

    new_input_ids = (input_ids * b)[:m]
    prompt = tokenizer.decode(new_input_ids)
    data = tokenizer(prompt, return_tensors="pt")
    input_ids = data["input_ids"].cuda()
    attention_mask = data["attention_mask"].cuda()
    s = 0
    T = 1
    for _ in range(T):
        torch.cuda.synchronize()
        start = time.time()
        with torch.no_grad():
            model(input_ids, attention_mask, use_cache=False)
        torch.cuda.synchronize()
        s += time.time() - start
    print(m, s / T)


if __name__ == "__main__":
    print(sys.argv)
    run_target_length(int(sys.argv[1]))
