import os
import pickle
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
from torch_geometric.data import Data
import matplotlib.pyplot as plt
import time
import copy
import argparse
from model import *
from utils import *
from robcon import process_adj_and_embeds
from sklearn.metrics import accuracy_score, f1_score,roc_auc_score
import torch
from tqdm import tqdm


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--split_dir', type=str, default='./data/splits')
    parser.add_argument('--pretrain_model_dir', type=str, default='./saved_models')
    parser.add_argument('--dataset', type=str, default='Cora' ,choices=['Cora', 'CiteSeer', 'Pubmed','Pubmed_new'])
    parser.add_argument('--dataset_argument', type=str, default='empty')
    parser.add_argument('--hidden_channels', type=int, default=64)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--learning_rate', type=float, default=5e-3)
    parser.add_argument('--id', type=int, default=1)
    parser.add_argument('--gpu', type=str, default='cuda:0')
    parser.add_argument('--encoder', type=str, default='GCN')
    parser.add_argument('--flag', type=str, default='base')
    parser.add_argument('--increase', type=str, default='1')
    parser.add_argument('--save_model', type=bool, default=True)
    parser.add_argument('--node_mask', type=bool, default=False)
    parser.add_argument('--node_mask_ratio', type=float, default=.85)
    parser.add_argument('--base_lambda', type=float, default=2)
    parser.add_argument('--start_lambda', type=float, default=1e-3)
    parser.add_argument('--beta', type=float, default=0.98)

    #attack相关参数----
    parser.add_argument('--ptb_rate', type=float, default=0.2, help='pertubation rate')
    parser.add_argument('--attack', type=str, default='mettack', help='mettack,random')
    # 下面的这些在主实验没用
    parser.add_argument('--jt', type=float, default=0.03, help='jaccard threshold')
    parser.add_argument('--cos', type=float, default=0.25, help='cosine similarity threshold')
    parser.add_argument('--beta_s', type=float, default=2, help='the weight of selfloop')
    parser.add_argument('--threshold', type=float, default=1, help='threshold')
    parser.add_argument('--k', type=int, default=7, help='add k neighbors')
    parser.add_argument('--alpha', type=float, default=-0.6, help='add k neighbors')

    #open-set相关参数
    parser.add_argument('--unseen_num', default=1, type=int, help='number of unseen class')
    parser.add_argument('--train_rate', default=0.7)
    parser.add_argument('--valid_rate', default=0.1)
    parser.add_argument('--use_softmax', type=bool, default=True)
    parser.add_argument('--Pseudo_ood_rate', type=float, default=0.1, help='pseudo_ood rate')

    #消融实验相关参数
    parser.add_argument('--normalize_features', action='store_true', help='Whether to normalize the features of nodes')
    parser.add_argument('--use_graph_optimization', action='store_false', default=True,
                        help='Whether to optimize the graph structure (default: True)')
    parser.add_argument('--use_uncertain_Threshold', action='store_false', default=True,
                        help='Whether to use uncertain  (default: True)')


    args = parser.parse_args()

    use_softmax = args.use_softmax
    unseen_label_index = -1
    device = torch.device(args.gpu if torch.cuda.is_available() else 'cpu')
    datasetName = args.dataset.lower()
    save_path = (f'./data/ptb_datasets/{datasetName}/{args.attack}_{datasetName}_{args.ptb_rate}'
                 f'_{args.jt}_{args.cos}_{args.beta_s}_{args.threshold}_{args.k}_{args.normalize_features}.pt')
    if args.flag == 'CL':
        save_dict = torch.load(save_path)
        data = save_dict['data']
    if args.flag == 'base':
        dataset = args.dataset.lower()
        unseen_num = args.unseen_num
        train_rate = args.train_rate
        valid_rate = args.valid_rate
        base_path = f"./data/process/{dataset}/{dataset}_{unseen_num}_{train_rate}_{valid_rate}"
        y_true = torch.load(f"{base_path}_y_true.pt")
        features = torch.load(f"{base_path}_features.pt")
        train_mask = torch.load(f"{base_path}_train_mask.pt")
        val_mask = torch.load(f"{base_path}_val_mask.pt")
        test_mask = torch.load(f"{base_path}_test_mask.pt")
        perturbed_adj = torch.load(f'./data/ptb_graphs_aligned/{args.attack}/{args.attack}_{datasetName}_{args.ptb_rate}'
                                   f'_{args.unseen_num}_{args.train_rate}_{args.valid_rate}.pt')
        if args.use_graph_optimization:
            edge_index, embeds = process_adj_and_embeds(perturbed_adj, features, args)
        else:
            edge_index = perturbed_adj.coalesce().indices()

        if args.normalize_features:
            normalized_features = features.clone()
            for i in range(normalized_features.size(0)):
                row = normalized_features[i]
                row = row - row.min()
                normalized_row = row / row.sum().clamp(min=1.)
                normalized_features[i] = normalized_row
            data = Data(x=normalized_features, edge_index=edge_index, y=y_true, train_mask=train_mask,
                        val_mask=val_mask, test_mask=test_mask)
        else:
            data = Data(x=features, edge_index=edge_index, y=y_true,
                        train_mask=train_mask,
                        val_mask=val_mask, test_mask=test_mask)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(save_dict, save_path)
    num_input_node_feature = data.num_node_features
    num_output_class = data.y.max().item() + 1
    device = torch.device(args.gpu if torch.cuda.is_available() else 'cpu')

    hidden_channels = args.hidden_channels
    lr = args.learning_rate
    num_layers = args.num_layers
    encoder = args.encoder
    seed = args.seed

    def calculate_uncertainty_loss(logits, unmasked_indices, num_classes, device):
        unmasked_logits = logits[unmasked_indices]  # 直接用索引替代gather
        def logits_to_probs(logits, use_softmax=True):
            if use_softmax:
                probs = F.softmax(logits, dim=1)
            else:
                probs = F.sigmoid(logits)
            return probs

        unmasked_probs = logits_to_probs(unmasked_logits)
        unmasked_probs = torch.clamp(unmasked_probs, min=1e-7, max=1.0)
        unmasked_preds = torch.argmax(unmasked_probs, dim=-1)
        batch_size = unmasked_logits.shape[0]
        indices = torch.arange(batch_size, device=device)
        unmasked_prob = unmasked_probs[indices, unmasked_preds]
        mask = torch.logical_and(
            unmasked_prob > 1.0 / num_classes,
            unmasked_prob < 0.7)
        selected_probs = unmasked_probs[mask]
        class_uncertainty_losses = selected_probs * torch.log(selected_probs)

        return class_uncertainty_losses


    def get_pseudo_ood_indices(logits, mask ):
        a_percent = args.Pseudo_ood_rate
        probs = F.softmax(logits, dim=-1) if use_softmax else torch.sigmoid(logits)
        train_probs = probs[mask]
        max_probs, _ = torch.max(train_probs, dim=-1)
        diff = 1 - max_probs
        threshold = torch.quantile(diff, 1 - a_percent)
        selected_train_indices = torch.nonzero(diff >= threshold).squeeze(1)
        pseudo_ood_indices = torch.nonzero(mask).squeeze(1)[selected_train_indices]
        return pseudo_ood_indices

    def ood_uncertainty_loss(ood_embedding, model, args):
        logits = model.decoder(ood_embedding)
        if args.use_softmax:
            probs = F.softmax(logits, dim=1)
        else:
            probs = F.sigmoid(logits)
        ood_uncertainty_losses = torch.mean(probs * torch.log(probs))
        return ood_uncertainty_losses, probs


    def ind_ood_binary_loss(pred, labels):
        return F.binary_cross_entropy(pred, labels)


    def add_noise(z , pseudo_ood_indices):
        z_pseudo_ood = z[pseudo_ood_indices]
        noise = torch.normal(mean=0.0, std=0.2, size=z_pseudo_ood.shape).to(device)
        z_pseudo_ood_noiose = z_pseudo_ood + noise
        return z_pseudo_ood_noiose

    def train_gae(features_all , mask, edge_index, edge_weight):
        gae_model.train()
        optimizer_gae.zero_grad()
        out, z = gae_model.encode_decode(features_all, edge_index, edge_weight)
        loss = criterion(out[mask], data.y[mask])

        pseudo_ood_indices = get_pseudo_ood_indices(out, mask)
        z_pseudo_ood_noise = add_noise(z, pseudo_ood_indices)
        ood_loss, ood_probs = ood_uncertainty_loss(z_pseudo_ood_noise, gae_model, args)

        if args.use_uncertain_Threshold:
            all_indices = np.arange(0, out.shape[0], dtype=np.int64)
            unmasked_indices = np.delete(all_indices, mask_to_idx(data.train_mask))
            loss_c = torch.mean(calculate_uncertainty_loss(out, unmasked_indices=unmasked_indices, num_classes=num_output_class,
                                           device=device)) * 1.0
            a_sigmoid = torch.sigmoid(a)
            b = 1 - a_sigmoid
            loss += a_sigmoid * loss_c + b * ood_loss
        loss.backward()
        optimizer_gae.step()
        return loss, ood_probs

    def evaluate_train(logits,mask):
        pred = logits.argmax(dim=1)
        train_correct = pred[mask] == data.y[mask]
        train_acc = int(train_correct.sum()) / int(mask.sum())
        return train_acc

    def train_spcl(z, edge_index, gt_edge=None, _lambda=1., loss_type='increase', beta=1.):
        spcl_model.train()
        optimizer_spcl.zero_grad()
        if gt_edge is None:
            gt_edge = torch.ones(len(edge_index[0]), device=z.device)
        loss = spcl_model.recon_loss(z, edge_index, _lambda=_lambda, gt_edge=gt_edge, loss_type=loss_type, beta=beta)
        loss.backward()
        optimizer_spcl.step()
        with torch.no_grad():
            for param in spcl_model.parameters():
                param.clamp_(0, 1)
        return loss

    def calculate_threshold(seen_probs_max):
        threshold1 = np.mean(seen_probs_max)
        entropy = -seen_probs_max * np.log(seen_probs_max)
        k_percent = 0.1
        k = max(1, int(len(seen_probs_max) * k_percent))
        top_k_entropy_indices = np.argsort(entropy)[-k:]
        threshold2 = np.mean(seen_probs_max[top_k_entropy_indices])
        return (threshold1 + threshold2) / 2.0

    def calculate_threshold_ood(seen_probs_max, ood_probs_max):
        threshold1 = np.mean(seen_probs_max)
        threshold2 = np.mean(ood_probs_max)
        return (threshold1 + threshold2) / 2.0


    def calculate_threshold_ind_ood(seen_probs_max, ood_probs_max):
        threshold1 = np.mean(seen_probs_max)
        entropy = -seen_probs_max * np.log(seen_probs_max)
        k_percent = 0.1
        k = max(1, int(len(seen_probs_max) * k_percent))
        top_k_entropy_indices = np.argsort(entropy)[-k:]
        threshold2 = np.mean(seen_probs_max[top_k_entropy_indices])
        a_sigmoid = torch.sigmoid(a)
        b = 1 - a_sigmoid
        threshold3 = np.mean(ood_probs_max)
        threshold = (threshold1 + a * threshold2 + b * threshold3) / 2.0
        a_list.append(a_sigmoid.item())
        b_list.append(b.item())
        return threshold.item()

    def evaluate_val(logits,mask, ood_probs=None):
        probs = F.softmax(logits, dim=-1) if use_softmax else torch.sigmoid(logits)
        mask_indices = mask_to_idx(mask)
        masked_logits = logits[mask_indices]
        masked_probs = probs[mask_indices]
        val_loss = criterion(masked_logits, data.y[mask_indices]).item()
        masked_y_pred = torch.argmax(masked_probs, dim=-1).cpu().numpy()
        masked_y_true = data.y[mask_indices].cpu().numpy()
        masked_probs_max = masked_probs.max(dim=1).values.cpu().numpy()
        seen_mask = masked_y_true != unseen_label_index
        seen_probs_max = masked_probs_max[seen_mask]
        if ood_probs is None:
            threshold = calculate_threshold(seen_probs_max)
        else:
            ood_probs_max = ood_probs.max(dim=1).values.cpu().numpy()
            threshold = calculate_threshold_ind_ood(seen_probs_max, ood_probs_max)
        metrics = {
            'accuracy': accuracy_score(masked_y_true, masked_y_pred),
            'macro_f1': f1_score(masked_y_true, masked_y_pred, average="macro"),
            'val_loss' : val_loss,
            'threshold': threshold
        }
        return metrics['accuracy'], metrics['macro_f1'], metrics['threshold'],metrics['val_loss']

    def evaluate_test(logits,mask,threshold=None):
        probs = F.softmax(logits, dim=-1) if use_softmax else torch.sigmoid(logits)
        mask_indices = mask_to_idx(mask)
        masked_logits = logits[mask_indices]
        masked_probs = probs[mask_indices]

        masked_y_pred = torch.argmax(masked_probs, dim=-1).cpu().numpy()
        masked_y_true = data.y[mask_indices].cpu().numpy()
        ind_indices = (masked_y_true != -1).nonzero()[0]  # Indices where label != -1
        ood_indices = (masked_y_true == -1).nonzero()[0]
        masked_probs_max = masked_probs.max(dim=1).values.cpu().numpy()

        masked_y_pred[masked_probs_max < threshold] = unseen_label_index
        labels_all = copy.deepcopy(masked_y_true)
        labels_all[ind_indices] = 1
        labels_all[ood_indices] = 0

        mapped_labels = data.y[mask_indices].clone()
        mapped_labels[mapped_labels == -1] = 0
        # 计算评估指标
        metrics = {
            'accuracy': accuracy_score(masked_y_true, masked_y_pred),
            'macro_f1': f1_score(masked_y_true, masked_y_pred, average="macro"),
            'auroc' : roc_auc_score(labels_all, masked_probs_max),
            'ind_acc': accuracy_score(masked_y_true[ind_indices], masked_y_pred[ind_indices]),
            'ood_acc': accuracy_score(masked_y_true[ood_indices], masked_y_pred[ood_indices]),
            'test_loss' : criterion(masked_logits, mapped_labels).item()
        }
        return metrics['accuracy'], metrics['macro_f1'], metrics['auroc'], metrics['ind_acc'], metrics['ood_acc'],metrics['test_loss']

    def predict_spcl(edge_index):
        spcl_model.eval()
        masked_edge_index, masked_edge_weight = spcl_model.structure_predict(
            edge_index)
        return masked_edge_index, masked_edge_weight
    if encoder == 'GCN':
        encoder_model = GCN(input_channels=num_input_node_feature,
                            hidden_channels=hidden_channels,
                            output_channels=hidden_channels,
                            num_layers=num_layers,
                            random_seed=seed
                            )
    elif encoder == 'GIN':
        encoder_model = GIN(input_channels=num_input_node_feature,
                            hidden_channels=hidden_channels,
                            output_channels=hidden_channels,
                            num_layers=num_layers,
                            random_seed=seed
                            )
    elif encoder == 'GraphSage':
        encoder_model = SAGE(input_channels=num_input_node_feature,
                             hidden_channels=hidden_channels,
                             output_channels=hidden_channels,
                             num_layers=num_layers,
                             random_seed=seed
                             )

    decoder_model = MLP(input_channels=hidden_channels,
                        hidden_channels=hidden_channels,
                        output_channels=num_output_class,
                        num_layers=1,
                        random_seed=seed)
    gae_model = GAE(
        encoder=encoder_model,
        decoder=decoder_model
    )
    a = torch.nn.Parameter(torch.tensor(0.5))
    a_list = []
    b_list = []
    optimizer_gae = torch.optim.Adam(
        list(gae_model.parameters()) + [a], lr=lr
    )

    if args.use_softmax:
        criterion = torch.nn.CrossEntropyLoss()
    else:
        criterion = torch.nn.BCEWithLogitsLoss()
    gae_model = gae_model.to(device)
    data = data.to(device)
    features_all = data.x

    if args.flag == 'base':
        edge_index = data.edge_index
        edge_weight = torch.ones(data.num_edges, device=device)
        train_acc_list = []
        train_loss_list = []

        val_acc_list = []
        val_F1_list = []
        val_loss_list = []

        test_acc_list = []
        test_ind_acc_list = []
        test_ood_acc_list = []
        test_F1_list = []
        test_auroc_list = []
        test_loss_list = []

        num_edge_list = []
        num_epochs = args.epochs
        start_time = time.time()
        best_model = copy.deepcopy(gae_model)
        best_val = 0.
        for epoch in tqdm(range(1, num_epochs + 1)):
            train_loss, ood_probs = train_gae(features_all , data.train_mask, edge_index, edge_weight)
            gae_model.eval()
            with torch.no_grad():
                logits,z = gae_model.encode_decode(features_all, edge_index, edge_weight)
            train_acc = evaluate_train(logits,data.train_mask)
            with torch.no_grad():
                val_acc, val_F1,threshold_val , val_loss = evaluate_val(logits,data.val_mask, ood_probs=ood_probs)
                if args.use_uncertain_Threshold:
                    test_acc, test_F1 ,test_auroc , test_ind_acc,test_ood_acc , test_loss = evaluate_test(logits,data.test_mask,threshold=threshold_val)
                else:
                    max_acc = 0
                    for i in range(0, 10):
                        t = i * 0.1
                        t_acc, t_F1, t_auroc, t_ind_acc, t_ood_acc, t_loss = evaluate_test(logits, data.test_mask,  threshold=t)
                        if t_acc > max_acc:
                            test_acc = t_acc
                            test_F1 = t_F1
                            test_auroc = t_auroc
                            test_ind_acc = t_ind_acc
                            test_ood_acc = t_ood_acc
                            test_loss = t_loss
            train_loss_list.append(float(train_loss))
            train_acc_list.append(train_acc)
            val_acc_list.append(val_acc)
            val_F1_list.append(val_F1)
            val_loss_list.append(val_loss)
            test_acc_list.append(test_acc)
            test_ind_acc_list.append(test_ind_acc)
            test_ood_acc_list.append(test_ood_acc)
            test_F1_list.append(test_F1)
            test_auroc_list.append(test_auroc)
            test_loss_list.append(test_loss)
            if best_val < val_acc:
                best_model = copy.deepcopy(gae_model)
                best_val = val_acc
        val_acc_list = np.array(val_acc_list)
        val_F1_list = np.array(val_F1_list)

        test_acc_list = np.array(test_acc_list)
        test_ind_acc_list = np.array(test_ind_acc_list)
        test_ood_acc_list = np.array(test_ood_acc_list)
        test_F1_list = np.array(test_F1_list)
        test_auroc_list = np.array(test_auroc_list)
        # 分别计算每个指标的最大值及其索引
        max_val_acc, max_val_epoch = get_max_with_index(val_acc_list)
        max_acc_test, max_acc_epoch = get_max_with_index(test_acc_list)
        max_f1_test, max_f1_epoch = get_max_with_index(test_F1_list)
        max_auroc_test, max_auroc_epoch = get_max_with_index(test_auroc_list)
        max_ind_acc_test, max_ind_acc_epoch = get_max_with_index(test_ind_acc_list)
        max_ood_acc_test, max_ood_acc_epoch = get_max_with_index(test_ood_acc_list)

        # 打印最终结果
        print(
            f'pruned dataset:{args.dataset}-{args.attack}-ptb_rate:{args.ptb_rate},encoder:{args.encoder},flag:{args.flag}\n'
            f'Best Epoch (based on validation): {max_val_epoch + 1}\n'
            f'  Validation - Acc: {val_acc_list[max_val_epoch]:.4f}, F1: {val_F1_list[max_val_epoch]:.4f}\n'
            f'Best Epoch (based on test accuracy): {max_acc_epoch + 1}\n'
            f'  Test - Acc: {test_acc_list[max_acc_epoch]:.4f}, F1: {test_F1_list[max_acc_epoch]:.4f}, AUROC: {test_auroc_list[max_acc_epoch]:.4f}, ind_acc: {test_ind_acc_list[max_acc_epoch]:.4f}, ood_acc: {test_ood_acc_list[max_acc_epoch]:.4f}\n'
            f'Best Epoch (based on test F1): {max_f1_epoch + 1}\n'
            f'  Test - Acc: {test_acc_list[max_f1_epoch]:.4f}, F1: {test_F1_list[max_f1_epoch]:.4f}, AUROC: {test_auroc_list[max_f1_epoch]:.4f}, ind_acc: {test_ind_acc_list[max_f1_epoch]:.4f}, ood_acc: {test_ood_acc_list[max_f1_epoch]:.4f}\n'
            f'Best Epoch (based on test AUROC): {max_auroc_epoch + 1}\n'
            f'  Test - Acc: {test_acc_list[max_auroc_epoch]:.4f}, F1: {test_F1_list[max_auroc_epoch]:.4f}, AUROC: {test_auroc_list[max_auroc_epoch]:.4f}, ind_acc: {test_ind_acc_list[max_auroc_epoch]:.4f}, ood_acc: {test_ood_acc_list[max_auroc_epoch]:.4f}\n'
            f'Best Epoch (based on test IND accuracy): {max_ind_acc_epoch + 1}\n'
            f'  Test - Acc: {test_acc_list[max_ind_acc_epoch]:.4f}, F1: {test_F1_list[max_ind_acc_epoch]:.4f}, AUROC: {test_auroc_list[max_ind_acc_epoch]:.4f}, ind_acc: {test_ind_acc_list[max_ind_acc_epoch]:.4f}, ood_acc: {test_ood_acc_list[max_ind_acc_epoch]:.4f}\n'
            f'Best Epoch (based on test OOD accuracy): {max_ood_acc_epoch + 1}\n'
            f'  Test - Acc: {test_acc_list[max_ood_acc_epoch]:.4f}, F1: {test_F1_list[max_ood_acc_epoch]:.4f}, AUROC: {test_auroc_list[max_ood_acc_epoch]:.4f}, ind_acc: {test_ind_acc_list[max_ood_acc_epoch]:.4f}, ood_acc: {test_ood_acc_list[max_ood_acc_epoch]:.4f}'
        )
        if args.save_model:
            #存储路径为
            save_model_dir = os.path.join(args.pretrain_model_dir,args.attack, str(args.ptb_rate),"pruned graph",args.dataset+"_ow")
            if not os.path.exists(save_model_dir):
                os.makedirs(save_model_dir)
            torch.save(best_model.state_dict(), os.path.join(save_model_dir, encoder + '_' + str(args.id)+ '_' +str(args.jt)+ '_' +str(args.cos)+ '_'
                                 +str(args.beta_s)+ '_' +str(args.threshold)+ '_' +str(args.k)+ '_' +str(args.normalize_features)+ '.txt'))
    elif args.flag == 'CL':
        if encoder == 'GCN':
            trained_encoder_model = GCN(input_channels=num_input_node_feature,
                                        hidden_channels=hidden_channels,
                                        output_channels=hidden_channels,
                                        num_layers=num_layers
                                        )
        elif encoder == 'GIN':
            trained_encoder_model = GIN(input_channels=num_input_node_feature,
                                        hidden_channels=hidden_channels,
                                        output_channels=hidden_channels,
                                        num_layers=num_layers
                                        )
        elif encoder == 'GraphSage':
            trained_encoder_model = SAGE(input_channels=num_input_node_feature,
                                         hidden_channels=hidden_channels,
                                         output_channels=hidden_channels,
                                         num_layers=num_layers
                                         )

        trained_decoder_model = MLP(input_channels=hidden_channels,
                                    hidden_channels=hidden_channels,
                                    output_channels=num_output_class,
                                    num_layers=1)
        trained_gae_model = GAE(
            encoder=trained_encoder_model,
            decoder=trained_decoder_model
        )
        save_model_dir = os.path.join(args.pretrain_model_dir , args.attack ,str(args.ptb_rate), "pruned graph",args.dataset+"_ow")
        trained_gae_model.load_state_dict(
            torch.load(os.path.join(save_model_dir, encoder + '_' + str(args.id)+ '_' +str(args.jt)+ '_' +str(args.cos)+ '_'
                                 +str(args.beta_s)+ '_' +str(args.threshold)+ '_' +str(args.k)+ '_' +str(args.normalize_features)+ '.txt')))
        trained_gae_model.to(device)
        base_lambda = args.base_lambda
        start_lambda = args.start_lambda
        beta = args.beta
        loss_type = 'increase'

        edge_weight = torch.ones(data.num_edges, device=device)
        out,z = trained_gae_model.encode_decode(data.x, data.edge_index, edge_weight)
        ce = torch.nn.CrossEntropyLoss(reduction='none')
        ce_scores = ce(out[data.train_mask], data.y[data.train_mask]).detach()
        difficulty_scores = torch.ones(data.num_nodes, device=device)
        difficulty_scores[data.train_mask] = ce_scores
        normalized_difficulty_scores = torch.exp(difficulty_scores)
        normalized_difficulty_scores = normalized_difficulty_scores / normalized_difficulty_scores.min()
        s_mask_weight = normalized_difficulty_scores[data.edge_index[0]] * normalized_difficulty_scores[data.edge_index[1]]
        term1 = normalized_difficulty_scores[data.edge_index[0]]
        term2 = normalized_difficulty_scores[data.edge_index[1]]
        term1[~data.train_mask[data.edge_index[0]]] = 1.
        term2[~data.train_mask[data.edge_index[1]]] = 1.
        s_mask_weight = term1 * term2

        if args.node_mask:
            node_mask = normalized_difficulty_scores < torch.quantile(normalized_difficulty_scores,
                                                                      args.node_mask_ratio)
        edge_weight_list = []
        structure_decoder = CosineDecoder()
        spcl_model = SPCL(
            data.num_edges if loss_type == 'increase' else 2 * data.num_edges,
            structure_decoder=structure_decoder
        )
        optimizer_spcl = torch.optim.Adam(spcl_model.parameters(), lr=0.1)
        spcl_model = spcl_model.to(device)
        edge_index = data.edge_index
        gt_edge = torch.ones(len(edge_index[0]), device=device)

        with torch.no_grad():
            z = trained_gae_model.encode(data.x, data.edge_index, torch.ones(data.num_edges, device=device))
        for t in range(1, 11):  # train mask s
            _lambda = start_lambda
            loss_s = train_spcl(z, edge_index, gt_edge=gt_edge, _lambda=_lambda, loss_type=loss_type, beta=beta)
        masked_edge_index, masked_edge_weight = predict_spcl(edge_index)
        with torch.no_grad():
            mask = spcl_model.s_mask > 0.5
        masked_edge_weight /= s_mask_weight[mask]
        edge_weight_list.append(masked_edge_weight)

        train_acc_list = []
        train_loss_list = []

        val_acc_list = []
        val_F1_list = []
        val_loss_list = []

        test_acc_list = []
        test_ind_acc_list = []
        test_ood_acc_list = []
        test_F1_list = []
        test_auroc_list = []
        test_loss_list = []

        num_edge_list = []
        num_epochs = args.epochs

        for epoch in tqdm(range(1, num_epochs + 1)):
            if args.node_mask:
                train_loss, ood_probs  = train_gae(features_all , data.train_mask & node_mask, masked_edge_index, masked_edge_weight)
            else:
                train_loss, ood_probs  = train_gae(features_all , data.train_mask, masked_edge_index, masked_edge_weight)

            test_masked_edge_weight = spcl_model.accumulated_s_mask / spcl_model.num / s_mask_weight

            with torch.no_grad():
                logits,z = gae_model.encode_decode(features_all, edge_index, test_masked_edge_weight)
                train_acc = evaluate_train(logits,data.train_mask)
                if args.node_mask:
                    val_acc,val_F1,threshold_val , val_loss = evaluate_val(logits,data.val_mask & node_mask, ood_probs)
                else:
                    val_acc,val_F1,threshold_val , val_loss = evaluate_val(logits ,data.val_mask, ood_probs)
                if args.use_uncertain_Threshold:
                    test_acc, test_F1, test_auroc, test_ind_acc, test_ood_acc, test_loss = evaluate_test(logits, data.test_mask,
                                                                                                         threshold=threshold_val)
                else:
                    max_acc = 0
                    for i in range(0, 10):
                        t = i * 0.1
                        t_acc, t_F1, t_auroc, t_ind_acc, t_ood_acc, t_loss = evaluate_test(logits, data.test_mask,  threshold=t)
                        if t_acc > max_acc:
                            test_acc = t_acc
                            test_F1 = t_F1
                            test_auroc = t_auroc
                            test_ind_acc = t_ind_acc
                            test_ood_acc = t_ood_acc
                            test_loss = t_loss
            train_acc_list.append(train_acc)
            train_loss_list.append(float(train_loss))
            val_acc_list.append(val_acc)
            val_F1_list.append(val_F1)
            val_loss_list.append(val_loss)
            test_acc_list.append(test_acc)
            test_ind_acc_list.append(test_ind_acc)
            test_ood_acc_list.append(test_ood_acc)
            test_F1_list.append(test_F1)
            test_auroc_list.append(test_auroc)
            test_loss_list.append(test_loss)
            num_edge_list.append(masked_edge_index.shape[1])
            edge_weight_list.append(test_masked_edge_weight)

            if epoch % 20 == 0:
                with torch.no_grad():
                    z = trained_gae_model.encode(features_all, masked_edge_index, masked_edge_weight)
                for t in range(1, 11):
                    if args.increase == '1':
                        _lambda = base_lambda / (
                                    num_epochs * 2 // 3 + 1 - epoch) if epoch < num_epochs * 2 // 3 else base_lambda
                    elif args.increase == '2':
                        _lambda = base_lambda / (num_epochs + 1 - epoch) if epoch < num_epochs else base_lambda
                    elif args.increase == '3':
                        _lambda = epoch / num_epochs * base_lambda + start_lambda
                    loss_s = train_spcl(z, edge_index, gt_edge=gt_edge, _lambda=_lambda, loss_type=loss_type, beta=beta)
                masked_edge_index, masked_edge_weight = predict_spcl(edge_index)
                with torch.no_grad():
                    mask = spcl_model.s_mask > 0.5
                masked_edge_weight /= s_mask_weight[mask]
        val_acc_list = np.array(val_acc_list)
        val_F1_list = np.array(val_F1_list)
        test_acc_list = np.array(test_acc_list)
        test_ind_acc_list = np.array(test_ind_acc_list)
        test_ood_acc_list = np.array(test_ood_acc_list)
        test_F1_list = np.array(test_F1_list)
        test_auroc_list = np.array(test_auroc_list)
        max_index ,max_epoch = get_max_with_index(val_acc_list)
        max_val_acc, max_val_epoch = get_max_with_index(val_acc_list)
        max_acc_test, max_acc_epoch = get_max_with_index(test_acc_list)
        max_f1_test, max_f1_epoch = get_max_with_index(test_F1_list)
        max_auroc_test, max_auroc_epoch = get_max_with_index(test_auroc_list)
        max_ind_acc_test, max_ind_acc_epoch = get_max_with_index(test_ind_acc_list)
        max_ood_acc_test, max_ood_acc_epoch = get_max_with_index(test_ood_acc_list)

        print(
            f'pruned dataset:{args.dataset}-{args.attack}-ptb_rate:{args.ptb_rate},encoder:{args.encoder},flag:{args.flag}\n'
            f'Best Epoch (based on validation): {max_val_epoch + 1}\n'
            f'  Validation - Acc: {val_acc_list[max_val_epoch]:.4f}, F1: {val_F1_list[max_val_epoch]:.4f}\n'
            f'Best Epoch (based on test accuracy): {max_acc_epoch + 1}\n'
            f'  Test - Acc: {test_acc_list[max_acc_epoch]:.4f}, F1: {test_F1_list[max_acc_epoch]:.4f}, AUROC: {test_auroc_list[max_acc_epoch]:.4f}, ind_acc: {test_ind_acc_list[max_acc_epoch]:.4f}, ood_acc: {test_ood_acc_list[max_acc_epoch]:.4f}\n'
            f'Best Epoch (based on test F1): {max_f1_epoch + 1}\n'
            f'  Test - Acc: {test_acc_list[max_f1_epoch]:.4f}, F1: {test_F1_list[max_f1_epoch]:.4f}, AUROC: {test_auroc_list[max_f1_epoch]:.4f}, ind_acc: {test_ind_acc_list[max_f1_epoch]:.4f}, ood_acc: {test_ood_acc_list[max_f1_epoch]:.4f}\n'
            f'Best Epoch (based on test AUROC): {max_auroc_epoch + 1}\n'
            f'  Test - Acc: {test_acc_list[max_auroc_epoch]:.4f}, F1: {test_F1_list[max_auroc_epoch]:.4f}, AUROC: {test_auroc_list[max_auroc_epoch]:.4f}, ind_acc: {test_ind_acc_list[max_auroc_epoch]:.4f}, ood_acc: {test_ood_acc_list[max_auroc_epoch]:.4f}\n'
            f'Best Epoch (based on test IND accuracy): {max_ind_acc_epoch + 1}\n'
            f'  Test - Acc: {test_acc_list[max_ind_acc_epoch]:.4f}, F1: {test_F1_list[max_ind_acc_epoch]:.4f}, AUROC: {test_auroc_list[max_ind_acc_epoch]:.4f}, ind_acc: {test_ind_acc_list[max_ind_acc_epoch]:.4f}, ood_acc: {test_ood_acc_list[max_ind_acc_epoch]:.4f}\n'
            f'Best Epoch (based on test OOD accuracy): {max_ood_acc_epoch + 1}\n'
            f'  Test - Acc: {test_acc_list[max_ood_acc_epoch]:.4f}, F1: {test_F1_list[max_ood_acc_epoch]:.4f}, AUROC: {test_auroc_list[max_ood_acc_epoch]:.4f}, ind_acc: {test_ind_acc_list[max_ood_acc_epoch]:.4f}, ood_acc: {test_ood_acc_list[max_ood_acc_epoch]:.4f}'
        )