from datasets import load_dataset
from itertools import chain
import jsonlines
ds = load_dataset("sewon/ambig_qa", "light", split="validation")

new_dataset = []

for example in ds:
    question = example["question"].strip()
    type = 'multiple' if 'multipleQAs' in example['annotations']['type'] else "single"
    truthful_answers = []
    
    
        
    for i,t in enumerate(example['annotations']['type']):
        if t =='singleAnswer':
            # 
            truthful_answers.extend(example['annotations']['answer'][i])
        else:
            # 
            truthful_answers.extend(list(chain(*example['annotations']['qaPairs'][i]["answer"])))
    truthful_answers = list(set([an.strip() for an in truthful_answers]))
    new_dataset.append({
        "query": question,
        "type": type,
        "truthful answer": truthful_answers,
    })
    
with jsonlines.open("data/datasets/ambigqa/test.jsonl", "w") as writer:
    writer.write_all(new_dataset)