import os, csv, json
import argparse
import time
from tqdm import tqdm
from datasets import load_dataset
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.multiprocessing as mp
from arkvale import adapter

model_map = json.loads(open('config/model2path.json', encoding='utf-8').read())
maxlen_map = json.loads(open('config/model2maxlen.json', encoding='utf-8').read())


template_rag = open('prompts/0shot_rag.txt', encoding='utf-8').read()
template_no_context = open('prompts/0shot_no_context.txt', encoding='utf-8').read()
template_0shot = open('prompts/0shot.txt', encoding='utf-8').read()
template_0shot_cot = open('prompts/0shot_cot.txt', encoding='utf-8').read()
template_0shot_cot_ans = open('prompts/0shot_cot_ans.txt', encoding='utf-8').read()

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_dir", "-s", type=str, default="results")
    parser.add_argument('--world_size', type=int, default=1,)
    parser.add_argument("--start_gpu", type=int, default=0)
    parser.add_argument('--model', type=str, default="Llama-3.1-8B-Instruct",)
    
    parser.add_argument("--sparse_attn", action="store_true", help="Enable Arkvale")

    parser.add_argument("--page_size", type=int, default=32)
    parser.add_argument("--budgets", type=int, default=4096)
    parser.add_argument("--n_max_bytes", type=int, default=30 * (1 << 30))
    parser.add_argument("--n_max_cpu_bytes", type=int, default=30 * (1 << 30))
    return parser.parse_args(args)
def query_llm(prompt, model_name, model, tokenizer,  device, temperature=0.5, max_new_tokens=128, ):
    # truncate
    max_len = maxlen_map[model_name]
    if len(prompt) > max_len:
        # input_ids = tokenizer.encode(prompt)
        # input_ids = input_ids[:max_len//2] + input_ids[-max_len//2:]
        # prompt = tokenizer.decode(input_ids, skip_special_tokens=True)
        prompt = prompt[:max_len//2] + prompt[-max_len//2:]
    input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
    output = model.generate(
            **input,
            max_new_tokens=max_new_tokens,
            num_beams=1,
            do_sample=False,
            temperature=temperature,
        )[0]
    output = tokenizer.decode(output, skip_special_tokens=True)
    return output
        
    


def extract_answer(response):
    response = response.replace('*', '')
    match = re.search(r'The correct answer is \(([A-D])\)', response)
    if match:
        return match.group(1)
    else:
        match = re.search(r'The correct answer is ([A-D])', response)
        if match:
            return match.group(1)
        else:
            return None

def load_model_and_tokenizer(path, device, args):
    tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
    model = AutoModelForCausalLM.from_pretrained(
            path,
            device_map = device,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
        )
    if args.sparse_attn:   
        page_size=args.page_size
        page_budgets=args.budgets // page_size
        adapter.enable_arkvale(
            model, 
            dtype=torch.float16, 
            device=device, 
            page_size=page_size,
            page_budgets=page_budgets,
            page_topks=page_budgets - 1,
            n_max_bytes=args.n_max_bytes,
            n_max_cpu_bytes=args.n_max_cpu_bytes,
            n_sink_pages=2,
            n_win_pages=2,
        )
        
    model = model.eval()
    return model, tokenizer

def get_pred(data, args, fout):
    device = torch.device('cuda:0')
    model_name = args.model
    model, tokenizer = load_model_and_tokenizer(model_map[model_name], device, args)
    print(len(data))
    for item in tqdm(data):
        context = item['context']
       
        template = template_0shot
        prompt = template.replace('$DOC$', context.strip()).replace('$Q$', item['question'].strip()).replace('$C_A$', item['choice_A'].strip()).replace('$C_B$', item['choice_B'].strip()).replace('$C_C$', item['choice_C'].strip()).replace('$C_D$', item['choice_D'].strip())
       
        output = query_llm(prompt, model_name, model, tokenizer, device, temperature=0.1, max_new_tokens=128)
        if output == '':
            continue
        
        response = output.strip()
        item['response'] = ""
        item['context'] = ""
        item['pred'] = extract_answer(response)
        item['judge'] = item['pred'] == item['answer']     
        fout.write(json.dumps(item, ensure_ascii=False) + '\n')
        fout.flush()
 


if __name__ == "__main__":
    args = parse_args()
    mp.set_start_method('spawn', force=True)
    world_size = args.world_size
    rank = os.environ["CUDA_VISIBLE_DEVICES"]

    os.makedirs(args.save_dir, exist_ok=True)
    if args.sparse_attn:
        out_file = os.path.join(args.save_dir, args.model.split("/")[-1] + f"_{args.budgets}.jsonl")
    else:
        out_file = os.path.join(args.save_dir, args.model.split("/")[-1] + ".jsonl")
    dataset = json.load(open('/home/wufang/PRCache/dataset/longbenchv2/data.json', 'r', encoding='utf-8'))
    data_all = [{"_id": item["_id"], "domain": item["domain"], "sub_domain": item["sub_domain"], "difficulty": item["difficulty"], "length": item["length"], "question": item["question"], "choice_A": item["choice_A"], "choice_B": item["choice_B"], "choice_C": item["choice_C"], "choice_D": item["choice_D"], "answer": item["answer"], "context": item["context"]} for item in dataset]

    # cache
    has_data = {}
    if os.path.exists(out_file):
        with open(out_file, encoding='utf-8') as f:
            has_data = {json.loads(line)["_id"]: 0 for line in f}
    fout = open(out_file, 'a', encoding='utf-8')
    data = []
    for item in data_all:
        if item["_id"] not in has_data:
            data.append(item)

    data_subsets = [data[i::world_size] for i in range(world_size)]
    
    data_subsets = data_subsets[int(rank)-args.start_gpu]


    get_pred(data_subsets, args, fout)