import argparse
import os

import yaml


def args_parser():
    """Parse command-line arguments and optionally load defaults from a YAML file."""
    pre_parser = argparse.ArgumentParser(add_help=False)
    pre_parser.add_argument("--config", type=str, default="", help="Path to the YAML configuration file.")

    args, unknown = pre_parser.parse_known_args()

    config = {}
    if args.config and os.path.exists(args.config):
        print(f"INFO: Loading configuration from: {args.config}")
        with open(args.config, "r", encoding="utf-8") as f:
            try:
                config = yaml.safe_load(f)
            except yaml.YAMLError as e:
                print(f"Error parsing YAML file: {e}")
    elif args.config:
        print(f"WARNING: Configuration file not found at: {args.config}. Using defaults and command-line args.")
    else:
        print("INFO: No configuration file specified. Using defaults and command-line args.")

    parser = argparse.ArgumentParser(
        description="Arguments for Byzantine-Robust Federated Learning",
        parents=[pre_parser],
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # Federated arguments
    parser.add_argument("--num_experiments", type=int, default=5, help="Number of experiments")
    parser.add_argument("--epochs", type=int, default=100, help="Rounds of training")
    parser.add_argument("--num_users", type=int, default=32, help="Number of users: K")
    parser.add_argument("--num_Chosenusers", type=int, default=16, help="Number of chosen users per round")
    parser.add_argument("--frac", type=float, default=0.1, help="The fraction of clients: C")
    parser.add_argument("--local_ep", type=int, default=4, help="The number of local epochs: E")
    parser.add_argument("--local_bs", type=int, default=64, help="Local batch size: B")
    parser.add_argument("--lr", type=float, default=0.05, help="Learning rate")
    parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum (default: 0.9)")
    parser.add_argument("--dirichlet_alpha", type=float, default=0.5, help="degree of non-i.i.d.")
    parser.add_argument("--num_items_train", type=int, default=2000, help="Number of data items per user for training")
    parser.add_argument(
        "--num_items_test", type=int, default=512, help="Number of data items per user for testing/estimation"
    )

    # Model arguments
    parser.add_argument("--model", type=str, default="cnn", help="Model name")
    parser.add_argument("--submodel", type=str, default="VGG11", help='NN submodel: "AlexNet", "VGG11","lenet"')
    parser.add_argument("--dataset", type=str, default="cifar", help="Name of dataset")
    parser.add_argument("--subdataset", type=str, default="mnist", help='Sub-dataset (e.g., "fashion" for mnist)')
    parser.add_argument("--kernel_num", type=int, default=9, help="Number of each kind of kernel")
    parser.add_argument(
        "--kernel_sizes", type=str, default="3,4,5", help="Comma-separated kernel size to use for convolution"
    )
    parser.add_argument("--norm", type=str, default="batch_norm", help="batch_norm, layer_norm, or None")
    parser.add_argument("--num_filters", type=int, default=32, help="Number of filters for conv nets")

    # Other arguments
    parser.add_argument("--num_classes", type=int, default=10, help="Number of classes")
    parser.add_argument("--num_channels", type=int, default=3, help="Number of channels of images")
    parser.add_argument("--gpu", type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument("--stopping_rounds", type=int, default=10, help="Rounds of early stopping")
    parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 1)")

    # Attack and detection arguments
    parser.add_argument("--num_attackers", type=int, default=4, help="Number of attackers")
    parser.add_argument(
        "--attackway", type=str, default="minsum", help="Attack algorithm (e.g., minmax, minsum, lie, fang)"
    )
    parser.add_argument("--attacker_ability", type=str, default="Full", help="Attackers ability (e.g., Full, Part)")
    parser.add_argument(
        "--detection", type=str, default="TriGuardFL", help="Detection algorithm (e.g., YE, DeFL, FLTrust)"
    )
    parser.add_argument("--significance", type=float, default=0.01, help="Detection significance level")
    parser.add_argument("--cos_threshold", type=float, default=0.1, help="Cosine similarity threshold")
    parser.add_argument("--reputation_threshold", type=float, default=0.6, help="Reputation threshold")
    parser.add_argument("--discount", type=float, default=0.9, help="Reputation term")
    parser.add_argument(
        "--epochs_phase_2", type=int, default=20, help="Only good clients can be selected After this round"
    )

    # Boolean flags
    parser.add_argument("--iid", type=eval, default=False, help="Use IID data distribution (True or False)")
    parser.add_argument("--verbose", type=eval, default=False, help="Enable verbose printing")
    parser.add_argument("--max-pool", type=eval, default=False, help="Use max pooling instead of strided convolutions")
    parser.add_argument("--attack", type=eval, default=True, help="Enable Byzantine attacks")

    parser.set_defaults(**config)

    final_args = parser.parse_args()

    return final_args
