
import argparse
import pickle
import json
import os
import re
from nhird_dataset import NHIRDDataset
from tqdm import tqdm
import math
from torch.utils.data import  Subset

# Example usage
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--patient_file_path", type=str, default="/data/nhird_tasks_cancer_low_risk/cancer_screening_dataset_lung_1.json/*" )
    parser.add_argument("--ids_file_path", type=str, default="/data/cancer_lung_low_risk")  
    parser.add_argument("--num_chunks", type=int, default=20)
    parser.add_argument("--chunk_idx", type=int, default=0)     
    args = parser.parse_args()

    # Load the dataset
    nhird = NHIRDDataset(args.patient_file_path)
    print("Total patients: ", len(nhird))
    os.makedirs(args.ids_file_path, exist_ok=True)
    dataset_len = len(nhird)
    partition_size = math.ceil(dataset_len / args.num_chunks)
    # Split dataset into partition indices
    file_path_batches = [list(range(i * partition_size, min((i + 1) * partition_size, dataset_len))) for i in range(args.num_chunks)]
    subset = Subset(nhird, file_path_batches[args.chunk_idx])
    # Collect total data along with indices
    partition_data = []
    for patient in tqdm(subset,desc=f"Generate patients from partition {args.chunk_idx}: "):
        partition_data.append(patient)
    
        
    with open(args.ids_file_path+f"partition{args.chunk_idx}.pkl", 'wb') as pkl_file:
        pickle.dump(partition_data, pkl_file)
        print(f"Partition{args.chunk_idx} is saved")
