import random
from copy import deepcopy
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch import Tensor
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from torch_geometric.data import TemporalData
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple

from torch_geometric.contrib.nn import PRBCDAttack
from torch_geometric.utils import coalesce, to_undirected


def add_negative_neighbor(adv_edges, rel_time_sd, n, msg_shape, last_edge_t, device):
    keys = list(adv_edges.keys())
    # sample edge for dtype determination
    s, d, t = keys[0]
    m = adv_edges[(s, d, t)]

    if False:
        # old approach using loops
        # adv_sources = torch.empty(n, dtype=s.dtype, device=device)
        # adv_dests = torch.empty(n, dtype=d.dtype, device=device)
        # adv_t = torch.empty(n, dtype=t.dtype, device=device)
        # adv_msgs = torch.empty((n,msg_shape),dtype=m.dtype, device=device)

        for i in range(n):
            (src, dst, old_t), msg = random.choice(list(adv_edges.items()))
            new_t = None

            if False:
                # make 3 attempts if collision
                for j in range(3):
                    # edge time: prev_time + sample from Gaussian(0, train_time_SD)
                    t = last_edge_t + round(np.random.normal(scale=rel_time_sd))
                    if (src, dst, t) not in adv_edges:
                        new_t = t
                        break
                # 3 collisions - dont add
                if new_t is None:
                    print(f"3 collisions encountered for {src,dst,new_t}. Skipping...")
                    continue

            adv_sources[i] = src
            adv_dests[i] = dst
            adv_msgs[i] = msg
            adv_t[i] = new_t
            del adv_edges[(src, dst, old_t)]
    else:
        # new approach using list comprehension
        def del_all(mapping, to_remove):
            """Remove list of elements from mapping."""
            for key in to_remove:
                del mapping[key]

        selected_events = random.sample(list(adv_edges.keys()), n)
        adv_sources = torch.tensor(
            [s[0] for s in selected_events], device=device, dtype=s.dtype
        )
        adv_dests = torch.tensor(
            [s[1] for s in selected_events], device=device, dtype=d.dtype
        )
        adv_msgs = torch.cat([adv_edges[s] for s in selected_events])
        adv_t = torch.tensor(
            last_edge_t.cpu() + np.random.normal(scale=rel_time_sd, size=(n)).round(),
            dtype=t.dtype,
            device=device,
        )

        del_all(adv_edges, selected_events)

    return adv_sources, adv_dests, adv_t, adv_msgs


def add_negative_neighbor_wrt_pos(
    adv_edges, pos_batch, rel_time_sd, n, msg_shape, last_edge_t, device
):
    keys = list(adv_edges.keys())
    # sample edge for dtype determination
    s, d, t = keys[0]
    m = adv_edges[(s, d, t)]

    candidates = deepcopy(adv_edges)
    current_batch = set(
        zip(pos_batch.src.cpu().numpy(), pos_batch.dst.cpu().numpy())
    )  # {(pos_batch.src[i],pos_batch.dst[i]) for i in range(len(pos_batch.src))}
    candidates = {
        x: candidates[x] for x in candidates if (x[0], x[1]) not in current_batch
    }

    def del_all(mapping, to_remove):
        """Remove list of elements from mapping."""
        for key in to_remove:
            del mapping[key]

    selected_events = random.sample(list(candidates.keys()), n)
    adv_sources = torch.tensor(
        [s[0] for s in selected_events], device=device, dtype=type(s)
    )
    adv_dests = torch.tensor(
        [s[1] for s in selected_events], device=device, dtype=type(d)
    )
    adv_msgs = torch.cat([adv_edges[s] for s in selected_events])
    adv_t = torch.tensor(
        last_edge_t.cpu() + np.random.normal(scale=rel_time_sd, size=(n)).round(),
        dtype=type(t),
        device=device,
    )

    del_all(adv_edges, selected_events)

    return adv_sources, adv_dests, adv_t, adv_msgs


def add_negative_neighbor_wrt_pos_src_neg_sample(
    model,
    data,
    assoc,
    pos_batch,
    n,
    last_edge_t,
    device,
    adv_edges,
    rel_time_sd,
    neg_set,
    *args,
):
    assert type(neg_set) == dict

    pos_src, pos_dst, pos_t, pos_msg = (
        pos_batch.src,
        pos_batch.dst,
        pos_batch.t,
        pos_batch.msg,
    )

    # get top n pos batch edges
    # y_pred = model(data, pos_src, pos_dst, assoc).squeeze().cpu().numpy()
    # top_pos_idx = y_pred.argsort()[:n]

    # random positive
    top_pos_idx = np.random.permutation(len(pos_src))

    # get pos with preds near 0.5
    # candidate_idx = np.arange(len(y_pred))[(y_pred>0.3)& (y_pred<0.7)]
    # top_pos_idx = candidate_idx[np.random.permutation(len(candidate_idx))[:n]]

    adv_src = pos_src[top_pos_idx]

    adv_dest = torch.tensor(
        [
            neg_set[
                (
                    pos_src[idx].cpu().item(),
                    pos_dst[idx].cpu().item(),
                    pos_t[idx].cpu().item(),
                )
            ][0]
            for idx in top_pos_idx
        ],
        dtype=pos_dst.dtype,
        device=device,
    )

    # modify the adv_edges to be of form {(src,dst): msg}
    adv_edges = {(k[0], k[1]): adv_edges[k] for k in adv_edges}

    # retain only those pos,dest for which there is atleast one historical(i.e. discard randoms) neg edge
    hist_neg_idx = [
        i
        for i in range(adv_src.size(0))
        if (adv_src[i].cpu().item(), adv_dest[i].cpu().item()) in adv_edges
    ]

    if len(hist_neg_idx) == 0:
        # no historical edge found for current batch
        return None
    # some historical edges found - choose them
    adv_src = adv_src[hist_neg_idx[:n]]
    adv_dest = adv_dest[hist_neg_idx[:n]]
    logging.info(f"attacking with {len(adv_src)} edges")

    # new n
    n = adv_src.size(0)

    # for time - gaussian distr
    adv_t = torch.tensor(
        last_edge_t.cpu() + np.random.normal(scale=rel_time_sd, size=(n)).round(),
        dtype=pos_t.dtype,
        device=device,
    )

    #  msg - choose msgs from modified adv_edge
    adv_msg = torch.cat(
        [
            adv_edges[(adv_src[i].cpu().item(), adv_dest[i].cpu().item())]
            for i in range(n)
        ]
    )

    return adv_src, adv_dest, adv_t, adv_msg


def add_negative_neighbor_wrt_pos_src_neg_sample_nat(
    pos_batch, n, last_edge_t, device, adv_edges, rel_time_sd, neg_set, *args
):
    assert type(neg_set) == dict
    pos_src, pos_dst, pos_t, pos_eid = pos_batch

    # random positive
    top_pos_idx = np.random.permutation(len(pos_src))

    adv_src = pos_src[top_pos_idx]

    # nat has node ids+1, neg_set has og node_ids
    adv_dest = np.array(
        [
            neg_set[(pos_src[idx] - 1, pos_dst[idx] - 1, pos_t[idx])][0] + 1
            for idx in top_pos_idx
        ],
        dtype=pos_dst.dtype,
    )

    # enable for uci and enron
    # adv_dest = np.array([neg_set[(pos_src[idx], pos_dst[idx], pos_t[idx])][0] for idx in top_pos_idx] , dtype=pos_dst.dtype)

    # modify the adv_edges to be of form {(src,dst): msg}
    # adv_edges = {(k[0],k[1]): adv_edges[k] for k in adv_edges}

    # retain only those pos,dest for which there is atleast one historical(i.e. discard randoms) neg edge
    hist_neg_idx = [
        i for i in range(len(adv_src)) if (adv_src[i], adv_dest[i]) in adv_edges
    ]

    if len(hist_neg_idx) == 0:
        # no historical edge found for current batch
        return None
    # some historical edges found - choose them
    adv_src = adv_src[hist_neg_idx[:n]]
    adv_dest = adv_dest[hist_neg_idx[:n]]
    logging.info(f"attacking with {len(adv_src)} edges")

    # new n
    n = len(adv_src)

    # for time - gaussian distr
    adv_t = np.array(
        last_edge_t + np.random.normal(scale=rel_time_sd, size=(n)).round(),
        dtype=pos_t.dtype,
    )

    #  msg - choose msgs from modified adv_edge
    adv_msg = np.concatenate(
        [adv_edges[(adv_src[i], adv_dest[i])].reshape(1, -1) for i in range(n)]
    )

    return adv_src, adv_dest, adv_t, adv_msg


