# import openai
import time
import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # del
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import json
from tqdm import tqdm
import argparse
import torch
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    AutoModelForCausalLM,
    set_seed,
)
from vllm import LLM, SamplingParams, AsyncLLMEngine
import argparse
import random
random.seed(48)

def chating(data, model, task, tokenizer, config):
    
    input_lists = []
    for item in data:
        ques = item['question']
        options = item['options']
        if task != 'pubmedqa_test':
            option_text = ['{}: {}'.format(key,options[key]) for key in options]
            ques = 'Question: ' +ques + '\nOptions: '+'\t'.join(option_text) 
        input_lists.append(ques)
    
    
    input_text_final = []
    for input_text in tqdm(input_lists):
        
        if task == 'medqa_test' or task == 'commonsenseqa_test':
            format_prompt = 'Your answer format should be like "Answer: [A-E]".'
        elif task == 'mmlu_samp_test' or task == 'mmlu_test' or task == 'openbookqa_test':
            format_prompt = 'Your answer format should be like "Answer: [A-D]".'
        elif task == 'arc_challenge_test':
            format_prompt = 'Your answer format should be like "Answer: [A-D]".'
        elif task == 'pubmedqa_test':
            format_prompt = 'Your answer format should be like "Answer: [yes/no/maybe]".'
        input_text += '\n'+format_prompt +'\n'
        if 'gemma' in args.model_name:
            messages=[
                    {"role": "user", "content": input_text},      
                    ]
        else:
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": input_text},      
                ]
        if 'Qwen3' in args.model_name:
            chat_text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=False
                
            )
        else:
            chat_text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            

        input_text_final.append(chat_text)
        # break
    # while True:
        # try:
    print('encode')
    
    sampling_params = SamplingParams(
    temperature=0,
    max_tokens=50
    # stop=["\n\nQuestion",'\n\nStatement']
    )
    # preds= tokenizer.batch_decode(outputs, skip_special_tokens=True)
    outputs = model.generate(
        input_text_final,
        sampling_params
        )
    replys = [output.outputs[0].text for output in outputs]
    
    return replys

input_file_name = 'disease_description'

# for i, line in enumerate(fin):
    
#     input_text = prompt + line.strip()
# model = 'gpt-3.5-turbo'

def load_data(path):
    return json.load(open(path,'r',encoding='utf8'))


def main(args):
    typs = ['medqa_test','mmlu_test','arc_challenge_test','commonsenseqa_test','openbookqa_test','pubmedqa_test']
    config = AutoConfig.from_pretrained(args.model,trust_remote_code=True)
    model = LLM(model=args.model,tensor_parallel_size=args.num_cuda, gpu_memory_utilization=args.util,swap_space=args.swap,max_model_len=2048,trust_remote_code=True)
    # model = AsyncLLMEngine()
    tokenizer = AutoTokenizer.from_pretrained(args.model,trust_remote_code=True)
    cnt = -1
    # backward_path = os.path.join('data/{}'.format(typ), 'backward_questions.json')
    
    
    # forward_data = load_data(forward_path)
    # backward_data = load_data(backward_path)

    for typ in typs:
        data = []
        with open('data/{}.json'.format(typ),'r') as f:
            for line in f:
                data.append(json.loads(line.strip()))
        if not os.path.exists('results/{}'.format(typ)):
            os.makedirs('results/{}'.format(typ))
            
        if os.path.exists('results/{0}/{1}_results.json'.format(typ,args.model_name)) and not args.rewrite:
            print('results exists, skip')
            continue
        outf = open('results/{0}/{1}_results.json'.format(typ,args.model_name),'w', encoding='utf-8')
        
            
        results = chating(data, model, typ, tokenizer, config)
        for i, item in enumerate(data):
            tmp_results = results[i]
            item['response'] = tmp_results
            outf.write(json.dumps(item)+'\n')
        outf.close()
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ntrain", "-k", type=int, default=5)
    parser.add_argument("--nchain", "-c", type=int, default=5)
    parser.add_argument("--nbatch", "-b", type=int, default=-1)
    parser.add_argument("--model", type=str, default='')
    parser.add_argument("--model_name", type=str, default='')
    parser.add_argument("--debug",action='store_true')
    parser.add_argument("--rewrite",action='store_true')
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--num_cuda", type=int, default=2)
    parser.add_argument("--util", type=float, default=0.98)
    parser.add_argument("--swap", type=int, default=4)
    args = parser.parse_args()
    main(args)

