from yacs.config import CfgNode as CN
from utils.utils import log_msg


# simplify cfg
def simplify_cfg(args, cfg):
    dump_cfg = CN()
    dump_cfg.DATASET = cfg.DATASET
    dump_cfg.OPTIMIZER = cfg.OPTIMIZER
    dump_cfg[args.method] = cfg[args.method]
    dump_cfg[args.task] = cfg[args.task]

    # simplify Sever cfg
    if cfg[args.method].global_method in list(cfg['Sever'].keys()):
        dump_cfg['Sever'] = CN()
        dump_cfg['Sever'][cfg[args.method].global_method] = CN()
        dump_cfg['Sever'][cfg[args.method].global_method] = cfg['Sever'][cfg[args.method].global_method]

    # simplify Local cfg
    if cfg[args.method].local_method in list(cfg['Local'].keys()):
        dump_cfg['Local'] = CN()
        dump_cfg['Local'][cfg[args.method].local_method] = CN()
        dump_cfg['Local'][cfg[args.method].local_method] = cfg['Local'][cfg[args.method].local_method]

    if args.attack_type != 'None':
        dump_cfg['attack'] = CN()
        dump_cfg['attack'].bad_client_rate = cfg['attack'].bad_client_rate
        dump_cfg['attack'].noise_data_rate = cfg['attack'].noise_data_rate
        dump_cfg['attack'][args.attack_type] = cfg['attack'][args.attack_type]

    return dump_cfg


def show_cfg(args, cfg, method):
    dump_cfg = CN()
    dump_cfg.DATASET = cfg.DATASET
    dump_cfg.OPTIMIZER = cfg.OPTIMIZER
    dump_cfg[method] = cfg[method]
    if args.attack_type != 'None':
        dump_cfg['attack'] = cfg['attack']
    print(log_msg("CONFIG:\n{}".format(dump_cfg.dump()), "INFO"))
    return dump_cfg


CFG = CN()
'''Federated dataset'''
CFG.DATASET = CN()
CFG.DATASET.dataset = "fl_cifar10"  #
CFG.DATASET.communication_epoch = 2
CFG.DATASET.n_classes = 10

CFG.DATASET.parti_num = 4
CFG.DATASET.online_ratio = 1.0
CFG.DATASET.domain_ratio = 1.0
CFG.DATASET.train_eval_domain_ratio = 0.01
CFG.DATASET.backbone = "resnet18"
CFG.DATASET.pretrained = False
CFG.DATASET.aug = "weak"
CFG.DATASET.beta = 0.5

'''attack'''
CFG.attack = CN()
CFG.attack.bad_client_rate = 0.2
CFG.attack.noise_data_rate = 0.5

CFG.attack.byzantine = CN()
CFG.attack.byzantine.evils = 'PairFlip'  # PairFlip SymFlip RandomNoise lie_attack min_max min_sum
CFG.attack.byzantine.dataset_type = 'single_domain'

# attack para for min_max and min_sum
CFG.attack.byzantine.dev_type = 'std'
CFG.attack.byzantine.lamda = 10.0
CFG.attack.byzantine.threshold_diff = 1e-5

CFG.attack.backdoor = CN()
CFG.attack.backdoor.evils = 'base_backdoor'  # base_backdoor semantic_backdoor
CFG.attack.backdoor.backdoor_label = 2
CFG.attack.backdoor.trigger_position = [
    [0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 4], [0, 0, 5], [0, 0, 6],
    [0, 2, 0], [0, 2, 1], [0, 2, 2], [0, 2, 4], [0, 2, 5], [0, 2, 6], ]
CFG.attack.backdoor.trigger_value = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ]

CFG.attack.backdoor.semantic_backdoor_label = 3

'''task'''
# label_skew
CFG.label_skew = CN()

# domain_skew
CFG.domain_skew = CN()

# OOD
CFG.OOD = CN()
# Digits: MNIST, USPS, SVHN, SYN
# PACS: 'photo', 'art_painting', 'cartoon', 'sketch'
# OfficeCaltech 'caltech', 'amazon','webcam','dslr'
# OfficeHome 'Art', 'Clipart', 'Product', 'Real_World'
# DomainNet 'clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'
# Office31 'amazon', 'dslr', 'webcam'
CFG.OOD.out_domain = 'caltech'

'''Federated OPTIMIZER'''
CFG.OPTIMIZER = CN()
CFG.OPTIMIZER.type = 'SGD'
CFG.OPTIMIZER.momentum = 0.9
CFG.OPTIMIZER.weight_decay = 1e-5
CFG.OPTIMIZER.local_epoch = 2
CFG.OPTIMIZER.local_train_batch = 64
CFG.OPTIMIZER.local_test_batch = 64
CFG.OPTIMIZER.val_batch = 64
CFG.OPTIMIZER.local_train_lr = 1e-3

'''Sever'''
CFG.Sever = CN()



'''Local'''
CFG.Local = CN()

CFG.Local.DKDRLocal = CN()
CFG.Local.DKDRLocal.tau = 1.0
CFG.Local.DKDRLocal.beta = 1.0

'''Federated Method'''
# qffeAVG
CFG.qffeAVG = CN()
CFG.qffeAVG.local_method = 'qffeAVGLocal'
CFG.qffeAVG.global_method = 'qffeAVGSever'

# FedAVG
CFG.FedAVG = CN()
CFG.FedAVG.local_method = 'BaseLocal'
CFG.FedAVG.global_method = 'BaseSever'

CFG.DKDR = CN()
CFG.DKDR.local_method = 'DKDRLocal'
CFG.DKDR.global_method = 'DKDRSever'