import torch
import os
import argparse
import numpy as np
import random
from transformers import AutoTokenizer,  AutoModelForCausalLM


from arkvale import adapter

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default="Llama-3.1-8B-Instruct",)
    parser.add_argument('--model_path', type=str, default=None,)
    parser.add_argument('--max_length', type=int, default=None,)
    parser.add_argument('--sparse_attn', action="store_true")

    parser.add_argument("--budgets", type=int, default=4096)
    parser.add_argument("--topks", type=int, default=4096)

    parser.add_argument("--page_size", type=int, default=32)
    parser.add_argument("--n_max_bytes", type=int, default=20 * (1 << 30))
    parser.add_argument("--n_max_cpu_bytes", type=int, default=20 * (1 << 30))
    

    return parser.parse_args(args)


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

def load_model_and_tokenizer(path, model_name, device, args):
    tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
    model = AutoModelForCausalLM.from_pretrained(
            path,
            device_map = device,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
        )
    
   
    page_size=args.page_size
    page_budgets=args.budgets // page_size
    page_topks = args.budgets // page_size
    
    if  args.sparse_attn:
        adapter.enable_arkvale(
            model, 
            dtype=torch.float16, 
            device=device, 
            page_size=page_size,
            page_budgets=page_budgets,
            page_topks=page_topks-1,
            n_max_bytes=args.n_max_bytes,
            n_max_cpu_bytes=args.n_max_cpu_bytes,
            compare_recall_rate=True,
        )
    model = model.eval()
    return model, tokenizer




def read_context():
    file_path = "./Llama-3_1-8B-Instruct(4096)_len_128000_depth_2200_context.txt"
    with open(file_path, 'r', encoding='utf-8') as f:
        content_string = f.read()
    return content_string


def generate_prompt(context, args, enc):
    if "Llama-3.1" in args.model_name:
        messages = [
            {"role": "system", "content": "You are a helpful assistant, you willed be gived a long story and retrieval the answer"},
            {"role": "user", "content": f"This is a very long story book: <book> {context} </book>.\n\nQuestion: Based on the content of the book, what is the best thing to do in San Francisco.? \n\nAnswer: The best thing to do in San Francisco is"},
        ]
        test_format = enc.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 
        # test_format = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>This is a very long story book: <book> {context} </book>.\n\nQuestion: Based on the content of the book, {self.retrieval_question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>",   
    elif "Qwen2.5" in args.model_name:
        test_format = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n This is a very long story book: <book> {context} </book>.\n\nQuestion: Based on the content of the book, what is the best thing to do in San Francisco.? give a complete sequence.\n\nAnswer: The best thing to do in San Francisco is<|im_end|>\n<|im_start|>assistant"
    return test_format
    


if __name__ == "__main__":
    seed_everything(42)
    args = parse_args()
    generate_tokens = 50
    # min_input_len = 64 * 1024
    # max_input_len = 128 * 1024
    device = torch.device("cuda:0")
    
    args.sparse_attn = 1
    model, tokenizer = load_model_and_tokenizer(args.model_path, args.model_name, device, args)
    texts = read_context()
    texts = generate_prompt(texts, args, tokenizer)
    input_ids = tokenizer(texts, return_tensors="pt").input_ids.to(device)
    prompt_length = input_ids.shape[-1]
    output = model.generate(input_ids, do_sample=True, max_new_tokens=generate_tokens, use_cache=True, )[0]
    print(tokenizer.decode(output[prompt_length:], skip_special_tokens=True))
    
    # args.sparse_attn = 0
    # model, tokenizer = load_model_and_tokenizer(args.model_path, args.model_name, device, args)
    # output = model.generate(input_ids, do_sample=True, max_new_tokens=generate_tokens, use_cache=True, )[0]
    # print(tokenizer.decode(output[prompt_length:], skip_special_tokens=True))
    
    # from arkvale import  utils
    # utils.com_attn_output_cos_similarity(max_generate_tokens=generate_tokens)
        