# config.py

import argparse
import torch

def get_args():
    """
    Manages all hyperparameters of the project.
    """
    parser = argparse.ArgumentParser(description="CLAGR Optimizer Official Implementation")

    # ================== Basic and Path Parameters ==================
    parser.add_argument('--output_dir', type=str, default='logs', help='Directory to save logs and models.')
    parser.add_argument('--datadir', type=str, default='./datasets', help='Path to the dataset directory.')
    parser.add_argument('--seed', type=int, default=42, help='Global random seed.')
    parser.add_argument('--log_freq', type=int, default=10, help="Frequency of printing training information.")
    
    # ================== Training Process Parameters ==================
    parser.add_argument('--epochs', type=int, default=200, help="Total number of training epochs.")
    parser.add_argument('--batch_size', type=int, default=128, help="Batch size for training and validation.")
    parser.add_argument('--num_workers', type=int, default=4, help="Number of worker threads for the dataloader.")
    parser.add_argument('--pin_memory', action='store_true', default=True, help="Whether to use pinned memory.")
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help="Training device.")
    parser.add_argument('--noise_percentage', type=int, default=0, help="Percentage of label noise to add to the training data (0-100).")


    # ================== Model Parameters ==================
    parser.add_argument('--model', type=str, default='resnet18', help="The model to use (e.g., resnet18, vit_ti, wideresnet28x10).")
    
    # ================== Base Optimizer Parameters (e.g., SGD) ==================
    parser.add_argument('--base_optimizer', type=str, default='sgd', help='The base optimizer (e.g., sgd, adam).')
    parser.add_argument('--lr', type=float, default=0.05, help="Learning rate for the base optimizer.")
    parser.add_argument('--weight_decay', type=float, default=5e-4, help="Weight decay.")
    parser.add_argument('--momentum', type=float, default=0.9, help="Momentum for SGD.")
    
    # ================== CLAGR Optimizer Specific Parameters ==================
    parser.add_argument('--optimizer', type=str, default='clagr', help='The base optimizer (e.g., sam,clagr).')
    parser.add_argument('--rho', type=float, default=0.1, help="Radius of the perturbation neighborhood (rho).")
    parser.add_argument('--inner_step', type=int, default=1, help="Number of inner optimization steps (K).")
    parser.add_argument('--cr_lambda', type=float, default=0.1, help="Coefficient for the gradient correction term (lambda_cr).")
    parser.add_argument('--lmomentum', type=float, default=0.8, help="Momentum for the lookahead mechanism (alpha_la).")
    parser.add_argument('--gamma_interp', type=float, default=0.9, help="Coefficient for gradient direction fusion (gamma_interp).")
    parser.add_argument('--beta_start', type=float, default=0.9, help="Starting coefficient for directional momentum interpolation (beta_start).")
    parser.add_argument('--beta_end', type=float, default=0.99, help="Ending coefficient for directional momentum interpolation (beta_end).")

    # ================== Learning Rate Scheduler Parameters ==================
    parser.add_argument('--lr_scheduler', type=str, default='CosineAnnealingLR', help='Type of learning rate scheduler.')
    parser.add_argument('--eta_min', type=float, default=0, help="Minimum learning rate for CosineAnnealingLR.")

    args = parser.parse_args()
    return args