import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid, Coauthor
from torch_geometric.loader import NeighborSampler
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import to_networkx
import networkx as nx
import random
from torch.nn import Linear
from sklearn.metrics import f1_score
import time
from lib_utils import utils
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GINConv
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torch_scatter import scatter_add
from sklearn.metrics import roc_auc_score
import torch_geometric.transforms as T
print("-----------------------------------------------------")

print(torch.__version__)
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

r = 100
un_ratio = 0.2
batch_size = 512
target_model = 'GIN'
# 导入Cora数据集
# dataset = Planetoid(root='data/Cora', name="Cora")  # root: 指定路径 name: 数据集名称
dataset = Coauthor(root='data/CS', name="CS")  # root: 指定路径 name: 数据集名称
# dataset = Planetoid(root='data/citeseer', name="citeseer")  # root: 指定路径 name: 数据集名称
print(dataset)
print("未使用零值域分解")
# 查看数据的基本情况
print("网络数据包含的类数量:", dataset.num_classes)
print("网络数据边的特征数量:", dataset.num_edge_features)
print("网络数据边的数量:", dataset[0].edge_index.shape[1] / 2)  # 除以2是OOC的组织形式
print("网络数据节点的特征数量:", dataset.num_node_features)
print("网络数据节点的数量:", dataset[0].x.shape[0])
print("网络节点标签的数量:", len(dataset[0].y))

# visualization
def visualize(out, color, filename):
    z = TSNE(n_components=2).fit_transform(out.detach().cpu().numpy())
    plt.figure(figsize=(10,10))
    plt.grid(True, linestyle='--', color='gray', linewidth=0.5)
    plt.scatter(z[:, 0], z[:, 1], s=18, c=color.cpu(), alpha=0.9, cmap="Set2")
    plt.show()
    plt.savefig(filename, bbox_inches='tight')
    
def _set_random_seed(seed):

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print("set pytorch seed")

# _set_random_seed(20250301)
# _set_random_seed(20250304)
# _set_random_seed(20250306)
# _set_random_seed(20250309)
# _set_random_seed(20250312)
# seed = 20250301
# seed = 20250304
# seed = 20250306
# seed = 20250309
# seed = 20250312
import datetime

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
seed = int(timestamp.replace("_", "").replace(":", "").replace("-", "")) % (2**32)
print("seed:", seed)

_set_random_seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print("set pytorch seed")
    
# train from all
# batch version
class GINNet(torch.nn.Module):
    def __init__(self):
        super(GINNet, self).__init__()

        dim = 32
        self.num_layers = 2

        nn1 = Sequential(Linear(dataset.num_features, dim), ReLU(), Linear(dim, dim))
        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))

        self.convs = torch.nn.ModuleList()
        self.convs.append(GINConv(nn1))
        self.convs.append(GINConv(nn2))

        self.bn = torch.nn.ModuleList()
        self.bn.append(torch.nn.BatchNorm1d(dim))
        self.bn.append(torch.nn.BatchNorm1d(dim))

        self.fc1 = Linear(dim, dim)
        self.fc2 = Linear(dim, dataset.num_classes)

    def forward(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
            x = self.bn[i](x)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)

    def inference(self, x_all, subgraph_loader, device):
        x_last = []
        for i in range(self.num_layers):
            xs = []
            if i==1:
                x_last = x_all # mark the last layer output
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = self.convs[i]((x, x_target), edge_index)
                if i != self.num_layers - 1:
                    x = F.relu(x)
                x = self.bn[i](x)
                xs.append(x)
            x_all = torch.cat(xs, dim=0)
        x_all = F.relu(self.fc1(x_all))
        x_all = self.fc2(x_all)
        return x_last, x_all

    def reset_parameters(self):
        torch.manual_seed(42)
        self.fc1.reset_parameters()
        self.fc2.reset_parameters()
        for i in range(self.num_layers):
            self.convs[i].reset_parameters()
            self.bn[i].reset_parameters()
            
