import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid,Coauthor
from torch_geometric.loader import NeighborSampler
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import to_networkx
import networkx as nx
import random
from torch.nn import Linear
from sklearn.metrics import f1_score
import time
from lib_utils import utils
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GINConv
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torch_scatter import scatter_add
import argparse
parser = argparse.ArgumentParser(description="GAT Unlearning Attack")
parser.add_argument("--ur", type=float, default=0.3, help="Unlearning ratio")

args = parser.parse_args()
print(torch.__version__)
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

r = 100
un_ratio = args.ur
batch_size = 512
target_model = 'GIN'
# 导入Cora数据集
dataset = Planetoid(root='data/citeseer', name="citeseer")  # root: 指定路径 name: 数据集名称
# dataset = Planetoid(root='data/Cora', name="Cora")  # root: 指定路径 name: 数据集名称
# dataset = Coauthor(root='data/CS', name="CS")  # root: 指定路径 name: 数据集名称
print(un_ratio)
# 查看数据的基本情况
print("网络数据包含的类数量:", dataset.num_classes)
print("网络数据边的特征数量:", dataset.num_edge_features)
print("网络数据边的数量:", dataset[0].edge_index.shape[1] / 2)  # 除以2是OOC的组织形式
print("网络数据节点的特征数量:", dataset.num_node_features)
print("网络数据节点的数量:", dataset[0].x.shape[0])
print("网络节点标签的数量:", len(dataset[0].y))

# visualization
def visualize(out, color, filename):
    z = TSNE(n_components=2).fit_transform(out.detach().cpu().numpy())
    plt.figure(figsize=(10,10))
    plt.grid(True, linestyle='--', color='gray', linewidth=0.5)
    plt.scatter(z[:, 0], z[:, 1], s=18, c=color.cpu(), alpha=0.9, cmap="Set2")
    plt.show()
    plt.savefig(filename, bbox_inches='tight')
    
def evaluate_model(model, _data, filename, test_loader):
    model.eval()
    edge_index = _data.edge_index   
    if target_model in ['GCN','SGC']:
        _, out = model.inference(_data.x, test_loader, edge_weight, device)
        visualize(out, color=_data.y, filename=filename)
    else:
        _, out = model.inference(_data.x, test_loader, device)
        visualize(out, color=_data.y, filename=filename)

    y_true = _data.y.cpu().unsqueeze(-1).to(device)
    y_pred = out.argmax(dim=-1, keepdim=True).to(device)
    results = []
    for mask in [_data.train_mask, _data.test_mask]:# 分别计算test和train的结果
        results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]
    
    # print("results")
    # print(results)
    return results

data= dataset[0].to(device)

train_indices, test_indices = train_test_split(np.arange((data.num_nodes)), test_size=0.1, random_state=100)
data.train_mask = torch.from_numpy(np.isin(np.arange(data.num_nodes), train_indices))
data.test_mask = torch.from_numpy(np.isin(np.arange(data.num_nodes), test_indices))

data_scratch = data
train_mask = data_scratch.train_mask                       # 获取train_mask
train_nodes = torch.nonzero(train_mask).squeeze(1) # 获取train_mask中为True的节点索引
num_nodes_to_remove = int(len(train_nodes) * un_ratio)  # 计算需要随机置为False的节点数量（10%的节点）

random.seed(2532)
nodes_to_remove = random.sample(list(train_nodes), num_nodes_to_remove) # 随机选取需要置为False的节点
nodes_to_remove = torch.tensor(nodes_to_remove, dtype=torch.long).to(device)
# print("nodes_to_remove", nodes_to_remove)
# print("labels for nodes to remove:", data.y[nodes_to_remove])
labels_to_modify = data.y[nodes_to_remove]
modified_labels = (labels_to_modify + 1) % dataset.num_classes # 修改标签
data.y[nodes_to_remove] = modified_labels  # 更新标签

# 输出修改后的标签
# print("Updated labels for nodes to remove:", data.y[nodes_to_remove])

