from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
from unsloth import is_bfloat16_supported
import torch
import json,os
import argparse
from tqdm import tqdm
from datasets import load_from_disk
from vllm import SamplingParams
cot_path="src/data/prompts/sqv.txt"
TASK_INS={"unanswer":"If there is no information available from the context, the answer should be 'unknown'. ",
"inconsistent":"If there is conflict information or multiple answers from the context, the answer should be 'conflict'.",
"normal":""}

model_path=YOUR_MODEL_PATH
base_model_name=model_path.split("/")[-1]
OUT_BASE=f"/mnt/mednas/jiangyinuo.jyn/RALM/code/outputs/{base_model_name}"
if not os.path.exists(OUT_BASE):
    os.makedirs(OUT_BASE)

def inference(model,input):
  sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
    # stop_token_ids=terminators
)
  output = model.fast_generate(
    input,
    sampling_params = sampling_params,
)[0].outputs[0].text
  return output
@torch.no_grad()
def infer_ds(model,ds_path):
  ds_config=ds_path.split("/")[-1]
  out_base=os.path.join(OUT_BASE,ds_config)
  if not os.path.exists(out_base):
    os.makedirs(out_base)
  adapter_path= model.config.adapter_path
  config_name=adapter_path.split("/")[-2]+"_"+adapter_path.split("/")[-1]
  out_file_name=out_base+"/"+config_name+".jsonl"
  ds=load_from_disk(ds_path)
  if os.path.exists(out_file_name):
    return
  for i in tqdm(range(len(ds))):
    item=ds[i]
    if "choices" in ds.column_names:
       input=input="{s}Context:{c}\n\nQuestion:{q}\n\nOptions:{o}\n\nAnswer:"
       input=input.format(s=SYS_PROMPT,c=item["context"],q=item["question"],o=",".join(item["choices"]["text"]))
    else:
      input="{s}Context:{c}\n\nQuestion:{q}\n\nAnswer:"
      input=input.format(s=SYS_PROMPT,c=item["context"],q=item["question"])
    res=inference(model,input)
    with open(out_file_name, 'a', encoding='utf-8') as file:
      file.write(json.dumps({"idx":str(i),"res":res,}, ensure_ascii=False) + '\n')
    
if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='inferring')
  parser.add_argument('--test', type=str, default="")
  parser.add_argument('--ckpt', type=str, default="")
  args = parser.parse_args()
  adapter_path=args.ckpt
  ds_path=args.test
  max_seq_length = 2048 # Can increase for longer reasoning traces
  lora_rank = 64 # Larger rank = smarter, but slower
  
  model, tokenizer = FastLanguageModel.from_pretrained(
      model_name = model_path,
      max_seq_length = max_seq_length,
      load_in_4bit = True, # False for LoRA 16bit
      fast_inference = True, # Enable vLLM fast inference
      max_lora_rank = lora_rank,
      gpu_memory_utilization = 0.7, # Reduce if out of memory
  )
  model = FastLanguageModel.get_peft_model(
      model,
      r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
      target_modules = [
          "q_proj", "k_proj", "v_proj", "o_proj",
          "gate_proj", "up_proj", "down_proj",
      ], # Remove QKVO if out of memory
      lora_alpha = lora_rank,
      use_gradient_checkpointing = "unsloth", # Enable long context finetuning
      random_state = 3407,
)
  model.load_lora(adapter_path)
  model.config.adapter_path=adapter_path
  FastLanguageModel.for_inference(model)
  if "inconsistent" in ds_path:
    task_specific_prompt=TASK_INS["inconsistent"]
  elif "unanswer" in ds_path:
    task_specific_prompt=TASK_INS["unanswer"]
  else:
    task_specific_prompt=TASK_INS["normal"]
  with open(cot_path, "r", encoding='utf-8') as f:  
    SYS_PROMPT = f.read()   #读取文本
  SYS_PROMPT=SYS_PROMPT.format(task_specific_prompt=task_specific_prompt)
  infer_ds(model,ds_path)