model = GINNet().to(device)
criterion = torch.nn.CrossEntropyLoss()                                   # Define loss criterion.
opt = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)  # Define optimizer.
data = dataset[0].to(device)
train_indices, test_indices = train_test_split(np.arange((data.num_nodes)), test_size=0.1, random_state=100)
data.train_mask = torch.from_numpy(np.isin(np.arange(data.num_nodes), train_indices))
data.test_mask = torch.from_numpy(np.isin(np.arange(data.num_nodes), test_indices))

train_indices = np.nonzero(data.train_mask.cpu().numpy())[0]
edge_index = utils.filter_edge_index(data.edge_index, train_indices, reindex=False)

if edge_index.shape[1] == 0:
    edge_index = torch.tensor([[1, 2], [2, 1]])

train_loader = NeighborSampler(
    edge_index, node_idx=data.train_mask,
    sizes=[5, 5], num_nodes=data.num_nodes,
    batch_size=batch_size, shuffle=True,
    num_workers=0)

test_loader = NeighborSampler(
            data.edge_index, node_idx=None,
            sizes=[-1], num_nodes=data.num_nodes,
            batch_size=64, shuffle=False,
            num_workers=0)
test_loader_1 = test_loader

def evaluate_model(model, _data, filename, test_loader):
    model.eval()
    edge_index = _data.edge_index   
    if target_model in ['GCN','SGC']:
        _, out = model.inference(_data.x, test_loader, edge_weight, device)
        visualize(out, color=_data.y, filename=filename)
    else:
        _, out = model.inference(_data.x, test_loader, device)
        visualize(out, color=_data.y, filename=filename)

    y_true = _data.y.cpu().unsqueeze(-1).to(device)
    y_pred = out.argmax(dim=-1, keepdim=True).to(device)
    results = []
    for mask in [_data.train_mask, _data.test_mask]:# 分别计算test和train的结果
        results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]
    
    print("results")
    print(results)
    return results
    
model.reset_parameters()
for epoch in range(r):
    print("epoch:",epoch)
#     train_from_all_results = evaluate_model(model, data, "000", test_loader)
    model.train()
    for batch_size, n_id, adjs in train_loader:
        adjs = [adj.to(device) for adj in adjs]
        opt.zero_grad()

        if target_model in ['GCN', 'SGC']:
            out = model(data.x[n_id], adjs, edge_weight)
        else:
            out = model(data.x[n_id], adjs)

        loss = F.nll_loss(out, data.y[n_id[:batch_size]])
        loss.backward()
        opt.step()

model.eval()
temp_out, out = model.inference(data.x, test_loader, device)
KL_xy_all = out
temp_KL_xy_all = temp_out
filename = 'gin_train_from_all_batch.pdf'
train_from_all_results = evaluate_model(model, data, filename, test_loader)

test_f1_train_from_all = f1_score(
        data.y[data.test_mask].cpu().numpy(), 
        out[data.test_mask].argmax(axis=1).cpu().numpy(), 
        average="micro"
)
# train from all
    
# train from scratch
class GINNet_scratch(torch.nn.Module):
    def __init__(self):
        super(GINNet_scratch, self).__init__()

        dim = 32
        self.num_layers = 2

        nn1 = Sequential(Linear(dataset.num_features, dim), ReLU(), Linear(dim, dim))
        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))

        self.convs = torch.nn.ModuleList()
        self.convs.append(GINConv(nn1))
        self.convs.append(GINConv(nn2))

        self.bn = torch.nn.ModuleList()
        self.bn.append(torch.nn.BatchNorm1d(dim))
        self.bn.append(torch.nn.BatchNorm1d(dim))

        self.fc1 = Linear(dim, dim)
        self.fc2 = Linear(dim, dataset.num_classes)

    def forward(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
            x = self.bn[i](x)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)

    def inference(self, x_all, subgraph_loader, device):
        x_last = []
        for i in range(self.num_layers):
            xs = []
            if i==1:
                x_last = x_all # mark the last layer output
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = self.convs[i]((x, x_target), edge_index)
                if i != self.num_layers - 1:
                    x = F.relu(x)
                x = self.bn[i](x)
                xs.append(x)
            x_all = torch.cat(xs, dim=0)
        x_all = F.relu(self.fc1(x_all))
        x_all = self.fc2(x_all)
        return x_last, x_all

    def reset_parameters(self):
        torch.manual_seed(42)
        self.fc1.reset_parameters()
        self.fc2.reset_parameters()
        for i in range(self.num_layers):
            self.convs[i].reset_parameters()
            self.bn[i].reset_parameters()
            
