
import argparse

def get_args():
    parser = argparse.ArgumentParser(description="FD-LoRA Training Configuration")

    # Training settings
    parser.add_argument('--epochs', type=int, default=3, help='Number of training epochs per round')
    parser.add_argument('--rounds', type=int, default=10, help='Number of communication rounds')
    parser.add_argument('--batch_size', type=int, default=16, help='Training batch size')
    parser.add_argument('--lr', type=float, default=2e-5, help='Learning rate')
    parser.add_argument('--alpha', type=float, default=0.5, help='Distillation weight (0-1)')
    parser.add_argument('--temperature', type=float, default=2.0, help='Distillation temperature')

    # LoRA settings
    parser.add_argument('--lora_r', type=int, default=8, help='LoRA rank')
    parser.add_argument('--lora_alpha', type=int, default=16, help='LoRA scaling factor')

    # Dataset
    parser.add_argument('--task', type=str, default='mnli', choices=['mnli', 'sst2', 'qqp', 'qnli'], help='Task name')
    parser.add_argument('--n_clients', type=int, default=3, help='Number of simulated clients')
    parser.add_argument('--non_iid', type=str, default='mild', choices=['iid', 'mild', 'severe'], help='Data partition type')

    # System
    parser.add_argument('--device', type=str, default='cuda', help='Device to use: "cuda" or "cpu"')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--log_dir', type=str, default='./logs', help='Logging directory')

    return parser.parse_args()
