import os,json,jsonlines
os.chdir("/mnt/mednas/jiangyinuo.jyn/RALM/code/src")
from utils import ds_infer,extract_json_dict
from datasets import load_from_disk
from tqdm import tqdm
temp_dir="src/data/prompts"
res_dir="src/data/ad_ds"
context_dir=res_dir
cot_path="src/data/prompts/sqv.txt"
with open(cot_path, "r", encoding='utf-8') as f:  
    SYS_PROMPT = f.read()   #读取文本

TASK_INS={"unanswer":"If there is no information available from the context, the answer should be 'unknown'. ",
"inconsistent":"If there is conflict information or multiple answers from the context, the answer should be 'conflict'.",
"normal":""}
task_specific_prompt=TASK_INS["unanswer"]
SYS_PROMPT=SYS_PROMPT.format(task_specific_prompt=task_specific_prompt)
def gen_cot(n_context,question):
  input="{s}Context:{c}\n\nQuestion:{q}\n\nAnswer:"
  input=input.format(s=SYS_PROMPT,c=n_context,q=question)
  res=ds_infer(input)
  return res
def find_index(data_list,idx):
    for index, element in enumerate(data_list):
      if element.get('index') == idx:
        return index
def gen_cot_ds(ds_name):
  # load n_context and question pairs
  n_context_path=os.path.join(context_dir,ds_name+".jsonl")
  res_dicts=[]
  with jsonlines.open(n_context_path) as reader:
        for i,obj in enumerate(reader): 
          res_dicts.append(obj)
  out_file_name=os.path.join(res_dir,ds_name)+".jsonl"
  for i in tqdm(range(len(res_dicts))):
    idx=res_dicts[i]["index"]
    n_context=res_dicts[i]["modified context"]
    question=res_dicts[i]["question"]
    res=gen_cot(n_context,question)
    # gd_ans=res_dicts[i]["annotation"]["answer"]
    assert res!="Error"
    with open(out_file_name, 'a', encoding='utf-8') as file:
        file.write(json.dumps({"index":idx,"answer":res,"question":question,"n_context":n_context}, ensure_ascii=False) + '\n')
if __name__ == '__main__':
  ds_name="squad_merged"
  gen_cot_ds(ds_name)
   