from datasets import  Dataset
templates={
  "OneKE":{ "prompt_input" : "[INST] <<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n\"instruction\": \"{instruction}\", {input}[/INST]",
    "instruction": ["You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string."],
    "input" :"\"schema\": {entities}, \"input\": \"{sentence}\""},
  
}
import json
global base_model
file_name="/gruntdata/cognition-nas/yinuo/datasets/schema/ner.json"
re_file_name="/mnt/prev_nas/jiangyinuo.jyn/code/datasets/re_schema.json"
with open(file_name, 'r', encoding='utf-8') as file:
    ner_schema_dict = json.load(file)
with open(re_file_name, 'r', encoding='utf-8') as file:
    re_schema_dict = json.load(file)
schema_dict={"RE":re_schema_dict,"NER":ner_schema_dict}
def format_output(example):
  ents=[]
  entity_list=schema_dict["NER"][example["source"]]
  res_dict={key:[] for key in entity_list}
  if example["source"] !="rare_disease":
    for ent in example["entities"]:
      if ent["text"] is not None and len(ent["text"])>0 and ent["type"] in entity_list:
        ents.append((ent["type"],ent["text"])) ###type, text 
    # if len(ents)>0:ents.sort()
    
    for t, text in ents:
      res_dict[t].append(text)
  else:
    for key in entity_list:
      #  res_dict[key]=1
      res_dict[key]=example["entities"][key] if example["entities"][key] is not None else []
      # if res_dict[key] is None:
      #   res_dict[key]=[1]
  example["output"]=json.dumps(res_dict)
  return example
def format_qwen_end(target):
    target["output"] = target["output"] + '<|endoftext|>'
    return target
def format_output_re(example):
  """
  triples to json
  """
  if "source" in example:
    type_list=schema_dict["RE"][example["source"]]
  else:
    type_list=schema_dict["RE"]["rare_disease"]
  res_dict={key:[] for key in type_list}
  triples=eval(example["output"])
  for tris in triples:
    rel=tris["rel"]
    res_dict[rel].append({"head":tris["head"][0],"tail":tris["tail"][0]})
  example["output"]=str(res_dict)
  return example
def format(dataset,model_name):
  task = "RE"  if model_name.endswith("RE") else "NER"
  def format_input(example):
    if example["source"] not in schema_dict[task]:
      return example
    entity_list=schema_dict[task][example["source"]]
    template=templates[model_name]
    sentence_column_name="input" if model_name.endswith("RE") else "sentence"
    example["sentence"]=  template["prompt_input"].format(instruction=template["instruction"][0], input=template["input"].format(entities=entity_list,sentence=example[sentence_column_name]))
    return example
  dataset=dataset.map(format_input)
  if model_name.endswith("RE") :
    dataset=dataset.map(format_output_re) 
  else :
    dataset=dataset.map(format_output)
  if model_name=="Qwen":
    dataset=dataset.map(format_qwen_end)
  return dataset
def format_infer(dataset,model_name):
  task = "RE"  if model_name.endswith("RE") else "NER"
  def format_input(example):
    if example["source"] not in schema_dict[task]:
      return example
    entity_list=schema_dict[task][example["source"]]
    template=templates[model_name]
    example["sentence"]=  template["prompt_input"].format(instruction=template["instruction"][0], input=template["input"].format(entities=entity_list,sentence=example["sentence"]))
    return example
  dataset=dataset.map(format_input) 
  return dataset
def format_squad(ds):
  questions = [q.strip() for q in ds['question']]
  contexts = [c.strip() for c in ds['context']]
  prompts = [
      f"Context: {context}\nQuestion: {question}\nAnswer:"
      for question, context in zip(questions, contexts)
  ]
  if isinstance(ds['answers'][0], list):
    answers = [a[0] for a in ds['answers']]
  else:
    answers = [a['text'][0] for a in ds['answers']]
  ds = Dataset.from_dict({'sentence': prompts,  'output': answers})
  return ds

