import numpy as np
import random
import torch

from evaluates.MainTaskVFedCD import MainTaskVFedCD
from load.LoadConfigs import *
from load.LoadParty import load_vfedcd_parties

from utils.basic_functions import append_exp_res, compute_metrics
import warnings

warnings.filterwarnings("ignore")

TARGETED_BACKDOOR = ['ReplacementBackdoor', 'ASB']  # main_acc  backdoor_acc
UNTARGETED_BACKDOOR = ['NoisyLabel', 'MissingFeature', 'NoisySample']  # main_acc
LABEL_INFERENCE = ['BatchLabelReconstruction', 'DirectLabelScoring', 'NormbasedScoring', \
                   'DirectionbasedScoring', 'PassiveModelCompletion', 'ActiveModelCompletion']
ATTRIBUTE_INFERENCE = ['AttributeInference']
FEATURE_INFERENCE = ['GenerativeRegressionNetwork', 'ResSFL', 'CAFE']


def set_seed(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def evaluate_no_attack(args):
    # No Attack
    set_seed(args.current_seed)

    if args.causal['skip_s1']:
        s1_mask = 'none'
    else:
        s1_args = args
        s1_args.stage = 0
        s1_args = load_vfedcd_parties(s1_args)
        vfl = MainTaskVFedCD(s1_args)
        s1_mask, dag = vfl.train()

        if args.B_true is not None:
            recall_mask = (args.B_true.astype(bool) & s1_mask.astype(bool)).sum() / args.B_true.sum()
            print(f"Recall of stage 1 mask: {recall_mask}")
        else:
            recall_mask = -1

        fraction_edges_mask = s1_mask.sum() / (
                s1_mask.shape[0] * s1_mask.shape[1]
        )
        print(f"Fraction of possible edges in mask: {fraction_edges_mask}")

    args.causal['mask'][1] = s1_mask
    s2_args = args
    s2_args.stage = 1
    s2_args = load_vfedcd_parties(s2_args)
    vfl = MainTaskVFedCD(s2_args)
    B_pred, dag_pred = vfl.train()
    if args.B_true is not None:
        metrics_dict = compute_metrics(dag_pred, args.B_true)
        print(f"Final pred: {metrics_dict}")

    return vfl


def evaluate_feature_inference(args):
    assert len(args.feature_inference_index) == 1, "get {} Unsplit attacks, expect one.".format(len(args.feature_inference_index))
    attack_index = args.feature_inference_index[0]
    torch.cuda.empty_cache()

    set_seed(args.current_seed)
    args = load_attack_configs(args.configs, args, attack_index)
    print('======= Test Attack', attack_index, ': ', args.attack_name, ' =======')
    print('attack configs:', args.attack_configs)

    if args.attack_name != 'UnSplit':
        assert 1 == 2, "BVCD supports UnSplit only, but get{}".format(args.attack_name)

    else:  # unsplit
        set_seed(args.current_seed)

        if args.causal['skip_s1']:
            s1_mask = 'none'
        else:
            s1_args = args
            s1_args.stage = 0
            s1_args = load_vfedcd_parties(s1_args)
            vfl = MainTaskVFedCD(s1_args)
            s1_mask, dag = vfl.train()

            if args.B_true is not None:
                recall_mask = (args.B_true.astype(bool) & s1_mask.astype(bool)).sum() / args.B_true.sum()
                print(f"Recall of stage 1 mask: {recall_mask}")
            else:
                recall_mask = -1

            fraction_edges_mask = s1_mask.sum() / (
                    s1_mask.shape[0] * s1_mask.shape[1]
            )
            print(f"Fraction of possible edges in mask: {fraction_edges_mask}")

        args.causal['mask'][1] = s1_mask
        s2_args = args
        s2_args.stage = 1
        s2_args = load_vfedcd_parties(s2_args)
        vfl = MainTaskVFedCD(s2_args)
        B_pred, dag_pred = vfl.train()
        if args.B_true is not None:
            metrics_dict = compute_metrics(dag_pred, args.B_true)
            print(f"Final pred: {metrics_dict}")
    unsplit_mse = vfl.evaluate_unsplit()

    print("UnSplit person:{}".format(unsplit_mse))
    # Save record for different defense method
    return vfl


if __name__ == '__main__':
    parser = argparse.ArgumentParser("backdoor")
    parser.add_argument('--device', type=str, default='cuda', help='use gpu or cpu')
    parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
    parser.add_argument('--seed', type=int, default=97, help='random seed')
    parser.add_argument('--configs', type=str, default='test', help='configure json file path')
    parser.add_argument('--save_model', type=bool, default=False, help='whether to save the trained model')
    args = parser.parse_args()

    # for seed in range(97,102): # test 5 times
    # for seed in [60]:
    # for seed in [97,98,99,100,101]: # test 5 times
    args.current_seed = args.seed
    set_seed(args.seed)
    print('================= iter seed ', args.seed, ' =================')

    args = load_basic_configs(args.configs, args)
    args.need_auxiliary = 0  # no auxiliary dataset for attackerB

    if args.device == 'cuda':
        cuda_id = args.gpu
        torch.cuda.set_device(cuda_id)
        print(f'running on cuda{torch.cuda.current_device()}')
    else:
        print('running on cpu')

    ####### load configs from *.json files #######
    ############ Basic Configs ############

    # for mode in [0]:

    #     if mode == 0:
    #         args.global_model = 'ClassificationModelHostHead'
    #     else:
    #         args.global_model = 'ClassificationModelHostTrainableHead'
    #     args.apply_trainable_layer = mode

    mode = args.apply_trainable_layer
    print('============ apply_trainable_layer=', args.apply_trainable_layer, '============')
    # print('================================')

    assert args.dataset_split != None, "dataset_split attribute not found config json file"
    assert 'dataset_name' in args.dataset_split, 'dataset not specified, please add the name of the dataset in config json file'
    args.dataset = args.dataset_split['dataset_name']
    # print(args.dataset)

    print('======= Defense ========')
    print('Defense_Name:', args.defense_name)
    print('Defense_Config:', str(args.defense_configs))
    print('===== Total Attack Tested:', args.attack_num, ' ======')
    print('targeted_backdoor:', args.targeted_backdoor_list, args.targeted_backdoor_index)
    print('untargeted_backdoor:', args.untargeted_backdoor_list, args.untargeted_backdoor_index)
    print('label_inference:', args.label_inference_list, args.label_inference_index)
    print('attribute_inference:', args.attribute_inference_list, args.attribute_inference_index)
    print('feature_inference:', args.feature_inference_list, args.feature_inference_index)

    # Save record for different defense method
    args.exp_res_dir = f'exp_result/{args.dataset}/Q{str(args.Q)}/{str(mode)}/'
    if not os.path.exists(args.exp_res_dir):
        os.makedirs(args.exp_res_dir)
    filename = f'{args.defense_name}_{args.defense_param},model={args.model_list[str(0)]["type"]}.txt'
    args.exp_res_path = args.exp_res_dir + filename
    print(args.exp_res_path)
    print('=================================\n')

    iterinfo = '===== iter ' + str(args.seed) + ' ===='
    append_exp_res(args.exp_res_path, iterinfo)

    args.basic_vfl_withaux = None
    args.main_acc_noattack_withaux = None
    args.basic_vfl = None
    args.main_acc_noattack = None

    # [s1_main_lr, s2_main_lr] = args.main_lr
    # [s1_alpha, s2_alpha] = args.causal['alpha']
    # [s1_beta, s2_beta] = args.causal['beta']
    # [s1_gamma_from, s2_gamma_from] = args.causal['gamma_from']
    # [s1_gamma_to, s2_gamma_to] = args.causal['gamma_to']
    # [s1_threshold, s2_threshold] = args.causal['threshold']
    # args.main_lr = s1_main_lr
    # args.causal['alpha'] = s1_alpha
    # args.causal['beta'] = s1_beta
    # args.causal['gamma_from'] = s1_gamma_from
    # args.causal['gamma_to'] = s1_gamma_to
    # args.causal['threshold'] = s1_threshold

    args = load_attack_configs(args.configs, args, -1)

    commuinfo = '== commu:' + args.communication_protocol
    append_exp_res(args.exp_res_path, commuinfo)

    if args.feature_inference_list != []:
        args.basic_vfl = evaluate_feature_inference(args)
    else:
        args.basic_vfl = evaluate_no_attack(args)