model_scratch = GINNet_scratch().to(device)
criterion = torch.nn.CrossEntropyLoss()                                   # Define loss criterion.
opt = torch.optim.Adam(model_scratch.parameters(), lr=0.01, weight_decay=0.0001)  # Define optimizer.
data_scratch = dataset[0].to(device)

# train_indices, test_indices = train_test_split(np.arange((data_scratch.num_nodes)), test_size=0.1, random_state=100) # 不重新划分
data_scratch.train_mask = torch.from_numpy(np.isin(np.arange(data_scratch.num_nodes), train_indices))                # 新的训练节点 
data_scratch.test_mask  = torch.from_numpy(np.isin(np.arange(data_scratch.num_nodes), test_indices))                 # 新的测试节点

train_mask = data_scratch.train_mask                    # 获取train_mask
train_nodes = torch.nonzero(train_mask).squeeze(1)      # 获取train_mask中为True的节点索引
num_nodes_to_remove = int(len(train_nodes) * un_ratio)  # 计算需要随机置为False的节点数量（10%的节点）

nodes_to_remove = random.sample(list(train_nodes), num_nodes_to_remove)                     # 随机选取需要置为False的节点
nodes_to_remove = torch.tensor(nodes_to_remove, dtype=torch.long).to(device)

train_mask[nodes_to_remove] = False                                                         # 将这些节点的train_mask置为False
data_scratch.train_mask = train_mask                                                        # 更新回数据对象
combined_mask = torch.logical_or(data_scratch.train_mask, data_scratch.test_mask).to(device)
train_indices = np.nonzero(data_scratch.train_mask.cpu().numpy())[0]                        # 训练索引

edge_index_0 = data_scratch.edge_index
left_mask  = torch.any(torch.eq(edge_index_0[0].unsqueeze(1), nodes_to_remove), dim=1)
right_mask = torch.any(torch.eq(edge_index_0[1].unsqueeze(1), nodes_to_remove), dim=1)
mask = ~(left_mask | right_mask)
data_scratch.edge_index = edge_index_0[:, mask]                       # 保留不涉及被删除节点的边

edge_index = utils.filter_edge_index(data_scratch.edge_index, train_indices, reindex=False) # 对应的边的索引

num_true_nodes = torch.sum(data_scratch.train_mask).item()
print("训练数据集的节点数量",num_true_nodes)

if edge_index.shape[1] == 0:
    edge_index = torch.tensor([[1, 2], [2, 1]])

train_loader_scratch = NeighborSampler(
    edge_index, node_idx=data_scratch.train_mask,
    sizes=[5, 5], num_nodes=data_scratch.num_nodes,
    batch_size=batch_size, shuffle=True,
    num_workers=0)

test_loader = NeighborSampler(
            data_scratch.edge_index, node_idx=None,
            sizes=[-1], num_nodes=data.num_nodes,
            batch_size=64, shuffle=False,
            num_workers=0)
test_loader_2 = test_loader
    
model_scratch.reset_parameters()
for epoch in range(r):
    print("epoch:",epoch)
#     train_from_all_results = evaluate_model(model, data, "000", test_loader)
    model_scratch.train()
    for batch_size, n_id, adjs in train_loader_scratch:
        adjs = [adj.to(device) for adj in adjs]
        opt.zero_grad()
        if target_model in ['GCN', 'SGC']:
            out = model_scratch(data_scratch.x[n_id], adjs, edge_weight)
        else:
            out = model_scratch(data_scratch.x[n_id], adjs)
        loss = F.nll_loss(out, data_scratch.y[n_id[:batch_size]])
        loss.backward()
        opt.step()
      
    
