import numpy as np

import argparse
from utils import *
from datasets import Dataset, concatenate_datasets

def creat_index(dataset):
    mapping = {}
    for item in dataset:
        mapping[item['index']] = item
    return mapping

def long_shot_doc_match(args):

    short_data_path=args.short_data_path
    long_data_dir=args.long_data_dir
    long_data_name=args.long_data_name
    long_text_length=args.long_text_length.split(",")
    
    # 1. load short dataset
    short_dataset = load_custom_dataset(short_data_path)
    print(short_dataset)
    print(np.mean(short_dataset["length"]))
    for long_text_len in long_text_length:
        print(long_text_len)
        # 2. load long dataset
        long_dataset = load_custom_dataset(f"{long_data_dir}/{long_text_len}/{long_data_name}")
        print("long context length",np.mean(long_dataset["length"]))
        long_dataset_mapping = creat_index(long_dataset)
        item_list = []
        for idx, short_item in enumerate(short_dataset):
            # 3. match
            # short_text_id = short_item['index']
            short_text_id = idx
            long_item = long_dataset_mapping[short_text_id]
            assert short_item["question"] == long_item["question"]
            item_list.append({
                "index": short_text_id,
                "short_docs_list": short_item["all_docs"],
                "long_docs_list": long_item["all_docs"],
                "question":short_item["question"],
                "answer_aliases":short_item["answer_aliases"],
                "positive_sample":short_item["positive_sample"],
                "negative_sample":short_item["negative_sample"],
                "long_text_length":long_item["length"],
            })
        final_dataset = Dataset.from_list(item_list)
        # 4. save
        output_path = f"{long_data_dir}/{long_text_len}/supporting_random_shor2long_{long_data_name}"
        final_dataset.to_json(output_path, force_ascii=False)

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--long_data_dir", type=str)
    parser.add_argument("--long_data_name", type=str)
    parser.add_argument("--short_data_path", type=str)
    parser.add_argument("--long_text_length", type=str,default="8k", help="4k,8k,16k")

    return parser.parse_args()

def main():
    args = get_args()
    long_shot_doc_match(args)

if __name__ == "__main__":
    main()
    