import torch.nn
import torch_geometric
from  Gnn_Model import GCN_Res,GCN,SGCN,GCNII
import numpy as np
from torch_geometric.datasets import Planetoid,Reddit,OGB_MAG
from torch_geometric.transforms import AddSelfLoops, NormalizeFeatures
from Large_Attack import  LinfPGDAttack
from torch_geometric.nn import GCNConv,GNNFF, aggr
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
import torch_geometric.transforms as T
import torch.nn.functional as F



evaluator = Evaluator(name='ogbn-arxiv')

def model_train(model, data, Features):
    data.x = Features
    model.train()
    optimizer.zero_grad()
    # out = model(data.x, data.adj_t)
    # loss = criterion(out[data.train_mask], data.y.squeeze(1)[data.train_mask]).to(device)
    out = model(data.x, data.adj_t)[data.train_mask]
    loss = F.nll_loss(out, data.y.squeeze(1)[data.train_mask]).to(device)
    loss.backward()
    optimizer.step()

    return loss


def model_train_regular(model, data, Features):
    data.x = Features
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.adj_t)[data.train_mask]
    # loss = criterion(out[data.train_mask], data.y.squeeze(1)[data.train_mask]).to(device)
    loss = F.nll_loss(out, data.y.squeeze(1)[data.train_mask]).to(device)
    param = model.state_dict()['conv1.lin.weight']
    regulari_loss = torch.sum(torch.abs(param))
    loss += 0.05 * regulari_loss
    loss.backward()
    optimizer.step()

    return loss


def model_val(model, data,Features,flag = 0):
    data.x = Features
    model.eval()
    out = model(data.x,data.adj_t)
    pred = out.argmax(dim=-1, keepdim=True)
    if flag == 0:
        mask = data.test_mask
    else :
        mask = data.train_mask
    acc = evaluator.eval({
        'y_true': data.y[mask],
        'y_pred': pred[mask],
    })['acc']
    # correct = pred[mask] == data.y[mask]
    # acc = int(correct.sum()) / int(mask.sum())

    return acc

def train_num(dataset, rate):  #  split label rate
    if rate == 0:
        return dataset.train_mask
    count = np.count_nonzero(dataset.train_mask.to(device1))
    index = torch.argwhere(dataset.train_mask == False)
    num = dataset.train_mask.shape[0]
    num_change = int(num * rate - count)

    for i in range(num_change):
        idx = index[i]
        dataset.train_mask[idx] = True
    return dataset.train_mask

device1 = torch.device('cpu')

device = torch.device('cuda')

runs = 10
epochs = 600



epsilon_list = [0.0, 0.002,0.004,0.006,0.008,0.010,0.012,0.014,0.016]
train_acc = []
train_std = []
ad_ave_acc = []
ad_std_acc = []


for eps_idx in range(len(epsilon_list)):

    eps = epsilon_list[eps_idx]
    dataset = PygNodePropPredDataset(name='ogbn-arxiv', transform=T.ToSparseTensor())
    data = dataset[0]
    print(data)
    split_idx = dataset.get_idx_split()
    data.train_mask = split_idx['train']
    data.val_mask = split_idx['valid']
    data.test_mask = split_idx['test']




    num_classes = dataset.num_classes
    num_features = dataset.num_features
    run_acc = []
    att_acc = []
    for run_i in range(runs):
        model = GCN(num_features=num_features, num_classes=num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
        Attack = LinfPGDAttack(model=model, epsilon=eps, alpha=float(eps / 128), k=5, random_start=True,
                               batch_size=32)
        print('*'*50)
        print(f'Train Run: {run_i: d}, Eps: {eps: .4f} ')
        print('*' * 50)
        epoch_i = 0

        # train_mask = train_num(data, rate=0.7)


        Adv_Fea_Matrix_train = Attack.pertub(input=data.x, edges=data.adj_t, label=data.y.squeeze(1), mask=data.train_mask)

        for epoch_i in range(epochs):

            loss = model_train(model, data.to(device), Adv_Fea_Matrix_train.to(device))
            print(f'Epoch: {epoch_i: d}, Loss: {loss: .4f} ')

        '........Val..........'
        Adv_Fea_Matrix_val = Attack.pertub(input=data.x, edges=data.adj_t, label=data.y.squeeze(1), mask=data.val_mask)
        val_acc = model_val(model,data.to(device), Adv_Fea_Matrix_train.to(device), flag=1)
        run_acc.append(val_acc)

        '.........test........'
        Adv_Fea_Matrix_test = Attack.pertub(input=data.x, edges= data.adj_t, label= data.y.squeeze(1), mask = data.test_mask)
        test_acc = model_val(model, data.to(device), Adv_Fea_Matrix_test.to(device), flag=0)
        att_acc.append(test_acc)

    train_acc.append(np.mean(run_acc))
    train_std.append(np.std(run_acc))
    ad_ave_acc.append(np.mean(att_acc))
    ad_std_acc.append(np.std(att_acc))

print(train_acc)
print(train_std)
print(ad_ave_acc)
print(ad_std_acc)

np.save('Result/ogbn_mean_or.npy', train_acc)
np.save('Result/ogbn_std_or.npy', train_std)
np.save('Result/adv_ogbn_mean_or.npy', ad_ave_acc)
np.save('Result/adv_ogbn_std_or.npy', ad_std_acc)