model_scratch.eval()
temp_out, out = model_scratch.inference(data.x, test_loader, device)
KL_xy_scratch = out  
temp_KL_xy_scratch = temp_out
filename = 'gin_train_from_scratch_batch.pdf'
train_from_scratch_results = evaluate_model(model_scratch, data_scratch, filename, test_loader)

test_f1_train_from_scratch = f1_score(
        data.y[data.test_mask].cpu().numpy(), 
        out[data.test_mask].argmax(axis=1).cpu().numpy(), 
        average="micro"
)
# train from scratch

# unlearning
degree = torch.zeros(dataset[0].x.shape[0], dtype=torch.long)
edge_index = dataset[0].edge_index
for i in range(edge_index.shape[1]):
    degree[edge_index[0, i]] += 1
    degree[edge_index[1, i]] += 1
    
ratio = 0.8
top_10_percent_threshold = torch.quantile(degree.float(), ratio)
top_10_percent_indices = torch.argsort(degree, descending=True)[:int(dataset[0].x.shape[0] * (1-ratio))]
top_10_mask = torch.zeros(dataset[0].x.shape[0], dtype=torch.bool)
top_10_mask[top_10_percent_indices] = 1
intersection_mask = top_10_mask & train_mask

class MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.relu = torch.nn.Sigmoid()
        self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
#         self.dropout = torch.nn.Dropout(0.5)
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
#         x = self.dropout(x)
        x = self.fc2(x)
        return x

model.eval()                       # 设置为评估模式
conv2_linear_layer = model.fc2   # 获取模型第二层conv2的线性层
H = conv2_linear_layer.weight.to(device)  # 维度: [out_features, in_features]
H_pinv = torch.pinverse(H).to(device)         # 维度: [in_features, out_features]
bias = conv2_linear_layer.bias            # 偏置: [bias]

print("bias:",bias)
print(type(bias))
if bias == None:
    bias = torch.zeros(dataset.num_classes).to(device)

# edge_n = degree[nodes_to_remove].sum()/num_nodes_to_remove

degree = degree.to(device)
nodes_to_remove = nodes_to_remove.to(device)
edge_n = degree[nodes_to_remove].sum() / num_nodes_to_remove

print("=====edge_n=====")
print(edge_n)
w_fix = 1 + 1/edge_n
print("=====w_fix=====")
print(w_fix.cpu().numpy())

# hadamard_edge_n = degree[nodes_to_remove]/num_nodes_to_remove
# hadamard_w_fix = 1 + 1/hadamard_edge_n
from torch_geometric.nn.conv.gcn_conv import GCNConv
class GIN_Unlearn(torch.nn.Module):
    def __init__(self, hidden_channels): # hidden channel is the first layer width
        super().__init__()
#         torch.manual_seed(1234567)
        self.hidden_channels = hidden_channels
        
        self.influ_ij = MLP(dataset.num_classes, int((dataset.num_classes * hidden_channels) ** 0.5), hidden_channels)           # plus
        self.influ_ji = MLP(2 * hidden_channels, int((2*hidden_channels*dataset.num_classes) ** 0.5), dataset.num_classes)       # minus
        self.H = H.detach().to(device)
        self.H_pinv = H_pinv.detach().to(device)
        self.bias = bias.detach().to(device)

    def forward(self, temp_x, x, edge_index, unlearn): # , edge_weight
        row, col = edge_index 
        x_i = temp_x[row]
        x_j = temp_x[col]
        x_cat_ij = torch.cat([x_i, x_j], dim=-1)  # 拼接后的特征，形状为 [num_edges, 2 * in_channels]
        
        influ_ji = self.influ_ji(x_cat_ij)        # 计算第二层的影响力
        influ_ij = self.influ_ij(influ_ji)        # 计算第一层的影响力
        
        I = torch.eye(self.H_pinv.size(0), device=device)  # 定义单位矩阵
        
        Null_Space = I - self.H_pinv @ self.H
        Range_Space= self.H_pinv.T
        
        influ_ij_RND = influ_ij
        # influ_ij_RND = (influ_ji - self.bias) @ self.H_pinv.T + influ_ij @ (I - self.H_pinv @ self.H) 
        
        mask = (col.unsqueeze(1) == unlearn).any(dim=1) # unlearn的部分不考虑
        minus_influ = torch.zeros_like(influ_ji)
