import torch.nn
import torch_geometric
from  Gnn_Model import GCN_Res,GCN,SGCN,GCNII
import numpy as np
from torch_geometric.datasets import Planetoid,OGB_MAG,Coauthor
from torch_geometric.transforms import AddSelfLoops, NormalizeFeatures
from  Attack_Algorithm import  LinfPGDAttack
from Data_Split import random_coauthor_amazon_splits
from torch_geometric.nn import GCNConv,GNNFF, aggr
import torch.nn.functional as F

def model_train(model, data, Features):
    data.x = Features
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask]).to(device)
    loss.backward()
    optimizer.step()

    return loss


def model_train_regular(model, data, Features):
    data.x = Features
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask]).to(device)
    # print(model.state_dict())
    param = model.state_dict()['conv1.lin.weight']   # 'conv1.lin.weight'  'linfea.weight'
    regulari_loss = torch.sum(torch.abs(param))
    loss += 0.05 * regulari_loss
    loss.backward()
    optimizer.step()

    return loss


def model_val(model, data,Features,flag = 0):
    data.x = Features
    model.eval()
    out = model(data.x,data.edge_index)
    pred = out.argmax(dim = 1)
    if flag == 0:
        mask = data.test_mask
    else :
        mask = data.train_mask
    correct = pred[mask] == data.y[mask]
    acc = int(correct.sum()) / int(mask.sum())

    return acc

def train_num(dataset, rate):   #  split label rate
    if rate == 0:
        return dataset.train_mask
    count = np.count_nonzero(dataset.train_mask.to(device1))
    index = torch.argwhere(dataset.train_mask == False)
    num = dataset.train_mask.shape[0]
    num_change = int(num * rate - count)

    for i in range(num_change):
        idx = index[i]
        dataset.train_mask[idx] = True
    return dataset.train_mask



device = torch.device('cuda')
# device1 = torch.device('cpu')

runs = 10
epochs = 600
criterion = torch.nn.CrossEntropyLoss()

epsilon_list = [0.0, 0.002,0.004,0.006,0.008,0.010,0.012,0.014,0.016]


train_acc = []
train_std = []
ad_ave_acc = []
ad_std_acc = []


for eps_idx in range(len(epsilon_list)):

    eps = epsilon_list[eps_idx]

    dataset = Coauthor(name='Physics', root='./Dataset/Physics')
    data = random_coauthor_amazon_splits(dataset[0], 8)
    dataset = data
    dataset.num_classes = 15  # 15  5
    dataset.num_features = 6805  # 6805 8415  Physics
    print(data)

    # dataset = Planetoid(root='./Dataset', name='Citeseer', transform=NormalizeFeatures())
    # data = dataset[0]

    num_classes = dataset.num_classes
    num_features = dataset.num_features
    run_acc = []
    att_acc = []
    for run_i in range(runs):
        model = GCN(num_features=num_features, num_classes=num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
        Attack = LinfPGDAttack(model=model, loss=criterion, epsilon=eps, alpha=float(eps / 128), k=5, random_start=True,
                               batch_size=32)
        print('*'*50)
        print(f'Train Run: {run_i: d}, Eps: {eps: .4f} ')
        print('*' * 50)
        epoch_i = 0

        # train_mask = train_num(data, rate=0.1)   # different rate

        Adv_Fea_Matrix_train = Attack.pertub(input=data.x, edges=data.edge_index, label=data.y,mask=data.train_mask)

        for epoch_i in range(epochs):

            loss = model_train(model, data.to(device), Adv_Fea_Matrix_train.to(device))
            print(f'Epoch: {epoch_i: d}, Loss: {loss: .4f} ')

        '........Val..........'
        Adv_Fea_Matrix_val = Attack.pertub(input=data.x, edges=data.edge_index, label=data.y, mask=data.val_mask)
        val_acc = model_val(model,data.to(device), Adv_Fea_Matrix_train.to(device), flag=1)
        run_acc.append(val_acc)

        '.........test........'
        Adv_Fea_Matrix_test = Attack.pertub(input=data.x, edges= data.edge_index, label= data.y, mask = data.test_mask)
        test_acc = model_val(model,data.to(device), Adv_Fea_Matrix_test.to(device), flag=0)
        att_acc.append(test_acc)

    train_acc.append(np.mean(run_acc))
    train_std.append(np.std(run_acc))
    ad_ave_acc.append(np.mean(att_acc))
    ad_std_acc.append(np.std(att_acc))

print(train_acc)
print(train_std)
print(ad_ave_acc)
print(ad_std_acc)

np.save('Results/Physics_mean_or.npy', train_acc)
np.save('Results/Physics_std_or.npy', train_std)
np.save('Results/adv_Physics_mean_or.npy', ad_ave_acc)
np.save('Results/adv_Physics_std_or.npy', ad_std_acc)



