import argparse
from misc.utils import *

class Parser:
    
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.set_arguments()
       
    def set_arguments(self):
        
        self.parser.add_argument('--gpu', type=str, default='0')
        self.parser.add_argument('--seed', type=int, default=1234)

        self.parser.add_argument('--model', type=str, default=None)
        self.parser.add_argument('--task', type=str, default=None)
        
        self.parser.add_argument('--backbone', type=str, default='GNN')
        self.parser.add_argument('--n-workers', type=int, default=None)
        self.parser.add_argument('--n-clients', type=int, default=None)
        self.parser.add_argument('--n-rnds', type=int, default=None)
        self.parser.add_argument('--n-eps', type=int, default=None)
        self.parser.add_argument('--frac', type=float, default=None)
        self.parser.add_argument('--n-dims', type=int, default=128)
        self.parser.add_argument('--lr', type=float, default=None)

        self.parser.add_argument('--folder', type=str, default='')
        self.parser.add_argument('--trial', type=str, default=None)
        self.parser.add_argument('--base-path', type=str, default='../')

        self.parser.add_argument('--optm-agg', action='store_true')
        self.parser.add_argument('--eval-global', action='store_true')
        self.parser.add_argument('--no-clsf-mask', action='store_true')
        self.parser.add_argument('--laye-mask-one', action='store_true')
        self.parser.add_argument('--clsf-mask-one', action='store_true')

        self.parser.add_argument('--agg-norm', type=str, default='exp', choices=['cosine', 'exp'])
        self.parser.add_argument('--norm-scale', type=float, default=10)   # scaling factor for exp normalization
        self.parser.add_argument('--n-proxy', type=int, default=5)
        self.parser.add_argument('--cluster', action='store_true')

        self.parser.add_argument('--l1', type=float, default=1e-3)
        self.parser.add_argument('--loc-l2', type=float, default=1e-3)

        self.parser.add_argument('--mask-aggr', action='store_true')
        self.parser.add_argument('--mask-rank', type=int, default=-1)
        self.parser.add_argument('--mask-drop', action='store_true')
        self.parser.add_argument('--mask-drop-ratio', type=float, default=0.5)
        self.parser.add_argument('--mask-noise', action='store_true')

        self.parser.add_argument('--debug', action='store_true')

    def parse(self):
        args, unparsed  = self.parser.parse_known_args()
        if len(unparsed) != 0:
            raise SystemExit('Unknown argument: {}'.format(unparsed))
        return args
