import nibabel as nib
import numpy as np
import os
import sys
import time
import multiprocessing as mp
from functools import partial
import glob

# 定义138标签到6个目标的映射
label_mapping = {
    'Ventricles': [1, 2, 21, 22, 23, 24],
    'Hippocampus': [19, 20],
    'Entorhinal': [55, 56],
    'Fusiform': [61, 62, 93, 94],
    'MidTemp': [89, 90],
    'WholeBrain': [
        41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 57, 58, 59, 60, 
        63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 
        81, 82, 83, 84, 85, 86, 87, 88, 91, 92, 95, 96, 97, 98, 99, 100, 101, 102, 
        103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 
        118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 
        133, 134, 135, 136, 137, 138, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 
        16, 17, 25, 26, 27, 28, 29, 30, 31, 32, 35, 36, 37, 38, 39, 40
    ]
}

# 创建全局映射表（进程安全）
def create_label_mapping_table():
    max_label = 138
    mapping_table = np.zeros(max_label + 1, dtype=np.uint8)
    
    # 按优先级从低到高填充
    for label in label_mapping['WholeBrain']:
        if label <= max_label:
            mapping_table[label] = 1
    for label in label_mapping['Ventricles']:
        if label <= max_label:
            mapping_table[label] = 2
    for label in label_mapping['Hippocampus']:
        if label <= max_label:
            mapping_table[label] = 3
    for label in label_mapping['Entorhinal']:
        if label <= max_label:
            mapping_table[label] = 4
    for label in label_mapping['Fusiform']:
        if label <= max_label:
            mapping_table[label] = 5
    for label in label_mapping['MidTemp']:
        if label <= max_label:
            mapping_table[label] = 6
    
    return mapping_table


# 处理单个文件
def process_file(mask_path, label_mapping_table):
    try:
        start_time = time.time()
        img = nib.load(mask_path)
        data = img.get_fdata()
        
        # 向量化映射
        new_mask = np.zeros_like(data, dtype=np.uint8)
        valid_mask = (data >= 0) & (data < len(label_mapping_table))
        new_mask[valid_mask] = label_mapping_table[data[valid_mask].astype(int)]
        
        # 创建输出路径
        dir_name = os.path.dirname(mask_path)
        base_name = os.path.basename(mask_path)
        if base_name.endswith('.nii.gz'):
            new_base = base_name.replace('.nii.gz', '_6targets.nii.gz')
        else:
            new_base = base_name.replace('.nii', '_6targets.nii')
        output_path = os.path.join(dir_name, new_base)
        
        # 保存结果
        new_img = nib.Nifti1Image(new_mask, img.affine, header=img.header)
        nib.save(new_img, output_path)
        
        return (mask_path, output_path, time.time() - start_time, None)
    except Exception as e:
        return (mask_path, None, 0, str(e))

# 主函数
def main(root_dir, num_processes):
    label_mapping_table = create_label_mapping_table()
    
    # 查找所有文件
    file_pattern = os.path.join(root_dir, '**', 'MALPEM-ADNI*.nii*')
    nii_files = glob.glob(file_pattern, recursive=True)
    total_files = len(nii_files)
    print(f"找到 {total_files} 个待处理文件，使用 {num_processes} 个进程并行处理")
    
    # 创建进程池
    start_time = time.time()
    with mp.Pool(processes=num_processes) as pool:
        # 使用partial固定映射表参数
        worker = partial(process_file, label_mapping_table=label_mapping_table)
        results = []
        for i, result in enumerate(pool.imap_unordered(worker, nii_files)):
            results.append(result)
            if result[3]:
                print(f"错误: {result[0]} -> {result[3]}")
            if (i + 1) % 100 == 0:
                print(f"已处理 {i+1}/{total_files} 文件")
    
    # 统计结果
    total_time = time.time() - start_time
    success_count = sum(1 for r in results if r[1] is not None)
    error_count = total_files - success_count
    
    print("\n" + "="*60)
    print(f"处理完成! 成功: {success_count}, 失败: {error_count}")
    print(f"总耗时: {total_time:.2f}秒 | 平均每文件: {total_time/total_files:.4f}秒")
    print("="*60)

if __name__ == '__main__':
    if len(sys.argv) != 2:
        print("用法: python process_labels_parallel.py <根目录>")
        sys.exit(1)
    
    root_dir = sys.argv[1]
    # 自动设置为CPU核心数的75% (避免过度占用系统)
    num_processes = max(1, int(mp.cpu_count() * 0.75))
    main(root_dir, num_processes)