#!/usr/bin/env python3
"""
Script to check file overlap between train and test folders.
The file name is each folder is by "year_pmid.json". Since pmid is unique, we can use it to check overlap.
"""

import os
import sys
import shutil

def get_files_in_folder(folder_path):
    """Get set of file names in a folder."""
    if not os.path.exists(folder_path):
        print(f"Error: Folder does not exist: {folder_path}")
        sys.exit(1)
    
    files = set()
    for item in os.listdir(folder_path):
        item_path = os.path.join(folder_path, item)
        if os.path.isfile(item_path):
            files.add(item)
    return files

# move files to new folder
def move_files_to_new_folder(files, ori_folder, new_folder):
    # Create destination folder if it doesn't exist
    os.makedirs(new_folder, exist_ok=True)
    for file in files:
        shutil.move(os.path.join(ori_folder, file), os.path.join(new_folder, file))
    print(f"Moved {len(files)} files from {ori_folder} to {new_folder}")


def main():
    # only need to modify the base folder, since postfix is fixed for current data
    # MODIFY THIS: Set your base folder path
    base_folder = "<YOUR_SFT_QA_DATA_DIR>/pubmed_sft_qa_data_v2_"
    # constant settings for current data
    train_folder_postfix = "run8"
    test_folder_postfix = "2025_October"
    folder_train = base_folder + train_folder_postfix
    folder_test = base_folder + test_folder_postfix
    folder_test_overlapping = folder_test + "_overlapping"
    
    print(f"Folder train: {folder_train}")
    print(f"Folder test: {folder_test}")
    print("-" * 80)
    
    # Get files from both folders
    files_train = get_files_in_folder(folder_train)
    files_test = get_files_in_folder(folder_test)
    
    print(f"Number of files in Folder train: {len(files_train)}")
    print(f"Number of files in Folder test: {len(files_test)}")
    print("-" * 80)
    
    # Find overlapping files
    overlapping_files = files_train.intersection(files_test)

    print(f"Number of overlapping files: {len(overlapping_files)}")
    print("-" * 80)
    
    if overlapping_files:
        print("Overlapping files:")
        for i, filename in enumerate(sorted(overlapping_files), 1):
            print(f"  {i}. {filename}")
    else:
        print("No overlapping files found.")
    
    print("-" * 80)

    # Move overlapping files to new folder
    move_files_to_new_folder(overlapping_files, folder_test, folder_test_overlapping)
    
    # Also show files unique to each folder
    unique_to_folder_train = files_train - files_test
    unique_to_folder_test = files_test - files_train
    
    print(f"Files unique to Folder train: {len(unique_to_folder_train)}")
    print(f"Files unique to Folder test: {len(unique_to_folder_test)}")



if __name__ == "__main__":
    main()