#         minus_influ[mask] = influ_ji[mask]
        minus_influ = influ_ji
        
#         print("influ_ji[mask].shape:")
#         print(influ_ji[mask].shape)
        
        aggregated_matrix = torch.zeros_like(x, device=x.device)          # minus unlearn node influence
        aggregated_matrix = scatter_add(
            src=minus_influ,       # 需要聚合的数据 [num_edges, out_channels]
            index=row,             # 目标节点索引 [num_edges]
            dim=0,                 # 节点维度
            dim_size=x.size(0)     # 确保输出维度匹配
        )
        
        aggregated_matrix_1 = torch.zeros(x.size(0), self.hidden_channels, device=x.device)
        aggregated_matrix_1 = scatter_add(
            src=influ_ij_RND,      # 需要聚合的数据 [num_edges, out_channels]
            index=row,             # 目标节点索引 [num_edges]
            dim=0,                 # 节点维度
            dim_size=x.size(0)     # 确保输出维度匹配
        )
        x = x - w_fix * aggregated_matrix  # 如果这个地方我换一个输出会怎么样？
        return x, aggregated_matrix_1      # 我感觉可以试试直接把influ_ij的结果输出
    
KLloss = torch.nn.KLDivLoss(reduction="mean", log_target=True)
KLloss_p = torch.nn.KLDivLoss(reduction='none', log_target=True)

model_Unlearn = GIN_Unlearn(hidden_channels=32).to(device)
opt = torch.optim.Adam(model_Unlearn.parameters(), lr=0.01, weight_decay = 5e-4)  # Define optimizer.
# 
x, edge_index, y = data.x, data.edge_index, data.y
unlearning_node = nodes_to_remove
    
model.eval()
temp_out, out = model.inference(data.x, test_loader, device)
middle_out = out
temp_middle_out = temp_out

test_f1_test_middle = f1_score(
    data.y[data.test_mask].cpu().numpy(), 
    middle_out[data.test_mask].argmax(axis=1).cpu().numpy(), 
    average="micro"
)

original_neighbors = out[intersection_mask].to(device)
original_unlearned = temp_out[unlearning_node].to(device)
original_unlearned = temp_out.to(device)
    
results_unlearn = []
out = None
data_all_edge = dataset[0].edge_index.to(device)

start_time = time.time()
for i in range(10):
    model_Unlearn.train()
    opt.zero_grad()
#     output, hidden_output = model_Unlearn.forward(temp_KL_xy_all.to(device), middle_out.to(device), data_all_edge, unlearning_node)
#     output, hidden_output = model_Unlearn.forward(temp_KL_xy_all.to(device), KL_xy_all.to(device), data_all_edge, unlearning_node)
    output, hidden_output = model_Unlearn.forward(temp_middle_out.to(device), middle_out.to(device), data_all_edge, unlearning_node)
    output_neighbors = output[intersection_mask].to(device)
    output_unlearned = hidden_output.to(device)
    
    # Step 3: Compute KL divergence losses
    weighted_kl_loss_neighbors = KLloss_p(F.log_softmax(original_neighbors.detach(), dim=-1), F.log_softmax(output_neighbors, dim=-1)) * degree[intersection_mask].view(-1, 1).to(device)
    kl_loss_neighbors = weighted_kl_loss_neighbors.mean()  

    # 3.2: KL divergence for unlearned nodes
    weighted_kl_loss_unlearned = KLloss_p(F.log_softmax(original_unlearned[unlearning_node].detach(), dim=-1), F.log_softmax(output_unlearned[unlearning_node], dim=-1)) * degree[unlearning_node].view(-1, 1).to(device)
    kl_loss_unlearned = weighted_kl_loss_unlearned.mean()
 
    # Step 4: Compute classification loss
    classification_loss = -F.cross_entropy(F.log_softmax(output[~combined_mask], dim=-1), y[~combined_mask])

    # Step 5: Combine losses
    kamma = 1 - un_ratio
    tamma = un_ratio