# train from all
# batch version
class GINNet(torch.nn.Module):
    def __init__(self):
        super(GINNet, self).__init__()

        dim = 32
        self.num_layers = 2

        nn1 = Sequential(Linear(dataset.num_features, dim), ReLU(), Linear(dim, dim))
        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))

        self.convs = torch.nn.ModuleList()
        self.convs.append(GINConv(nn1))
        self.convs.append(GINConv(nn2))

        self.bn = torch.nn.ModuleList()
        self.bn.append(torch.nn.BatchNorm1d(dim))
        self.bn.append(torch.nn.BatchNorm1d(dim))

        self.fc1 = Linear(dim, dim)
        self.fc2 = Linear(dim, dataset.num_classes)

    def forward(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
            x = self.bn[i](x)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)

    def inference(self, x_all, subgraph_loader, device):
        x_last = []
        for i in range(self.num_layers):
            xs = []
            if i==1:
                x_last = x_all # mark the last layer output
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = self.convs[i]((x, x_target), edge_index)
                if i != self.num_layers - 1:
                    x = F.relu(x)
                x = self.bn[i](x)
                xs.append(x)
            x_all = torch.cat(xs, dim=0)
        x_all = F.relu(self.fc1(x_all))
        x_all = self.fc2(x_all)
        return x_last, x_all

    def reset_parameters(self):
        self.fc1.reset_parameters()
        self.fc2.reset_parameters()
        for i in range(self.num_layers):
            self.convs[i].reset_parameters()
            self.bn[i].reset_parameters()
            
model = GINNet().to(device)
criterion = torch.nn.CrossEntropyLoss()                                   # Define loss criterion.
opt = torch.optim.Adam(model.parameters(), lr=0.05, weight_decay=0.0001)  # Define optimizer.

train_indices = np.nonzero(data.train_mask.cpu().numpy())[0]
edge_index = utils.filter_edge_index(data.edge_index, train_indices, reindex=False)

if edge_index.shape[1] == 0:
    edge_index = torch.tensor([[1, 2], [2, 1]])

train_loader = NeighborSampler(
    edge_index, node_idx=data.train_mask,
    sizes=[5, 5], num_nodes=data.num_nodes,
    batch_size=batch_size, shuffle=True,
    num_workers=0)

test_loader = NeighborSampler(
            data.edge_index, node_idx=None,
            sizes=[-1], num_nodes=data.num_nodes,
            batch_size=64, shuffle=False,
            num_workers=0)


    
model.reset_parameters()
for epoch in range(r):
    model.train()
    for batch_size, n_id, adjs in train_loader:
        adjs = [adj.to(device) for adj in adjs]
        opt.zero_grad()

        if target_model in ['GCN', 'SGC']:
            out = model(data.x[n_id], adjs, edge_weight)
        else:
            out = model(data.x[n_id], adjs)

        loss = F.nll_loss(out, data.y[n_id[:batch_size]])
        loss.backward()
        opt.step()
        
temp_out, out = model.inference(data.x, test_loader, device)
KL_xy_all = out
temp_KL_xy_all = temp_out
filename = 'gin_train_from_all_batch.pdf'
train_from_all_results = evaluate_model(model, data, filename, test_loader)
# train from all
    
# train from scratch
class GINNet_scratch(torch.nn.Module):
    def __init__(self):
        super(GINNet_scratch, self).__init__()

        dim = 32
        self.num_layers = 2

        nn1 = Sequential(Linear(dataset.num_features, dim), ReLU(), Linear(dim, dim))
        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))

        self.convs = torch.nn.ModuleList()
        self.convs.append(GINConv(nn1))
        self.convs.append(GINConv(nn2))

        self.bn = torch.nn.ModuleList()
        self.bn.append(torch.nn.BatchNorm1d(dim))
        self.bn.append(torch.nn.BatchNorm1d(dim))

        self.fc1 = Linear(dim, dim)
        self.fc2 = Linear(dim, dataset.num_classes)

    def forward(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
            x = self.bn[i](x)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)

    def inference(self, x_all, subgraph_loader, device):
        x_last = []
        for i in range(self.num_layers):
            xs = []
            if i==1:
                x_last = x_all # mark the last layer output
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = self.convs[i]((x, x_target), edge_index)
                if i != self.num_layers - 1:
                    x = F.relu(x)
                x = self.bn[i](x)
                xs.append(x)
            x_all = torch.cat(xs, dim=0)
        x_all = F.relu(self.fc1(x_all))
        x_all = self.fc2(x_all)
        return x_last, x_all

    def reset_parameters(self):
        self.fc1.reset_parameters()
        self.fc2.reset_parameters()
        for i in range(self.num_layers):
            self.convs[i].reset_parameters()
            self.bn[i].reset_parameters()
            
