import argparse
import sys


def _str2bool(v):
    if isinstance(v, bool):
        return v
    if isinstance(v, str):
        val = v.lower()
        if val in {"true", "1", "yes", "y"}:
            return True
        if val in {"false", "0", "no", "n"}:
            return False
    raise argparse.ArgumentTypeError("Boolean value expected.")
def args_parser():
    parser = argparse.ArgumentParser()

    ### inirial ###
    parser.add_argument('--device_id', type=int, default=0, help='The Device Id for Experiment')
    parser.add_argument('--device', type=int, default=0, help='The Device Id for Experiment')
    parser.add_argument('--target', type=str, default='retrain',
                        choices={'learning', 'retrain', 'rapid_retrain', 'federaser', 'increase_loss',
                                 'class_pruning', 'fedsalun', 'fu_dws'},
                        help='target method')
    parser.add_argument('--dataset', type=str, default='domain_digits', help="name of dataset domain_digits", choices=['domain_digits', 'office-caltech10', 'DomainNet', 'PACS'])
    parser.add_argument('--dataset_fullparti', type=bool, default=True, help="if data full participant")
    parser.add_argument('--n_train', type=int, default=450, help="num of train set") ## 450 for domain_digits and 5 for office-caltech10
    parser.add_argument('--iid',default=True, help='whether i.i.d or not')
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
    parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
    parser.add_argument('--verbose', action='store_true', help='verbose print')
    parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
    parser.add_argument('--percent', type = float, default= 1, help ='percentage of dataset to train')
    parser.add_argument('--load', type=str, default="test", help="load")
    parser.add_argument('--save', type=str, default='rec', help="save")
    parser.add_argument('--pre_train', type=int, default=0, help="save code")
    parser.add_argument('--evals', type=bool, default=False)
    parser.add_argument('--domain_skew', type=bool, default=True)
    parser.add_argument('--domain_skew_ratio', type=float, default=10)
    parser.add_argument('--freeze_ulr', type=str, default='neglabel', choices=['neglabel', 'increaloss'], help='freeze_ulr')
    parser.add_argument('--fedfrz_epoch', type=int, default=10)
    parser.add_argument('--frzulr', type=float, default=0.01)
    parser.add_argument('--freeze_layers', 
                    type=str, 
                    default='conv1,conv2,conv3,fc1,fc2',
                    help='Comma-separated layer names to freeze')
    

    ### FL settings ###
    parser.add_argument('--epochs', type=int, default=10, help="num of communication rounds") ### 10 for domain_digits
    parser.add_argument('--method', type=str, default='fedavg', help='The Device Id for Experiment')
    parser.add_argument('--local_ep', type=int, default=10, help="the number of local epochs: E")
    parser.add_argument('--local_ep_nlp', type=int, default=50, help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=64, help="local batch size: B") # 64
    parser.add_argument('--bs', type=int, default=128, help="test batch size") # 128
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
    parser.add_argument('--momentum', type=float, default=0.9, help="SGD momentum (default: 0.5)")
    parser.add_argument('--num_users', type=int, default=5, help="number of users: K")
    parser.add_argument('--domain_split_factor', type=int, default=1,
                        help='split each domain into this many clients')
    parser.add_argument('--domain_split_config', type=str, default='',
                        help='comma separated client counts for each domain, e.g. "1,2,3"')
    parser.add_argument('--bkd_domain_idx', type=int, default=12345,
                        help='domain index for backdoor (12345 disables)')
    parser.add_argument('--domain_times_factor', type=int, default=1,
                        help='domain-specific scaling factor')
    parser.add_argument('--frac', type=float, default=1, help="the fraction of clients: C")
    parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")

    ### model arguments ###
    parser.add_argument('--model', type=str, default='cnn', choices={'cnn', 'vgg16', 'resnet18', 'resnet50', 'vit', 'mobilevit', 'mae_vit'}, help='model name')
    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to use for convolution')
    parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than strided convolutions")
    parser.add_argument('--grad_accumulation', type=int, default=1,
                    help="Gradient accumulation steps")
    parser.add_argument('--use_cuda_graph', type=bool, default=True,
                    help="Enable CUDA Graph optimization")
    parser.add_argument('--vgg_depth', type=int, default=16, choices=[11,13,16,19], help='VGG depth configuration')
    parser.add_argument('--fc_dropout', type=float, default=0.5, help='FC layer dropout ratio')
    
    # verify
    parser.add_argument('--verify', type=str, default='normal', choices=['normal', 'marker', 'backdoor'],help="Verification Methods")
    parser.add_argument('--backdoor_target_label', type=int, default=9, help="backdoor_target_label rounds")
    parser.add_argument('--backdoor_client_idx', type=str, default='0')

    parser.add_argument('--record_forget_event', action='store_true', help="record forget event")
    parser.add_argument('--proto', type=_str2bool, default=False,
                        help='Use prototype-based domain updates during local training (True/False)')
    parser.add_argument('--unlearn', action='store_true', help='unlearning_client')
    parser.add_argument('--sal_th', type=float, default=0.5, help="sal_threshold")

    parser.add_argument('--out', type=str, default='', help="save code")
    
    parser.add_argument('--remove_sample', action='store_true')
    parser.add_argument('--remain', type=int, default=0, help="save code")

    parser.add_argument('--step_size', help='rate weight step', type=float, default=0.2)
    parser.add_argument("--fair", type=str, default='acc', choices=['acc', 'loss'],help="the fairness metric for FedAvg")

    parser.add_argument('--change_client', type=int, default=-1, help='unlearning_client')
    parser.add_argument('--backdoor_percent_poison', type=float, default=0.5, help="backdoor_percent_poison")
    ### unlearning ###
    parser.add_argument('--unlearning_client', type=str, default='0',
                        help='comma separated client ids for unlearning (e.g. "0,1")')
    parser.add_argument('--num_local_unlearn_epochs', type=int, default=10, help="rounds of bl1_local_unlearn_epochs")
    parser.add_argument('--baseline', type=int, default=-1, help='the unlearning baseline')
    # baseline1
    parser.add_argument('--clip_grad', type=int, default=5)
    # baseline2
    parser.add_argument('--unlearn_interval', type=int, default=1)
    parser.add_argument('--forget_local_epoch_ratio', type=float, default=0.1,help="backdoor client id 0:MNIST 1:svhn  2:usps 3:synth, 4:mnistm_")
    parser.add_argument('--unlearn_epoch', type=int, default=10, help="new epoch")
    # SalUn
    parser.add_argument('--unlearn_lr', type=float, default=0.1, help='unlearning rate')
    parser.add_argument('--mask_ratio', type=float, default=0.99, help='mask sparsity threshold ratio')
    parser.add_argument('--diff_mask_ratio', type=float, default=0.99, help='Threshold ratio for diff importance mask (0-1 quantile, controls freezing)')
    # parser.add_argument('--freeze_threshold', type=float, default=0.1, help='freeze')
    # parser.add_argument('--importance_threshold', type=float, default=0.1, help='importance_threshold')
    # parser.add_argument('--mask_ratio_recur', type=float, default=0.1, help='mask sparsity threshold ratio')
    # parser.add_argument('--unlearn_lr_recur', type=float, default=0.2, help='unlearning rate')
    parser.add_argument('--fedsalun_epoch', type=int, default=1, help="reserver learning epoch")
    parser.add_argument('--lamb', type=float, default=0.1)
    parser.add_argument('--alpha', type=float, default=0.1, help='parameter consistency strength')
    parser.add_argument('--beta', type=float, default=0.05, help='KL divergence loss weight')
    parser.add_argument('--noise_type', type=str, default='gaussian', 
                   choices=['gaussian', 'uniform', 'laplace', 'bernoulli', 'none'])
    parser.add_argument('--noise_scale', type=float, default=0.01,
                   help='noise strength coefficient (std or amplitude)')
    parser.add_argument('--tg_std', type=float, default=0.1,
                   )
    
    ### MAE ###
    parser.add_argument('--norm_pix_loss', action='store_true',
                       help='Use normalized pixel loss')
    parser.add_argument('--pretrained', action='store_true',
                       help='Load pretrained weights')
    parser.add_argument('--pretrained_path', type=str, default='mae_pretrain.pth',
                       help='pretrained weights path')
    parser.add_argument('--finetune', action='store_true',
                       help='Finetune mode')
    parser.add_argument('--normalize_globally', action='store_true', help='Normalize scores globally across all layers/models instead of layer-wise')
    parser.add_argument('--act_gate_local', type=float, default=0.7)
    parser.add_argument('--act_gate_difference', type=float, default=0.7)
    parser.add_argument('--batch_size', type=int, default=32)

    
    

    args = parser.parse_args()

    # Parse comma-separated client indices into lists
    def _parse_list(val):
        if isinstance(val, str):
            return [int(v) for v in val.split(',') if v.strip()]
        if isinstance(val, (list, tuple)):
            return [int(v) for v in val]
        return [int(val)]

    args.backdoor_client_idx = _parse_list(args.backdoor_client_idx)
    args.unlearning_client = _parse_list(args.unlearning_client)
    args.bkd_domain_idx = int(args.bkd_domain_idx)
    args.domain_times_factor = int(args.domain_times_factor)

    passed_ul = '--unlearning_client' in sys.argv
    passed_bd = '--backdoor_client_idx' in sys.argv
    if passed_bd and not passed_ul:
        args.unlearning_client = args.backdoor_client_idx
    elif passed_ul and not passed_bd:
        args.backdoor_client_idx = args.unlearning_client

    return args


