import argparse
from misc.utils import *

class Parser:
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.set_arguments()

    def set_arguments(self):

        self.parser.add_argument('--gpu', type=str, default='0')
        self.parser.add_argument('--seed', type=int, default=1234)
        self.parser.add_argument('--model', type=str, default='distilbert-base-multilingual-cased', help='backbone pre-trained language models')
        self.parser.add_argument('--method', type=str, default='fedavg', help='fedavg, lora or')
        self.parser.add_argument('--dataset', type=str, default='20_newsgroups') # ag_news, 20_newsgroups, multi_sent
        self.parser.add_argument('--base-path', type=str, default='../')
        self.parser.add_argument('--adapter', type=str, default='lora', help='lora, hyper, or p-tuning')

        self.parser.add_argument('--quantize', action='store_true', help='Enable quantization')
        self.parser.add_argument('--random_quantize', action='store_true', help='Enable random quantization for each client')

        self.parser.add_argument('--rank', type=int, default=16, help='Rank for LoRA')
        self.parser.add_argument('--random_rank', action='store_true', help='Enable random LoRA rank for each client')
        self.parser.add_argument('--first_stage_epochs', type=int, default=40, help='Number of epochs for the first stage of LoRA fine-tuning')

        self.parser.add_argument('--n-workers', type=int, default=1)
        self.parser.add_argument('--n-clients', type=int, default=10)
        self.parser.add_argument('--n-rnds', type=int, default=60)
        self.parser.add_argument('--n-eps', type=int, default=1)
        self.parser.add_argument('--frac', type=float, default=1.0)
        self.parser.add_argument('--lr', type=float, default=2e-3)
        self.parser.add_argument('--fft', type=int, default=0, help='Full Fine Tuning epoch')

        self.parser.add_argument('--agg-norm', type=str, default='exp', choices=['cosine', 'exp'])
        self.parser.add_argument('--norm-scale', type=float, default=10)

        self.parser.add_argument('--print', type=int, default=1)
        self.parser.add_argument('--debug', action='store_true')
        self.parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training')
        self.parser.add_argument('--datadir', type=str, required=False, default="./data", help="Data directory")
        self.parser.add_argument('--share-data', type=int, default=1, help='share batch for VAE')

    def parse(self):
        args, unparsed = self.parser.parse_known_args()
        if len(unparsed) != 0:
            raise SystemExit('Unknown argument: {}'.format(unparsed))
        return args