model_scratch = GINNet_scratch().to(device)
criterion = torch.nn.CrossEntropyLoss()                                   # Define loss criterion.
opt = torch.optim.Adam(model_scratch.parameters(), lr=0.05, weight_decay=0.0001)  # Define optimizer.

train_mask[nodes_to_remove] = False                                                         # 将这些节点的train_mask置为False
data_scratch.train_mask = train_mask                                                        # 更新回数据对象
combined_mask = torch.logical_or(data_scratch.train_mask, data_scratch.test_mask).to(device)
train_indices = np.nonzero(data_scratch.train_mask.cpu().numpy())[0]                        # 训练索引

edge_index_0 = data_scratch.edge_index
left_mask  = torch.any(torch.eq(edge_index_0[0].unsqueeze(1), nodes_to_remove), dim=1)
right_mask = torch.any(torch.eq(edge_index_0[1].unsqueeze(1), nodes_to_remove), dim=1)
mask = ~(left_mask | right_mask)
data_scratch.edge_index = edge_index_0[:, mask]                       # 保留不涉及被删除节点的边

edge_index = utils.filter_edge_index(data_scratch.edge_index, train_indices, reindex=False) # 对应的边的索引

num_true_nodes = torch.sum(data_scratch.train_mask).item()
# print("训练数据集的节点数量",num_true_nodes)

if edge_index.shape[1] == 0:
    edge_index = torch.tensor([[1, 2], [2, 1]])

train_loader_scratch = NeighborSampler(
    edge_index, node_idx=data_scratch.train_mask,
    sizes=[5, 5], num_nodes=data_scratch.num_nodes,
    batch_size=batch_size, shuffle=True,
    num_workers=0)

test_loader = NeighborSampler(
            data_scratch.edge_index, node_idx=None,
            sizes=[-1], num_nodes=data.num_nodes,
            batch_size=64, shuffle=False,
            num_workers=0)
    
model_scratch.reset_parameters()
for epoch in range(r):
    model_scratch.train()
    for batch_size, n_id, adjs in train_loader_scratch:
        adjs = [adj.to(device) for adj in adjs]
        opt.zero_grad()
        if target_model in ['GCN', 'SGC']:
            out = model_scratch(data_scratch.x[n_id], adjs, edge_weight)
        else:
            out = model_scratch(data_scratch.x[n_id], adjs)
        loss = F.nll_loss(out, data_scratch.y[n_id[:batch_size]])
        loss.backward()
        opt.step()

temp_out, out = model_scratch.inference(data.x, test_loader, device)
KL_xy_scratch = out  
temp_KL_xy_scratch = temp_out
filename = 'gin_train_from_scratch_batch.pdf'
train_from_scratch_results = evaluate_model(model_scratch, data_scratch, filename, test_loader)
# train from scratch

# unlearning
degree = torch.zeros(dataset[0].x.shape[0], dtype=torch.long)
edge_index = dataset[0].edge_index
for i in range(edge_index.shape[1]):
    degree[edge_index[0, i]] += 1
    degree[edge_index[1, i]] += 1
    
ratio = 0.8
top_10_percent_threshold = torch.quantile(degree.float(), ratio)
top_10_percent_indices = torch.argsort(degree, descending=True)[:int(dataset[0].x.shape[0] * (1-ratio))]
top_10_mask = torch.zeros(dataset[0].x.shape[0], dtype=torch.bool)
top_10_mask[top_10_percent_indices] = 1
intersection_mask = top_10_mask & train_mask

class MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.relu = torch.nn.Sigmoid()
        self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model.eval()                       # 设置为评估模式
conv2_linear_layer = model.fc2   # 获取模型第二层conv2的线性层
H = conv2_linear_layer.weight.to(device)  # 维度: [out_features, in_features]
H_pinv = torch.pinverse(H).to(device)         # 维度: [in_features, out_features]
bias = conv2_linear_layer.bias            # 偏置: [bias]

# print("bias:",bias)
# print(type(bias))
if bias == None:
    bias = torch.zeros(dataset.num_classes).to(device)

