from utils import *
from transformers import AutoTokenizer
from datasets import Dataset
from tqdm import tqdm
import random
import numpy as np
import argparse

TEXT_TEMP="Passage \"{title}\":\n{paragraph_text}"

def modify_supporting_text(src_paragraphs_list, mixed_supporting_text_list):
    paragraphs_list = []
    count = 0
    supporting_list = []
    for para in src_paragraphs_list:
        if para["is_supporting"]:
            supporting_list.append(para)
        title = para["title"]
        paragraph_text = para["paragraph_text"]
        src_text = TEXT_TEMP.replace("{title}", title).replace("{paragraph_text}", paragraph_text)
        if src_text in mixed_supporting_text_list:
            para["is_supporting"] = True
            count += 1
        paragraphs_list.append(para)
    assert count == len(mixed_supporting_text_list)
    return paragraphs_list,supporting_list

def main_select_raw_text(args):
    ref_file_path = args.ref_file_path
    src_file_path = args.src_file_path
    tgt_save_path = ref_file_path.replace(".jsonl", "_raw.jsonl")
    
    ref_dataset = load_custom_dataset(ref_file_path)
    src_dataset = load_custom_dataset(src_file_path)
    src_dataset = src_dataset.add_column("index", list(range(src_dataset.num_rows)))

    src_mapping = {x["index"]: x for x in src_dataset}
    tgt_dataset_list = []
    for ref_item in tqdm(ref_dataset):
        idx = ref_item["index"]
        src_item = src_mapping[idx]
        assert src_item["question"] == ref_item["question"]
        src_item["positive_sample"] = ref_item["positive_sample"]
        src_item["negative_sample"] = ref_item["negative_sample"]
        paragraphs_list,supporting_list = modify_supporting_text(src_item["paragraphs"], ref_item["all_docs"])
        src_item["paragraphs"] = paragraphs_list
        src_item["supporting_paragraphs"] = supporting_list
        tgt_dataset_list.append(src_item)

    tgt_dataset = Dataset.from_list(tgt_dataset_list)
    assert tgt_dataset.num_rows == ref_dataset.num_rows
    tgt_dataset.to_json(tgt_save_path, force_ascii=False)

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ref_file_path", type=str, help="the data file with preference pairs, which is a subset of raw musique data file")
    parser.add_argument("--src_file_path", type=str, help="raw musique data file, which contains supporting texts and irelevant texts")
    return parser.parse_args()

def main():
    args = get_args()
    main_select_raw_text(args)


if __name__ == "__main__":
    main()