import os,json
os.chdir("/mnt/mednas/jiangyinuo.jyn/RALM/code/src")
from utils import ds_infer,extract_json_dict
from datasets import load_from_disk
from tqdm import tqdm
temp_dir="src/data/prompts"
res_dir="src/data/ad_ds"

NOISY_TYPE="counterfactual" # set different noisy type
with open(os.path.join(temp_dir,f"{NOISY_TYPE}.txt"), "r", encoding='utf-8') as f:  
    SYS_PROMPT = f.read()  

def add_noise(context,question,answer):
  input="{s}context:{c}\n\nquestion:{q}\n\nold answer:{a}"
  input=input.format(s=SYS_PROMPT,c=context,q=question,a=answer)
  res=ds_infer(input)
  return res
def add_noise_cf(question,answer):
  input="{s}\n\nquestion:{q}\n\nold answer:{a}"
  input=input.format(s=SYS_PROMPT,q=question,a=answer)
  res=ds_infer(input)
  return res
def add_noise_ds(ds_path):
  out_file_name=os.path.join(res_dir,ds_path.split("/")[-1])+f"_{NOISY_TYPE}.jsonl"
  ds=load_from_disk(ds_path)
  for i in tqdm(range(2409,len(ds))):
    item=ds[i]
    # c=get_item_facts(item)
    c=item["context"]
    q=item["question"]
    a=item["answers"]["text"][0]
    res=add_noise(c,q,a)
    assert res!="Error"
    res_dict=extract_json_dict(res)
    if not isinstance(res_dict, dict):
      continue
    res_dict["index"]=i
    with open(out_file_name, 'a', encoding='utf-8') as file:
        file.write(json.dumps(res_dict, ensure_ascii=False) + '\n')
if __name__ == '__main__':

  ds_path=DS_PATH
  add_noise_ds(ds_path)