# edge_n = degree[nodes_to_remove].sum()/num_nodes_to_remove
degree = degree.to(device)
nodes_to_remove = nodes_to_remove.to(device)
edge_n = degree[nodes_to_remove].sum()/num_nodes_to_remove
# print("=====edge_n=====")
# print(edge_n)
w_fix = 1 + 1/edge_n
# print("=====w_fix=====")
# print(w_fix.cpu().numpy())

# hadamard_edge_n = degree[nodes_to_remove]/num_nodes_to_remove
# hadamard_w_fix = 1 + 1/hadamard_edge_n
from torch_geometric.nn.conv.gcn_conv import GCNConv
class GIN_Unlearn(torch.nn.Module):
    def __init__(self, hidden_channels): # hidden channel is the first layer width
        super().__init__()
        torch.manual_seed(1234567)
        self.hidden_channels = hidden_channels
        
        self.influ_ij = MLP(dataset.num_classes, int((dataset.num_classes * hidden_channels) ** 0.5), hidden_channels)           # plus
        self.influ_ji = MLP(2 * hidden_channels, int((2*hidden_channels*dataset.num_classes) ** 0.5), dataset.num_classes)       # minus
        self.H = H.detach().to(device)
        self.H_pinv = H_pinv.detach().to(device)
        self.bias = bias.detach().to(device)

    def forward(self, temp_x, x, edge_index, unlearn): # , edge_weight
        row, col = edge_index 
        x_i = temp_x[row]
        x_j = temp_x[col]
        x_cat_ij = torch.cat([x_i, x_j], dim=-1)  # 拼接后的特征，形状为 [num_edges, 2 * in_channels]
        
        influ_ji = self.influ_ji(x_cat_ij)        # 计算第二层的影响力
        influ_ij = self.influ_ij(influ_ji)        # 计算第一层的影响力
        
        I = torch.eye(self.H_pinv.size(0), device=device)  # 定义单位矩阵
        
        Null_Space = I - self.H_pinv @ self.H
        Range_Space= self.H_pinv.T
        
        influ_ij_RND = influ_ij
        influ_ij_RND = (influ_ji - self.bias) @ self.H_pinv.T + influ_ij @ (I - self.H_pinv @ self.H) 
        
        mask = (col.unsqueeze(1) == unlearn).any(dim=1) # unlearn的部分不考虑
        minus_influ = torch.zeros_like(influ_ji)
        minus_influ[mask] = influ_ji[mask]
        
        # print("influ_ji[mask].shape:")
        # print(influ_ji[mask].shape)
        
        aggregated_matrix = torch.zeros_like(x, device=x.device)          # minus unlearn node influence
        aggregated_matrix = scatter_add(
            src=minus_influ,       # 需要聚合的数据 [num_edges, out_channels]
            index=row,             # 目标节点索引 [num_edges]
            dim=0,                 # 节点维度
            dim_size=x.size(0)     # 确保输出维度匹配
        )
        
        aggregated_matrix_1 = torch.zeros(x.size(0), self.hidden_channels, device=x.device)
        aggregated_matrix_1 = scatter_add(
            src=influ_ij_RND,      # 需要聚合的数据 [num_edges, out_channels]
            index=row,             # 目标节点索引 [num_edges]
            dim=0,                 # 节点维度
            dim_size=x.size(0)     # 确保输出维度匹配
        )
        x = x - w_fix * aggregated_matrix  # 如果这个地方我换一个输出会怎么样？
        return x, aggregated_matrix_1      # 我感觉可以试试直接把influ_ij的结果输出
    
KLloss = torch.nn.KLDivLoss(reduction="mean", log_target=True)
KLloss_p = torch.nn.KLDivLoss(reduction='none', log_target=True)

model_Unlearn = GIN_Unlearn(hidden_channels=32).to(device)
opt = torch.optim.Adam(model_Unlearn.parameters(), lr=0.01, weight_decay = 5e-4)  # Define optimizer.

x, edge_index, y = data.x, data.edge_index, data.y
unlearning_node = nodes_to_remove
    
model.eval()
temp_out, out = model.inference(data.x, test_loader, device)
middle_out = out
original_neighbors = out[intersection_mask].to(device)
original_unlearned = temp_out[unlearning_node].to(device)
original_unlearned = temp_out.to(device)
    
results_unlearn = []
out = None
for i in range(2):
    model_Unlearn.train()
    opt.zero_grad()
    data_all_edge = dataset[0].edge_index.to(device)
    output, hidden_output = model_Unlearn.forward(temp_KL_xy_all.to(device), middle_out.to(device), data_all_edge, unlearning_node)
