import pandas as pd
import pickle
import os
import json
datasets = ["hotpotqa", "nq_search", "2wikimultihopqa", "musique", "triviaqa", "popqa", "bamboogle"]



def clean_data(dataset):
        with open(f"data/{dataset}/process_test.parquet", "rb") as f:
            df = pd.read_parquet(f)
        print(f"Loaded {len(df)} questions from {dataset}")
        # dirty questions 
        with open(f"data/dirty_questions_{dataset}.dump", "rb") as f:
            dirty_questions = pickle.load(f) # set of dirty questions {"question1", "question2", ...}
        # remove dirty questions
        df = df[~df["question"].isin(dirty_questions)]
        print(f"Removed {len(dirty_questions)} dirty questions from {dataset}")
        df.to_parquet(f"data/{dataset}/process_test_clean.parquet")
        print(f"Saved {len(df)} questions to {dataset}/process_test_clean.parquet")


def clean_results(dataset, base_dir="eval_results", output_dir="eval_results_cleaned"):
    # for all jsonl files in eval_results/{dataset}
    with open(f"data/dirty_questions_{dataset}.dump", "rb") as f:
        dirty_questions = pickle.load(f) # set of dirty questions {"question1", "question2", ...}
    for file in os.listdir(f"{base_dir}/{dataset}"):
        if file.endswith(".jsonl"):
            print(f"Processing {file}")
            try:
                with open(f"{base_dir}/{dataset}/{file}", "r") as f:
                    data = [json.loads(line) for line in f]
                # remove questions that in dirty_questions_{dataset}.dump
            except:
                print(f"Error processing {file}")
                continue
            try:
                data = [item for item in data if item["question"] not in dirty_questions]
                print(f"Removed {len(dirty_questions)} dirty questions from {dataset} {file}")
            except:
                print(f"Error processing {file}")
                continue
            os.makedirs(f"{output_dir}/{dataset}", exist_ok=True)
            with open(f"{output_dir}/{dataset}/{file}", "w") as f:
                for item in data:
                    f.write(json.dumps(item) + "\n")
            print(f"Saved {len(data)} questions to {dataset}/{file}")

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default=None)
    parser.add_argument("--base_dir", type=str, default="eval_results")
    parser.add_argument("--output_dir", type=str, default="eval_results_cleaned")
    args = parser.parse_args()
    if args.dataset is None:
        for dataset in datasets:
            clean_results(dataset, args.base_dir, args.output_dir)
    else:
        clean_results(args.dataset, args.base_dir, args.output_dir)