from PIL import Image
import numpy as np
import glob
import os
from multiprocessing import Pool, cpu_count
from functools import partial
import time

def process_single_case(k, file_list, output_dir, slice_count=304):
    pid = file_list[k*slice_count].split('/')[-2]
    img_3d = np.zeros((slice_count, 256, slice_count))
    
    case_files = file_list[k*slice_count:(k+1)*slice_count]
    
    for i, img_path in enumerate(case_files):
        pid_ = img_path.split('/')[-2]
        assert pid == pid_, f"Patient ID mismatch: {pid} != {pid_}"
        
        idx = int(img_path.split('/')[-1].split('.')[0])
        
        try:
            with Image.open(img_path) as img:
                img = img.resize((slice_count, 256), Image.Resampling.LANCZOS)
                arr_img = np.array(img)
                img_3d[idx-1, :, :] = arr_img
        except Exception as e:
            print(f"Error processing image {img_path}: {str(e)}")
            continue
    img_3d = (img_3d-img_3d.min())/(img_3d.max()-img_3d.min())
    
    save_path = os.path.join(output_dir, f"{pid}.npy")
    np.save(save_path, img_3d)
    return pid

def process_dataset(img_dir, train_dir, val_dir, test_dir):
    for d in [train_dir, val_dir, test_dir]:
        os.makedirs(d, exist_ok=True)
    
    file_list = []
    for subdir, _, files in os.walk(img_dir):
        file_list.extend([os.path.join(subdir, f) for f in files if f.endswith('.bmp')])
    file_list.sort()
    
    print(f"Total files found: {len(file_list)}")
    print(f"First file: {file_list[0]}")
    
    num_processes = max(1, cpu_count() - 1)
    print(f"Using {num_processes} CPU cores")
    
    splits = {
        'train': (0, 120),
        'val': (120, 134),
        'test': (134, 200)
    }
    
    with Pool(num_processes) as pool:
        for split_name, (start, end) in splits.items():
            print(f"\nProcessing {split_name} set ({start} to {end})")
            output_dir = {'train': train_dir, 'val': val_dir, 'test': test_dir}[split_name]
            
            process_func = partial(process_single_case, 
                                 file_list=file_list, 
                                 output_dir=output_dir)
            
            results = pool.map(process_func, range(start, end))
            
            print(f"Completed {split_name} set, processed cases: {len(results)}")

if __name__ == '__main__':
    img_dir = ""
    save_dir = ""
    val_dir = ""
    test_dir = ""
    
    start_time = time.time()
    
    process_dataset(img_dir, save_dir, val_dir, test_dir)
    
    elapsed_time = time.time() - start_time
    print(f"\nTotal processing time: {elapsed_time:.2f} seconds")