import os
import json
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

def write_text_file(args):
    filename, text, output_dir = args
    file_path = os.path.join(output_dir, filename)
    with open(file_path, 'w', encoding='utf-8') as f:
        f.write(text)

def extract_text(input_dir, output_dir):

    request_tok_path = os.path.join(input_dir, 'request_tok')
    json_files = [f for f in os.listdir(request_tok_path) if f.endswith('.json')]

    print(f"Processing {len(json_files)} JSON files in request_tok...")

    passages = []
    for json_file in tqdm(json_files, desc="Processing files"):
        gt_table = json_file[:-5]
        request_file_path = os.path.join(request_tok_path, json_file)
        
        try:
            with open(request_file_path, 'r') as f:
                passages_data = json.load(f)    

            for idx, (wiki_path, passage_text) in enumerate(passages_data.items()):
                passage_filename = f"{gt_table}_{idx}.txt"
                passages.append((passage_filename, passage_text, output_dir))
        
        except Exception as e:
            print(f"Error processing {request_file_path}: {e}")
    
    with ThreadPoolExecutor() as executor:
        list(tqdm(executor.map(write_text_file, passages), total=len(passages), desc="Writing text files"))

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Extract texts from HybridQA dataset.")
    parser.add_argument('--input', type=str, default='WikiTables-WithLinks', help='Input directory containing request_tok')
    parser.add_argument('--output', type=str, default='text', help='Output directory for extracted texts')
    args = parser.parse_args()

    input_dir = args.input
    output_dir = args.output
    os.makedirs(output_dir, exist_ok=True)

    extract_text(input_dir, output_dir)