def random_attack(data, budget, device, is_msg_gaussian=False):
    src_nodes = torch.tensor(
        list(set(data.src.cpu().tolist())), dtype=data.src.dtype
    ).to(device)
    dst_nodes = torch.tensor(
        list(set(data.dst.cpu().tolist())), dtype=data.src.dtype
    ).to(device)

    start_t = data.t[0].item()
    end_t = data.t[-1].item()
    if is_msg_gaussian:
        # sample gaussian noise for msg
        adv_msg = torch.randn(budget, data.msg.shape[1], dtype=data.msg.dtype).to(
            device
        )
    else:
        # sample msg randomly from data
        msg_ind = torch.randperm(len(data.src))
        adv_msg = data.msg[msg_ind[:budget]]
    adv_t = torch.randint(start_t, end_t, (budget,), dtype=data.t.dtype, device=device)
    src_ind = torch.randint(0, len(src_nodes), (budget,))
    dst_ind = torch.randint(0, len(dst_nodes), (budget,))

    adv_src = src_nodes[src_ind]
    adv_dst = dst_nodes[dst_ind]

    return adv_src, adv_dst, adv_t, adv_msg


def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()
    # Create the perturbed signal by adjusting each dimension of the input message
    perturbed_data = image + epsilon * sign_data_grad
    # # Adding clipping to maintain [0,1] range
    # perturbed_data = torch.clamp(perturbed_data, 0, 1)
    # Return the perturbed image
    return perturbed_data


import time, logging


def fgsm_tgnw(
    model,
    data,
    assoc,
    adv_sources,
    adv_dests,
    adv_t,
    adv_msgs,
    epsilon,
    edge_class,
    edge_weights,
    device,
):

    adv_msgs.requires_grad = True
    model.train()

    # add the adv edges to self.data
    cur_idx = model.get_current_edge_index()
    data.src = torch.cat([data.src[:cur_idx], adv_sources, data.src[cur_idx:]])
    data.dst = torch.cat([data.dst[:cur_idx], adv_dests, data.dst[cur_idx:]])
    data.t = torch.cat([data.t[:cur_idx], adv_t, data.t[cur_idx:]])
    data.msg = torch.cat([data.msg[:cur_idx], adv_msgs, data.msg[cur_idx:]])

    # first update memory ?
    # Update memory and neighbor loader with adv. edges .
    model.update_memory(adv_sources, adv_dests, adv_t, adv_msgs, edge_weights)

    model.insert_neighbor(adv_sources, adv_dests)

    edge_index = torch.cat(
        [adv_sources.unsqueeze(0), adv_dests.unsqueeze(0)], axis=0
    ).to(device)
    edge_weight = None
    y_pred = model(data, edge_index, edge_weight, assoc)

    # Calculate the loss
    y_target = torch.ones_like(y_pred) * edge_class
    loss = F.binary_cross_entropy_with_logits(y_pred, y_target)
    # loss = F.binary_cross_entropy(y_pred, y_target)

    # # Update memory and neighbor loader with ground-truth state.
    # model['memory'].update_state(adv_sources, adv_dests, adv_t, adv_msgs)
    # neighbor_loader.insert(adv_sources, adv_dests)

    # Zero all existing gradients
    model.zero_grad()

    # Calculate gradients of model in backward pass
    loss.backward()

    # Collect ``datagrad``
    msg_grad = adv_msgs.grad.data

    perturbed_msgs = fgsm_attack(adv_msgs, epsilon, msg_grad)

    return deepcopy(perturbed_msgs.detach())


def fgsm_tgn(
    model,
    data,
    assoc,
    adv_sources,
    adv_dests,
    adv_t,
    adv_msgs,
    epsilon,
    edge_class,
    edge_weights=None,
):

    adv_msgs.requires_grad = True
    model.train()

    # add the adv edges to self.data
    cur_idx = model.get_current_edge_index()
    data.src = torch.cat([data.src[:cur_idx], adv_sources, data.src[cur_idx:]])
    data.dst = torch.cat([data.dst[:cur_idx], adv_dests, data.dst[cur_idx:]])
    data.t = torch.cat([data.t[:cur_idx], adv_t, data.t[cur_idx:]])
    data.msg = torch.cat([data.msg[:cur_idx], adv_msgs, data.msg[cur_idx:]])

    # first update memory ?
    # Update memory and neighbor loader with adv. edges .
    model.update_memory(adv_sources, adv_dests, adv_t, adv_msgs)
    model.insert_neighbor(adv_sources, adv_dests)

    y_pred = model(data, adv_sources, adv_dests, edge_weight=None, assoc=assoc)

    # Calculate the loss
    y_target = torch.ones_like(y_pred) * edge_class
    loss = F.binary_cross_entropy_with_logits(y_pred, y_target)

    # # Update memory and neighbor loader with ground-truth state.
    # model['memory'].update_state(adv_sources, adv_dests, adv_t, adv_msgs)
    # neighbor_loader.insert(adv_sources, adv_dests)

    # Zero all existing gradients
    model.zero_grad()

    # Calculate gradients of model in backward pass
    loss.backward()

    # Collect ``datagrad``
    msg_grad = adv_msgs.grad.data

    perturbed_msgs = fgsm_attack(adv_msgs, epsilon, msg_grad)

    return deepcopy(perturbed_msgs.detach())


def fgsm_nat(
    model, pos_batch, adv_sources, adv_dests, adv_t, adv_msgs, epsilon, edge_class
):

    pos_src, pos_dst, pos_t, pos_eid = pos_batch
    cur_idx = min(pos_eid)

    current_e_feats = model.e_feat_th.data
    device = current_e_feats.device

    adv_msgs = torch.tensor(adv_msgs, device=device, dtype=current_e_feats.dtype)
    adv_msgs.requires_grad = True

    model.e_feat_th.requires_grad = True
    updated_e_feats = torch.cat(
        [current_e_feats[:cur_idx], adv_msgs, current_e_feats[cur_idx:]]
    )
    model.e_feat_th.data = updated_e_feats

    model.edge_raw_embed = torch.nn.Embedding.from_pretrained(
        model.e_feat_th, padding_idx=0, freeze=False
    )
    adv_e_id = np.arange(len(adv_sources)) + cur_idx
    pos_eid += len(adv_sources)

    model.train()
    # Update memory  with adv. edges .
    model.contrast_modified(
        adv_sources, adv_dests, adv_t, adv_e_id, pos_edge=True, test=True
    )

    # make predictions on positive edges
    # negPos
    # y_pred = model.contrast_modified(pos_src, pos_dst,
    #                             pos_t, pos_eid, pos_edge=True, test=True)
    # y_target = torch.ones_like(y_pred) *  1.0

    # negNeg, posPos
    y_pred = model.contrast_modified(
        adv_sources, adv_dests, adv_t, adv_e_id, pos_edge=True, test=True
    )
    y_target = torch.ones_like(y_pred) * edge_class

    # Calculate the loss

    loss = F.binary_cross_entropy_with_logits(y_pred, y_target)

    # # Update memory and neighbor loader with ground-truth state.
    # model['memory'].update_state(adv_sources, adv_dests, adv_t, adv_msgs)
    # neighbor_loader.insert(adv_sources, adv_dests)

    # Zero all existing gradients
    model.zero_grad()

    # Calculate gradients of model in backward pass
    loss.backward()

    # Collect ``datagrad``
    msg_grad = model.edge_raw_embed.weight.grad[
        cur_idx : cur_idx + len(adv_sources)
    ].data

    perturbed_msgs = fgsm_attack(adv_msgs, epsilon, msg_grad)

    return deepcopy(perturbed_msgs.detach())


