import os
import argparse

TASK_POOL          = 7 # 7 (CIFAR100-superclass-overlapped)
FEDERATED          = True
CONTINUAL          = True
SPARSE_COMM        = False
SERVER_SPARSE_COMM = False
CLIENT_SPARSITY    = 0.7 # if 0.7 then client sends 30% of local weights
SERVER_SPARSITY    = 0.7 # if 0.7 then server sends 30% of global weights

MODEL              = 3 # 3(FedWeIT)
BASE_ARCHITECT     = 0 # 0(LeNet)
FED_METHOD         = 1 # 1(FedPlox)
NUM_CLIENTS        = 1
FRAC_CLIENTS       = 1
NUM_ROUNDS         = 5
NUM_EPCOHS         = 1
BATCH_SIZE         = 100

DATA_DIR           = 'data/tasks'
MIXTURE_DIR        = 'data/mixture_loader'

HOST_PORT          = 5023
SAVE_WEIGHTS       = False
LOAD_WEIGHTS       = False
LOAD_WEIGHTS_DIR   = ''

GPU_DEFAULT        = '-1'
WORKER_TYPE        = 'server'
MANUAL_TASKS       = []

class Parser:

	def __init__(self):
		self.parser = argparse.ArgumentParser()
		self.set_arguments()

	def set_arguments(self):
		self.parser.add_argument('-g', '--gpu',          type=str,   default=GPU_DEFAULT,      help='gpu id')
		self.parser.add_argument('-w', '--worker-type',  type=str,   default=WORKER_TYPE,      help='worker type (server, client, data)')
		self.parser.add_argument('-t', '--manual-tasks', type=str,   default=MANUAL_TASKS,     help='task name to experiment', nargs='*')
		self.parser.add_argument('--federated',          type=bool,  default=FEDERATED,        help='federated learning, otherwise continual learning')
		self.parser.add_argument('--continual',          type=bool,  default=CONTINUAL,        help='continual learning, otherwise single task learning')
		self.parser.add_argument('--sparse-comm',        type=bool,  default=SPARSE_COMM,      help='sparse communication')
		self.parser.add_argument('--server-sparse-comm', type=bool,  default=SERVER_SPARSE_COMM,      help='sparse communication')
		self.parser.add_argument('--client-sparsity',    type=float, default=CLIENT_SPARSITY,  help='sparsity ratio of client-side communication')
		self.parser.add_argument('--sparse-broad-rate',  type=float, default=SERVER_SPARSITY,  help='sparsity ratio of broadcasting')
		self.parser.add_argument('--model',              type=int,   default=MODEL,            help='model: 0(L2T), 1(APD), 2(ABC)')
		self.parser.add_argument('--base-architect',     type=int,   default=BASE_ARCHITECT,   help='architecture: 0(LeNet), 1(AlexNet)')
		self.parser.add_argument('--task-pool',          type=int,   default=TASK_POOL,        help='task pool: 0 (Hetero-8), 1 (NonIID-58), 2 (Overalpped-50) ')
		self.parser.add_argument('--fed-method',         type=int,   default=FED_METHOD,       help='0(FedAvg), 1(FedPlox)')
		self.parser.add_argument('--num-clients',        type=int,   default=NUM_CLIENTS,      help='number of clients')
		self.parser.add_argument('--frac-clients',       type=float, default=FRAC_CLIENTS,     help='fraction of clients per round')
		self.parser.add_argument('--num-rounds',         type=int,   default=NUM_ROUNDS,       help='number of rounds')
		self.parser.add_argument('--num-epochs',         type=int,   default=NUM_EPCOHS,       help='number of epochs')
		self.parser.add_argument('--batch-size',         type=int,   default=BATCH_SIZE,       help='batch size')
		self.parser.add_argument('--data-dir',           type=str,   default=DATA_DIR,         help='data path')
		self.parser.add_argument('--mixture-dir',        type=str,   default=MIXTURE_DIR,      help='mixture data path')
		self.parser.add_argument('--load-weights-dir',   type=str,   default=LOAD_WEIGHTS_DIR, help='load weights path')
		self.parser.add_argument('--save-weights',       type=bool,  default=SAVE_WEIGHTS,     help='True, if want to save global weights')
		self.parser.add_argument('--load-weights',       type=bool,  default=LOAD_WEIGHTS,     help='True, if want to load global weights')
		self.parser.add_argument('--host-port',          type=int,   default=HOST_PORT,        help='host port num')

	def  parse(self):
		args, unparsed  = self.parser.parse_known_args()
		if len(unparsed) != 0:
			raise SystemExit('Unknown argument: {}'.format(unparsed))
		return args
