import torch
import os
import argparse
import numpy as np
import random
from transformers import AutoTokenizer,  AutoModelForCausalLM


from arkvale import adapter

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default="Llama-3.1-8B-Instruct",)
    parser.add_argument('--model_path', type=str, default=None,)
    parser.add_argument('--max_length', type=int, default=None,)
    parser.add_argument('--sparse_attn', action="store_true")

    parser.add_argument("--budgets", type=int, default=4096)
    parser.add_argument("--topks", type=int, default=4096)

    parser.add_argument("--page_size", type=int, default=32)
    parser.add_argument("--n_max_bytes", type=int, default=20 * (1 << 30))
    parser.add_argument("--n_max_cpu_bytes", type=int, default=20 * (1 << 30))
    

    return parser.parse_args(args)


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

def load_model_and_tokenizer(path, model_name, device, args):
    tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
    model = AutoModelForCausalLM.from_pretrained(
            path,
            device_map = device,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
        )
    
   
    page_size=args.page_size
    page_budgets=args.budgets // page_size
    page_topks = args.budgets // page_size
    
    if  args.sparse_attn:
        adapter.enable_arkvale(
            model, 
            dtype=torch.float16, 
            device=device, 
            page_size=page_size,
            page_budgets=page_budgets,
            page_topks=page_topks-1,
            n_max_bytes=args.n_max_bytes,
            n_max_cpu_bytes=args.n_max_cpu_bytes,
            compare_recall_rate=True,
        )
    model = model.eval()
    return model, tokenizer




def read_context():
    import json
    file_path = "../../dataset/longbenchv1/gov_report.jsonl"
    prompt_template = "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:"
    prompts = ""
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            # 跳过空行
            if line.strip():
                try:
                    # 每行是一个独立的 JSON 对象
                    item = json.loads(line)
                    context = item.get("context", "")
                    if context:  # 确保 context 不为空
                        # prompt = prompt_template.format(context=context)
                        prompts += context
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON line: {e}")
                    continue  
    prompts = prompt_template.format(context = prompts)
    return prompts




if __name__ == "__main__":
    seed_everything(42)
    args = parse_args()
    generate_tokens = 130

    device = torch.device("cuda:0")
    
    args.sparse_attn = 1
    model, tokenizer = load_model_and_tokenizer(args.model_path, args.model_name, device, args)
    texts = read_context()
    input_ids = tokenizer(texts, return_tensors="pt").input_ids.to(device)[:, :65536]
    
    prompt_length = input_ids.shape[-1]
    output = model.generate(input_ids, do_sample=True, max_new_tokens=generate_tokens, use_cache=True, )[0]
    print(tokenizer.decode(output[prompt_length:], skip_special_tokens=True))
    

        