import argparse


def parse_arguments(cfg):
    """
    Parse command line arguments and override default configuration values.

    Args:
        cfg (dict): Default configuration dictionary

    Returns:
        dict: Updated configuration with command-line overrides
    """
    parser = argparse.ArgumentParser(description='Model training with configurable parameters')

    # Add arguments for user-configurable parameters with both short and long forms
    parser.add_argument('-nm', '--noise_mean', type=float, default=cfg['noise_mean'],
                        help=f'Mean of Gaussian noise (default: {cfg["noise_mean"]})')

    parser.add_argument('-ns', '--noise_sd', type=float, default=cfg['noise_sd'],
                        help=f'Standard deviation of Gaussian noise (default: {cfg["noise_sd"]})')

    parser.add_argument('-af', '--adaptation_flag', type=lambda x: x.lower() == 'true', 
                        default=cfg['adaptation_flag'],
                        help=f'Enable adaptation (default: {cfg["adaptation_flag"]})')

    parser.add_argument('-alm', '--adapt_loss_mode', type=str, 
                        choices=['coral', 'geo_adapt', 'mmd', 'homm', 'log_coral', 'ddc', 'cmd'],
                        default=cfg['adapt_loss_mode'],
                        help=f'Adaptation loss mode (default: {cfg["adapt_loss_mode"]})')

    parser.add_argument('-hm', '--highest_moment', type=int, default=cfg['highest_moment'],
                        help=f'Highest moment for adaptation (default: {cfg["highest_moment"]})')

    parser.add_argument('-l', '--lambda', type=float, dest='lambda_tmp', default=cfg['lambda'],
                        help=f'Lambda value for loss weighting (default: {cfg["lambda"]})')

    parser.add_argument('-ji', '--job_index', type=int, default=cfg['job_index'],
                        help=f'Job index (default: {cfg["job_index"]})')

    parser.add_argument('-d', '--dataset_name', type=str, default=cfg['dataset_name'],
                        help=f'Dataset name (default: {cfg["dataset_name"]})')

    parser.add_argument('-rs', '--random_seed', type=int, default=cfg['random_seed'],
                        help=f'Random seed for reproducible results (default: {cfg["random_seed"]})')
    

    args = parser.parse_args()

    # Update cfg with parsed arguments
    updated_cfg = cfg.copy()
    updated_cfg['noise_mean'] = args.noise_mean
    updated_cfg['noise_sd'] = args.noise_sd
    updated_cfg['adaptation_flag'] = args.adaptation_flag
    updated_cfg['adapt_loss_mode'] = args.adapt_loss_mode
    updated_cfg['highest_moment'] = args.highest_moment
    updated_cfg['lambda'] = args.lambda_tmp
    updated_cfg['job_index'] = args.job_index
    updated_cfg['dataset_name'] = args.dataset_name
    updated_cfg['random_seed'] = args.random_seed

    # Print the updated parameters that were changed
    for key in ['noise_mean', 'noise_sd', 'adaptation_flag', 'adapt_loss_mode', 'dataset_name',
                'highest_moment', 'lambda', 'job_index', 'random_seed']:
        if key == 'lambda' and updated_cfg[key] != cfg[key]:
            print(f"Command-line override: lambda = {updated_cfg[key]} (default: {cfg[key]})")
        elif updated_cfg[key] != cfg[key]:
            print(f"Command-line override: {key} = {updated_cfg[key]} (default: {cfg[key]})")

    return updated_cfg


# Usage:
# Define the cfg dictionary with defaults first
default_cfg = {
    # Data Loading
    "dataset_name": "mnist",  # mnist, fashion_mnist
    "selected_labels": None,  # None or e.g., [0, 1, 7]
    "num_samples_per_label": None,  # None, int (e.g., 1000), or list (e.g., [500, 800]) corresponding to selected_labels

    # Splitting
    "target_split_ratio": 0.5,  # e.g., 0.2 means 20% target, 80% source
    "random_seed": 1,  # For reproducible split (e.g. 0, 1, 42, or any int)

    # Noise Addition (Target Only)
    "noise_mean": 0.4,  # Noise mean N(mean, std_dev^2)
    "noise_sd": 0.7,  # Noise standard deviation

    # General
    "job_index": 1,
    "adaptation_flag": True,
    "adapt_loss_mode": "geo_adapt",  
    "geo_adapt_metric": "airm",
    "highest_moment": 2,
    "num_epochs": 200,
    "batch_size": 128,
    "lr": 0.0002,
    "lambda": 0.1,
    "every_nth_epoch": 1,
    "initial_epochs": 0,
    # Visualization
    "num_visualize": 40  # Number of images to show in plots
}