"""
Script to attack a GCN using PGD
"""

import argparse
import copy

import torch
import torch.nn.functional as F
from tqdm import tqdm
from src.models.gcn import GCN
from src.models.utils import train_function, test_function, classification_loss

from src.models.utils import normalize_tensor_adj

from datasets.loader import data_loader

from src.attacks.topological_attacks import pgd_attack

import warnings
warnings.filterwarnings('ignore')


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--name_dataset', type = str, default='PROTEINS', help='Data set')
    parser.add_argument('--model', type = str, default='GCN', help='Model type')
    parser.add_argument('--hidden_dim', type = int, default='32', help='Hidden dim')
    parser.add_argument('--pooling', type = str, default='sum', help='Pooling type')
    args = parser.parse_args()

    data = data_loader(args.name_dataset)

    # Define params
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    batch_size = 32
    training_epochs = 101
    lr = 1e-03
    fold = 0

    # Define the model
    model = GCN(data.input_dim, args.hidden_dim,
                data.num_classes, device, args.pooling).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # train model
    model_prediction = train_function(model, optimizer, data, fold, device,
                            num_epochs=training_epochs, batch_size=batch_size)

    _, _, _, _, _, _, Adj_test, X_test, y_test = data.get_fold_data(fold)

    test_acc = test_function(model_prediction, data, \
                        Adj_test, X_test, y_test, device, \
                        batch_size=batch_size, verbose=True)


    budget = 0.3
    num_epochs = 200

    # Attack the model
    model_prediction.eval()
    success = 0

    for i in tqdm(range(len(Adj_test))):
        ori_adj, x, y = Adj_test[i], X_test[i], y_test[i]
        x = x.to(device)
        y = y.to(device)

        # Calculate the perturbations in terms of edges
        n_perturbations = int(ori_adj.sum() // 2 * budget) + 1

        adj = copy.deepcopy(ori_adj)
        pred_1 = model_prediction.predict(adj, x)

        # If the adj is correctly classified, we proceed to the attack
        if pred_1.detach().max(1)[1] == y:
            attacker = pgd_attack(ori_adj, x, y, model_prediction, device)
            attacker.attack(num_epochs, n_perturbations)

            adj_attacked = copy.deepcopy(attacker.modified_adj)
            pred_2 = model_prediction.predict(adj_attacked, x)

            if pred_2.detach().max(1)[1] != y:
                # Attack success
                success += 1

        else:
            # This is needed for the final accuracy
            success += 1

    # The attacked accuracy is the success attack and the non-well classified
    success_rate = 1 - success / len(Adj_test)

    print("Budget: {} - Attack succes rate: {}" .format(budget, success_rate))