#
class PRBCD(PRBCDAttack):

    def __init__(
        self,
        model,
        block_size,
        data,
        bipartite,
        device,
        rel_time_sd,
        edge_label,
        neg_edge_sampling,
        fixed_t_msg,
        optimize_t=False,
        t_exp=0,
        **kwargs,
    ):
        super().__init__(model, block_size, **kwargs)

        self.num_nodes = data.num_nodes
        self.bipartite = bipartite
        if self.bipartite:
            self.src_ids = torch.unique(data.src)
            self.dst_ids = torch.unique(data.dst)
        self.device = device
        self.data = data
        self.rel_time_sd = rel_time_sd
        self.edge_label = edge_label
        self.neg_edge_sampling = neg_edge_sampling
        self.fixed_t_msg = fixed_t_msg
        self.optimize_t = optimize_t
        self.t_exp = t_exp

    def attack(
        self,
        pos_data: TemporalData,
        edge_index: Tensor,
        budget: int,
        idx_attack: Optional[Tensor] = None,
        **kwargs,
    ) -> Tuple[Tensor, Tensor]:
        """Attack the predictions for the provided model and graph.

        A subset of predictions may be specified with :attr:`idx_attack`. The
        attack is allowed to flip (i.e. add or delete) :attr:`budget` edges and
        will return the strongest perturbation it can find. It returns both the
        resulting perturbed :attr:`edge_index` as well as the perturbations.

        Args:
            x (torch.Tensor): The node feature matrix.
            edge_index (torch.Tensor): The edge indices.
            budget (int): The number of allowed perturbations (i.e.
                number of edges that are flipped at most).
            idx_attack (torch.Tensor, optional): Filter for predictions/labels.
                Shape and type must match that it can index :attr:`labels`
                and the model's predictions.
            **kwargs (optional): Additional arguments passed to the GNN module.

        :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`)
        """
        self.model.eval()
        logging.info(f"initial sampling with neg edges: {self.neg_edge_sampling}")
        logging.info(f"Fixed t and m: {self.fixed_t_msg}")
        # self.device = x.device
        assert kwargs.get("edge_weight") is None
        if edge_index.size(0) > 0:
            edge_weight = torch.ones(
                edge_index.size(1), device=self.device
            )  # ToDo: should we set this to zero ?0>necessity to initialize all possible edges
        else:
            edge_index = torch.randint(1, (2, 0))
            edge_weight = torch.empty((0,), dtype=torch.float32)
        self.edge_index = edge_index.cpu().clone()  # 2,n
        self.edge_weight = edge_weight.cpu().clone()

        # For collecting attack statistics
        self.attack_statistics = defaultdict(list)

        # Prepare attack and define `self.iterable` to iterate over
        step_sequence = self._prepare(budget)
        neg_edges = kwargs.pop("neg_edges")
        hist_ratio = kwargs.pop("hist_ratio")

        if self.neg_edge_sampling and neg_edges is not None:
            if hist_ratio is not None:
                logging.info(f"hist_ratio is not None: {hist_ratio}")
                n_hist_edges = int(hist_ratio * self.block_size)
                n_hist_edges = min(n_hist_edges, self.block_edge_index.size(1))
                idx = torch.randperm(neg_edges.size(1))[:n_hist_edges]
                block_edges = neg_edges[:, idx]
                n_hist_edges = min(n_hist_edges, block_edges.size(1))
                self.block_edge_index[:, :n_hist_edges] = block_edges

                # shuffle
                idx = torch.randperm(self.block_edge_index.shape[1])
                self.block_edge_index = self.block_edge_index[:, idx]
            else:
                logging.info(f"hist_ratio is  None: {hist_ratio}")
                idx = torch.randperm(neg_edges.size(1))[: self.block_size]
                block_edges = neg_edges[:, idx]
                self.block_edge_index = block_edges

            self.block_edge_weight = torch.full(
                (self.block_edge_index.shape[1],),
                self.coeffs["eps"],
                device=self.device,
            )

        if self.fixed_t_msg:
            cur_idx = self.model.get_current_edge_index()
            self.t = torch.tensor(
                self.data.t[cur_idx].cpu()
                + np.random.normal(
                    scale=self.rel_time_sd, size=(self.block_edge_index.size(1))
                ).round(),
                dtype=self.data.t.dtype,
                device=self.device,
            )

            msg_idx = torch.randperm(cur_idx)[: self.block_edge_index.size(1)]
            self.msg = self.data.msg[msg_idx]

            # prepare
            self.adv_t = self.t.new_empty((0,)).to(self.device)
            self.adv_msg = self.msg.new_empty((0, self.msg.shape[1])).to(self.device)
        else:
            self.t = None
            self.msg = None
        if self.optimize_t:

            t_batch_min, t_batch_max = (
                pos_data.t.min().cpu().item(),
                pos_data.t.max().cpu().item(),
            )
            self.t = torch.randint(
                t_batch_min,
                t_batch_max + 1,
                size=(self.block_edge_index.size(1),),
                dtype=torch.float32,  # self.data.t.dtype,
                device=self.device,
            )

        # if there are no intial edges: then initialize them with the block edges and zero weights (obtained from previous steps)
        # effect: in the first forward step A XOR P their weights (eps) will be added once more - still remains zero
        if edge_index.size(0) == 0:
            self.edge_index = self.block_edge_index.clone()  # clone ?
            self.edge_weight = self.block_edge_weight.clone().to(torch.float32)
        self.labels = (torch.ones_like(self.edge_weight) * self.edge_label).to(
            torch.int64
        )

        # Loop over the epochs (Algorithm 1, line 5)
        for step in tqdm(step_sequence, disable=not self.log, desc="Attack"):
            loss, gradient = self._forward_and_gradient(
                self.data, self.labels, idx_attack, pos_data=pos_data, **kwargs
            )

            scalars = self._update(
                step, gradient, self.data, self.labels, budget, idx_attack, **kwargs
            )

            scalars["loss"] = loss.item()
            self._append_statistics(scalars)

        perturbed_edge_index, flipped_edges, adv_t, adv_msg = self._close(
            self.data, self.labels, budget, idx_attack, **kwargs
        )

        assert flipped_edges.size(1) <= budget, (
            f"# perturbed edges {flipped_edges.size(1)} " f"exceeds budget {budget}"
        )

        return perturbed_edge_index, flipped_edges, adv_t, adv_msg

    def _forward_and_gradient(
        self,
        x: TemporalData,
        labels: Tensor,
        idx_attack: Optional[Tensor] = None,
        msg=None,
        t=None,
        **kwargs,
    ) -> Tuple[Tensor, Tensor]:
        """Forward and update edge weights."""
        msg_grad = False
        self.block_edge_weight.requires_grad = True
        if self.optimize_t:
            self.t.requires_grad = True
        if msg_grad:
            self.msg.requires_grad = True
        # Retrieve sparse perturbed adjacency matrix `A \oplus p_{t-1}`
        # (Algorithm 1, line 6 / Algorithm 2, line 7)
        edge_index, edge_weight, edge_labels, t, msg = self._get_modified_adj(
            self.edge_index,
            self.edge_weight,
            labels,
            self.block_edge_index,
            self.block_edge_weight,
            self.t,
            self.msg,
        )

        # update neighbors and memory state
        data = x.clone()

        src = edge_index[0]
        dst = edge_index[1]

        cur_idx = self.model.get_current_edge_index()
        # add the adv edges to data
        if t is None:
            logging.info("sampling t in attack")
            t = torch.tensor(
                data.t[cur_idx].cpu()
                + np.random.normal(scale=self.rel_time_sd, size=(src.size(0))).round(),
                dtype=data.t.dtype,
                device=self.device,
            )

        if msg is None:
            logging.info("sampling msg in attack")
            msg_idx = torch.randperm(cur_idx)[: src.size(0)]
            msg = data.msg[msg_idx]
        data.src = torch.cat([data.src[:cur_idx], src, data.src[cur_idx:]])
        data.dst = torch.cat([data.dst[:cur_idx], dst, data.dst[cur_idx:]])
        data.t = torch.cat([data.t[:cur_idx], t, data.t[cur_idx:]])
        data.msg = torch.cat([data.msg[:cur_idx], msg, data.msg[cur_idx:]])
        data.edge_weights = torch.cat(
            [data.edge_weights[:cur_idx], edge_weight, data.edge_weights[cur_idx:]]
        )

        neighbor_loader = deepcopy(self.model.neighbor_loader)
        memory_mod = deepcopy(self.model.memory)

        self.model.update_memory(src, dst, t, msg, edge_weight)
        self.model.insert_neighbor(src, dst)

        # clone x and make modifications there instead.

        pos_data = kwargs.pop("pos_data")
        if False:
            # computing loss on pos. edges
            pos_src, pos_dst = pos_data.src, pos_data.dst
            edge_index = torch.stack([pos_src, pos_dst], dim=0).to(edge_index.dtype)
            edge_labels = torch.ones_like(pos_src)
        # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7)
        prediction = self._forward(data, edge_index, edge_weight=None, **kwargs)
        if self.model.__class__.__name__ == "WeightedTNCNLinkPred":
            # if tncnw
            prediction = prediction.sigmoid()

        # make it two class
        prediction = PRBCD.get_two_class_probs(prediction)
        # Calculate loss combining all each node

        # (Algorithm 1, line 7 / Algorithm 2, line 8)
        loss = self.loss(prediction, edge_labels, idx_attack)
        # Retrieve gradient towards the current block
        # (Algorithm 1, line 7 / Algorithm 2, line 8)

        if msg_grad:
            logging.info("computing gradient on messages as well")
            w_gradient, m_gradient = torch.autograd.grad(
                loss,
                [self.block_edge_weight, self.msg],
            )

            self.msg = self.msg + m_gradient

            gradient = w_gradient
        elif self.optimize_t:
            w_gradient, t_gradient = torch.autograd.grad(
                loss,
                [self.block_edge_weight, self.t],
            )

            gradient = w_gradient * (t_gradient**self.t_exp)

        else:
            gradient = torch.autograd.grad(loss, self.block_edge_weight)[0]

        # restore model's original neighbor and memory module stores

        self.model.neighbor_loader = neighbor_loader
        self.model.memory = memory_mod

        return loss, gradient

    def _get_modified_adj(
        self,
        edge_index: Tensor,
        edge_weight: Tensor,
        edge_labels: Tensor,
        block_edge_index: Tensor,
        block_edge_weight: Tensor,
        block_t=None,
        block_msg=None,
    ) -> Tuple[Tensor, Tensor]:
        """Merges adjacency matrix with current block (incl. weights)."""
        modified_edge_t = None
        modified_edge_msg = None
        if self.is_undirected:
            block_edge_index, block_edge_weight = to_undirected(
                block_edge_index,
                block_edge_weight,
                num_nodes=self.num_nodes,
                reduce="mean",
            )

        modified_edge_index = torch.cat(
            (edge_index.to(self.device), block_edge_index), dim=-1
        )
        modified_edge_weight = torch.cat(
            (edge_weight.to(self.device), block_edge_weight)
        )

        modified_edge_labels = torch.cat(
            (
                edge_labels.to(self.device),
                torch.zeros_like(block_edge_weight).to(edge_labels.dtype),
            )
        )

        _, modified_edge_labels = coalesce(
            modified_edge_index,
            modified_edge_labels,
            num_nodes=self.num_nodes,
            reduce="max",
        )

        if self.fixed_t_msg and block_t is not None and block_msg is not None:
            # TODO: first concatenate and get modified_block_t and modified_block_msg, then coalesce
            _, modified_edge_t = coalesce(
                modified_edge_index, block_t, num_nodes=self.num_nodes, reduce="min"
            )
            _, modified_edge_msg = coalesce(
                modified_edge_index, block_msg, num_nodes=self.num_nodes, reduce="min"
            )

        modified_edge_index, modified_edge_weight = coalesce(
            modified_edge_index,
            modified_edge_weight,
            num_nodes=self.num_nodes,
            reduce="sum",
        )

        # Allow (soft) removal of edges
        is_edge_in_clean_adj = modified_edge_weight > 1
        modified_edge_weight[is_edge_in_clean_adj] = (
            2 - modified_edge_weight[is_edge_in_clean_adj]
        )

        return (
            modified_edge_index,
            modified_edge_weight,
            modified_edge_labels,
            modified_edge_t,
            modified_edge_msg,
        )

    def _sample_random_block(self, budget: int = 0):
        for _ in range(self.coeffs["max_trials_sampling"]):
            if self.bipartite:
                n_src = len(self.src_ids)
                n_dst = len(self.dst_ids)
                num_possible_edges = n_src * n_dst
            else:
                num_possible_edges = self._num_possible_edges(
                    self.num_nodes, self.is_undirected
                )
            self.current_block = torch.randint(
                num_possible_edges, (self.block_size,), device=self.device
            )
            self.current_block = torch.unique(self.current_block, sorted=True)
            if self.bipartite:
                self.block_edge_index = self._linear_to_bpt_idx(
                    n_src, n_dst, self.current_block
                )
            elif self.is_undirected:
                self.block_edge_index = self._linear_to_triu_idx(
                    self.num_nodes, self.current_block
                )
            else:
                self.block_edge_index = self._linear_to_full_idx(
                    self.num_nodes, self.current_block
                )

                # remove edges representing self-loops, never occurs for bpt
                self._filter_self_loops_in_block(with_weight=False)

            # initializing p with zeros
            self.block_edge_weight = torch.full(
                self.current_block.shape, self.coeffs["eps"], device=self.device
            )

            if self.current_block.size(0) >= budget:
                return
        raise RuntimeError(
            "Sampling random block was not successful. " "Please decrease `budget`."
        )

    def _resample_random_block(self, budget: int):
        # Keep at most half of the block (i.e. resample low weights)
        sorted_idx = torch.argsort(self.block_edge_weight)
        keep_above = (self.block_edge_weight <= self.coeffs["eps"]).sum().long()
        if keep_above < sorted_idx.size(0) // 2:
            keep_above = sorted_idx.size(0) // 2
        sorted_idx = sorted_idx[keep_above:]

        self.current_block = self.current_block[sorted_idx]

        # Sample until enough edges were drawn
        for _ in range(self.coeffs["max_trials_sampling"]):
            n_edges_resample = self.block_size - self.current_block.size(0)

            if self.bipartite:
                n_src = len(self.src_ids)
                n_dst = len(self.dst_ids)
                num_possible_edges = n_src * n_dst
            else:
                num_possible_edges = self._num_possible_edges(
                    self.num_nodes, self.is_undirected
                )
            lin_index = torch.randint(
                num_possible_edges, (n_edges_resample,), device=self.device
            )

            current_block = torch.cat((self.current_block, lin_index))
            self.current_block, unique_idx = torch.unique(
                current_block, sorted=True, return_inverse=True
            )

            if self.bipartite:
                self.block_edge_index = self._linear_to_bpt_idx(
                    n_src, n_dst, self.current_block
                )
            elif self.is_undirected:
                self.block_edge_index = self._linear_to_triu_idx(
                    self.num_nodes, self.current_block
                )
            else:
                self.block_edge_index = self._linear_to_full_idx(
                    self.num_nodes, self.current_block
                )

            # Merge existing weights with new edge weights
            block_edge_weight_prev = self.block_edge_weight[sorted_idx]
            self.block_edge_weight = torch.full(
                self.current_block.shape, self.coeffs["eps"], device=self.device
            )
            self.block_edge_weight[unique_idx[: sorted_idx.size(0)]] = (
                block_edge_weight_prev
            )

            if not self.is_undirected:
                self._filter_self_loops_in_block(with_weight=True)

            if self.current_block.size(0) > budget:
                return

    def _linear_to_bpt_idx(self, n_src, n_dst, lin_idx):
        """
        example:
        src: 0-4 (5)
        dst: 5-10 (6)
        lin_idx ∈ [0,29]
        row_col_idx ∈ [0,10]
        0: 0,5
        1: 0,6
        4: 0,9
        5: 0,10
        6: 1,5
        7: 1,6
        10: 1,9
        25: 4,6
        28: 4,9
        29: 4,10
        """
        n = n_src + n_dst
        row_idx = torch.div(lin_idx, n_dst, rounding_mode="floor")
        col_idx = n_src + (lin_idx % n_dst)

        return torch.stack((row_idx, col_idx))

    def _sample_final_edges(
        self,
        x: Tensor,
        labels: Tensor,
        budget: int,
        idx_attack: Optional[Tensor] = None,
        **kwargs,
    ) -> Tuple[Tensor, Tensor]:
        best_metric = float("-Inf")
        block_edge_weight = self.block_edge_weight
        block_edge_weight[block_edge_weight <= self.coeffs["eps"]] = 0

        for i in range(self.coeffs["max_final_samples"]):
            if i == 0:
                # In first iteration employ top k heuristic instead of sampling
                sampled_edges = torch.zeros_like(block_edge_weight)
                sampled_edges[torch.topk(block_edge_weight, budget).indices] = 1
            else:
                sampled_edges = torch.bernoulli(block_edge_weight).float()

            if sampled_edges.sum() > budget:
                # Allowed budget is exceeded
                continue

            edge_index, edge_weight, edge_labels = self._get_modified_adj(
                self.edge_index,
                self.edge_weight,
                labels,
                self.block_edge_index,
                sampled_edges,
            )

            prediction = self._forward(
                x, edge_index, edge_weight=None, **kwargs
            )  # .sigmoid()

            # make it two class
            prediction = PRBCD.get_two_class_probs(prediction)
            metric = self.metric(prediction, edge_labels, idx_attack)

            # Save best sample
            if metric > best_metric:
                best_metric = metric
                self.block_edge_weight = sampled_edges.clone().cpu()

        # Recover best sample
        self.block_edge_weight = self.block_edge_weight.to(self.device)
        flipped_edges = self.block_edge_index[:, self.block_edge_weight > 0]

        edge_index, edge_weight, _ = self._get_modified_adj(
            self.edge_index,
            self.edge_weight,
            labels,
            self.block_edge_index,
            self.block_edge_weight,
        )
        edge_mask = edge_weight == 1
        edge_index = edge_index[:, edge_mask]

        return edge_index, flipped_edges

    @torch.no_grad()
    def _update(
        self,
        epoch: int,
        gradient: Tensor,
        x: Tensor,
        labels: Tensor,
        budget: int,
        idx_attack: Optional[Tensor] = None,
        **kwargs,
    ) -> Dict[str, float]:
        """Update edge weights given gradient."""
        # Gradient update step (Algorithm 1, line 7)
        self.block_edge_weight = self._update_edge_weights(
            budget, self.block_edge_weight, epoch, gradient
        )

        # For monitoring
        pmass_update = torch.clamp(self.block_edge_weight, 0, 1)
        # Projection to stay within relaxed `L_0` budget
        # (Algorithm 1, line 8)
        self.block_edge_weight = self._project(
            budget, self.block_edge_weight, self.coeffs["eps"]
        )

        # For monitoring
        scalars = dict(
            prob_mass_after_update=pmass_update.sum().item(),
            prob_mass_after_update_max=pmass_update.max().item(),
            prob_mass_after_projection=self.block_edge_weight.sum().item(),
            prob_mass_after_projection_nonzero_weights=(
                self.block_edge_weight > self.coeffs["eps"]
            )
            .sum()
            .item(),
            prob_mass_after_projection_max=self.block_edge_weight.max().item(),
        )
        if not self.coeffs["with_early_stopping"]:
            return scalars

        # Calculate metric after the current epoch (overhead
        # for monitoring and early stopping)
        topk_block_edge_weight = torch.zeros_like(self.block_edge_weight)
        topk_block_edge_weight[torch.topk(self.block_edge_weight, budget).indices] = 1
        edge_index, edge_weight, edge_labels = self._get_modified_adj(
            self.edge_index,
            self.edge_weight,
            labels,
            self.block_edge_index,
            topk_block_edge_weight,
        )

        prediction = self._forward(
            x, edge_index, edge_weight=None, **kwargs
        )  # .sigmoid()
        # make it two class
        prediction = PRBCD.get_two_class_probs(prediction)
        metric = self.metric(prediction, edge_labels, idx_attack)

        # Save best epoch for early stopping
        # (not explicitly covered by pseudo code)
        if metric > self.best_metric:
            self.best_metric = metric
            self.best_block = self.current_block.cpu().clone()
            self.best_edge_index = self.block_edge_index.cpu().clone()
            self.best_pert_edge_weight = self.block_edge_weight.cpu().clone()

        # Resampling of search space (Algorithm 1, line 9-14)
        if epoch < self.epochs_resampling - 1:
            self._resample_random_block(budget)
        elif epoch == self.epochs_resampling - 1:
            # Retrieve best epoch if early stopping is active
            # (not explicitly covered by pseudo code)
            self.current_block = self.best_block.to(self.device)
            self.block_edge_index = self.best_edge_index.to(self.device)
            block_edge_weight = self.best_pert_edge_weight.clone()
            self.block_edge_weight = block_edge_weight.to(self.device)

        scalars["metric"] = metric.item()
        return scalars

    @staticmethod
    def get_two_class_probs(prediction):
        zero_class_prob = 1.0 - prediction
        prediction = torch.cat([zero_class_prob, prediction], dim=1)
        return prediction


