import argparse
from argparse import RawTextHelpFormatter

def get_args(): 
    parser = argparse.ArgumentParser(description='Federated Ravan', formatter_class=RawTextHelpFormatter)

    # System params 
    parser.add_argument('--seed', type=int, default=0,
        help='Torch random seed')
    parser.add_argument('--verbose', type=int, default=1,
        help='Print or log output')
    parser.add_argument('--wandb-logging', type=int, default=0,
        help='Whether or not to log results on wandb')
    parser.add_argument('--wandb-project', type=str, default='',
        help='Name to use for wandb project logging')
    
    # Dataset params 
    parser.add_argument('--dataset', type=str, default='cifar10',
        help='Which dataset to use')
    parser.add_argument('--clients', type=int, default=20,
        help='Number of client devices to use in FL training')
    parser.add_argument('--iid-alpha', type=float, default=-1,
        help='Level of heterogeneity to introduce across client devices') 
    
    # LoRA params 
    parser.add_argument('--adaptation-method', type=str, default='lora', 
        help='Which adaptation method to use for fine-tuning')
    parser.add_argument('--rank', type=int, default=16,
        help='LoRA Rank')
    parser.add_argument('--alpha', type=int, default=16,
        help='Alpha parameter for LoRA adapter')
    parser.add_argument('--b-var', type=float, default=0.0,
        help='Variance for normally initialized B parameter')
    parser.add_argument('--a-var', type=float, default=0.02,
        help='Variance for normally initialized A parameter')
    parser.add_argument('--r-var', type=float, default=0.02,
        help='Variance for normally distributed R parameter')
    parser.add_argument('--num-heads', type=int, default=4,
        help='Variance for normally distributed R parameter')
    parser.add_argument('--het-ranks', type=int, default=0, 
        help='Whether or not we are training with heterogeneous ranks')
    parser.add_argument('--het-dist', type=str, default='uniform', 
        help='If training with heterogeneous ranks, rank distribution across clients')
    parser.add_argument('--ranking', type=str, default='random',
        help='How to rank the heads in Ravan if performing some kind of partial freezing')
    parser.add_argument('--init-scheme', type=str, default='random_normal',
        help='How to initialize ravan parameters')
    
    # Server training params 
    parser.add_argument('--server-lr', type=float, default=5e-3,
        help='Learning rate at server for aggregation')
    parser.add_argument('--comm-rounds', type=int, default=100,
        help='Number of communication rounds')
    parser.add_argument('--clients-round', type=int, default=3,
        help='Number of clients per communication round')
    parser.add_argument('--aggregation-method', type=str, default='FedAvg', 
        help='Aggregation method at server')
    
    # Client training params 
    parser.add_argument('--optimizer', type=str, default='sgd',
        help='Which optimizer (ADAM or SGD) to use for training')
    parser.add_argument('--client-lr', type=float, default=1e-3,
        help='Learning rate at client for local updates')
    parser.add_argument('--momentum', type=float, default=0.9,
        help='Momentum')
    parser.add_argument('--epochs', type=int, default=1,
        help='Number of local epochs to train for')
    parser.add_argument('--local-steps', type=int, default=0,
        help='Number of local steps to train for')
    parser.add_argument('--batch-size', type=int, default=16, 
        help='Batch size for local client training')
    
    args = parser.parse_args()

    out_str = None
    if 'ravan' in args.adaptation_method: 
        out_str = 'Method=' + str(args.adaptation_method) + "_Rank=" + str(args.rank) + "_BVar=" + str(args.b_var) + "_RVar=" + str(args.r_var) + "_AVar=" + str(args.a_var) + "_optimizer=" + str(args.optimizer) + "_lr=" + str(args.client_lr) + "_clients=" + str(args.clients) + "_iid=" + str(args.iid_alpha) + "_agg=" + str(args.aggregation_method) + "_seed=" + str(args.seed) 
    elif 'sb' in args.adaptation_method: 
        out_str = 'Method=' + str(args.adaptation_method) + "_Rank=" + str(args.rank) + "_optimizer=" + str(args.optimizer) + "_lr=" + str(args.client_lr) + "_clients=" + str(args.clients) + "_iid=" + str(args.iid_alpha) + "_agg=" + str(args.aggregation_method) + "_seed=" + str(args.seed) 
    elif 'lora' in args.adaptation_method: 
        out_str = 'Method=' + str(args.adaptation_method) + "_Rank=" + str(args.rank) + "_optimizer=" + str(args.optimizer) + "_lr=" + str(args.client_lr) + "_clients=" + str(args.clients) + "_iid=" + str(args.iid_alpha) + "_agg=" + str(args.aggregation_method) + "_seed=" + str(args.seed) 
    elif 'full_ft' in args.adaptation_method: 
        out_str = 'Method=' + str(args.adaptation_method) + "_optimizer=" + str(args.optimizer) + "_lr=" + str(args.client_lr) + "_clients=" + str(args.clients) + "_iid=" + str(args.iid_alpha) + "_agg=" + str(args.aggregation_method) + "_seed=" + str(args.seed)

    return args, out_str