import warnings
from sklearn import metrics, model_selection
from models.GCN_baseline import GCN_Baseline
from models.gat import GAT
from load_data import GraphData, collect_batch
from parser_1 import _parser
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data
import logging
import time
import numpy as np
import matplotlib
from DataReader import DataReader
from models.GCN_new import GCN_NEW
from models.multigraph import MGCN
from models.Transformer import TransformerModel
from models.对比模型LSTM import LSTMModel
from models.multigraph_midify import GNNModel
from models.ema import EMA
import matplotlib.pyplot as plt

matplotlib.use('agg')
warnings.filterwarnings('ignore')

if __name__ == '__main__':
    print('using torch', torch.__version__)
    args = _parser()
    args.filters = list(map(int, args.filters.split(',')))
    args.lr_decay_steps = list(map(int, args.lr_decay_steps.split(',')))
    for arg in vars(args):
        print(arg, getattr(args, arg))

    n_folds = args.folds
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    rnd_state = np.random.RandomState(args.seed)

    print('Loading training_data...')

    # Creating datareader object and reading the training data file from directory
    # data / call - graph - data
    datareader = DataReader(data_dir='./data/call-graph-data/%s/' % args.dataset, rnd_state=rnd_state,
                            use_cont_node_attr=args.use_cont_node_attr, folds=args.folds)

    # train and test
    result_folds = []
    for fold_id in range(n_folds):
        loaders = []
        for split in ['train', 'test']:
            graph_data = GraphData(
                fold_id=fold_id, datareader=datareader, split=split)
            loader = DataLoader(graph_data, batch_size=args.batch_size, shuffle=split.find('train') >= 0,
                                num_workers=args.threads, collate_fn=collect_batch)
            loaders.append(loader)
        # for i, batch in enumerate(loader):
        #     print(f"Batch {i}:", batch)  # 这行代码可以根据具体数据结构做调整以适应你的需求
        #     break  # 仅打印一个batch，防止输出过多
        print('FOLD {}, train {}, test {}'.format(
            fold_id, len(loaders[0].dataset), len(loaders[1].dataset)))
        print(f"test_data:{loaders[1]}")

        if args.model == 'gat':
            model = GAT(nfeat=loaders[0].dataset.num_features,
                        nhid=64,
                        nclass=loaders[0].dataset.num_classes,
                        dropout=args.dropout,
                        alpha=args.alpha,
                        nheads=args.multi_head).to(args.device)

        elif args.model == 'GCN_baseline':
            model = GCN_Baseline(n_feature=loaders[0].dataset.num_features,
                                 n_hidden=64,
                                 n_class=loaders[0].dataset.num_classes,
                                 dropout=args.dropout).to(args.device)

        elif args.model == 'GCN_new':
            model = GCN_NEW(in_features=loaders[0].dataset.num_features,
                            out_features=loaders[0].dataset.num_classes,
                            n_hidden=args.n_hidden,
                            filters=args.filters,
                            dropout=args.dropout,
                            adj_sq=args.adj_sq,
                            scale_identity=args.scale_identity).to(args.device)

        elif args.model == 'multigraph':
            model = MGCN(in_features=loaders[0].dataset.num_features,
                         out_features=loaders[0].dataset.num_classes,
                         n_relations=2,
                         n_hidden=args.n_hidden,
                         n_hidden_edge=args.n_hidden_edge,
                         filters=args.filters,
                         dropout=args.dropout,
                         adj_sq=args.adj_sq,
                         scale_identity=args.scale_identity).to(args.device)
        elif args.model == 'lstm':
            model = LSTMModel(in_features=loaders[0].dataset.num_features,
                              hidden_dim=64,
                              out_features=loaders[0].dataset.num_classes,
                              dropout=args.dropout).to(args.device)
        elif args.model == 'gnn':
            model = GNNModel(in_features=loaders[0].dataset.num_features,
                             out_features=loaders[0].dataset.num_classes,
                             hidden_dim=64,
                             dropout=args.dropout).to(args.device)
        elif args.model == 'multigraph_modify':
            model = GNNModel(in_features=loaders[0].dataset.num_features,
                             out_features=loaders[0].dataset.num_classes,
                             hidden_dim=64,
                             dropout=args.dropout).to(args.device)
        elif args.model == 'transformer':
            input_dim = loaders[0].dataset.num_features
            num_classes = loaders[0].dataset.num_classes
            model = TransformerModel(input_dim=input_dim, num_classes=num_classes).to(args.device)


        else:
            raise NotImplementedError(args.model)

        print('Initialize model...')
        print(torch.cuda.is_available())

        train_parameters = list(
            filter(lambda p: p.requires_grad, model.parameters()))
        print('N trainable parameters:', np.sum(
            [p.numel() for p in train_parameters]))
        # 设置优化器
        optimizer = optim.AdamW(train_parameters, lr=args.lr, betas=(
            0.5, 0.999), weight_decay=args.wd)
        scheduler = lr_scheduler.MultiStepLR(
            optimizer, args.lr_decay_steps, gamma=0.1)  # dynamic adjustment lr
        # loss_fn = F.nll_loss  # model is gat or gcn, use this
        # 设置损失函数
        loss_fn = F.cross_entropy  # when model is gcn_new, and multigraph use this

        # 初始化EMA
        ema = EMA(model, decay=0.999)
        ema.register()

        # 用于存储每个epoch的平均loss
        train_loss_history = []
        test_loss_history = []

        def train(train_loader):
            model.train()
            start = time.time()
            train_loss, n_samples = 0, 0
            for batch_idx, data in enumerate(train_loader):
                for i in range(len(data)):
                    data[i] = data[i].to(args.device)
                optimizer.zero_grad()
                output = model(data)
                loss = loss_fn(output, data[4])
                loss.backward()
                optimizer.step()
                # 更新EMA参数
                ema.update()

                time_iter = time.time() - start
                train_loss += loss.item() * len(output)
                n_samples += len(output)
                scheduler.step()

            avg_train_loss = train_loss / n_samples
            train_loss_history.append(avg_train_loss)

            print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} (avg: {:.6f})  sec/iter: {:.4f}'.format(
                epoch + 1, n_samples, len(train_loader.dataset), 100. *
                (batch_idx + 1) / len(train_loader),
                loss.item(), train_loss / n_samples, time_iter / (batch_idx + 1)))

        def test(test_loader):
            model.eval()
            start = time.time()
            test_loss, n_samples, count = 0, 0, 0
            tn, fp, fn, tp = 0, 0, 0, 0  # calculate recall, precision, F1 score
            accuracy, recall, precision, F1 = 0, 0, 0, 0

            for batch_idx, data in enumerate(test_loader):
                for i in range(len(data)):
                    data[i] = data[i].to(args.device)
                # when model is gcn baseline or gat, use this
                # output = model(data[0], data[1])
                # when model is gcn_new, and multigraph use this
                output = model(data)
                loss = loss_fn(output, data[4], reduction='sum')
                test_loss += loss.item()
                n_samples += len(output)
                count += 1
                pred = output.detach().cpu().max(1, keepdim=True)[1]

                for k in range(len(pred)):
                    if (pred.view_as(data[4])[k].cpu().item() == 1) & (data[4].cpu()[k].item() == 1):
                        # TP predict == 1 & label == 1
                        tp += 1
                    elif (pred.view_as(data[4])[k].cpu().item() == 0) & (data[4].cpu()[k].item() == 0):
                        # TN predict == 0 & label == 0
                        tn += 1
                    elif (pred.view_as(data[4])[k].cpu().item() == 0) & (data[4].cpu()[k].item() == 1):
                        # FN predict == 0 & label == 1
                        fn += 1
                    elif (pred.view_as(data[4])[k].cpu().item() == 1) & (data[4].cpu()[k].item() == 0):
                        # FP predict == 1 & label == 0
                        fp += 1

                accuracy += metrics.accuracy_score(data[4].cpu(), pred.view_as(data[4]).cpu())
                recall += metrics.recall_score(data[4].cpu(), pred.view_as(data[4]).cpu())
                precision += metrics.precision_score(data[4].cpu(), pred.view_as(data[4]).cpu())
                F1 += metrics.f1_score(data[4].cpu(), pred.view_as(data[4]).cpu())

            print('\nTrue Positive = ', tp)
            print('\nTrue Negative = ', tn)
            print('\nFalse Positive = ', fp)
            print('\nFalse Negative = ', fn, '\n')
            accuracy = 100. * accuracy / count
            recall = 100. * recall / count
            precision = 100. * precision / count
            F1 = 100. * F1 / count
            if fp == 0 and tn == 0:
                print("FPR error")
            else:
                FPR = fp / (fp + tn)
            if tp == 0 and fn == 0:
                print("TPR error")
            else:
                TPR = tp / (tp + fn)

            print(
                'Test set (epoch {}): \n   Average loss: {:.4f}, \n   Accuracy: ({:.2f}%),'
                '\n   Recall: ({:.2f}%), \n   Precision: ({:.2f}%), \n   F1-Score: ({:.2f}%), '
                '\n   TPR: ({:.2f}%), \n   FPR: ({:.2f}%)  \n   sec/iter: {:.4f}\n'.format(
                    epoch + 1, test_loss / n_samples, accuracy, recall, precision, F1, TPR, FPR,
                    (time.time() - start) / len(test_loader))
            )

            return accuracy, recall, precision, F1, FPR, TPR

        for epoch in range(args.epochs):
            train(loaders[0])
        accuracy, recall, precision, F1, FPR, TPR = test(loaders[1])
        result_folds.append([accuracy, recall, precision, F1, FPR, TPR])

        # 生成图表并保存
        plt.figure(figsize=(10, 5))
        plt.plot(range(1, args.epochs + 1), train_loss_history, label='Training Loss')
        # plt.plot(range(1, args.epochs + 1), test_loss_history, label='Testing Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title(f'Training and Testing Loss Over Epochs - Fold {fold_id + 1}')
        plt.legend()
        plt.savefig(f'loss_history_fold_{fold_id + 1}.png')  # 保存图像
        # plt.show()

    accuracy_list = []
    recall_list = []
    precision_list = []
    F1_list = []
    FPR_list = []
    TPR_list = []

    for i in range(len(result_folds)):
        accuracy_list.append(result_folds[i][0])
        recall_list.append(result_folds[i][1])
        precision_list.append(result_folds[i][2])
        F1_list.append(result_folds[i][3])
        FPR_list.append(result_folds[i][4])
        TPR_list.append(result_folds[i][5])

    # Configure logging
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        filename='experiment_log.txt',  # Log file name
                        filemode='w')  # 'w' to overwrite the log file each run, 'a' to append

    print(
        '{}-fold cross validation with average accuracy(+- Standard deviation): {}% ({}%), Recall (+- Standard deviation): {}% ({}%), Precision (+- Standard deviation): {}% ({}%), '
        'F1-Score (+- Standard deviation): {}% ({}%), FPR (+- fpr): {}% ({}%), TPR (+- fpr): {}% ({}%)'.format(
            n_folds, np.mean(accuracy_list), np.std(
                accuracy_list), np.mean(recall_list), np.std(recall_list),
            np.mean(precision_list), np.std(precision_list), np.mean(
                F1_list), np.std(F1_list), np.mean(FPR_list),
            np.std(FPR_list), np.mean(TPR_list), np.std(TPR_list))
    )

    # Example: Logging parameter settings
    parameter_settings = {
        "learning_rate": 0.01,
        "batch_size": 32,
        "epochs": 100,
        "optimizer": "Adam"
    }

    # Calculate average and standard deviation
    avg_accuracy = np.mean(accuracy_list)
    std_accuracy = np.std(accuracy_list)
    avg_recall = np.mean(recall_list)
    std_recall = np.std(recall_list)
    avg_precision = np.mean(precision_list)
    std_precision = np.std(precision_list)
    avg_F1 = np.mean(F1_list)
    std_F1 = np.std(F1_list)
    avg_FPR = np.mean(FPR_list)
    std_FPR = np.std(FPR_list)
    avg_TPR = np.mean(TPR_list)
    std_TPR = np.std(TPR_list)

    # Log the results
    log_message = (
        f'{n_folds}-fold cross validation with average accuracy(+- Standard deviation): {avg_accuracy * 100:.2f}% ({std_accuracy * 100:.2f}%), '
        f'Recall (+- Standard deviation): {avg_recall * 100:.2f}% ({std_recall * 100:.2f}%), '
        f'Precision (+- Standard deviation): {avg_precision * 100:.2f}% ({std_precision * 100:.2f}%), '
        f'F1-Score (+- Standard deviation): {avg_F1 * 100:.2f}% ({std_F1 * 100:.2f}%), '
        f'FPR (+- fpr): {avg_FPR * 100:.2f}% ({std_FPR * 100:.2f}%), '
        f'TPR (+- fpr): {avg_TPR * 100:.2f}% ({std_TPR * 100:.2f}%)'
    )

    print(f"FPR_list:{FPR_list}, TPR_list:{TPR_list}")
    import matplotlib
    matplotlib.use('TkAgg')  # 或尝试其他后端，如'Qt5Agg', 'GTK3Agg', 'WebAgg', 等。
    import matplotlib.pyplot as plt

    plt.plot(FPR_list, TPR_list)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.show()