from utils import *
from transformers import AutoTokenizer
from datasets import Dataset
from tqdm import tqdm
import random
import numpy as np
import argparse

TEXT_TEMP="Passage \"{title}\":\n{paragraph_text}"

def get_supporting_text(example,tokenizer = None, with_random_text=True):
    paragraphs = example["paragraphs"]
    all_docs = []
    unrelated_texts = []
    for para in paragraphs:
        if para["is_supporting"]:
            title = para["title"]
            paragraph_text = para["paragraph_text"]
            all_docs.append(TEXT_TEMP.replace("{title}", title).replace("{paragraph_text}", paragraph_text))
        else:
            title = para["title"]
            paragraph_text = para["paragraph_text"]
            unrelated_texts.append(TEXT_TEMP.replace("{title}", title).replace("{paragraph_text}", paragraph_text))
    if with_random_text:
        all_docs = all_docs + random.sample(unrelated_texts, k = 3*len(all_docs))
    random.shuffle(all_docs)
    example["all_docs"] = all_docs
    example["outputs"] = [example["answer"]] + example["answer_aliases"]
    if tokenizer:
        length = len(tokenizer.encode("\n\n".join(all_docs)+"\n\n"+example["question"]))
        example["length"] = length
    return example

def main_down_sampling(args):
    tok_path = args.tokenizer_path
    src_data_path = args.src_data_path
    tgt_save_path = args.tgt_save_path

    tokenizer = AutoTokenizer.from_pretrained(tok_path)
    dataset = load_custom_dataset(src_data_path)
    dataset = dataset.map(lambda x:get_supporting_text(x, tokenizer))
    dataset = dataset.add_column("index", list(range(dataset.num_rows)))
    dataset = dataset.select_columns(["index","all_docs","question","outputs","length"])
    dataset.to_json(tgt_save_path, force_ascii=False)
    print("length", np.mean(dataset["length"]))

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--tokenizer_path", type=str)
    parser.add_argument("--src_data_path", type=str)
    parser.add_argument("--tgt_save_path", type=str)
    parser.add_argument("--irrelevant_docs", type=int, default=3)
    # Since the supporting texts in the MuSiQue dataset are relatively short, we still need to randomly add some irrelevant texts to expand the average length to 1K.
    return parser.parse_args()

def main():
    args = get_args()
    main_down_sampling(args)

if __name__ == "__main__":
    main()