#     output = output.to(device)
    
    output_neighbors = output[intersection_mask].to(device)
    output_unlearned = hidden_output[unlearning_node].to(device)
    output_unlearned = hidden_output.to(device)
    
    # Step 3: Compute KL divergence losses
    weighted_kl_loss_neighbors = KLloss_p(F.log_softmax(original_neighbors.detach(), dim=-1), F.log_softmax(output_neighbors, dim=-1)) * degree[intersection_mask].view(-1, 1).to(device)
    kl_loss_neighbors = weighted_kl_loss_neighbors.mean()  
    print("kl_loss_neighbors")
    print(kl_loss_neighbors.item())

    # 3.2: KL divergence for unlearned nodes
    weighted_kl_loss_unlearned = KLloss_p(F.log_softmax(original_unlearned.detach(), dim=-1), F.log_softmax(output_unlearned, dim=-1)) 
    kl_loss_unlearned = weighted_kl_loss_unlearned.mean()
    print("kl_loss_unlearned")
    print(kl_loss_unlearned.item())
 
    # Step 4: Compute classification loss
    classification_loss = -F.cross_entropy(F.log_softmax(output[~combined_mask], dim=-1), y[~combined_mask])
    print("classification_loss")
    print(classification_loss.item())
    
    # Step 5: Combine losses
    kamma = 0.7
    tamma = 0.3
    alpha = 0.4
    beta = 0.5
    loss = tamma * classification_loss + kamma * (alpha * kl_loss_neighbors + beta * kl_loss_unlearned) 
#     loss = tamma * classification_loss + alpha * kl_loss_neighbors + beta * kl_loss_unlearned
    loss.backward(retain_graph=True)
    opt.step()
    
    model_Unlearn.eval()
    _data = data

    output, _  = model_Unlearn(temp_KL_xy_all.to(device), middle_out.to(device), data_all_edge, unlearning_node)
    y_true = _data.y.cpu().unsqueeze(-1).to(device)
    y_pred = output.argmax(dim=-1, keepdim=True).to(device)
    results = []
    for mask in [_data.train_mask, _data.test_mask]:# 分别计算test和train的结果
        results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]
    print("==============results==============")
    print(results)
    results_unlearn = results
    
    y_pred = middle_out.argmax(dim=-1, keepdim=True).to(device)
    results_middle = []
    for mask in [_data.train_mask, _data.test_mask]:# 分别计算test和train的结果
        results_middle += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]
    
    # Step 6: Evaluation
    print("combined_mask: ")
    print("num: ", i)
    print("KL divergence of a and b:{}".format(KLloss(F.log_softmax(KL_xy_scratch[combined_mask].detach(), dim=-1), F.log_softmax(KL_xy_all[combined_mask].detach(), dim=-1))))
    print("KL divergence of b and c:{}".format(KLloss(F.log_softmax(KL_xy_scratch[combined_mask].detach().to(device), dim=-1), F.log_softmax(output[combined_mask].detach().to(device), dim=-1))))
    print("KL divergence of a and c:{}".format(KLloss(F.log_softmax(KL_xy_all[combined_mask].detach().to(device), dim=-1), F.log_softmax(output[combined_mask].detach().to(device), dim=-1))))
    print("KL divergence of b and d:{}".format(KLloss(F.log_softmax(KL_xy_scratch[combined_mask].detach().to(device), dim=-1), F.log_softmax(middle_out[combined_mask].detach().to(device), dim=-1))))
    print("KL divergence of d and c:{}".format(KLloss(F.log_softmax(middle_out[combined_mask].detach().to(device), dim=-1), F.log_softmax(output[combined_mask].detach().to(device), dim=-1))))
    
    print("all_mask: ")
    print("KL divergence of a and c:{}".format(KLloss(F.log_softmax(KL_xy_all.detach().to(device), dim=-1), F.log_softmax(output.detach().to(device), dim=-1))))
    out = output

filename = 'gin_unlearning_20250319.pdf'
visualize(out, color=data.y, filename=filename)
print("**************train from all**************")
print(train_from_all_results)
print("**************train from scratch**************")
print(train_from_scratch_results)
print("**************middle-out**************")
print(results_middle)
print("**************unlearn**************")
print(results_unlearn)