import os
import argparse
import numpy as np
import torch
from scipy import ndimage
from tqdm import tqdm
import json # <-- 1. 导入json库，用于序列化args

def prepare_masks():
    """
    为Pipe数据集的测试集生成并保存复杂/简单区域的掩码。
    复杂度基于真值场的梯度范数。
    元数据(args)将被保存为JSON字符串以保证安全性和可移植性。
    """
    parser = argparse.ArgumentParser('Mask Preparation for Transolver Pipe Experiment')
    
    # --- 输入参数 ---
    parser.add_argument('--data_path', type=str, required=True, 
                        help='包含Pipe .npy数据集的目录路径')
    parser.add_argument('--complex_threshold_percent', type=int, default=20, 
                        help='梯度最高的Top N%的点被定义为“复杂区域”')
    parser.add_argument('--simple_threshold_percent', type=int, default=50, 
                        help='梯度最低的Bottom N%的点被定义为“简单区域”')
    parser.add_argument('--save_path', type=str, default='./pipe_masks.json', 
                        help='保存生成的掩码文件的路径')
    
    # --- 数据集参数 ---
    parser.add_argument('--downsamplex', type=int, default=1, help='X轴的下采样率')
    parser.add_argument('--downsampley', type=int, default=1, help='Y轴的下采样率')
    
    args = parser.parse_args()

    print("--- 1. Loading and Processing Test Data Ground Truth ---")
    OUTPUT_Sigma = os.path.join(args.data_path, 'Pipe_Q.npy')
    if not os.path.exists(OUTPUT_Sigma):
        raise FileNotFoundError(f"真值文件未找到: {OUTPUT_Sigma}")

    ntest = 200
    r1, r2 = args.downsamplex, args.downsampley
    s1 = int(((129 - 1) / r1) + 1)
    s2 = int(((129 - 1) / r2) + 1)
    
    output = np.load(OUTPUT_Sigma)[:, 0]
    y_full = torch.tensor(output, dtype=torch.float)
    y_test = y_full[-ntest:, ::r1, ::r2][:, :s1, :s2]
    y_test_np = y_test.numpy()
    print(f"Test data ground truth loaded. Shape: {y_test_np.shape}")

    print("\n--- 2. Calculating Gradient-based Importance Scores ---")
    all_scores = []
    for i in tqdm(range(y_test_np.shape[0]), desc="Calculating Gradients"):
        sample_grid = y_test_np[i]
        grad_y, grad_x = np.gradient(sample_grid)
        gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2)
        all_scores.append(gradient_magnitude.flatten())
        
    scores_tensor = np.array(all_scores)
    print(f"Scores tensor calculated. Shape: {scores_tensor.shape}")

    print("\n--- 3. Determining Global Thresholds ---")
    flat_scores = scores_tensor.flatten()
    
    complex_thresh_val = np.percentile(flat_scores, 100 - args.complex_threshold_percent)
    simple_thresh_val = np.percentile(flat_scores, args.simple_threshold_percent)
    
    print(f"  -> Complex region threshold (scores > {complex_thresh_val:.4f}) corresponds to Top {args.complex_threshold_percent}%")
    print(f"  -> Simple region threshold (scores < {simple_thresh_val:.4f}) corresponds to Bottom {args.simple_threshold_percent}%")

    print("\n--- 4. Generating and Saving Masks ---")
    complex_masks = scores_tensor > complex_thresh_val
    simple_masks = scores_tensor < simple_thresh_val
    
    # --- 核心修改：将args字典转换为JSON字符串 ---
    # 1. 将 argparse.Namespace 转换为标准字典
    args_dict = vars(args)
    # 2. 使用 json.dumps 将字典序列化为字符串
    args_json_string = json.dumps(args_dict, indent=4) # indent增加可读性
    # --- 修改结束 ---
    
    # 使用np.savez保存多个数组到一个压缩文件中
    np.savez(
        args.save_path,
        # 保存掩码数组 (bool类型，安全)
        complex_masks=complex_masks,
        simple_masks=simple_masks,
        # --- 核心修改：保存JSON字符串而不是Python字典对象 ---
        #    Numpy会将字符串保存为普通的字节数组，而不是对象数组
        args=args_json_string
    )
    print(f"Masks successfully saved to '{args.save_path}'")
    
    # 打印一些统计信息以供验证
    total_points_per_sample = s1 * s2
    avg_complex_points = np.mean(np.sum(complex_masks, axis=1))
    avg_simple_points = np.mean(np.sum(simple_masks, axis=1))
    print(f"  -> Average complex points per sample: {avg_complex_points:.1f} / {total_points_per_sample} (~{avg_complex_points/total_points_per_sample*100:.1f}%)")
    print(f"  -> Average simple points per sample: {avg_simple_points:.1f} / {total_points_per_sample} (~{avg_simple_points/total_points_per_sample*100:.1f}%)")

if __name__ == "__main__":
    prepare_masks()