# import openai
import time
import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # del
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,5"
import json
from tqdm import tqdm
import argparse
import torch
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    AutoModelForCausalLM,
    set_seed,
)
from vllm import LLM, SamplingParams, AsyncLLMEngine
from vllm.lora.request import LoRARequest
import argparse
import random
random.seed(48)
def chating(input_lists, data, num_examples, model, tokenizer, config):
    input_text_final = []
    for item in tqdm(input_lists):
        idx = item[0]
        mcq_questions = item[1]
        for text in mcq_questions:

            input_text = text
            format_prompt = 'Your answer format should be like "Answer: [A-D]".'
            input_text += '\n'+format_prompt +'\n'
            if 'gemma' in args.model_name:
                messages=[
                        {"role": "user", "content": input_text},      
                        ]
            else:
                messages=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": input_text},      
                    ]
            if 'Qwen3' in args.model_name:
                chat_text = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=False
                    
                )
            else:
                chat_text = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
                # print(tokenizer.encode(input_text, return_tensors="pt").size(1))
                # if new_len > 512:
                #     new_len = 512
                

            input_text_final.append(chat_text)
        # break
    # while True:
        # try:
    print('encode')
    
    sampling_params = SamplingParams(
    temperature=0,
    max_tokens=512
    # stop=["\n\nQuestion",'\n\nStatement']
    )
    # preds= tokenizer.batch_decode(outputs, skip_special_tokens=True)
    outputs = model.generate(
        input_text_final,
        sampling_params,
        lora_request=LoRARequest("sql_adapter", 1, args.model)
        )
    replys = [output.outputs[0].text for output in outputs]
    return replys

def load_data(path):
    return json.load(open(path,'r',encoding='utf8'))




def main(args):
    config = AutoConfig.from_pretrained(args.model,trust_remote_code=True)
    model = LLM(model=args.base_model,tensor_parallel_size=args.num_cuda, gpu_memory_utilization=args.util,swap_space=args.swap,max_model_len=2048,trust_remote_code=True, enable_lora=True)
    # model = AsyncLLMEngine()
    tokenizer = AutoTokenizer.from_pretrained(args.model,trust_remote_code=True)
    cnt = -1
    # backward_path = os.path.join('data/{}'.format(typ), 'backward_questions.json')
    data = []
    with open('primekg/primekg_multifaceteval_probe.jsonl','r') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    # forward_data = load_data(forward_path)
    # backward_data = load_data(backward_path)
    out_data = []
    
    if not os.path.exists('results/{}'.format(args.dataset)):
        os.makedirs('results/{}'.format(args.dataset))
    outf = open('results/{0}/{1}_results.json'.format(args.dataset,args.model_name),'w', encoding='utf-8')
    
    input_lists = []
    for i,item in enumerate(tqdm(data)):
        mcq = item[3]
        
        input_lists.append([i,mcq])
        
    results = chating(input_lists, data, args.ntrain, model, tokenizer, config)
    for i, item in enumerate(data):
        tmp_results = results[i*20:(i+1)*20]
        item.append(tmp_results)
        outf.write(json.dumps(item)+'\n')
    outf.close()
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ntrain", "-k", type=int, default=5)
    parser.add_argument("--nchain", "-c", type=int, default=5)
    parser.add_argument("--nbatch", "-b", type=int, default=-1)
    parser.add_argument("--typs", action="extend",nargs="+", type=str)
    parser.add_argument("--dataset",type=str,default='primekg_probe')
    parser.add_argument("--base_model", type=str, default='/ssd/common/LLMs/Meta-Llama-3-8B-Instruct')
    parser.add_argument("--model", type=str, default='/ssd/common/LLMs/Meta-Llama-3-8B-Instruct')
    parser.add_argument("--model_name", type=str, default='llama3-8B')
    parser.add_argument("--debug",action='store_true')
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--num_cuda", type=int, default=1)
    parser.add_argument("--util", type=float, default=0.98)
    parser.add_argument("--swap", type=int, default=4)
    # parser.add_argument("--max_tokens",type=int, default=2048)
    # parser.add_argument("--subjects", type=list, default=['cat_qa','isa_qa1'])
    args = parser.parse_args()
    main(args)