#     alpha = 0.8
#     beta = 0.2
#     loss = tamma * classification_loss + kamma * (alpha * kl_loss_neighbors + beta * kl_loss_unlearned) 
    loss = tamma * (classification_loss + kl_loss_unlearned) + kamma * kl_loss_neighbors
    loss.backward(retain_graph=True)
    opt.step()

end_time = time.time()
print(f"Elapsed time: {end_time - start_time} seconds")
model_Unlearn.eval()
_data = data

output, _  = model_Unlearn(temp_middle_out.to(device), middle_out.to(device), data_all_edge, unlearning_node)
# output, _  = model_Unlearn(temp_KL_xy_all.to(device), KL_xy_all.to(device), data_all_edge, unlearning_node)
y_true = _data.y.cpu().unsqueeze(-1).to(device)
y_pred = output.argmax(dim=-1, keepdim=True).to(device)
results = []
for mask in [_data.train_mask, _data.test_mask]:# 分别计算test和train的结果
    results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]
print("==============results==============")
print(results)
results_unlearn = results

y_pred = middle_out.argmax(dim=-1, keepdim=True).to(device)
results_middle = []
for mask in [_data.train_mask, _data.test_mask]:# 分别计算test和train的结果
    results_middle += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]

test_f1_test_unlearn = f1_score(
    data.y[data.test_mask].cpu().numpy(), 
    output[data.test_mask].argmax(axis=1).cpu().numpy(), 
    average="micro"
)

# Step 6: Evaluation
print("combined_mask: ")
print("num: ", i)
print("KL divergence of a and b:{}".format(KLloss(F.log_softmax(KL_xy_scratch[combined_mask].detach(), dim=-1), F.log_softmax(KL_xy_all[combined_mask].detach(), dim=-1))))
print("KL divergence of b and c:{}".format(KLloss(F.log_softmax(KL_xy_scratch[combined_mask].detach().to(device), dim=-1), F.log_softmax(output[combined_mask].detach().to(device), dim=-1))))
print("KL divergence of a and c:{}".format(KLloss(F.log_softmax(KL_xy_all[combined_mask].detach().to(device), dim=-1), F.log_softmax(output[combined_mask].detach().to(device), dim=-1))))
print("KL divergence of b and d:{}".format(KLloss(F.log_softmax(KL_xy_scratch[combined_mask].detach().to(device), dim=-1), F.log_softmax(middle_out[combined_mask].detach().to(device), dim=-1))))
print("KL divergence of d and c:{}".format(KLloss(F.log_softmax(middle_out[combined_mask].detach().to(device), dim=-1), F.log_softmax(output[combined_mask].detach().to(device), dim=-1))))

print("all_mask: ")
print("KL divergence of a and c:{}".format(KLloss(F.log_softmax(KL_xy_all.detach().to(device), dim=-1), F.log_softmax(output.detach().to(device), dim=-1))))
out = output

filename = 'gat_unlearning_20250319.pdf'
visualize(out, color=data.y, filename=filename)
print("**************train from all**************")
print(train_from_all_results)
print("**************train from scratch**************")
print(train_from_scratch_results)
print("**************middle-out**************")
print(results_middle)
print("**************unlearn**************")
print(results_unlearn)

print("(((((((((((((f1-score))))))))))))))))")
print("**************train from all**************")
print(test_f1_train_from_all)
print("**************train from scratch**************")
print(test_f1_train_from_scratch)
print("**************middle-out**************")
print(test_f1_test_middle)
print("**************unlearn**************")
print(test_f1_test_unlearn)

