import torch
import os
import argparse
import numpy as np
import random
from transformers import AutoTokenizer,  AutoModelForCausalLM
from torch.profiler import profile, record_function, ProfilerActivity
from typing import Iterable, Union, Any
import re
from pathlib import Path
import openpyxl
import torch.multiprocessing as mp
from arkvale import adapter
import json

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:cudaMallocAsync'
def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default="Llama-3.1-8B-Instruct",)
    parser.add_argument('--model_path', type=str, default=None,)
    parser.add_argument('--max_length', type=int, default=None,)
    parser.add_argument('--batch_size', type=int, default=1,)

    parser.add_argument("--sparse_attn", action="store_true", help="Sparse Attention")
    parser.add_argument("--prefetch", action="store_true", help="prefetch next_layer KVCache from CPU")

    parser.add_argument("--test_latency", action="store_true", help="Sparse Attention")
    parser.add_argument("--test_TPOT", action="store_true", help="Sparse Attention")

    parser.add_argument("--budgets", type=int, default=4096)
    parser.add_argument("--page_size", type=int, default=32)
    parser.add_argument("--n_max_bytes", type=int, default=20 * (1 << 30))
    parser.add_argument("--n_max_cpu_bytes", type=int, default=20 * (1 << 30))
    

    return parser.parse_args(args)


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

def load_model_and_tokenizer(path, model_name, device, args):
    tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
    model = AutoModelForCausalLM.from_pretrained(
            path,
            device_map = device,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
        )
    
    if args.sparse_attn:
        # if args.test_TPOT:
        #     add_patch(model)
        page_size=args.page_size
        page_budgets=args.budgets // page_size
        
        adapter.enable_arkvale(
            model, 
            dtype=torch.float16, 
            device=device, 
            page_size=page_size,
            page_budgets=page_budgets,
            page_topks=page_budgets - 1,
            n_max_bytes=args.n_max_bytes,
            n_max_cpu_bytes=args.n_max_cpu_bytes,
            n_sink_pages=2,
            n_win_pages=2,
        )
    model = model.eval()
    return model, tokenizer


def load_jsonl(file: Union[str, Path]) -> Iterable[Any]:
    with open(file, "r", encoding="utf-8") as f:
        for line in f:
            try:
                yield json.loads(line)
            except:
                print("Error in loading:", line)
                exit()
                
def load_2024_dataset():
   
    data_file = f"data/aime/test.jsonl"

    examples = list(load_jsonl(data_file))

    if "idx" not in examples[0]:
        examples = [{"idx": i, **example} for i, example in enumerate(examples)]

    # dedepulicate & sort
    examples = sorted(examples, key=lambda x: x["idx"])
    return examples






def test_latency(args):
    device = torch.device("cuda:0")
    model, tokenizer = load_model_and_tokenizer(args.model_path, args.model_name, device, args)
    
    dataset_loaded = load_2024_dataset()
    dataset_as_list_full = list(dataset_loaded)
    dataset_to_process_this_run = dataset_as_list_full[2:3] 
    
    data = dataset_to_process_this_run[0]
    problem_text = data["problem"]
    
    batch_size = args.batch_size
    
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    allocate_time = 0
    # latency
    input_ids = tokenizer(problem_text, return_tensors="pt").input_ids.to(device)
    start.record()
    output = model.generate(input_ids.repeat(batch_size, 1), do_sample=True, max_new_tokens=16*1024, use_cache=True, )
    end.record()
    torch.cuda.synchronize()
    total_time = start.elapsed_time(end)
    print("latency time: ", total_time-allocate_time)
    
    

def record_TPOT(input_len, output_file="TPOT_ours.txt"):
    log_file_path = "./log/testModule.log"
    tbt_pattern = r"TPOT:\s*(\d+\.\d+)\s*ms"
    
    
    tbt_values = []
    with open(log_file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()
        last_lines = lines[-100:] if len(lines) >= 100 else lines
        
        for line in last_lines:
            match = re.search(tbt_pattern, line)
            if match:
                tbt_value = float(match.group(1))
                tbt_values.append(tbt_value)
            

    mean_tbt_value = sum(tbt_values) / len(tbt_values)
    try:
        with open(output_file, 'a', encoding='utf-8') as file:
            file.write(f"input_len{input_len}: {mean_tbt_value} ms\n")
        print(f"TPOT {mean_tbt_value} appended to {output_file}")
    except Exception as e:
        print(f"Error writing to file: {e}")
        return

    with open(log_file_path, 'w', encoding='utf-8') as file:
        file.write("")


if __name__ == "__main__":  
    # mp.set_start_method("spawn", force=True)
    seed_everything(42)
    args = parse_args()
    if args.test_latency:
        test_latency(args)