import torch
import os
import argparse
import numpy as np
import random
from transformers import AutoTokenizer,  AutoModelForCausalLM
from torch.profiler import profile, record_function, ProfilerActivity
from testModule import  LlamaModel_forward, Qwen2_forward
import re
from pathlib import Path
import openpyxl
import torch.multiprocessing as mp
from arkvale import adapter
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:cudaMallocAsync'
def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default="Llama-3.1-8B-Instruct",)
    parser.add_argument('--model_path', type=str, default=None,)
    parser.add_argument('--max_length', type=int, default=None,)
    parser.add_argument('--batch_size', type=int, default=1,)

    parser.add_argument("--sparse_attn", action="store_true", help="Sparse Attention")
    parser.add_argument("--prefetch", action="store_true", help="prefetch next_layer KVCache from CPU")

    parser.add_argument("--test_latency", action="store_true", help="Sparse Attention")
    parser.add_argument("--test_TPOT", action="store_true", help="Sparse Attention")

    parser.add_argument("--budgets", type=int, default=4096)
    parser.add_argument("--page_size", type=int, default=32)
    parser.add_argument("--n_max_bytes", type=int, default=20 * (1 << 30))
    parser.add_argument("--n_max_cpu_bytes", type=int, default=20 * (1 << 30))
    

    return parser.parse_args(args)


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

def load_model_and_tokenizer(path, model_name, device, args):
    tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
    model = AutoModelForCausalLM.from_pretrained(
            path,
            device_map = device,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
        )
    
    if args.sparse_attn:
        if args.test_TPOT:
            add_patch(model)
        page_size=args.page_size
        page_budgets=args.budgets // page_size
        
        adapter.enable_arkvale(
            model, 
            dtype=torch.float16, 
            device=device, 
            page_size=page_size,
            page_budgets=page_budgets,
            page_topks=page_budgets - 1,
            n_max_bytes=args.n_max_bytes,
            n_max_cpu_bytes=args.n_max_cpu_bytes,
            n_sink_pages=2,
            n_win_pages=2,
            n_prefetch_layers=1 if args.prefetch else None,
        )
    model = model.eval()
    return model, tokenizer



def add_patch(model):
    for mod in (model).modules():
        mod_cls = str(mod.__class__)
        if "LlamaModel" in mod_cls:
            mod.forward = (
                lambda mod: lambda *args, **kwargs: LlamaModel_forward(
                    mod, *args,  **kwargs)
            )(mod)
        if "Qwen2Model" in mod_cls:
            mod.forward = (
                lambda mod: lambda *args, **kwargs: Qwen2_forward(
                    mod, *args,  **kwargs)
            )(mod)
    
def read_context():
    import json
    file_path = "../../dataset/longbenchv1/gov_report.jsonl"
    prompt_template = "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:"
    prompts = ""
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            # 跳过空行
            if line.strip():
                try:
                    # 每行是一个独立的 JSON 对象
                    item = json.loads(line)
                    context = item.get("context", "")
                    if context:  # 确保 context 不为空
                        # prompt = prompt_template.format(context=context)
                        prompts += context
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON line: {e}")
                    continue
    prompts = prompt_template.format(context=prompts)
    return prompts


def chuncked_prefill(model, input_ids):
    chunk_size = 4000
    past_key_values = None

    with torch.no_grad():
        for i in range(0, input_ids.size(1), chunk_size):
            input_chunk = input_ids[:, i:i + chunk_size]

            outputs = model(
                input_ids=input_chunk,
                past_key_values=past_key_values,
                use_cache=True
            )

            # 更新 past_key_values 用于下一块
            past_key_values = outputs.past_key_values
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    return past_key_values, pred_token_idx


def test_TPOT(args):
    device = torch.device("cuda:0")
    model, tokenizer = load_model_and_tokenizer(args.model_path, args.model_name, device, args)
    texts = read_context()
    
    input_ids = tokenizer(texts, return_tensors="pt").input_ids.to(device)[:, : args.max_length]
    print(input_ids.size())
    batch_size = args.batch_size
    input_ids = input_ids.repeat(batch_size, 1)
    if not args.sparse_attn:
        output_file = "TPOT_base.txt"
        past_key_values, pred_token_idx = chuncked_prefill(model, input_ids)
        for i in range(100):
            with torch.no_grad():
                _ = model(
                    input_ids=pred_token_idx,
                    past_key_values=past_key_values,
                    use_cache=True,
                )
        torch.cuda.synchronize()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        for i in range(100):
            with torch.no_grad():
                _ = model(
                    input_ids=pred_token_idx,
                    past_key_values=past_key_values,
                    use_cache=True,
                )
                # past_key_values.evict_last(1)
        end.record()
        torch.cuda.synchronize()
        total_time = start.elapsed_time(end)
        avg_time = total_time / 100
        print(avg_time)
        with open(output_file, 'a', encoding='utf-8') as file:
            file.write(f"input_len{args.max_length}: {avg_time} ms\n")
    else:
        output_file = "TPOT_ours.txt"
        output = model.generate(input_ids, do_sample=True, max_new_tokens=512, use_cache=True, )
        record_TPOT(args.max_length, output_file)
        
def test_latency(args):
    device = torch.device("cuda:0")
    model, tokenizer = load_model_and_tokenizer(args.model_path, args.model_name, device, args)
    texts = read_context(context_len = args.max_length)
    texts = texts[:1]
    _ans = 0
    batch_size = args.batch_size
    for test_format in texts:
        input_ids = tokenizer(test_format, return_tensors="pt").input_ids.to(device)[:, : args.max_length]
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        output = model.generate(input_ids.repeat(batch_size, 1), do_sample=True, max_new_tokens=512, use_cache=True, )
        end.record()
        torch.cuda.synchronize()
        total_time = start.elapsed_time(end)
        print(total_time)
        _ans += total_time
    print(_ans / len(texts))

def record_TPOT(input_len, output_file="TPOT_ours.txt"):
    log_file_path = "./log/testModule.log"
    tbt_pattern = r"TPOT:\s*(\d+\.\d+)\s*ms"
    
    
    tbt_values = []
    with open(log_file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()
        last_lines = lines[-100:] if len(lines) >= 100 else lines
        
        for line in last_lines:
            match = re.search(tbt_pattern, line)
            if match:
                tbt_value = float(match.group(1))
                tbt_values.append(tbt_value)
            

    mean_tbt_value = sum(tbt_values) / len(tbt_values)
    try:
        with open(output_file, 'a', encoding='utf-8') as file:
            file.write(f"input_len{input_len}: {mean_tbt_value} ms\n")
        print(f"TPOT {mean_tbt_value} appended to {output_file}")
    except Exception as e:
        print(f"Error writing to file: {e}")
        return

    with open(log_file_path, 'w', encoding='utf-8') as file:
        file.write("")


if __name__ == "__main__":  
    # mp.set_start_method("spawn", force=True)
    seed_everything(42)
    args = parse_args()
    if args.test_TPOT:
        test_TPOT(args)
    elif args.test_latency:
        test_latency(args)