# MIA测试相关代码
def get_posteriors(model, data, test_loader, edge_weight, device='cpu'):
    """
    获取模型对全部节点的后验概率
    """
    model.eval()
    with torch.no_grad():
        if isinstance(model, GIN_Unlearn):
            output, _ = model(temp_KL_xy_all.to(device), KL_xy_all.to(device), data.edge_index, nodes_to_remove)
            posteriors = torch.softmax(output, dim=1)
        elif target_model in ['GCN', 'SGC']:
            _, out = model.inference(data.x, test_loader, edge_weight, device)
            posteriors = torch.softmax(out, dim=1)
        else:
            _, out = model.inference(data.x, test_loader, device)
            posteriors = torch.softmax(out, dim=1)
    return posteriors.cpu().numpy()

def mia_attack(original_probs, unlearn_model, data, train_mask, test_mask, test_loader, edge_weight=None, device='cpu'):
    """
    MIA攻击流程
    """
    # 获取原始模型的后验概率
    # original_probs = get_posteriors(original_model, data, test_loader, edge_weight, device)
    # 获取未学习模型的后验概率
    unlearn_probs = get_posteriors(unlearn_model, data, test_loader, edge_weight, device)
    # 计算L2距离（特征差异）
    distances = np.linalg.norm(original_probs - unlearn_probs, axis=1)
    
    # 获取被遗忘节点索引（正样本）
    unlearn_indices = nodes_to_remove.cpu().numpy()
    
    # 获取测试节点索引（负样本）
    test_indices = torch.where(test_mask.cpu())[0].numpy()
    
    # 平衡正负样本数量
    min_samples = min(len(unlearn_indices), len(test_indices))
    
    # 随机选择相同数量的样本
    unlearn_indices = np.random.choice(unlearn_indices, min_samples, replace=False)
    test_indices = np.random.choice(test_indices, min_samples, replace=False)
    
    print(f"Number of positive samples (unlearned nodes): {len(unlearn_indices)}")
    print(f"Number of negative samples (test nodes): {len(test_indices)}")
    
    # 构造特征和标签
    features = np.concatenate([
        distances[unlearn_indices],  # 正样本：被遗忘节点的距离
        distances[test_indices]      # 负样本：测试节点的距离
    ])
    
    labels = np.concatenate([
        np.ones(len(unlearn_indices)),  # 正样本标签1
        np.zeros(len(test_indices))     # 负样本标签0
    ])
    # 计算AUC
#     print("\nL2 Distances for Positive Samples (Unlearned Nodes):")
#     print(distances[unlearn_indices])
#     print("\nL2 Distances for Negative Samples (Test Nodes):")
#     print(distances[test_indices])
    
    auc = roc_auc_score(labels, features, max_fpr=0.1)
    return auc

unlearn_probs= torch.softmax(out.detach(), dim=1).to(device)
model_probs  = torch.softmax(KL_xy_all.detach(), dim=1).to('cpu')
scratch_probs= torch.softmax(KL_xy_scratch.detach(), dim=1).to(device)

# 运行MIA攻击
print("\n============== MIA Attack Results Model_vs_Unlearn ==============")
mia_auc = mia_attack(
    original_probs=model_probs,
    unlearn_model=model_Unlearn,
    data=data,
    train_mask=data.train_mask,
    test_mask=data.test_mask,
    test_loader=test_loader_1,
    edge_weight=None,
    device=device
)
print(f"MIA Attack AUC: {mia_auc:.4f}")

print("\n============== MIA Attack Results Model_vs_Scratch ==============")
mia_auc = mia_attack(
    original_probs=model_probs,
    unlearn_model=model_scratch,
    data=data,
    train_mask=data.train_mask,
    test_mask=data.test_mask,
    test_loader=test_loader_1,
    edge_weight=None,
    device=device
)
print(f"MIA Attack AUC: {mia_auc:.4f}")

# 计算不同模型之间的KL散度
print("\n============== KL Divergence Results ==============")
print("KL divergence between original and unlearned model:")
print(KLloss(F.log_softmax(KL_xy_all[combined_mask].detach().to(device), dim=-1), 
            F.log_softmax(out[combined_mask].detach().to(device), dim=-1)).item())

print("\nKL divergence between scratch and unlearned model:")
print(KLloss(F.log_softmax(KL_xy_scratch.detach().to(device), dim=-1), 
            F.log_softmax(out.detach().to(device), dim=-1)).item())