import numpy as np
import torch
torch.autograd.set_detect_anomaly(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import argparse
import warnings
warnings.filterwarnings("ignore")
from sklearn.model_selection import KFold


def parse_args():
    # input arguments
    parser = argparse.ArgumentParser(description='UMMAN')

    parser.add_argument('--embedder', nargs='?', default='UMMAN')
    parser.add_argument('--dataset', nargs='?', default='data')
    parser.add_argument('--relationships', nargs='?', default='euclidean,braycurtis,correlation')
    parser.add_argument('--epochs', type=int, default=1000)  # epoch 10000
    parser.add_argument('--hidden_nodes', type=int, default=1024)  # 64     *
    parser.add_argument('--learningrate', type=float, default=0.001)  # 学习率 0.0005    *
    parser.add_argument('--l2_coef', type=float, default=0.0001)  # 0.0001    *
    parser.add_argument('--drop_prob', type=float, default=0.8)  # 0.5   *
    parser.add_argument('--attn_coef', type=float, default=0.004)  # 0.001 *
    parser.add_argument('--self_conv', type=float, default=300)  # 3.0
    parser.add_argument('--limit', type=int, default=15)  # 20
    parser.add_argument('--head_num', type=int, default=1)  # 1
    parser.add_argument('--activation', nargs='?', default='relu')  # relu     .relu6.leakyrelu
    parser.add_argument('--bias', action='store_true', default=True)  # False
    parser.add_argument('--Attn', action='store_true', default=True)  # False
    parser.add_argument('--addVector', action='store_true', default=True)
    parser.add_argument('--n', type=int, default=10)

    return parser.parse_known_args()


def main():
    accs = []
    acc_stds = []
    precisions = []
    precision_stds = []
    recalls = []
    recall_stds = []
    AUCs = []
    AUC_stds = []
    macro_f1s = []
    macro_f1_stds = []
    micro_f1s = []
    micro_f1_stds = []
    args, unknown = parse_args()
    from models import UMMAN
    embedder = UMMAN(args)

    data = embedder.features[0].numpy().tolist()[0]

    # Normal
    # embedder.training()
    #
    # kFold
    kf = KFold(n_splits=10, shuffle=True, random_state=0)
    for train_index, test_index in kf.split(data):
        embedder.idx_train = train_index
        embedder.idx_test = test_index
        embedder.idx_val = test_index
        acc, acc_std, precision, precision_std, recall, recall_std, AUC, AUC_std, macro_f1, macro_f1_std, micro_f1, micro_f1_std = \
            embedder.training()
        accs.append(acc)
        acc_stds.append(acc_std)
        precisions.append(precision)
        precision_stds.append(precision_std)
        recalls.append(recall)
        recall_stds.append(recall_std)
        AUCs.append(AUC)
        AUC_stds.append(AUC_std)
        macro_f1s.append(macro_f1)
        macro_f1_stds.append(macro_f1_std)
        micro_f1s.append(micro_f1)
        micro_f1_stds.append(micro_f1_std)
    print('Average:')
    print("\tAcc:{:.4f} ({:.4f})".format(np.mean(accs), np.mean(acc_stds)))
    print("\tPrecision:{:.4f} ({:.4f})".format(np.mean(precisions), np.mean(precision_stds)))
    print("\tRecall:{:.4f} ({:.4f})".format(np.mean(recalls), np.mean(recall_stds)))
    print("\tAUC:{:.4f} ({:.4f})".format(np.mean(AUCs), np.mean(AUC_stds)))
    print("\tF1: {:.4f} ({:.4f})".format(np.mean(micro_f1s),np.mean(micro_f1_stds)))
    # print("\tF1: {:.4f} ({:.4f})".format(2 * np.mean(precisions) * np.mean(recalls) / (np.mean(precisions) + np.mean(recalls)),
    #                                      2 * np.mean(precision_stds) * np.mean(recall_stds) / (np.mean(precision_stds) + np.mean(recall_stds))))


if __name__ == '__main__':
    main()