class GRBCD(PRBCD):
    coeffs = {"max_trials_sampling": 20, "eps": 1e-7}

    @torch.no_grad()
    def _prepare(self, budget: int) -> List[int]:
        """Prepare attack."""
        self.flipped_edges = self.edge_index.new_empty(2, 0).to(self.device)

        # Determine the number of edges to be flipped in each attach step/epoch
        step_size = budget // self.epochs
        if step_size > 0:
            steps = self.epochs * [step_size]
            for i in range(budget % self.epochs):
                steps[i] += 1
        else:
            steps = [1] * budget

        # Sample initial search space (Algorithm 2, line 3-4)
        self._sample_random_block(step_size)

        return steps

    @torch.no_grad()
    def _update(
        self, step_size: int, gradient: Tensor, x, labels, *args, **kwargs
    ) -> Dict[str, Any]:
        """Update edge weights given gradient."""
        _, topk_edge_index = torch.topk(gradient, step_size)

        flip_edge_index = self.block_edge_index[:, topk_edge_index]
        flip_edge_weight = torch.ones_like(flip_edge_index[0], dtype=torch.float32)

        adv_t = self.t[topk_edge_index].to(torch.int64)
        adv_msg = self.msg[topk_edge_index, :]
        self.adv_t = torch.cat((self.adv_t, adv_t))
        self.adv_msg = torch.cat((self.adv_msg, adv_msg), axis=0)

        self.flipped_edges = torch.cat((self.flipped_edges, flip_edge_index), axis=-1)

        if self.is_undirected:
            flip_edge_index, flip_edge_weight = to_undirected(
                flip_edge_index,
                flip_edge_weight,
                num_nodes=self.num_nodes,
                reduce="mean",
            )
        edge_index = torch.cat(
            (self.edge_index.to(self.device), flip_edge_index.to(self.device)), dim=-1
        )

        labels = torch.cat(
            (
                labels.to(self.device),
                torch.ones_like(flip_edge_index[0], dtype=labels.dtype),
            )
        )
        _, labels = coalesce(edge_index, labels, num_nodes=self.num_nodes, reduce="max")

        edge_weight = torch.cat(
            (self.edge_weight.to(self.device), flip_edge_weight.to(self.device))
        )
        edge_index, edge_weight = coalesce(
            edge_index, edge_weight, num_nodes=self.num_nodes, reduce="sum"
        )

        is_one_mask = torch.isclose(edge_weight, torch.tensor(1.0))
        self.edge_index = edge_index[:, is_one_mask]
        self.edge_weight = edge_weight[is_one_mask]
        self.labels = labels[is_one_mask]
        # self.edge_weight = torch.ones_like(self.edge_weight)
        assert self.edge_index.size(1) == self.edge_weight.size(0)

        # Sample initial search space (Algorithm 2, line 3-4)
        self._sample_random_block(step_size)

        # Return debug information
        scalars = {"number_positive_entries_in_gradient": (gradient > 0).sum().item()}
        return scalars

    def _close(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
        """Clean up and prepare return argument."""
        return self.edge_index, self.flipped_edges, self.adv_t, self.adv_msg


class PRBCD_NAT(PRBCDAttack):

    def __init__(
        self,
        model,
        block_size,
        src_node_ids,
        dst_node_ids,
        cur_e_id,
        cur_t,
        bipartite,
        device,
        rel_time_sd,
        edge_label,
        pos_batch,
        neg_edge_sampling,
        fixed_t_msg,
        **kwargs,
    ):
        super().__init__(model, block_size, **kwargs)

        self.num_nodes = model.total_nodes
        self.bipartite = bipartite
        if self.bipartite:
            self.src_ids = src_node_ids
            self.dst_ids = dst_node_ids
        self.device = device
        self.rel_time_sd = rel_time_sd
        self.edge_label = edge_label
        self.cur_t = cur_t
        self.cur_e_id = cur_e_id
        self.pos_batch = pos_batch
        self.neg_edge_sampling = neg_edge_sampling
        self.fixed_t_msg = fixed_t_msg

    def attack(
        self,
        edge_index: Tensor,
        budget: int,
        idx_attack: Optional[Tensor] = None,
        **kwargs,
    ) -> Tuple[Tensor, Tensor]:
        """Attack the predictions for the provided model and graph.

        A subset of predictions may be specified with :attr:`idx_attack`. The
        attack is allowed to flip (i.e. add or delete) :attr:`budget` edges and
        will return the strongest perturbation it can find. It returns both the
        resulting perturbed :attr:`edge_index` as well as the perturbations.

        Args:
            x (torch.Tensor): The node feature matrix.
            edge_index (torch.Tensor): The edge indices.
            budget (int): The number of allowed perturbations (i.e.
                number of edges that are flipped at most).
            idx_attack (torch.Tensor, optional): Filter for predictions/labels.
                Shape and type must match that it can index :attr:`labels`
                and the model's predictions.
            **kwargs (optional): Additional arguments passed to the GNN module.

        :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`)
        """
        self.model.eval()
        logging.info(f"initial sampling with neg edges: {self.neg_edge_sampling}")
        logging.info(f"Fixed t and m: {self.fixed_t_msg}")
        # self.device = x.device
        assert kwargs.get("edge_weight") is None
        if edge_index.size(0) > 0:
            edge_weight = torch.ones(edge_index.size(1), device=self.device)
        else:
            edge_index = torch.randint(1, (2, 0))
            edge_weight = torch.empty((0,), dtype=torch.float32)
        self.edge_index = edge_index.cpu().clone()  # 2,n
        self.edge_weight = edge_weight.cpu().clone()

        # For collecting attack statistics
        self.attack_statistics = defaultdict(list)

        # Prepare attack and define `self.iterable` to iterate over
        step_sequence = self._prepare(budget)

        # if there are no intial edges: then initialize them with the block edges and zero weights (obtained from previous steps)
        # effect: in the first forward step A XOR P their weights (eps) will be added once more - still remains zero
        neg_edges = kwargs.pop("neg_edges")

        if self.neg_edge_sampling and neg_edges is not None:
            idx = torch.randperm(neg_edges.size(1))[: self.block_size]
            block_edges = neg_edges[:, idx]
            self.block_edge_index = block_edges
            self.block_edge_weight = torch.full(
                (block_edges.shape[1],), self.coeffs["eps"], device=self.device
            )

        if self.fixed_t_msg:
            cur_idx = self.cur_e_id
            self.t = torch.tensor(
                self.cur_t
                + np.random.normal(
                    scale=self.rel_time_sd, size=(self.block_edge_index.size(1))
                ).round(),
                dtype=torch.int32,
                device=self.device,
            )

            msg_idx = torch.randperm(cur_idx).to(self.device)[
                : self.block_edge_index.size(1)
            ]
            self.msg = self.model.edge_raw_embed(msg_idx)

            # prepare
            self.adv_t = self.t.new_empty((0,)).to(self.device)
            self.adv_msg = self.msg.new_empty((0, self.msg.shape[1])).to(self.device)
        else:
            self.t = None
            self.msg = None

        if edge_index.size(0) == 0:
            self.edge_index = self.block_edge_index.clone()
            self.edge_weight = self.block_edge_weight.clone().to(torch.float32)
        self.labels = (torch.ones_like(self.edge_weight) * self.edge_label).to(
            torch.int64
        )

        # Loop over the epochs (Algorithm 1, line 5)
        for step in tqdm(step_sequence, disable=not self.log, desc="Attack"):

            loss, gradient = self._forward_and_gradient(
                self.labels, idx_attack, **kwargs
            )

            scalars = self._update(
                step, gradient, self.labels, budget, idx_attack, **kwargs
            )

            scalars["loss"] = loss.item()
            self._append_statistics(scalars)

        perturbed_edge_index, flipped_edges, adv_t, adv_msg = self._close(
            None, self.labels, budget, idx_attack, **kwargs
        )

        assert flipped_edges.size(1) <= budget, (
            f"# perturbed edges {flipped_edges.size(1)} " f"exceeds budget {budget}"
        )

        return perturbed_edge_index, flipped_edges, adv_t, adv_msg

    def _forward(self, edge_index: Tensor, **kwargs) -> Tensor:
        """Forward model."""
        src = edge_index[0]
        dst = edge_index[1]
        pred = self.model.contrast_modified(src, dst, **kwargs)
        return pred

    def _forward_and_gradient(
        self,
        labels: Tensor,
        idx_attack: Optional[Tensor] = None,
        msg=None,
        t=None,
        **kwargs,
    ) -> Tuple[Tensor, Tensor]:
        """Forward and update edge weights."""
        self.block_edge_weight.requires_grad = True

        # Retrieve sparse perturbed adjacency matrix `A \oplus p_{t-1}`
        # (Algorithm 1, line 6 / Algorithm 2, line 7)
        edge_index, edge_weight, edge_labels, t, msg = self._get_modified_adj(
            self.edge_index,
            self.edge_weight,
            labels,
            self.block_edge_index,
            self.block_edge_weight,
            self.t,
            self.msg,
        )

        # update neighbors and memory state

        # should we only add edges with weight 1?
        src = edge_index[0]
        dst = edge_index[1]
        # add the adv edges to data

        if t is None:
            logging.info("sampling t in attack")
            t = np.array(
                self.cur_t
                + np.random.normal(scale=self.rel_time_sd, size=(len(src))).round(),
                dtype=np.int32,
            )

        if msg is None:
            logging.info("sampling msg in attack")
            msg_idx = torch.randperm(self.cur_e_id).to(self.device)[: src.size(0)]
            msg = self.model.edge_raw_embed(msg_idx)

        # retain original neighbor and memory module's msg store
        ngh_store = deepcopy(self.model.neighborhood_store)
        self_rep = deepcopy(self.model.self_rep)
        prev_raw = deepcopy(self.model.prev_raw)
        e_feat = deepcopy(self.model.e_feat_th)
        current_weights = self.model.get_weights()

        # update edge embedding and e_id
        current_e_feats = self.model.e_feat_th.data
        updated_e_feats = torch.cat(
            [
                current_e_feats[: self.cur_e_id],
                torch.tensor(msg, device=self.device, dtype=current_e_feats.dtype),
                current_e_feats[self.cur_e_id :],
            ]
        )
        self.model.e_feat_th.data = updated_e_feats
        self.model.edge_raw_embed = torch.nn.Embedding.from_pretrained(
            self.model.e_feat_th, padding_idx=0, freeze=True
        )

        updated_weights = torch.cat(
            [
                current_weights[: self.cur_e_id],
                edge_weight.unsqueeze(1),
                current_weights[self.cur_e_id :],
            ]
        )
        self.model.set_weights(updated_weights)

        adv_e_id = np.arange(len(src)) + self.cur_e_id
        # self.cur_e_id += len(src)

        # update memory
        self.model.contrast_modified(src, dst, t, adv_e_id, pos_edge=True, test=True)

        # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7)
        # prediction = self.model.contrast_modified(src, dst, # TODO: should we pass edge weight here, or keep it 1 ?
        #                         t, adv_e_id, pos_edge=False, test=True, edge_weight=edge_weight)
        prediction = self._forward(
            edge_index,
            cut_time_l=t,
            pos_edge=False,
            test=True,
        )

        # add neg edge as pos
        # prediction = self.model.contrast_modified(src, dst,
        #                         t, adv_e_id, pos_edge=True, test=True)
        # neg-Pos

        # pos_src, pos_dst, pos_t, pos_eid = self.pos_batch
        # pos_edge_index = torch.tensor(np.stack((pos_src, pos_dst),axis=0))
        # pos_e_weights = np.ones((len(pos_src),1), dtype = np.int32)
        # prediction = self._forward(pos_edge_index,  cut_time_l=pos_t,
        #                            e_idx_l=pos_eid+ len(src), pos_edge=True, test=True,)
        # edge_labels = torch.tensor(np.ones((len(pos_src)), dtype = np.int64)).to(self.device)

        # make it two class
        prediction = PRBCD.get_two_class_probs(prediction.unsqueeze(1))
        # Calculate loss combining all each node
        # (Algorithm 1, line 7 / Algorithm 2, line 8)
        # loss = self.loss(torch.log(prediction), edge_labels) # nllloss
        loss = self.loss(prediction, edge_labels, idx_attack)
        # Retrieve gradient towards the current block
        # (Algorithm 1, line 7 / Algorithm 2, line 8)
        gradient = torch.autograd.grad(
            loss,
            self.block_edge_weight,
        )[0]

        # restore model's original neighbor and memory module stores
        self.model.neighborhood_store = ngh_store
        self.model.self_rep = self_rep
        self.model.prev_raw = prev_raw
        self.model.set_weights(current_weights)
        self.model.e_feat_th = e_feat
        self.model.edge_raw_embed = torch.nn.Embedding.from_pretrained(
            self.model.e_feat_th, padding_idx=0, freeze=True
        )

        return loss, gradient

    def _get_modified_adj(
        self,
        edge_index: Tensor,
        edge_weight: Tensor,
        edge_labels: Tensor,
        block_edge_index: Tensor,
        block_edge_weight: Tensor,
        block_t=None,
        block_msg=None,
    ) -> Tuple[Tensor, Tensor]:
        """Merges adjacency matrix with current block (incl. weights)."""
        modified_edge_t = None
        modified_edge_msg = None
        if self.is_undirected:
            block_edge_index, block_edge_weight = to_undirected(
                block_edge_index,
                block_edge_weight,
                num_nodes=self.num_nodes,
                reduce="mean",
            )

        modified_edge_index = torch.cat(
            (edge_index.to(self.device), block_edge_index), dim=-1
        )
        modified_edge_weight = torch.cat(
            (edge_weight.to(self.device), block_edge_weight)
        )

        modified_edge_labels = torch.cat(
            (
                edge_labels.to(self.device),
                torch.zeros_like(block_edge_weight).to(edge_labels.dtype),
            )
        )

        _, modified_edge_labels = coalesce(
            modified_edge_index,
            modified_edge_labels,
            num_nodes=self.num_nodes,
            reduce="max",
        )

        if self.fixed_t_msg and block_t is not None and block_msg is not None:
            _, modified_edge_t = coalesce(
                modified_edge_index, block_t, num_nodes=self.num_nodes, reduce="min"
            )
            _, modified_edge_msg = coalesce(
                modified_edge_index, block_msg, num_nodes=self.num_nodes, reduce="min"
            )

        modified_edge_index, modified_edge_weight = coalesce(
            modified_edge_index,
            modified_edge_weight,
            num_nodes=self.num_nodes,
            reduce="sum",
        )

        # Allow (soft) removal of edges
        is_edge_in_clean_adj = modified_edge_weight > 1
        modified_edge_weight[is_edge_in_clean_adj] = (
            2 - modified_edge_weight[is_edge_in_clean_adj]
        )

        return (
            modified_edge_index,
            modified_edge_weight,
            modified_edge_labels,
            modified_edge_t,
            modified_edge_msg,
        )

    def _sample_random_block(self, budget: int = 0):
        for _ in range(self.coeffs["max_trials_sampling"]):
            if self.bipartite:
                n_src = len(self.src_ids)
                n_dst = len(self.dst_ids)
                num_possible_edges = n_src * n_dst
            else:
                num_possible_edges = self._num_possible_edges(
                    self.num_nodes, self.is_undirected
                )
            self.current_block = torch.randint(
                num_possible_edges, (self.block_size,), device=self.device
            )
            self.current_block = torch.unique(self.current_block, sorted=True)

            if self.bipartite:
                self.block_edge_index = self._linear_to_bpt_idx(
                    n_src, n_dst, self.current_block
                )
            elif self.is_undirected:
                self.block_edge_index = self._linear_to_triu_idx(
                    self.num_nodes, self.current_block
                )
            else:
                self.block_edge_index = self._linear_to_full_idx(
                    self.num_nodes, self.current_block
                )

                # remove edges representing self-loops, never occurs for bpt
                self._filter_self_loops_in_block(with_weight=False)
            # add 1 for NAT
            # self.block_edge_index += 1

            # initializing p with zeros
            self.block_edge_weight = torch.full(
                self.current_block.shape, self.coeffs["eps"], device=self.device
            )

            if self.current_block.size(0) >= budget:
                return
        raise RuntimeError(
            "Sampling random block was not successful. " "Please decrease `budget`."
        )

    def _resample_random_block(self, budget: int):
        # Keep at most half of the block (i.e. resample low weights)
        sorted_idx = torch.argsort(self.block_edge_weight)
        keep_above = (self.block_edge_weight <= self.coeffs["eps"]).sum().long()
        if keep_above < sorted_idx.size(0) // 2:
            keep_above = sorted_idx.size(0) // 2
        sorted_idx = sorted_idx[keep_above:]

        self.current_block = self.current_block[sorted_idx]

        # Sample until enough edges were drawn
        for _ in range(self.coeffs["max_trials_sampling"]):
            n_edges_resample = self.block_size - self.current_block.size(0)

            if self.bipartite:
                n_src = len(self.src_ids)
                n_dst = len(self.dst_ids)
                num_possible_edges = n_src * n_dst
            else:
                num_possible_edges = self._num_possible_edges(
                    self.num_nodes, self.is_undirected
                )
            lin_index = torch.randint(
                num_possible_edges, (n_edges_resample,), device=self.device
            )

            current_block = torch.cat((self.current_block, lin_index))
            self.current_block, unique_idx = torch.unique(
                current_block, sorted=True, return_inverse=True
            )

            if self.bipartite:
                self.block_edge_index = self._linear_to_bpt_idx(
                    n_src, n_dst, self.current_block
                )
            elif self.is_undirected:
                self.block_edge_index = self._linear_to_triu_idx(
                    self.num_nodes, self.current_block
                )
            else:
                self.block_edge_index = self._linear_to_full_idx(
                    self.num_nodes, self.current_block
                )
            # add 1 for nat
            # self.block_edge_index += 1
            # Merge existing weights with new edge weights
            block_edge_weight_prev = self.block_edge_weight[sorted_idx]
            self.block_edge_weight = torch.full(
                self.current_block.shape, self.coeffs["eps"], device=self.device
            )
            self.block_edge_weight[unique_idx[: sorted_idx.size(0)]] = (
                block_edge_weight_prev
            )

            if not self.is_undirected:
                self._filter_self_loops_in_block(with_weight=True)

            if self.current_block.size(0) > budget:
                return

    def _linear_to_bpt_idx(self, n_src, n_dst, lin_idx):
        """
        example:
        src: 0-4 (5)
        dst: 5-10 (6)
        lin_idx ∈ [0,29]
        row_col_idx ∈ [0,10]
        0: 0,5
        1: 0,6
        4: 0,9
        5: 0,10
        6: 1,5
        7: 1,6
        10: 1,9
        25: 4,6
        28: 4,9
        29: 4,10
        """
        n = n_src + n_dst
        row_idx = torch.div(lin_idx, n_dst, rounding_mode="floor")
        col_idx = n_src + (lin_idx % n_dst)

        return torch.stack((row_idx, col_idx)) + 1

    def _sample_final_edges(
        self,
        x,
        labels: Tensor,
        budget: int,
        idx_attack: Optional[Tensor] = None,
        **kwargs,
    ) -> Tuple[Tensor, Tensor]:
        best_metric = float("-Inf")
        block_edge_weight = self.block_edge_weight
        block_edge_weight[block_edge_weight <= self.coeffs["eps"]] = 0

        for i in range(self.coeffs["max_final_samples"]):
            if i == 0:
                # In first iteration employ top k heuristic instead of sampling
                sampled_edges = torch.zeros_like(block_edge_weight)
                sampled_edges[torch.topk(block_edge_weight, budget).indices] = 1
            else:
                sampled_edges = torch.bernoulli(block_edge_weight).float()

            if sampled_edges.sum() > budget:
                # Allowed budget is exceeded
                continue

            edge_index, edge_weight, edge_labels = self._get_modified_adj(
                self.edge_index,
                self.edge_weight,
                labels,
                self.block_edge_index,
                sampled_edges,
            )

            t = np.array(
                self.cur_t
                + np.random.normal(
                    scale=self.rel_time_sd, size=(edge_index.shape[1])
                ).round(),
                dtype=np.int32,
            )
            prediction = self._forward(
                edge_index,
                edge_weight,
                cut_time_l=t,
                pos_edge=False,
                test=True,
            )
            # make it two class
            zero_class_prob = 1.0 - prediction
            prediction = torch.stack([zero_class_prob, prediction], dim=1)
            metric = self.metric(prediction, edge_labels, idx_attack)

            # Save best sample
            if metric > best_metric:
                best_metric = metric
                self.block_edge_weight = sampled_edges.clone().cpu()

        # Recover best sample
        self.block_edge_weight = self.block_edge_weight.to(self.device)
        flipped_edges = self.block_edge_index[:, self.block_edge_weight > 0]

        edge_index, edge_weight, _ = self._get_modified_adj(
            self.edge_index,
            self.edge_weight,
            labels,
            self.block_edge_index,
            self.block_edge_weight,
        )
        edge_mask = edge_weight == 1
        edge_index = edge_index[:, edge_mask]

        return edge_index, flipped_edges

    @torch.no_grad()
    def _update(
        self,
        epoch: int,
        gradient: Tensor,
        labels: Tensor,
        budget: int,
        idx_attack: Optional[Tensor] = None,
        **kwargs,
    ) -> Dict[str, float]:
        """Update edge weights given gradient."""
        # Gradient update step (Algorithm 1, line 7)
        self.block_edge_weight = self._update_edge_weights(
            budget, self.block_edge_weight, epoch, gradient
        )

        # For monitoring
        pmass_update = torch.clamp(self.block_edge_weight, 0, 1)
        # Projection to stay within relaxed `L_0` budget
        # (Algorithm 1, line 8)
        self.block_edge_weight = self._project(
            budget, self.block_edge_weight, self.coeffs["eps"]
        )

        # For monitoring
        scalars = dict(
            prob_mass_after_update=pmass_update.sum().item(),
            prob_mass_after_update_max=pmass_update.max().item(),
            prob_mass_after_projection=self.block_edge_weight.sum().item(),
            prob_mass_after_projection_nonzero_weights=(
                self.block_edge_weight > self.coeffs["eps"]
            )
            .sum()
            .item(),
            prob_mass_after_projection_max=self.block_edge_weight.max().item(),
        )
        if not self.coeffs["with_early_stopping"]:
            return scalars

        # Calculate metric after the current epoch (overhead
        # for monitoring and early stopping)
        topk_block_edge_weight = torch.zeros_like(self.block_edge_weight)
        topk_block_edge_weight[torch.topk(self.block_edge_weight, budget).indices] = 1
        edge_index, edge_weight, edge_labels = self._get_modified_adj(
            self.edge_index,
            self.edge_weight,
            labels,
            self.block_edge_index,
            topk_block_edge_weight,
        )

        t = np.array(
            self.cur_t
            + np.random.normal(
                scale=self.rel_time_sd, size=(edge_index.shape[1])
            ).round(),
            dtype=np.int32,
        )
        prediction = self._forward(
            edge_index,
            edge_weight,
            cut_time_l=t,
            pos_edge=False,
            test=True,
        )
        # make it two class
        prediction = PRBCD.get_two_class_probs(prediction.unsqueeze(1))
        metric = self.metric(prediction, edge_labels, idx_attack)

        # Save best epoch for early stopping
        # (not explicitly covered by pseudo code)
        if metric > self.best_metric:
            self.best_metric = metric
            self.best_block = self.current_block.cpu().clone()
            self.best_edge_index = self.block_edge_index.cpu().clone()
            self.best_pert_edge_weight = self.block_edge_weight.cpu().clone()

        # Resampling of search space (Algorithm 1, line 9-14)
        if epoch < self.epochs_resampling - 1:
            self._resample_random_block(budget)
        elif epoch == self.epochs_resampling - 1:
            # Retrieve best epoch if early stopping is active
            # (not explicitly covered by pseudo code)
            self.current_block = self.best_block.to(self.device)
            self.block_edge_index = self.best_edge_index.to(self.device)
            block_edge_weight = self.best_pert_edge_weight.clone()
            self.block_edge_weight = block_edge_weight.to(self.device)

        scalars["metric"] = metric.item()
        return scalars


class GRBCD_NAT(PRBCD_NAT):
    coeffs = {"max_trials_sampling": 20, "eps": 1e-7}

    @torch.no_grad()
    def _prepare(self, budget: int) -> List[int]:
        """Prepare attack for NAT's dictionary-based architecture."""
        self.flipped_edges = self.edge_index.new_empty(2, 0).to(self.device)

        # NAT-specific: Account for hash collision and replace_prob
        # Increase effective budget to compensate for storage failures
        effective_budget = min(
            budget * 3, self.block_size
        )  # 3x budget due to replace_prob=0.7

        # Determine the number of edges to be flipped in each attack step/epoch
        step_size = effective_budget // self.epochs
        if step_size > 0:
            steps = self.epochs * [step_size]
            for i in range(effective_budget % self.epochs):
                steps[i] += 1
        else:
            steps = [1] * effective_budget

        # NAT-specific: Sample search space targeting multiple hops
        self._sample_random_block_multihop(step_size)

        return steps

    @torch.no_grad()
    def _sample_random_block_multihop(self, step_size: int):
        """Sample random block considering NAT's multi-hop architecture."""
        if hasattr(self, "model") and hasattr(self.model, "n_hops"):
            n_hops = self.model.n_hops
        else:
            n_hops = 2  # Default for NAT

        # Sample edges that will affect multiple hops
        if self.bipartite:
            # For each hop, sample different source-destination pairs
            hop_samples = []
            for hop in range(min(n_hops + 1, 3)):  # Limit to prevent explosion
                hop_step_size = step_size // (hop + 1)  # Distribute across hops
                if hop_step_size > 0:
                    src_sample = self.src_ids[
                        torch.randperm(len(self.src_ids))[:hop_step_size]
                    ]
                    dst_sample = self.dst_ids[
                        torch.randperm(len(self.dst_ids))[:hop_step_size]
                    ]
                    hop_samples.append(torch.stack([src_sample, dst_sample]))

            if hop_samples:
                self.block_edge_index = torch.cat(hop_samples, dim=1).to(self.device)
            else:
                # Fallback to original sampling
                self._sample_random_block(step_size)
        else:
            # Non-bipartite case - use original sampling
            self._sample_random_block(step_size)

        self.block_edge_weight = torch.full(
            (self.block_edge_index.size(1),), self.coeffs["eps"], device=self.device
        )

    @torch.no_grad()
    def _update(
        self, step_size: int, gradient: Tensor, labels, *args, **kwargs
    ) -> Dict[str, Any]:
        """Update edge weights given gradient."""

        _, topk_edge_index = torch.topk(gradient, step_size)

        flip_edge_index = self.block_edge_index[:, topk_edge_index]
        flip_edge_weight = torch.ones_like(flip_edge_index[0], dtype=torch.float32)

        adv_t = self.t[topk_edge_index]
        adv_msg = self.msg[topk_edge_index, :]
        self.adv_t = torch.cat((self.adv_t, adv_t))
        self.adv_msg = torch.cat((self.adv_msg, adv_msg), axis=0)

        self.flipped_edges = torch.cat((self.flipped_edges, flip_edge_index), axis=-1)

        if self.is_undirected:
            flip_edge_index, flip_edge_weight = to_undirected(
                flip_edge_index,
                flip_edge_weight,
                num_nodes=self.num_nodes,
                reduce="mean",
            )
        edge_index = torch.cat(
            (self.edge_index.to(self.device), flip_edge_index.to(self.device)), dim=-1
        )

        labels = torch.cat(
            (
                labels.to(self.device),
                torch.ones_like(flip_edge_index[0], dtype=labels.dtype),
            )
        )
        _, labels = coalesce(edge_index, labels, num_nodes=self.num_nodes, reduce="max")

        edge_weight = torch.cat(
            (self.edge_weight.to(self.device), flip_edge_weight.to(self.device))
        )
        edge_index, edge_weight = coalesce(
            edge_index, edge_weight, num_nodes=self.num_nodes, reduce="sum"
        )

        is_one_mask = torch.isclose(edge_weight, torch.tensor(1.0))
        self.edge_index = edge_index[:, is_one_mask]
        self.edge_weight = edge_weight[is_one_mask]
        self.labels = labels[is_one_mask]
        # self.edge_weight = torch.ones_like(self.edge_weight)
        assert self.edge_index.size(1) == self.edge_weight.size(0)

        # Sample initial search space (Algorithm 2, line 3-4)
        self._sample_random_block(step_size)

        # Return debug information
        scalars = {"number_positive_entries_in_gradient": (gradient > 0).sum().item()}
        print(scalars)
        return scalars

    def _close(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
        """Clean up and prepare return argument."""
        return self.edge_index, self.flipped_edges, self.adv_t, self.adv_msg


def fgsm_old(
    model,
    data,
    neighbor_loader,
    assoc,
    adv_sources,
    adv_dests,
    adv_t,
    adv_msgs,
    epsilon,
    edge_class,
    device,
):
    "used with non-modular tgn model"
    adv_msgs.requires_grad = True
    model["memory"].train()
    model["gnn"].train()
    model["link_pred"].train()

    # add the adv edges to self.data
    cur_idx = neighbor_loader.cur_e_id
    data.src = torch.cat([data.src[:cur_idx], adv_sources, data.src[cur_idx:]])
    data.dst = torch.cat([data.dst[:cur_idx], adv_dests, data.dst[cur_idx:]])
    data.t = torch.cat([data.t[:cur_idx], adv_t, data.t[cur_idx:]])
    data.msg = torch.cat([data.msg[:cur_idx], adv_msgs, data.msg[cur_idx:]])

    # first update memory ?
    # Update memory and neighbor loader with ground-truth state.
    model["memory"].update_state(adv_sources, adv_dests, adv_t, adv_msgs)
    neighbor_loader.insert(adv_sources, adv_dests)

    n_id = torch.cat([adv_sources, adv_dests]).unique()
    n_id, edge_index, e_id = neighbor_loader(n_id)
    assoc[n_id] = torch.arange(n_id.size(0), device=device)

    # Get updated memory of all nodes involved in the computation.
    z, last_update = model["memory"](n_id)

    z = model["gnn"](
        z,
        last_update,
        edge_index,
        data.t[e_id].to(device),
        data.msg[e_id].to(device),
    )

    y_pred = model["link_pred"](z[assoc[adv_sources]], z[assoc[adv_dests]])
    # Calculate the loss
    import torch.nn.functional as F

    y_target = torch.ones_like(y_pred) * edge_class
    loss = F.binary_cross_entropy_with_logits(y_pred, y_target)

    # # Update memory and neighbor loader with ground-truth state.
    # model['memory'].update_state(adv_sources, adv_dests, adv_t, adv_msgs)
    # neighbor_loader.insert(adv_sources, adv_dests)

    # Zero all existing gradients
    model["memory"].zero_grad()
    model["link_pred"].zero_grad()
    model["gnn"].zero_grad()

    # Calculate gradients of model in backward pass
    loss.backward()

    # Collect ``datagrad``
    msg_grad = adv_msgs.grad.data

    perturbed_msgs = fgsm_attack(adv_msgs, epsilon, msg_grad)

    return perturbed_msgs
