import argparse
from misc.utils import *

class Parser:
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.set_arguments()
       
    def set_arguments(self):
        self.parser.add_argument('--gpu', type=str, default='0')
        self.parser.add_argument('--seed', type=int, default=1234)

        self.parser.add_argument('--model', type=str, default=None)
        self.parser.add_argument('--dataset', type=str, default=None)
        self.parser.add_argument('--mode', type=str, default='disjoint')
        self.parser.add_argument('--base-path', type=str, default='../')

        self.parser.add_argument('--n-workers', type=int, default=None)
        self.parser.add_argument('--n-clients', type=int, default=None)
        self.parser.add_argument('--n-rnds', type=int, default=None)
        self.parser.add_argument('--n-eps', type=int, default=None)
        self.parser.add_argument('--frac', type=float, default=1.0)
        self.parser.add_argument('--n-dims', type=int, default=128)
        self.parser.add_argument('--lr', type=float, default=0.01)

        # self.parser.add_argument('--agg-norm', type=str, default='exp', choices=['cosine', 'exp'])
        # self.parser.add_argument('--norm-scale', type=float, default=10)
        # self.parser.add_argument('--n-proxy', type=int, default=5)

        # self.parser.add_argument('--l1', type=float, default=1e-3)
        self.parser.add_argument('--loc-l2', type=float, default=0.001)
        self.parser.add_argument('--fname', type=str, default=None)

        self.parser.add_argument('--debug', action='store_true')
        self.parser.add_argument('--c', type=float, default=1.0)

        self.parser.add_argument('--optimizer', type=str, default='adam')
        self.parser.add_argument('--learnable_k', type=int, default=1) 
        self.parser.add_argument('--classifier', type=str, default='tan')
        self.parser.add_argument('--lss_func', type=str, default='CELoss')
        self.parser.add_argument('--rescale', type=int, default=0)
        self.parser.add_argument('--bc', type=float, default=0.7)

        self.parser.add_argument('--csv', type=int, default=1)
        self.parser.add_argument('--summary', type=int, default=1)
        self.parser.add_argument('--wandb', type=int, default=0)


    def parse(self):
        args, unparsed  = self.parser.parse_known_args()
        if len(unparsed) != 0:
            raise SystemExit('Unknown argument: {}'.format(unparsed))
        return args
