#!/usr/bin/env python
#-*- coding:utf-8 _*-

import argparse

def get_args():
    parser = argparse.ArgumentParser(description="GNOT and variants: Training and Benchmarking for Operator Learning")

    # --- 1. 核心与环境设置 (Core & Environment) ---
    core_group = parser.add_argument_group('Core & Environment Settings')
    core_group.add_argument('--seed', type=int, default=2023, help='Random seed for reproducibility.')
    core_group.add_argument('--gpu', type=int, default=0, help='GPU ID to use.')
    core_group.add_argument('--comment', type=str, default="", help="A custom comment for the experiment run.")
    core_group.add_argument('--use-tb', action='store_true', help='Enable TensorBoard logging.')
    core_group.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')

    # --- 2. 数据集与加载 (Dataset & Loading) ---
    data_group = parser.add_argument_group('Dataset & Loading Settings')
    data_group.add_argument('--dataset', type=str, default='ns2d_autoregressive',
                            choices=['ns2d_autoregressive', 'kf2d', 'rd','pipe2d',
                                     'heat2d', 'ns2d', 'inductor2d','airfoil2d'],
                            help="Name of the dataset to use.")
    data_group.add_argument('--train-num', type=str, default='all',
                            help="Number of training samples (time steps) to use. Can be 'all' or an integer.")
    data_group.add_argument('--test-num', type=str, default='all',
                            help="Number of test samples (time steps) to use. Can be 'all' or an integer.")
    data_group.add_argument('--num_train_sims', type=int, default=1000,
                            help='Number of simulations to use for the training set.')
    data_group.add_argument('--num_test_sims', type=int, default=200,
                            help='Number of simulations to use for the test set.')
    data_group.add_argument('--timesteps_per_sim', type=int, default=20,
                            help='Number of time steps per simulation. Used to calculate sim boundaries.')
    data_group.add_argument('--normalize_x', type=str, default='unit', choices=['none', 'unit'],
                            help="Normalization method for input coordinates and parameters.")
    data_group.add_argument('--use-normalizer', type=str, default='unit', choices=['none', 'unit'],
                            help="Normalization method for the state/solution field (y).")
    data_group.add_argument('--num_workers', type=int, default=0, help='Number of worker processes for DataLoader.')
    data_group.add_argument('--sort-data',type=int, default=0)

    # --- 3. 训练过程 (Training Process) ---
    train_group = parser.add_argument_group('Training Process Settings')
    train_group.add_argument('--epochs', type=int, default=150, help='Number of epochs to train.')
    train_group.add_argument('--batch-size', type=int, default=8, help='Input batch size for training.')
    train_group.add_argument('--val-batch-size', type=int, default=10, help='Input batch size for validation.')
    train_group.add_argument('--optimizer', type=str, default='AdamW', choices=['Adam', 'AdamW'], help="Optimizer to use.")
    train_group.add_argument('--lr', type=float, default=1e-3, help='Max learning rate for the optimizer.')
    train_group.add_argument('--weight-decay', type=float, default=1e-5, help='Weight decay for the optimizer.')
    train_group.add_argument('--grad-clip', type=float, default=1.0, help='Gradient clipping value.')
    train_group.add_argument('--lr-method', type=str, default='cycle', choices=['cycle', 'step', 'warmup'],
                            help="Learning rate schedule method.")
    train_group.add_argument('--amp', action='store_true', help="Enable Automatic Mixed Precision (AMP) training.")
    train_group.add_argument('--validation_freq', type=int, default=1, help="Frequency of validation (in epochs).")
    train_group.add_argument('--loss-name', type=str, default='rel2', choices=['rel2', 'rel1', 'l2', 'l1'])
    train_group.add_argument('--component', type=str, default='all-reduce', help="Component to compute loss on.")
    train_group.add_argument('--hfourier-dim',type=int,default=0)
    train_group.add_argument('--ponder_loss_weight', type=float, default=0.01,
                        help='Weight for the Ponder Cost auxiliary loss for GNOT_StaticDepth.')
    train_group.add_argument('--target_average_depth', type=float, default=2.5,
                        help='The target average computation depth for GNOT_StaticDepth.')
    # --- 4. 长时序自回归训练 (Autoregressive Training) ---
    ar_group = parser.add_argument_group('Autoregressive Training Settings')
    ar_group.add_argument('--unroll_steps', type=int, default=5,
                          help='Number of steps to unroll for multi-step loss during training.')
    ar_group.add_argument('--val_unroll_steps', type=int, default=10, # <--- 就是这一行
                      help='验证/测试时展开的步数。')
    ar_group.add_argument('--segment_length', type=int, default=15,
                          help='Total length of the continuous time-series segment sampled from the dataset.')
    ar_group.add_argument('--tf_start', type=float, default=1.0,
                          help='Initial teacher forcing rate at the beginning of training.')
    ar_group.add_argument('--tf_end', type=float, default=0.0,
                          help='Final teacher forcing rate at the end of the decay period.')
    ar_group.add_argument('--tf_decay_epochs', type=int, default=50,
                          help='Number of epochs over which to decay the teacher forcing rate.')


    # --- 5. 模型架构 (Model Architecture) ---
    model_group = parser.add_argument_group('Model Architecture Settings')
    model_group.add_argument('--model-name', type=str, default='SR_GNOT',
                             choices=['CGPT', 'GNOT', 'SR_GNOT', 'StructuredRecursiveGNOT','SR_GNOT_SS','GNOT_StaticDepth'],
                             help="Name of the model architecture to use.")
    model_group.add_argument('--n-hidden', type=int, default=64, help='Hidden dimension size.')
    model_group.add_argument('--n-layers', type=int, default=8, help='Number of layers or recursion depth.')
    model_group.add_argument('--n-head', type=int, default=1, help='Number of attention heads.')
    model_group.add_argument('--act', type=str, default='gelu', choices=['gelu', 'relu', 'tanh', 'sigmoid'])
    model_group.add_argument('--ffn-dropout', type=float, default=0.0, help='Dropout for the FFN in attention blocks.')
    model_group.add_argument('--attn-dropout', type=float, default=0.0, help='Dropout for the attention mechanism.')
    model_group.add_argument('--mlp-layers', type=int, default=2, help='Number of layers in MLP blocks.')
    model_group.add_argument('--attn-type', type=str, default='linear', help="Type of attention mechanism.")
    # GNOT / SR_GNOT specific
    model_group.add_argument('--n-experts', type=int, default=2, help='Number of experts in MoE layers.')
    model_group.add_argument('--n-inner', type=int, default=4, help='Factor for inner dimension in FFNs (n_inner * n_hidden).')
    # SR_GNOT specific
    model_group.add_argument('--final_keep_ratio', type=float, default=0.25,
                             help='For SR_GNOT, final ratio of tokens to keep in the last recursion layer.')
    model_group.add_argument('--capacity_ratios', type=float, nargs='+', default=None,
                             help="For SR_GNOT, a list of integers to specify the capacity for each recursion layer.")
    
    # This will be set dynamically in data_utils_fix.py, but having a default is good practice.
    model_group.add_argument('--space_dim', type=int, default=2, help='Spatial dimension of the coordinates (e.g., 1 for 1D, 2 for 2D).')


    # --- 6. 基准测试 (Benchmarking) ---
    benchmark_group = parser.add_argument_group('Benchmarking Settings')
    benchmark_group.add_argument('--model_paths', type=str, nargs='+', default=None,
                                 help="[Benchmark] Path(s) to trained model checkpoint(s) (.pt) to evaluate.")
    benchmark_group.add_argument('--rollout_length', type=int, default=10,
                                 help="[Benchmark] Number of autoregressive steps to perform during inference.")
    benchmark_group.add_argument('--save_path', type=str, default=None,
                                 help="[Benchmark] Path to save the detailed benchmark results as a JSON file.")
    benchmark_group.add_argument('--speed_test_runs', type=int, default=3,
                                 help="[Benchmark] Number of times to run the speed test for stable results.")
    benchmark_group.add_argument('--benchmark_batch_size', type=int, default=32,
                                 help="[Benchmark] Batch size to use during inference/benchmarking.")
    benchmark_group.add_argument('--test_sims_indices', type=int, nargs='*', default=None,
                                 help="[Benchmark] Specific simulation indices from the test set to run the benchmark on.")

    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = get_args()
    print("--- Parsed Arguments ---")
    # 打印所有参数，按字母顺序，以便查看
    for key, value in sorted(vars(args).items()):
        print(f"{key}: {value}")
    print("-" * 24)