"""
Utils script containing utils function to be used
"""

import torch
import torch.nn.functional as F
import numpy as np
from torch.nn import Parameter
# from parseval_constraint import parseval_weight_projections

def train(model_local, optimizer_local, data_local, adj_normalized):
    model_local.train()
    optimizer_local.zero_grad()
    out = model_local(data_local.x, adj_normalized)
    loss = F.cross_entropy(out[data_local.train_mask], data_local.y[data_local.train_mask])
    loss.backward()
    optimizer_local.step()
    return model_local, float(loss)


@torch.no_grad()
def test(model_local, data_local, adj_normalized):
    model_local.eval()
    pred = out = model_local(data_local.x, adj_normalized).argmax(dim=-1)

    accs = []
    for mask in [data_local.train_mask, data_local.val_mask, data_local.test_mask]:
        accs.append(int((pred[mask] == data_local.y[mask]).sum()) / int(mask.sum()))
    return accs

def compute_acc_perturbation(model_local, data_local, data_perturbed, adj_local):

    model_local.eval()
    out_1 = model_local(data_local.x, adj_local)
    pred_1 = out_1.argmax(dim=-1)
    acc_1 = int((pred_1[data_local.test_mask] == data_local.y[data_local.test_mask]).sum()) / int(data_local.test_mask.sum())

    out_2 = model_local(data_perturbed.x, adj_local)
    pred_2 = out_2.argmax(dim=-1)
    acc_2 = int((pred_2[data_local.test_mask] == data_local.y[data_local.test_mask]).sum()) / int(data_local.test_mask.sum())

    return acc_1, acc_2, out_1, out_2


def split(dataset, split_type="random", num_train_per_class=20, num_val=500, num_test=1000):
    data = dataset.get(0)
    if split_type=="public" and hasattr(data, "train_mask"):
        train_mask = data.train_mask
        val_mask = data.val_mask
        test_mask = data.test_mask
    else:
        train_mask = torch.zeros_like(data.y, dtype=torch.bool)
        val_mask = torch.zeros_like(data.y, dtype=torch.bool)
        test_mask = torch.zeros_like(data.y, dtype=torch.bool)

        for c in range(dataset.num_classes):
            idx = (data.y == c).nonzero(as_tuple=False).view(-1)
            idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]]
            train_mask[idx] = True

        remaining = (~train_mask).nonzero(as_tuple=False).view(-1)
        remaining = remaining[torch.randperm(remaining.size(0))]

        val_mask[remaining[:num_val]] = True
        test_mask[remaining[num_val:num_val + num_test]] = True
    return (train_mask, val_mask, test_mask)


class PGD():
    """
    Proximal Gradient attack
    adapted from : https://github.com/DSE-MSU/DeepRobust
    ---
    Budget : Budget of the attack to be generated
    epoch_iter : Number of PGD iterations to be used
    """

    def __init__(self, model_local, data_local, norm_adj, budget, epoch_iter = 50):
        self.model_local = model_local
        self.data_local = data_local
        self.budget = budget
        self.epoch_iter = epoch_iter
        self.budget = budget
        self.norm_adj = norm_adj

    def attack(self):
        self.model_local.eval()

        perturb = Parameter(torch.zeros(self.data_local.x.shape[0], \
                    self.data_local.x.shape[1])).to(self.data_local.x.device)

        for t in range(self.epoch_iter):

            temp_x = self.data_local.x + perturb
            out = self.model_local(temp_x, self.norm_adj)
            loss = F.cross_entropy(out[self.data_local.train_mask], self.data_local.y[self.data_local.train_mask])
            x_grad = torch.autograd.grad(loss, perturb)[0]

            lr = self.epoch_iter / np.sqrt(t+1)
            perturb.data.add_(lr * x_grad)

            perturb.data.copy_(torch.clamp(perturb.data, min=0, max=1))

        return self.project_perturb(perturb)

    def project_perturb(self, perturbation):
        norm_val = (self.budget * self.data_local.x.norm()) /  perturbation.norm()
        perturbation.data = (perturbation.data * norm_val)

        return perturbation
if __name__ == "__main__":
    pass
