import distributed
import torch


def parse_args(base_parser, args, namespace):
    parser = base_parser
    # General training params
    parser.add_argument('--num_clients', required=True, type=int)
    parser.add_argument('--batch_size', default=50, type=int)
    parser.add_argument('--acc_steps', default=4, type=int)
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--device', default='cuda:0', type=str)
    parser.add_argument('--iterations', default=15000, type=int)
    parser.add_argument('--lr', default=2e-3, type=float)
    parser.add_argument('--warmup_percent', default=0.02, type=float)
    parser.add_argument('--weight_decay', default=1e-3, type=float)
    parser.add_argument('--beta1', default=0.9, type=float)
    parser.add_argument('--beta2', default=0.95, type=float)
    parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none'])
    parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd'])
    parser.add_argument('--eval_freq', default=200, type=int)  # in iterations
    parser.add_argument('--results_base_folder', default="./exps", type=str)
    parser.add_argument('--grad_clip', default=0.0, type=float)  # default value is 1.0 in NanoGPT
    # Dataset params
    parser.add_argument('--dataset', default='wikitext',
                        choices=['wikitext', 'multi', 'fed_cc_news',
                                 'agnews_mixed', 'agnews_specific',
                                 'three_multi_specific', 'three_multi_mixed',
                                 'github_wiki_specific', 'github_wiki_mixed',
                                 'split_wiki_de', 'split_wiki_it', 'split_wiki_fr', 'split_wiki_en', 'wiki40b'])
    parser.add_argument('--vocab_size', default=50304, type=int)
    # Model params
    parser.add_argument('--model', default='lora', choices=['lora'])
    parser.add_argument('--use_pretrained', default='gpt2',
                        type=str)  # 'none', 'gpt2' or a path to the pretrained model
    parser.add_argument('--dropout', default=0.2, type=float)
    parser.add_argument('--n_head', default=12, type=int)
    parser.add_argument('--n_layer', default=24, type=int)  # depths in att + ff blocks
    parser.add_argument('--n_embd', default=768, type=int)  # embedding size / hidden size ...
    parser.add_argument('--sequence_length', default=512, type=int)
    parser.add_argument('--dtype', default=torch.float16, type=torch.dtype)
    parser.add_argument('--bias', default=False, type=bool)
    parser.add_argument('--no_compile', action='store_true')  # if true then model is not compiled
    # Distributed args
    parser.add_argument('--distributed_backend', default=None, type=str, required=False,
                        choices=distributed.registered_backends())  # distributed backend type
    # logging params (WandB)
    parser.add_argument('--wandb', action='store_true')  # whether to use wandb or not
    parser.add_argument('--wandb_project', default='none', type=str)
    parser.add_argument('--wandb_group', default='none', type=str)
    parser.add_argument('--wandb_run_prefix', default='none',
                        type=str)  # is added before the autogenerated experiment name
    # LoRA params
    parser.add_argument('--lora_rank', default=4, type=int)
    parser.add_argument('--lora_alpha', default=32., type=float)
    parser.add_argument('--lora_dropout', default=0.1, type=float)
    # LoRA config params
    parser.add_argument('--lora_mlp', action='store_true')
    parser.add_argument('--lora_causal_self_attention', action='store_true')
    parser.add_argument('--lora_freeze_all_non_lora', action='store_true')
    parser.add_argument('--lora_allow_embedding', action='store_true')
    # Trust scheme params
    parser.add_argument('--trust', type=str, default='none', help='none, dynamic, naive, dynamic-thresh')
    parser.add_argument('--trust_freq', type=int, default=1)
    parser.add_argument('--pretraining_rounds', type=int, default=0)
    parser.add_argument('--k', type=int, default=4)
    parser.add_argument('--w_lambda', type=float, default=1.0)

    return parser.parse_args(args, namespace)
